diff --git a/lib/std/Url.zig b/lib/std/Url.zig new file mode 100644 index 000000000000..8887f5de923a --- /dev/null +++ b/lib/std/Url.zig @@ -0,0 +1,98 @@ +scheme: []const u8, +host: []const u8, +path: []const u8, +port: ?u16, + +/// TODO: redo this implementation according to RFC 1738. This code is only a +/// placeholder for now. +pub fn parse(s: []const u8) !Url { + var scheme_end: usize = 0; + var host_start: usize = 0; + var host_end: usize = 0; + var path_start: usize = 0; + var port_start: usize = 0; + var port_end: usize = 0; + var state: enum { + scheme, + scheme_slash1, + scheme_slash2, + host, + port, + path, + } = .scheme; + + for (s) |b, i| switch (state) { + .scheme => switch (b) { + ':' => { + state = .scheme_slash1; + scheme_end = i; + }, + else => {}, + }, + .scheme_slash1 => switch (b) { + '/' => { + state = .scheme_slash2; + }, + else => return error.InvalidUrl, + }, + .scheme_slash2 => switch (b) { + '/' => { + state = .host; + host_start = i + 1; + }, + else => return error.InvalidUrl, + }, + .host => switch (b) { + ':' => { + state = .port; + host_end = i; + port_start = i + 1; + }, + '/' => { + state = .path; + host_end = i; + path_start = i; + }, + else => {}, + }, + .port => switch (b) { + '/' => { + port_end = i; + state = .path; + path_start = i; + }, + else => {}, + }, + .path => {}, + }; + + const port_slice = s[port_start..port_end]; + const port = if (port_slice.len == 0) null else try std.fmt.parseInt(u16, port_slice, 10); + + return .{ + .scheme = s[0..scheme_end], + .host = s[host_start..host_end], + .path = s[path_start..], + .port = port, + }; +} + +const Url = @This(); +const std = @import("std.zig"); +const testing = std.testing; + +test "basic" { + const parsed = try parse("https://ziglang.org/download"); + try testing.expectEqualStrings("https", parsed.scheme); + try testing.expectEqualStrings("ziglang.org", parsed.host); + try testing.expectEqualStrings("/download", parsed.path); + try testing.expectEqual(@as(?u16, null), parsed.port); +} + +test "with port" { + const parsed = try parse("http://example:1337/"); + try testing.expectEqualStrings("http", parsed.scheme); + try testing.expectEqualStrings("example", parsed.host); + try testing.expectEqualStrings("/", parsed.path); + try testing.expectEqual(@as(?u16, 1337), parsed.port); +} diff --git a/lib/std/c.zig b/lib/std/c.zig index 5f03f1c61902..212b8e2d4d4b 100644 --- a/lib/std/c.zig +++ b/lib/std/c.zig @@ -206,7 +206,7 @@ pub extern "c" fn sendto( dest_addr: ?*const c.sockaddr, addrlen: c.socklen_t, ) isize; -pub extern "c" fn sendmsg(sockfd: c.fd_t, msg: *const std.x.os.Socket.Message, flags: c_int) isize; +pub extern "c" fn sendmsg(sockfd: c.fd_t, msg: *const c.msghdr_const, flags: u32) isize; pub extern "c" fn recv(sockfd: c.fd_t, arg1: ?*anyopaque, arg2: usize, arg3: c_int) isize; pub extern "c" fn recvfrom( @@ -217,7 +217,7 @@ pub extern "c" fn recvfrom( noalias src_addr: ?*c.sockaddr, noalias addrlen: ?*c.socklen_t, ) isize; -pub extern "c" fn recvmsg(sockfd: c.fd_t, msg: *std.x.os.Socket.Message, flags: c_int) isize; +pub extern "c" fn recvmsg(sockfd: c.fd_t, msg: *c.msghdr, flags: u32) isize; pub extern "c" fn kill(pid: c.pid_t, sig: c_int) c_int; pub extern "c" fn getdirentries(fd: c.fd_t, buf_ptr: [*]u8, nbytes: usize, basep: *i64) isize; diff --git a/lib/std/c/darwin.zig b/lib/std/c/darwin.zig index b68f04379f7c..9c5ac1e93a21 100644 --- a/lib/std/c/darwin.zig +++ b/lib/std/c/darwin.zig @@ -1007,7 +1007,16 @@ pub const sockaddr = extern struct { data: [14]u8, pub const SS_MAXSIZE = 128; - pub const storage = std.x.os.Socket.Address.Native.Storage; + pub const storage = extern struct { + len: u8 align(8), + family: sa_family_t, + padding: [126]u8 = undefined, + + comptime { + assert(@sizeOf(storage) == SS_MAXSIZE); + assert(@alignOf(storage) == 8); + } + }; pub const in = extern struct { len: u8 = @sizeOf(in), family: sa_family_t = AF.INET, diff --git a/lib/std/c/dragonfly.zig b/lib/std/c/dragonfly.zig index 26c0b34abe4e..b632211307a6 100644 --- a/lib/std/c/dragonfly.zig +++ b/lib/std/c/dragonfly.zig @@ -1,5 +1,6 @@ const builtin = @import("builtin"); const std = @import("../std.zig"); +const assert = std.debug.assert; const maxInt = std.math.maxInt; const iovec = std.os.iovec; @@ -478,11 +479,20 @@ pub const CLOCK = struct { pub const sockaddr = extern struct { len: u8, - family: u8, + family: sa_family_t, data: [14]u8, pub const SS_MAXSIZE = 128; - pub const storage = std.x.os.Socket.Address.Native.Storage; + pub const storage = extern struct { + len: u8 align(8), + family: sa_family_t, + padding: [126]u8 = undefined, + + comptime { + assert(@sizeOf(storage) == SS_MAXSIZE); + assert(@alignOf(storage) == 8); + } + }; pub const in = extern struct { len: u8 = @sizeOf(in), diff --git a/lib/std/c/freebsd.zig b/lib/std/c/freebsd.zig index 28d759ddc75d..7a4e30b909ec 100644 --- a/lib/std/c/freebsd.zig +++ b/lib/std/c/freebsd.zig @@ -1,4 +1,5 @@ const std = @import("../std.zig"); +const assert = std.debug.assert; const builtin = @import("builtin"); const maxInt = std.math.maxInt; const iovec = std.os.iovec; @@ -404,7 +405,16 @@ pub const sockaddr = extern struct { data: [14]u8, pub const SS_MAXSIZE = 128; - pub const storage = std.x.os.Socket.Address.Native.Storage; + pub const storage = extern struct { + len: u8 align(8), + family: sa_family_t, + padding: [126]u8 = undefined, + + comptime { + assert(@sizeOf(storage) == SS_MAXSIZE); + assert(@alignOf(storage) == 8); + } + }; pub const in = extern struct { len: u8 = @sizeOf(in), diff --git a/lib/std/c/haiku.zig b/lib/std/c/haiku.zig index 86b9f25902cb..9c4f8460deb4 100644 --- a/lib/std/c/haiku.zig +++ b/lib/std/c/haiku.zig @@ -1,4 +1,5 @@ const std = @import("../std.zig"); +const assert = std.debug.assert; const builtin = @import("builtin"); const maxInt = std.math.maxInt; const iovec = std.os.iovec; @@ -339,7 +340,16 @@ pub const sockaddr = extern struct { data: [14]u8, pub const SS_MAXSIZE = 128; - pub const storage = std.x.os.Socket.Address.Native.Storage; + pub const storage = extern struct { + len: u8 align(8), + family: sa_family_t, + padding: [126]u8 = undefined, + + comptime { + assert(@sizeOf(storage) == SS_MAXSIZE); + assert(@alignOf(storage) == 8); + } + }; pub const in = extern struct { len: u8 = @sizeOf(in), diff --git a/lib/std/c/netbsd.zig b/lib/std/c/netbsd.zig index a96a7c983342..b963b2e2b114 100644 --- a/lib/std/c/netbsd.zig +++ b/lib/std/c/netbsd.zig @@ -1,4 +1,5 @@ const std = @import("../std.zig"); +const assert = std.debug.assert; const builtin = @import("builtin"); const maxInt = std.math.maxInt; const iovec = std.os.iovec; @@ -481,7 +482,16 @@ pub const sockaddr = extern struct { data: [14]u8, pub const SS_MAXSIZE = 128; - pub const storage = std.x.os.Socket.Address.Native.Storage; + pub const storage = extern struct { + len: u8 align(8), + family: sa_family_t, + padding: [126]u8 = undefined, + + comptime { + assert(@sizeOf(storage) == SS_MAXSIZE); + assert(@alignOf(storage) == 8); + } + }; pub const in = extern struct { len: u8 = @sizeOf(in), diff --git a/lib/std/c/openbsd.zig b/lib/std/c/openbsd.zig index 6796f83139fc..51c4bcb6dd07 100644 --- a/lib/std/c/openbsd.zig +++ b/lib/std/c/openbsd.zig @@ -1,4 +1,5 @@ const std = @import("../std.zig"); +const assert = std.debug.assert; const maxInt = std.math.maxInt; const builtin = @import("builtin"); const iovec = std.os.iovec; @@ -372,7 +373,16 @@ pub const sockaddr = extern struct { data: [14]u8, pub const SS_MAXSIZE = 256; - pub const storage = std.x.os.Socket.Address.Native.Storage; + pub const storage = extern struct { + len: u8 align(8), + family: sa_family_t, + padding: [254]u8 = undefined, + + comptime { + assert(@sizeOf(storage) == SS_MAXSIZE); + assert(@alignOf(storage) == 8); + } + }; pub const in = extern struct { len: u8 = @sizeOf(in), diff --git a/lib/std/c/solaris.zig b/lib/std/c/solaris.zig index cbeeb5fb429a..fe60c426e5e4 100644 --- a/lib/std/c/solaris.zig +++ b/lib/std/c/solaris.zig @@ -1,4 +1,5 @@ const std = @import("../std.zig"); +const assert = std.debug.assert; const builtin = @import("builtin"); const maxInt = std.math.maxInt; const iovec = std.os.iovec; @@ -435,7 +436,15 @@ pub const sockaddr = extern struct { data: [14]u8, pub const SS_MAXSIZE = 256; - pub const storage = std.x.os.Socket.Address.Native.Storage; + pub const storage = extern struct { + family: sa_family_t align(8), + padding: [254]u8 = undefined, + + comptime { + assert(@sizeOf(storage) == SS_MAXSIZE); + assert(@alignOf(storage) == 8); + } + }; pub const in = extern struct { family: sa_family_t = AF.INET, diff --git a/lib/std/crypto.zig b/lib/std/crypto.zig index 8aaf305143a7..20522c175d89 100644 --- a/lib/std/crypto.zig +++ b/lib/std/crypto.zig @@ -176,6 +176,9 @@ const std = @import("std.zig"); pub const errors = @import("crypto/errors.zig"); +pub const tls = @import("crypto/tls.zig"); +pub const Certificate = @import("crypto/Certificate.zig"); + test { _ = aead.aegis.Aegis128L; _ = aead.aegis.Aegis256; @@ -264,6 +267,8 @@ test { _ = utils; _ = random; _ = errors; + _ = tls; + _ = Certificate; } test "CSPRNG" { diff --git a/lib/std/crypto/Certificate.zig b/lib/std/crypto/Certificate.zig new file mode 100644 index 000000000000..fe211c614671 --- /dev/null +++ b/lib/std/crypto/Certificate.zig @@ -0,0 +1,1115 @@ +buffer: []const u8, +index: u32, + +pub const Bundle = @import("Certificate/Bundle.zig"); + +pub const Algorithm = enum { + sha1WithRSAEncryption, + sha224WithRSAEncryption, + sha256WithRSAEncryption, + sha384WithRSAEncryption, + sha512WithRSAEncryption, + ecdsa_with_SHA224, + ecdsa_with_SHA256, + ecdsa_with_SHA384, + ecdsa_with_SHA512, + + pub const map = std.ComptimeStringMap(Algorithm, .{ + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x05 }, .sha1WithRSAEncryption }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0B }, .sha256WithRSAEncryption }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0C }, .sha384WithRSAEncryption }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0D }, .sha512WithRSAEncryption }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0E }, .sha224WithRSAEncryption }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x01 }, .ecdsa_with_SHA224 }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x02 }, .ecdsa_with_SHA256 }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x03 }, .ecdsa_with_SHA384 }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x04 }, .ecdsa_with_SHA512 }, + }); + + pub fn Hash(comptime algorithm: Algorithm) type { + return switch (algorithm) { + .sha1WithRSAEncryption => crypto.hash.Sha1, + .ecdsa_with_SHA224, .sha224WithRSAEncryption => crypto.hash.sha2.Sha224, + .ecdsa_with_SHA256, .sha256WithRSAEncryption => crypto.hash.sha2.Sha256, + .ecdsa_with_SHA384, .sha384WithRSAEncryption => crypto.hash.sha2.Sha384, + .ecdsa_with_SHA512, .sha512WithRSAEncryption => crypto.hash.sha2.Sha512, + }; + } +}; + +pub const AlgorithmCategory = enum { + rsaEncryption, + X9_62_id_ecPublicKey, + + pub const map = std.ComptimeStringMap(AlgorithmCategory, .{ + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x01 }, .rsaEncryption }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x02, 0x01 }, .X9_62_id_ecPublicKey }, + }); +}; + +pub const Attribute = enum { + commonName, + serialNumber, + countryName, + localityName, + stateOrProvinceName, + organizationName, + organizationalUnitName, + organizationIdentifier, + pkcs9_emailAddress, + + pub const map = std.ComptimeStringMap(Attribute, .{ + .{ &[_]u8{ 0x55, 0x04, 0x03 }, .commonName }, + .{ &[_]u8{ 0x55, 0x04, 0x05 }, .serialNumber }, + .{ &[_]u8{ 0x55, 0x04, 0x06 }, .countryName }, + .{ &[_]u8{ 0x55, 0x04, 0x07 }, .localityName }, + .{ &[_]u8{ 0x55, 0x04, 0x08 }, .stateOrProvinceName }, + .{ &[_]u8{ 0x55, 0x04, 0x0A }, .organizationName }, + .{ &[_]u8{ 0x55, 0x04, 0x0B }, .organizationalUnitName }, + .{ &[_]u8{ 0x55, 0x04, 0x61 }, .organizationIdentifier }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x09, 0x01 }, .pkcs9_emailAddress }, + }); +}; + +pub const NamedCurve = enum { + secp384r1, + X9_62_prime256v1, + + pub const map = std.ComptimeStringMap(NamedCurve, .{ + .{ &[_]u8{ 0x2B, 0x81, 0x04, 0x00, 0x22 }, .secp384r1 }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x03, 0x01, 0x07 }, .X9_62_prime256v1 }, + }); +}; + +pub const ExtensionId = enum { + subject_key_identifier, + key_usage, + private_key_usage_period, + subject_alt_name, + issuer_alt_name, + basic_constraints, + crl_number, + certificate_policies, + authority_key_identifier, + + pub const map = std.ComptimeStringMap(ExtensionId, .{ + .{ &[_]u8{ 0x55, 0x1D, 0x0E }, .subject_key_identifier }, + .{ &[_]u8{ 0x55, 0x1D, 0x0F }, .key_usage }, + .{ &[_]u8{ 0x55, 0x1D, 0x10 }, .private_key_usage_period }, + .{ &[_]u8{ 0x55, 0x1D, 0x11 }, .subject_alt_name }, + .{ &[_]u8{ 0x55, 0x1D, 0x12 }, .issuer_alt_name }, + .{ &[_]u8{ 0x55, 0x1D, 0x13 }, .basic_constraints }, + .{ &[_]u8{ 0x55, 0x1D, 0x14 }, .crl_number }, + .{ &[_]u8{ 0x55, 0x1D, 0x20 }, .certificate_policies }, + .{ &[_]u8{ 0x55, 0x1D, 0x23 }, .authority_key_identifier }, + }); +}; + +pub const GeneralNameTag = enum(u5) { + otherName = 0, + rfc822Name = 1, + dNSName = 2, + x400Address = 3, + directoryName = 4, + ediPartyName = 5, + uniformResourceIdentifier = 6, + iPAddress = 7, + registeredID = 8, + _, +}; + +pub const Parsed = struct { + certificate: Certificate, + issuer_slice: Slice, + subject_slice: Slice, + common_name_slice: Slice, + signature_slice: Slice, + signature_algorithm: Algorithm, + pub_key_algo: PubKeyAlgo, + pub_key_slice: Slice, + message_slice: Slice, + subject_alt_name_slice: Slice, + validity: Validity, + + pub const PubKeyAlgo = union(AlgorithmCategory) { + rsaEncryption: void, + X9_62_id_ecPublicKey: NamedCurve, + }; + + pub const Validity = struct { + not_before: u64, + not_after: u64, + }; + + pub const Slice = der.Element.Slice; + + pub fn slice(p: Parsed, s: Slice) []const u8 { + return p.certificate.buffer[s.start..s.end]; + } + + pub fn issuer(p: Parsed) []const u8 { + return p.slice(p.issuer_slice); + } + + pub fn subject(p: Parsed) []const u8 { + return p.slice(p.subject_slice); + } + + pub fn commonName(p: Parsed) []const u8 { + return p.slice(p.common_name_slice); + } + + pub fn signature(p: Parsed) []const u8 { + return p.slice(p.signature_slice); + } + + pub fn pubKey(p: Parsed) []const u8 { + return p.slice(p.pub_key_slice); + } + + pub fn pubKeySigAlgo(p: Parsed) []const u8 { + return p.slice(p.pub_key_signature_algorithm_slice); + } + + pub fn message(p: Parsed) []const u8 { + return p.slice(p.message_slice); + } + + pub fn subjectAltName(p: Parsed) []const u8 { + return p.slice(p.subject_alt_name_slice); + } + + pub const VerifyError = error{ + CertificateIssuerMismatch, + CertificateNotYetValid, + CertificateExpired, + CertificateSignatureAlgorithmUnsupported, + CertificateSignatureAlgorithmMismatch, + CertificateFieldHasInvalidLength, + CertificateFieldHasWrongDataType, + CertificatePublicKeyInvalid, + CertificateSignatureInvalidLength, + CertificateSignatureInvalid, + CertificateSignatureUnsupportedBitCount, + CertificateSignatureNamedCurveUnsupported, + }; + + /// This function verifies: + /// * That the subject's issuer is indeed the provided issuer. + /// * The time validity of the subject. + /// * The signature. + pub fn verify(parsed_subject: Parsed, parsed_issuer: Parsed, now_sec: i64) VerifyError!void { + // Check that the subject's issuer name matches the issuer's + // subject name. + if (!mem.eql(u8, parsed_subject.issuer(), parsed_issuer.subject())) { + return error.CertificateIssuerMismatch; + } + + if (now_sec < parsed_subject.validity.not_before) + return error.CertificateNotYetValid; + if (now_sec > parsed_subject.validity.not_after) + return error.CertificateExpired; + + switch (parsed_subject.signature_algorithm) { + inline .sha1WithRSAEncryption, + .sha224WithRSAEncryption, + .sha256WithRSAEncryption, + .sha384WithRSAEncryption, + .sha512WithRSAEncryption, + => |algorithm| return verifyRsa( + algorithm.Hash(), + parsed_subject.message(), + parsed_subject.signature(), + parsed_issuer.pub_key_algo, + parsed_issuer.pubKey(), + ), + + inline .ecdsa_with_SHA224, + .ecdsa_with_SHA256, + .ecdsa_with_SHA384, + .ecdsa_with_SHA512, + => |algorithm| return verify_ecdsa( + algorithm.Hash(), + parsed_subject.message(), + parsed_subject.signature(), + parsed_issuer.pub_key_algo, + parsed_issuer.pubKey(), + ), + } + } + + pub const VerifyHostNameError = error{ + CertificateHostMismatch, + CertificateFieldHasInvalidLength, + }; + + pub fn verifyHostName(parsed_subject: Parsed, host_name: []const u8) VerifyHostNameError!void { + // If the Subject Alternative Names extension is present, this is + // what to check. Otherwise, only the common name is checked. + const subject_alt_name = parsed_subject.subjectAltName(); + if (subject_alt_name.len == 0) { + if (checkHostName(host_name, parsed_subject.commonName())) { + return; + } else { + return error.CertificateHostMismatch; + } + } + + const general_names = try der.Element.parse(subject_alt_name, 0); + var name_i = general_names.slice.start; + while (name_i < general_names.slice.end) { + const general_name = try der.Element.parse(subject_alt_name, name_i); + name_i = general_name.slice.end; + switch (@intToEnum(GeneralNameTag, @enumToInt(general_name.identifier.tag))) { + .dNSName => { + const dns_name = subject_alt_name[general_name.slice.start..general_name.slice.end]; + if (checkHostName(host_name, dns_name)) return; + }, + else => {}, + } + } + + return error.CertificateHostMismatch; + } + + fn checkHostName(host_name: []const u8, dns_name: []const u8) bool { + if (mem.eql(u8, dns_name, host_name)) { + return true; // exact match + } + + if (mem.startsWith(u8, dns_name, "*.")) { + // wildcard certificate, matches any subdomain + // TODO: I think wildcards are not supposed to match any prefix but + // only match exactly one subdomain. + if (mem.endsWith(u8, host_name, dns_name[1..])) { + // The host_name has a subdomain, but the important part matches. + return true; + } + if (mem.eql(u8, dns_name[2..], host_name)) { + // The host_name has no subdomain and matches exactly. + return true; + } + } + + return false; + } +}; + +pub fn parse(cert: Certificate) !Parsed { + const cert_bytes = cert.buffer; + const certificate = try der.Element.parse(cert_bytes, cert.index); + const tbs_certificate = try der.Element.parse(cert_bytes, certificate.slice.start); + const version = try der.Element.parse(cert_bytes, tbs_certificate.slice.start); + try checkVersion(cert_bytes, version); + const serial_number = try der.Element.parse(cert_bytes, version.slice.end); + // RFC 5280, section 4.1.2.3: + // "This field MUST contain the same algorithm identifier as + // the signatureAlgorithm field in the sequence Certificate." + const tbs_signature = try der.Element.parse(cert_bytes, serial_number.slice.end); + const issuer = try der.Element.parse(cert_bytes, tbs_signature.slice.end); + const validity = try der.Element.parse(cert_bytes, issuer.slice.end); + const not_before = try der.Element.parse(cert_bytes, validity.slice.start); + const not_before_utc = try parseTime(cert, not_before); + const not_after = try der.Element.parse(cert_bytes, not_before.slice.end); + const not_after_utc = try parseTime(cert, not_after); + const subject = try der.Element.parse(cert_bytes, validity.slice.end); + + const pub_key_info = try der.Element.parse(cert_bytes, subject.slice.end); + const pub_key_signature_algorithm = try der.Element.parse(cert_bytes, pub_key_info.slice.start); + const pub_key_algo_elem = try der.Element.parse(cert_bytes, pub_key_signature_algorithm.slice.start); + const pub_key_algo_tag = try parseAlgorithmCategory(cert_bytes, pub_key_algo_elem); + var pub_key_algo: Parsed.PubKeyAlgo = undefined; + switch (pub_key_algo_tag) { + .rsaEncryption => { + pub_key_algo = .{ .rsaEncryption = {} }; + }, + .X9_62_id_ecPublicKey => { + // RFC 5480 Section 2.1.1.1 Named Curve + // ECParameters ::= CHOICE { + // namedCurve OBJECT IDENTIFIER + // -- implicitCurve NULL + // -- specifiedCurve SpecifiedECDomain + // } + const params_elem = try der.Element.parse(cert_bytes, pub_key_algo_elem.slice.end); + const named_curve = try parseNamedCurve(cert_bytes, params_elem); + pub_key_algo = .{ .X9_62_id_ecPublicKey = named_curve }; + }, + } + const pub_key_elem = try der.Element.parse(cert_bytes, pub_key_signature_algorithm.slice.end); + const pub_key = try parseBitString(cert, pub_key_elem); + + var common_name = der.Element.Slice.empty; + var name_i = subject.slice.start; + while (name_i < subject.slice.end) { + const rdn = try der.Element.parse(cert_bytes, name_i); + var rdn_i = rdn.slice.start; + while (rdn_i < rdn.slice.end) { + const atav = try der.Element.parse(cert_bytes, rdn_i); + var atav_i = atav.slice.start; + while (atav_i < atav.slice.end) { + const ty_elem = try der.Element.parse(cert_bytes, atav_i); + const ty = try parseAttribute(cert_bytes, ty_elem); + const val = try der.Element.parse(cert_bytes, ty_elem.slice.end); + switch (ty) { + .commonName => common_name = val.slice, + else => {}, + } + atav_i = val.slice.end; + } + rdn_i = atav.slice.end; + } + name_i = rdn.slice.end; + } + + const sig_algo = try der.Element.parse(cert_bytes, tbs_certificate.slice.end); + const algo_elem = try der.Element.parse(cert_bytes, sig_algo.slice.start); + const signature_algorithm = try parseAlgorithm(cert_bytes, algo_elem); + const sig_elem = try der.Element.parse(cert_bytes, sig_algo.slice.end); + const signature = try parseBitString(cert, sig_elem); + + // Extensions + var subject_alt_name_slice = der.Element.Slice.empty; + ext: { + if (pub_key_info.slice.end >= tbs_certificate.slice.end) + break :ext; + + const outer_extensions = try der.Element.parse(cert_bytes, pub_key_info.slice.end); + if (outer_extensions.identifier.tag != .bitstring) + break :ext; + + const extensions = try der.Element.parse(cert_bytes, outer_extensions.slice.start); + + var ext_i = extensions.slice.start; + while (ext_i < extensions.slice.end) { + const extension = try der.Element.parse(cert_bytes, ext_i); + ext_i = extension.slice.end; + const oid_elem = try der.Element.parse(cert_bytes, extension.slice.start); + const ext_id = parseExtensionId(cert_bytes, oid_elem) catch |err| switch (err) { + error.CertificateHasUnrecognizedObjectId => continue, + else => |e| return e, + }; + const critical_elem = try der.Element.parse(cert_bytes, oid_elem.slice.end); + const ext_bytes_elem = if (critical_elem.identifier.tag != .boolean) + critical_elem + else + try der.Element.parse(cert_bytes, critical_elem.slice.end); + switch (ext_id) { + .subject_alt_name => subject_alt_name_slice = ext_bytes_elem.slice, + else => continue, + } + } + } + + return .{ + .certificate = cert, + .common_name_slice = common_name, + .issuer_slice = issuer.slice, + .subject_slice = subject.slice, + .signature_slice = signature, + .signature_algorithm = signature_algorithm, + .message_slice = .{ .start = certificate.slice.start, .end = tbs_certificate.slice.end }, + .pub_key_algo = pub_key_algo, + .pub_key_slice = pub_key, + .validity = .{ + .not_before = not_before_utc, + .not_after = not_after_utc, + }, + .subject_alt_name_slice = subject_alt_name_slice, + }; +} + +pub fn verify(subject: Certificate, issuer: Certificate, now_sec: i64) !void { + const parsed_subject = try subject.parse(); + const parsed_issuer = try issuer.parse(); + return parsed_subject.verify(parsed_issuer, now_sec); +} + +pub fn contents(cert: Certificate, elem: der.Element) []const u8 { + return cert.buffer[elem.slice.start..elem.slice.end]; +} + +pub fn parseBitString(cert: Certificate, elem: der.Element) !der.Element.Slice { + if (elem.identifier.tag != .bitstring) return error.CertificateFieldHasWrongDataType; + if (cert.buffer[elem.slice.start] != 0) return error.CertificateHasInvalidBitString; + return .{ .start = elem.slice.start + 1, .end = elem.slice.end }; +} + +/// Returns number of seconds since epoch. +pub fn parseTime(cert: Certificate, elem: der.Element) !u64 { + const bytes = cert.contents(elem); + switch (elem.identifier.tag) { + .utc_time => { + // Example: "YYMMDD000000Z" + if (bytes.len != 13) + return error.CertificateTimeInvalid; + if (bytes[12] != 'Z') + return error.CertificateTimeInvalid; + + return Date.toSeconds(.{ + .year = @as(u16, 2000) + try parseTimeDigits(bytes[0..2].*, 0, 99), + .month = try parseTimeDigits(bytes[2..4].*, 1, 12), + .day = try parseTimeDigits(bytes[4..6].*, 1, 31), + .hour = try parseTimeDigits(bytes[6..8].*, 0, 23), + .minute = try parseTimeDigits(bytes[8..10].*, 0, 59), + .second = try parseTimeDigits(bytes[10..12].*, 0, 59), + }); + }, + .generalized_time => { + // Examples: + // "19920521000000Z" + // "19920622123421Z" + // "19920722132100.3Z" + if (bytes.len < 15) + return error.CertificateTimeInvalid; + return Date.toSeconds(.{ + .year = try parseYear4(bytes[0..4]), + .month = try parseTimeDigits(bytes[4..6].*, 1, 12), + .day = try parseTimeDigits(bytes[6..8].*, 1, 31), + .hour = try parseTimeDigits(bytes[8..10].*, 0, 23), + .minute = try parseTimeDigits(bytes[10..12].*, 0, 59), + .second = try parseTimeDigits(bytes[12..14].*, 0, 59), + }); + }, + else => return error.CertificateFieldHasWrongDataType, + } +} + +const Date = struct { + /// example: 1999 + year: u16, + /// range: 1 to 12 + month: u8, + /// range: 1 to 31 + day: u8, + /// range: 0 to 59 + hour: u8, + /// range: 0 to 59 + minute: u8, + /// range: 0 to 59 + second: u8, + + /// Convert to number of seconds since epoch. + pub fn toSeconds(date: Date) u64 { + var sec: u64 = 0; + + { + var year: u16 = 1970; + while (year < date.year) : (year += 1) { + const days: u64 = std.time.epoch.getDaysInYear(year); + sec += days * std.time.epoch.secs_per_day; + } + } + + { + const is_leap = std.time.epoch.isLeapYear(date.year); + var month: u4 = 1; + while (month < date.month) : (month += 1) { + const days: u64 = std.time.epoch.getDaysInMonth( + @intToEnum(std.time.epoch.YearLeapKind, @boolToInt(is_leap)), + @intToEnum(std.time.epoch.Month, month), + ); + sec += days * std.time.epoch.secs_per_day; + } + } + + sec += (date.day - 1) * @as(u64, std.time.epoch.secs_per_day); + sec += date.hour * @as(u64, 60 * 60); + sec += date.minute * @as(u64, 60); + sec += date.second; + + return sec; + } +}; + +pub fn parseTimeDigits(nn: @Vector(2, u8), min: u8, max: u8) !u8 { + const zero: @Vector(2, u8) = .{ '0', '0' }; + const mm: @Vector(2, u8) = .{ 10, 1 }; + const result = @reduce(.Add, (nn -% zero) *% mm); + if (result < min) return error.CertificateTimeInvalid; + if (result > max) return error.CertificateTimeInvalid; + return result; +} + +test parseTimeDigits { + const expectEqual = std.testing.expectEqual; + try expectEqual(@as(u8, 0), try parseTimeDigits("00".*, 0, 99)); + try expectEqual(@as(u8, 99), try parseTimeDigits("99".*, 0, 99)); + try expectEqual(@as(u8, 42), try parseTimeDigits("42".*, 0, 99)); + + const expectError = std.testing.expectError; + try expectError(error.CertificateTimeInvalid, parseTimeDigits("13".*, 1, 12)); + try expectError(error.CertificateTimeInvalid, parseTimeDigits("00".*, 1, 12)); +} + +pub fn parseYear4(text: *const [4]u8) !u16 { + const nnnn: @Vector(4, u16) = .{ text[0], text[1], text[2], text[3] }; + const zero: @Vector(4, u16) = .{ '0', '0', '0', '0' }; + const mmmm: @Vector(4, u16) = .{ 1000, 100, 10, 1 }; + const result = @reduce(.Add, (nnnn -% zero) *% mmmm); + if (result > 9999) return error.CertificateTimeInvalid; + return result; +} + +test parseYear4 { + const expectEqual = std.testing.expectEqual; + try expectEqual(@as(u16, 0), try parseYear4("0000")); + try expectEqual(@as(u16, 9999), try parseYear4("9999")); + try expectEqual(@as(u16, 1988), try parseYear4("1988")); + + const expectError = std.testing.expectError; + try expectError(error.CertificateTimeInvalid, parseYear4("999b")); + try expectError(error.CertificateTimeInvalid, parseYear4("crap")); +} + +pub fn parseAlgorithm(bytes: []const u8, element: der.Element) !Algorithm { + return parseEnum(Algorithm, bytes, element); +} + +pub fn parseAlgorithmCategory(bytes: []const u8, element: der.Element) !AlgorithmCategory { + return parseEnum(AlgorithmCategory, bytes, element); +} + +pub fn parseAttribute(bytes: []const u8, element: der.Element) !Attribute { + return parseEnum(Attribute, bytes, element); +} + +pub fn parseNamedCurve(bytes: []const u8, element: der.Element) !NamedCurve { + return parseEnum(NamedCurve, bytes, element); +} + +pub fn parseExtensionId(bytes: []const u8, element: der.Element) !ExtensionId { + return parseEnum(ExtensionId, bytes, element); +} + +fn parseEnum(comptime E: type, bytes: []const u8, element: der.Element) !E { + if (element.identifier.tag != .object_identifier) + return error.CertificateFieldHasWrongDataType; + const oid_bytes = bytes[element.slice.start..element.slice.end]; + return E.map.get(oid_bytes) orelse return error.CertificateHasUnrecognizedObjectId; +} + +pub fn checkVersion(bytes: []const u8, version: der.Element) !void { + if (@bitCast(u8, version.identifier) != 0xa0 or + !mem.eql(u8, bytes[version.slice.start..version.slice.end], "\x02\x01\x02")) + { + return error.UnsupportedCertificateVersion; + } +} + +fn verifyRsa( + comptime Hash: type, + message: []const u8, + sig: []const u8, + pub_key_algo: Parsed.PubKeyAlgo, + pub_key: []const u8, +) !void { + if (pub_key_algo != .rsaEncryption) return error.CertificateSignatureAlgorithmMismatch; + const pk_components = try rsa.PublicKey.parseDer(pub_key); + const exponent = pk_components.exponent; + const modulus = pk_components.modulus; + if (exponent.len > modulus.len) return error.CertificatePublicKeyInvalid; + if (sig.len != modulus.len) return error.CertificateSignatureInvalidLength; + + const hash_der = switch (Hash) { + crypto.hash.Sha1 => [_]u8{ + 0x30, 0x21, 0x30, 0x09, 0x06, 0x05, 0x2b, 0x0e, + 0x03, 0x02, 0x1a, 0x05, 0x00, 0x04, 0x14, + }, + crypto.hash.sha2.Sha224 => [_]u8{ + 0x30, 0x2d, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, + 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x04, 0x05, + 0x00, 0x04, 0x1c, + }, + crypto.hash.sha2.Sha256 => [_]u8{ + 0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, + 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01, 0x05, + 0x00, 0x04, 0x20, + }, + crypto.hash.sha2.Sha384 => [_]u8{ + 0x30, 0x41, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, + 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x02, 0x05, + 0x00, 0x04, 0x30, + }, + crypto.hash.sha2.Sha512 => [_]u8{ + 0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, + 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03, 0x05, + 0x00, 0x04, 0x40, + }, + else => @compileError("unreachable"), + }; + + var msg_hashed: [Hash.digest_length]u8 = undefined; + Hash.hash(message, &msg_hashed, .{}); + + var rsa_mem_buf: [512 * 64]u8 = undefined; + var fba = std.heap.FixedBufferAllocator.init(&rsa_mem_buf); + const ally = fba.allocator(); + + switch (modulus.len) { + inline 128, 256, 512 => |modulus_len| { + const ps_len = modulus_len - (hash_der.len + msg_hashed.len) - 3; + const em: [modulus_len]u8 = + [2]u8{ 0, 1 } ++ + ([1]u8{0xff} ** ps_len) ++ + [1]u8{0} ++ + hash_der ++ + msg_hashed; + + const public_key = rsa.PublicKey.fromBytes(exponent, modulus, ally) catch |err| switch (err) { + error.OutOfMemory => unreachable, // rsa_mem_buf is big enough + }; + const em_dec = rsa.encrypt(modulus_len, sig[0..modulus_len].*, public_key, ally) catch |err| switch (err) { + error.OutOfMemory => unreachable, // rsa_mem_buf is big enough + + error.MessageTooLong => unreachable, + error.NegativeIntoUnsigned => @panic("TODO make RSA not emit this error"), + error.TargetTooSmall => @panic("TODO make RSA not emit this error"), + error.BufferTooSmall => @panic("TODO make RSA not emit this error"), + }; + + if (!mem.eql(u8, &em, &em_dec)) { + return error.CertificateSignatureInvalid; + } + }, + else => { + return error.CertificateSignatureUnsupportedBitCount; + }, + } +} + +fn verify_ecdsa( + comptime Hash: type, + message: []const u8, + encoded_sig: []const u8, + pub_key_algo: Parsed.PubKeyAlgo, + sec1_pub_key: []const u8, +) !void { + const sig_named_curve = switch (pub_key_algo) { + .X9_62_id_ecPublicKey => |named_curve| named_curve, + else => return error.CertificateSignatureAlgorithmMismatch, + }; + + switch (sig_named_curve) { + .secp384r1 => { + const P = crypto.ecc.P384; + const Ecdsa = crypto.sign.ecdsa.Ecdsa(P, Hash); + const sig = Ecdsa.Signature.fromDer(encoded_sig) catch |err| switch (err) { + error.InvalidEncoding => return error.CertificateSignatureInvalid, + }; + const pub_key = Ecdsa.PublicKey.fromSec1(sec1_pub_key) catch |err| switch (err) { + error.InvalidEncoding => return error.CertificateSignatureInvalid, + error.NonCanonical => return error.CertificateSignatureInvalid, + error.NotSquare => return error.CertificateSignatureInvalid, + }; + sig.verify(message, pub_key) catch |err| switch (err) { + error.IdentityElement => return error.CertificateSignatureInvalid, + error.NonCanonical => return error.CertificateSignatureInvalid, + error.SignatureVerificationFailed => return error.CertificateSignatureInvalid, + }; + }, + .X9_62_prime256v1 => { + return error.CertificateSignatureNamedCurveUnsupported; + }, + } +} + +const std = @import("../std.zig"); +const crypto = std.crypto; +const mem = std.mem; +const Certificate = @This(); + +pub const der = struct { + pub const Class = enum(u2) { + universal, + application, + context_specific, + private, + }; + + pub const PC = enum(u1) { + primitive, + constructed, + }; + + pub const Identifier = packed struct(u8) { + tag: Tag, + pc: PC, + class: Class, + }; + + pub const Tag = enum(u5) { + boolean = 1, + integer = 2, + bitstring = 3, + octetstring = 4, + null = 5, + object_identifier = 6, + sequence = 16, + sequence_of = 17, + utc_time = 23, + generalized_time = 24, + _, + }; + + pub const Element = struct { + identifier: Identifier, + slice: Slice, + + pub const Slice = struct { + start: u32, + end: u32, + + pub const empty: Slice = .{ .start = 0, .end = 0 }; + }; + + pub const ParseError = error{CertificateFieldHasInvalidLength}; + + pub fn parse(bytes: []const u8, index: u32) ParseError!Element { + var i = index; + const identifier = @bitCast(Identifier, bytes[i]); + i += 1; + const size_byte = bytes[i]; + i += 1; + if ((size_byte >> 7) == 0) { + return .{ + .identifier = identifier, + .slice = .{ + .start = i, + .end = i + size_byte, + }, + }; + } + + const len_size = @truncate(u7, size_byte); + if (len_size > @sizeOf(u32)) { + return error.CertificateFieldHasInvalidLength; + } + + const end_i = i + len_size; + var long_form_size: u32 = 0; + while (i < end_i) : (i += 1) { + long_form_size = (long_form_size << 8) | bytes[i]; + } + + return .{ + .identifier = identifier, + .slice = .{ + .start = i, + .end = i + long_form_size, + }, + }; + } + }; +}; + +test { + _ = Bundle; +} + +/// TODO: replace this with Frank's upcoming RSA implementation. the verify +/// function won't have the possibility of failure - it will either identify a +/// valid signature or an invalid signature. +/// This code is borrowed from https://github.com/shiguredo/tls13-zig +/// which is licensed under the Apache License Version 2.0, January 2004 +/// http://www.apache.org/licenses/ +/// The code has been modified. +pub const rsa = struct { + const BigInt = std.math.big.int.Managed; + + pub const PSSSignature = struct { + pub fn fromBytes(comptime modulus_len: usize, msg: []const u8) [modulus_len]u8 { + var result = [1]u8{0} ** modulus_len; + std.mem.copy(u8, &result, msg); + return result; + } + + pub fn verify(comptime modulus_len: usize, sig: [modulus_len]u8, msg: []const u8, public_key: PublicKey, comptime Hash: type, allocator: std.mem.Allocator) !void { + const mod_bits = try countBits(public_key.n.toConst(), allocator); + const em_dec = try encrypt(modulus_len, sig, public_key, allocator); + + try EMSA_PSS_VERIFY(msg, &em_dec, mod_bits - 1, Hash.digest_length, Hash, allocator); + } + + fn EMSA_PSS_VERIFY(msg: []const u8, em: []const u8, emBit: usize, sLen: usize, comptime Hash: type, allocator: std.mem.Allocator) !void { + // TODO + // 1. If the length of M is greater than the input limitation for + // the hash function (2^61 - 1 octets for SHA-1), output + // "inconsistent" and stop. + + // emLen = \ceil(emBits/8) + const emLen = ((emBit - 1) / 8) + 1; + std.debug.assert(emLen == em.len); + + // 2. Let mHash = Hash(M), an octet string of length hLen. + var mHash: [Hash.digest_length]u8 = undefined; + Hash.hash(msg, &mHash, .{}); + + // 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop. + if (emLen < Hash.digest_length + sLen + 2) { + return error.InvalidSignature; + } + + // 4. If the rightmost octet of EM does not have hexadecimal value + // 0xbc, output "inconsistent" and stop. + if (em[em.len - 1] != 0xbc) { + return error.InvalidSignature; + } + + // 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, + // and let H be the next hLen octets. + const maskedDB = em[0..(emLen - Hash.digest_length - 1)]; + const h = em[(emLen - Hash.digest_length - 1)..(emLen - 1)]; + + // 6. If the leftmost 8emLen - emBits bits of the leftmost octet in + // maskedDB are not all equal to zero, output "inconsistent" and + // stop. + const zero_bits = emLen * 8 - emBit; + var mask: u8 = maskedDB[0]; + var i: usize = 0; + while (i < 8 - zero_bits) : (i += 1) { + mask = mask >> 1; + } + if (mask != 0) { + return error.InvalidSignature; + } + + // 7. Let dbMask = MGF(H, emLen - hLen - 1). + const mgf_len = emLen - Hash.digest_length - 1; + var mgf_out = try allocator.alloc(u8, ((mgf_len - 1) / Hash.digest_length + 1) * Hash.digest_length); + defer allocator.free(mgf_out); + var dbMask = try MGF1(mgf_out, h, mgf_len, Hash, allocator); + + // 8. Let DB = maskedDB \xor dbMask. + i = 0; + while (i < dbMask.len) : (i += 1) { + dbMask[i] = maskedDB[i] ^ dbMask[i]; + } + + // 9. Set the leftmost 8emLen - emBits bits of the leftmost octet + // in DB to zero. + i = 0; + mask = 0; + while (i < 8 - zero_bits) : (i += 1) { + mask = mask << 1; + mask += 1; + } + dbMask[0] = dbMask[0] & mask; + + // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not + // zero or if the octet at position emLen - hLen - sLen - 1 (the + // leftmost position is "position 1") does not have hexadecimal + // value 0x01, output "inconsistent" and stop. + if (dbMask[mgf_len - sLen - 2] != 0x00) { + return error.InvalidSignature; + } + + if (dbMask[mgf_len - sLen - 1] != 0x01) { + return error.InvalidSignature; + } + + // 11. Let salt be the last sLen octets of DB. + const salt = dbMask[(mgf_len - sLen)..]; + + // 12. Let + // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ; + // M' is an octet string of length 8 + hLen + sLen with eight + // initial zero octets. + var m_p = try allocator.alloc(u8, 8 + Hash.digest_length + sLen); + defer allocator.free(m_p); + std.mem.copy(u8, m_p, &([_]u8{0} ** 8)); + std.mem.copy(u8, m_p[8..], &mHash); + std.mem.copy(u8, m_p[(8 + Hash.digest_length)..], salt); + + // 13. Let H' = Hash(M'), an octet string of length hLen. + var h_p: [Hash.digest_length]u8 = undefined; + Hash.hash(m_p, &h_p, .{}); + + // 14. If H = H', output "consistent". Otherwise, output + // "inconsistent". + if (!std.mem.eql(u8, h, &h_p)) { + return error.InvalidSignature; + } + } + + fn MGF1(out: []u8, seed: []const u8, len: usize, comptime Hash: type, allocator: std.mem.Allocator) ![]u8 { + var counter: usize = 0; + var idx: usize = 0; + var c: [4]u8 = undefined; + + var hash = try allocator.alloc(u8, seed.len + c.len); + defer allocator.free(hash); + std.mem.copy(u8, hash, seed); + var hashed: [Hash.digest_length]u8 = undefined; + + while (idx < len) { + c[0] = @intCast(u8, (counter >> 24) & 0xFF); + c[1] = @intCast(u8, (counter >> 16) & 0xFF); + c[2] = @intCast(u8, (counter >> 8) & 0xFF); + c[3] = @intCast(u8, counter & 0xFF); + + std.mem.copy(u8, hash[seed.len..], &c); + Hash.hash(hash, &hashed, .{}); + + std.mem.copy(u8, out[idx..], &hashed); + idx += hashed.len; + + counter += 1; + } + + return out[0..len]; + } + }; + + pub const PublicKey = struct { + n: BigInt, + e: BigInt, + + pub fn deinit(self: *PublicKey) void { + self.n.deinit(); + self.e.deinit(); + } + + pub fn fromBytes(pub_bytes: []const u8, modulus_bytes: []const u8, allocator: std.mem.Allocator) !PublicKey { + var _n = try BigInt.init(allocator); + errdefer _n.deinit(); + try setBytes(&_n, modulus_bytes, allocator); + + var _e = try BigInt.init(allocator); + errdefer _e.deinit(); + try setBytes(&_e, pub_bytes, allocator); + + return .{ + .n = _n, + .e = _e, + }; + } + + pub fn parseDer(pub_key: []const u8) !struct { modulus: []const u8, exponent: []const u8 } { + const pub_key_seq = try der.Element.parse(pub_key, 0); + if (pub_key_seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType; + const modulus_elem = try der.Element.parse(pub_key, pub_key_seq.slice.start); + if (modulus_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType; + const exponent_elem = try der.Element.parse(pub_key, modulus_elem.slice.end); + if (exponent_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType; + // Skip over meaningless zeroes in the modulus. + const modulus_raw = pub_key[modulus_elem.slice.start..modulus_elem.slice.end]; + const modulus_offset = for (modulus_raw) |byte, i| { + if (byte != 0) break i; + } else modulus_raw.len; + return .{ + .modulus = modulus_raw[modulus_offset..], + .exponent = pub_key[exponent_elem.slice.start..exponent_elem.slice.end], + }; + } + }; + + fn encrypt(comptime modulus_len: usize, msg: [modulus_len]u8, public_key: PublicKey, allocator: std.mem.Allocator) ![modulus_len]u8 { + var m = try BigInt.init(allocator); + defer m.deinit(); + + try setBytes(&m, &msg, allocator); + + if (m.order(public_key.n) != .lt) { + return error.MessageTooLong; + } + + var e = try BigInt.init(allocator); + defer e.deinit(); + + try pow_montgomery(&e, &m, &public_key.e, &public_key.n, allocator); + + var res: [modulus_len]u8 = undefined; + + try toBytes(&res, &e, allocator); + + return res; + } + + fn setBytes(r: *BigInt, bytes: []const u8, allcator: std.mem.Allocator) !void { + try r.set(0); + var tmp = try BigInt.init(allcator); + defer tmp.deinit(); + for (bytes) |b| { + try r.shiftLeft(r, 8); + try tmp.set(b); + try r.add(r, &tmp); + } + } + + fn pow_montgomery(r: *BigInt, a: *const BigInt, x: *const BigInt, n: *const BigInt, allocator: std.mem.Allocator) !void { + var bin_raw: [512]u8 = undefined; + try toBytes(&bin_raw, x, allocator); + + var i: usize = 0; + while (bin_raw[i] == 0x00) : (i += 1) {} + const bin = bin_raw[i..]; + + try r.set(1); + var r1 = try BigInt.init(allocator); + defer r1.deinit(); + try BigInt.copy(&r1, a.toConst()); + i = 0; + while (i < bin.len * 8) : (i += 1) { + if (((bin[i / 8] >> @intCast(u3, (7 - (i % 8)))) & 0x1) == 0) { + try BigInt.mul(&r1, r, &r1); + try mod(&r1, &r1, n, allocator); + try BigInt.sqr(r, r); + try mod(r, r, n, allocator); + } else { + try BigInt.mul(r, r, &r1); + try mod(r, r, n, allocator); + try BigInt.sqr(&r1, &r1); + try mod(&r1, &r1, n, allocator); + } + } + } + + fn toBytes(out: []u8, a: *const BigInt, allocator: std.mem.Allocator) !void { + const Error = error{ + BufferTooSmall, + }; + + var mask = try BigInt.initSet(allocator, 0xFF); + defer mask.deinit(); + var tmp = try BigInt.init(allocator); + defer tmp.deinit(); + + var a_copy = try BigInt.init(allocator); + defer a_copy.deinit(); + try a_copy.copy(a.toConst()); + + // Encoding into big-endian bytes + var i: usize = 0; + while (i < out.len) : (i += 1) { + try tmp.bitAnd(&a_copy, &mask); + const b = try tmp.to(u8); + out[out.len - i - 1] = b; + try a_copy.shiftRight(&a_copy, 8); + } + + if (!a_copy.eqZero()) { + return Error.BufferTooSmall; + } + } + + fn mod(rem: *BigInt, a: *const BigInt, n: *const BigInt, allocator: std.mem.Allocator) !void { + var q = try BigInt.init(allocator); + defer q.deinit(); + + try BigInt.divFloor(&q, rem, a, n); + } + + fn countBits(a: std.math.big.int.Const, allocator: std.mem.Allocator) !usize { + var i: usize = 0; + var a_copy = try BigInt.init(allocator); + defer a_copy.deinit(); + try a_copy.copy(a); + + while (!a_copy.eqZero()) { + try a_copy.shiftRight(&a_copy, 1); + i += 1; + } + + return i; + } +}; diff --git a/lib/std/crypto/Certificate/Bundle.zig b/lib/std/crypto/Certificate/Bundle.zig new file mode 100644 index 000000000000..a1684fda734b --- /dev/null +++ b/lib/std/crypto/Certificate/Bundle.zig @@ -0,0 +1,189 @@ +//! A set of certificates. Typically pre-installed on every operating system, +//! these are "Certificate Authorities" used to validate SSL certificates. +//! This data structure stores certificates in DER-encoded form, all of them +//! concatenated together in the `bytes` array. The `map` field contains an +//! index from the DER-encoded subject name to the index of the containing +//! certificate within `bytes`. + +/// The key is the contents slice of the subject. +map: std.HashMapUnmanaged(der.Element.Slice, u32, MapContext, std.hash_map.default_max_load_percentage) = .{}, +bytes: std.ArrayListUnmanaged(u8) = .{}, + +pub const VerifyError = Certificate.Parsed.VerifyError || error{ + CertificateIssuerNotFound, +}; + +pub fn verify(cb: Bundle, subject: Certificate.Parsed, now_sec: i64) VerifyError!void { + const bytes_index = cb.find(subject.issuer()) orelse return error.CertificateIssuerNotFound; + const issuer_cert: Certificate = .{ + .buffer = cb.bytes.items, + .index = bytes_index, + }; + // Every certificate in the bundle is pre-parsed before adding it, ensuring + // that parsing will succeed here. + const issuer = issuer_cert.parse() catch unreachable; + try subject.verify(issuer, now_sec); +} + +/// The returned bytes become invalid after calling any of the rescan functions +/// or add functions. +pub fn find(cb: Bundle, subject_name: []const u8) ?u32 { + const Adapter = struct { + cb: Bundle, + + pub fn hash(ctx: @This(), k: []const u8) u64 { + _ = ctx; + return std.hash_map.hashString(k); + } + + pub fn eql(ctx: @This(), a: []const u8, b_key: der.Element.Slice) bool { + const b = ctx.cb.bytes.items[b_key.start..b_key.end]; + return mem.eql(u8, a, b); + } + }; + return cb.map.getAdapted(subject_name, Adapter{ .cb = cb }); +} + +pub fn deinit(cb: *Bundle, gpa: Allocator) void { + cb.map.deinit(gpa); + cb.bytes.deinit(gpa); + cb.* = undefined; +} + +/// Clears the set of certificates and then scans the host operating system +/// file system standard locations for certificates. +/// For operating systems that do not have standard CA installations to be +/// found, this function clears the set of certificates. +pub fn rescan(cb: *Bundle, gpa: Allocator) !void { + switch (builtin.os.tag) { + .linux => return rescanLinux(cb, gpa), + .windows => { + // TODO + }, + .macos => { + // TODO + }, + else => {}, + } +} + +pub fn rescanLinux(cb: *Bundle, gpa: Allocator) !void { + var dir = fs.openIterableDirAbsolute("/etc/ssl/certs", .{}) catch |err| switch (err) { + error.FileNotFound => return, + else => |e| return e, + }; + defer dir.close(); + + cb.bytes.clearRetainingCapacity(); + cb.map.clearRetainingCapacity(); + + var it = dir.iterate(); + while (try it.next()) |entry| { + switch (entry.kind) { + .File, .SymLink => {}, + else => continue, + } + + try addCertsFromFile(cb, gpa, dir.dir, entry.name); + } + + cb.bytes.shrinkAndFree(gpa, cb.bytes.items.len); +} + +pub fn addCertsFromFile( + cb: *Bundle, + gpa: Allocator, + dir: fs.Dir, + sub_file_path: []const u8, +) !void { + var file = try dir.openFile(sub_file_path, .{}); + defer file.close(); + + const size = try file.getEndPos(); + + // We borrow `bytes` as a temporary buffer for the base64-encoded data. + // This is possible by computing the decoded length and reserving the space + // for the decoded bytes first. + const decoded_size_upper_bound = size / 4 * 3; + const needed_capacity = std.math.cast(u32, decoded_size_upper_bound + size) orelse + return error.CertificateAuthorityBundleTooBig; + try cb.bytes.ensureUnusedCapacity(gpa, needed_capacity); + const end_reserved = @intCast(u32, cb.bytes.items.len + decoded_size_upper_bound); + const buffer = cb.bytes.allocatedSlice()[end_reserved..]; + const end_index = try file.readAll(buffer); + const encoded_bytes = buffer[0..end_index]; + + const begin_marker = "-----BEGIN CERTIFICATE-----"; + const end_marker = "-----END CERTIFICATE-----"; + + const now_sec = std.time.timestamp(); + + var start_index: usize = 0; + while (mem.indexOfPos(u8, encoded_bytes, start_index, begin_marker)) |begin_marker_start| { + const cert_start = begin_marker_start + begin_marker.len; + const cert_end = mem.indexOfPos(u8, encoded_bytes, cert_start, end_marker) orelse + return error.MissingEndCertificateMarker; + start_index = cert_end + end_marker.len; + const encoded_cert = mem.trim(u8, encoded_bytes[cert_start..cert_end], " \t\r\n"); + const decoded_start = @intCast(u32, cb.bytes.items.len); + const dest_buf = cb.bytes.allocatedSlice()[decoded_start..]; + cb.bytes.items.len += try base64.decode(dest_buf, encoded_cert); + // Even though we could only partially parse the certificate to find + // the subject name, we pre-parse all of them to make sure and only + // include in the bundle ones that we know will parse. This way we can + // use `catch unreachable` later. + const parsed_cert = try Certificate.parse(.{ + .buffer = cb.bytes.items, + .index = decoded_start, + }); + if (now_sec > parsed_cert.validity.not_after) { + // Ignore expired cert. + cb.bytes.items.len = decoded_start; + continue; + } + const gop = try cb.map.getOrPutContext(gpa, parsed_cert.subject_slice, .{ .cb = cb }); + if (gop.found_existing) { + cb.bytes.items.len = decoded_start; + } else { + gop.value_ptr.* = decoded_start; + } + } +} + +const builtin = @import("builtin"); +const std = @import("../../std.zig"); +const fs = std.fs; +const mem = std.mem; +const crypto = std.crypto; +const Allocator = std.mem.Allocator; +const Certificate = std.crypto.Certificate; +const der = Certificate.der; +const Bundle = @This(); + +const base64 = std.base64.standard.decoderWithIgnore(" \t\r\n"); + +const MapContext = struct { + cb: *const Bundle, + + pub fn hash(ctx: MapContext, k: der.Element.Slice) u64 { + return std.hash_map.hashString(ctx.cb.bytes.items[k.start..k.end]); + } + + pub fn eql(ctx: MapContext, a: der.Element.Slice, b: der.Element.Slice) bool { + const bytes = ctx.cb.bytes.items; + return mem.eql( + u8, + bytes[a.start..a.end], + bytes[b.start..b.end], + ); + } +}; + +test "scan for OS-provided certificates" { + if (builtin.os.tag == .wasi) return error.SkipZigTest; + + var bundle: Bundle = .{}; + defer bundle.deinit(std.testing.allocator); + + try bundle.rescan(std.testing.allocator); +} diff --git a/lib/std/crypto/aegis.zig b/lib/std/crypto/aegis.zig index 01dd5d547bb3..da09aca351d8 100644 --- a/lib/std/crypto/aegis.zig +++ b/lib/std/crypto/aegis.zig @@ -174,7 +174,7 @@ pub const Aegis128L = struct { acc |= (computed_tag[j] ^ tag[j]); } if (acc != 0) { - mem.set(u8, m, 0xaa); + @memset(m.ptr, undefined, m.len); return error.AuthenticationFailed; } } @@ -343,7 +343,7 @@ pub const Aegis256 = struct { acc |= (computed_tag[j] ^ tag[j]); } if (acc != 0) { - mem.set(u8, m, 0xaa); + @memset(m.ptr, undefined, m.len); return error.AuthenticationFailed; } } diff --git a/lib/std/crypto/aes_gcm.zig b/lib/std/crypto/aes_gcm.zig index 30fd37e6a0ab..6eadcdee2f67 100644 --- a/lib/std/crypto/aes_gcm.zig +++ b/lib/std/crypto/aes_gcm.zig @@ -91,7 +91,7 @@ fn AesGcm(comptime Aes: anytype) type { acc |= (computed_tag[p] ^ tag[p]); } if (acc != 0) { - mem.set(u8, m, 0xaa); + @memset(m.ptr, undefined, m.len); return error.AuthenticationFailed; } diff --git a/lib/std/crypto/sha2.zig b/lib/std/crypto/sha2.zig index 9cdf8edcf180..217dea37231d 100644 --- a/lib/std/crypto/sha2.zig +++ b/lib/std/crypto/sha2.zig @@ -142,6 +142,11 @@ fn Sha2x32(comptime params: Sha2Params32) type { d.total_len += b.len; } + pub fn peek(d: Self) [digest_length]u8 { + var copy = d; + return copy.finalResult(); + } + pub fn final(d: *Self, out: *[digest_length]u8) void { // The buffer here will never be completely full. mem.set(u8, d.buf[d.buf_len..], 0); @@ -175,6 +180,12 @@ fn Sha2x32(comptime params: Sha2Params32) type { } } + pub fn finalResult(d: *Self) [digest_length]u8 { + var result: [digest_length]u8 = undefined; + d.final(&result); + return result; + } + const W = [64]u32{ 0x428A2F98, 0x71374491, 0xB5C0FBCF, 0xE9B5DBA5, 0x3956C25B, 0x59F111F1, 0x923F82A4, 0xAB1C5ED5, 0xD807AA98, 0x12835B01, 0x243185BE, 0x550C7DC3, 0x72BE5D74, 0x80DEB1FE, 0x9BDC06A7, 0xC19BF174, @@ -621,6 +632,11 @@ fn Sha2x64(comptime params: Sha2Params64) type { d.total_len += b.len; } + pub fn peek(d: Self) [digest_length]u8 { + var copy = d; + return copy.finalResult(); + } + pub fn final(d: *Self, out: *[digest_length]u8) void { // The buffer here will never be completely full. mem.set(u8, d.buf[d.buf_len..], 0); @@ -654,6 +670,12 @@ fn Sha2x64(comptime params: Sha2Params64) type { } } + pub fn finalResult(d: *Self) [digest_length]u8 { + var result: [digest_length]u8 = undefined; + d.final(&result); + return result; + } + fn round(d: *Self, b: *const [128]u8) void { var s: [80]u64 = undefined; diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig new file mode 100644 index 000000000000..7d89da892908 --- /dev/null +++ b/lib/std/crypto/tls.zig @@ -0,0 +1,494 @@ +//! Plaintext: +//! * type: ContentType +//! * legacy_record_version: u16 = 0x0303, +//! * length: u16, +//! - The length (in bytes) of the following TLSPlaintext.fragment. The +//! length MUST NOT exceed 2^14 bytes. +//! * fragment: opaque +//! - the data being transmitted +//! +//! Ciphertext +//! * ContentType opaque_type = application_data; /* 23 */ +//! * ProtocolVersion legacy_record_version = 0x0303; /* TLS v1.2 */ +//! * uint16 length; +//! * opaque encrypted_record[TLSCiphertext.length]; +//! +//! Handshake: +//! * type: HandshakeType +//! * length: u24 +//! * data: opaque +//! +//! ServerHello: +//! * ProtocolVersion legacy_version = 0x0303; +//! * Random random; +//! * opaque legacy_session_id_echo<0..32>; +//! * CipherSuite cipher_suite; +//! * uint8 legacy_compression_method = 0; +//! * Extension extensions<6..2^16-1>; +//! +//! Extension: +//! * ExtensionType extension_type; +//! * opaque extension_data<0..2^16-1>; + +const std = @import("../std.zig"); +const Tls = @This(); +const net = std.net; +const mem = std.mem; +const crypto = std.crypto; +const assert = std.debug.assert; + +pub const Client = @import("tls/Client.zig"); + +pub const record_header_len = 5; +pub const max_ciphertext_len = (1 << 14) + 256; +pub const max_ciphertext_record_len = max_ciphertext_len + record_header_len; +pub const hello_retry_request_sequence = [32]u8{ + 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, + 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, +}; + +pub const close_notify_alert = [_]u8{ + @enumToInt(AlertLevel.warning), + @enumToInt(AlertDescription.close_notify), +}; + +pub const ProtocolVersion = enum(u16) { + tls_1_2 = 0x0303, + tls_1_3 = 0x0304, + _, +}; + +pub const ContentType = enum(u8) { + invalid = 0, + change_cipher_spec = 20, + alert = 21, + handshake = 22, + application_data = 23, + _, +}; + +pub const HandshakeType = enum(u8) { + client_hello = 1, + server_hello = 2, + new_session_ticket = 4, + end_of_early_data = 5, + encrypted_extensions = 8, + certificate = 11, + certificate_request = 13, + certificate_verify = 15, + finished = 20, + key_update = 24, + message_hash = 254, + _, +}; + +pub const ExtensionType = enum(u16) { + /// RFC 6066 + server_name = 0, + /// RFC 6066 + max_fragment_length = 1, + /// RFC 6066 + status_request = 5, + /// RFC 8422, 7919 + supported_groups = 10, + /// RFC 8446 + signature_algorithms = 13, + /// RFC 5764 + use_srtp = 14, + /// RFC 6520 + heartbeat = 15, + /// RFC 7301 + application_layer_protocol_negotiation = 16, + /// RFC 6962 + signed_certificate_timestamp = 18, + /// RFC 7250 + client_certificate_type = 19, + /// RFC 7250 + server_certificate_type = 20, + /// RFC 7685 + padding = 21, + /// RFC 8446 + pre_shared_key = 41, + /// RFC 8446 + early_data = 42, + /// RFC 8446 + supported_versions = 43, + /// RFC 8446 + cookie = 44, + /// RFC 8446 + psk_key_exchange_modes = 45, + /// RFC 8446 + certificate_authorities = 47, + /// RFC 8446 + oid_filters = 48, + /// RFC 8446 + post_handshake_auth = 49, + /// RFC 8446 + signature_algorithms_cert = 50, + /// RFC 8446 + key_share = 51, + + _, +}; + +pub const AlertLevel = enum(u8) { + warning = 1, + fatal = 2, + _, +}; + +pub const AlertDescription = enum(u8) { + close_notify = 0, + unexpected_message = 10, + bad_record_mac = 20, + record_overflow = 22, + handshake_failure = 40, + bad_certificate = 42, + unsupported_certificate = 43, + certificate_revoked = 44, + certificate_expired = 45, + certificate_unknown = 46, + illegal_parameter = 47, + unknown_ca = 48, + access_denied = 49, + decode_error = 50, + decrypt_error = 51, + protocol_version = 70, + insufficient_security = 71, + internal_error = 80, + inappropriate_fallback = 86, + user_canceled = 90, + missing_extension = 109, + unsupported_extension = 110, + unrecognized_name = 112, + bad_certificate_status_response = 113, + unknown_psk_identity = 115, + certificate_required = 116, + no_application_protocol = 120, + _, +}; + +pub const SignatureScheme = enum(u16) { + // RSASSA-PKCS1-v1_5 algorithms + rsa_pkcs1_sha256 = 0x0401, + rsa_pkcs1_sha384 = 0x0501, + rsa_pkcs1_sha512 = 0x0601, + + // ECDSA algorithms + ecdsa_secp256r1_sha256 = 0x0403, + ecdsa_secp384r1_sha384 = 0x0503, + ecdsa_secp521r1_sha512 = 0x0603, + + // RSASSA-PSS algorithms with public key OID rsaEncryption + rsa_pss_rsae_sha256 = 0x0804, + rsa_pss_rsae_sha384 = 0x0805, + rsa_pss_rsae_sha512 = 0x0806, + + // EdDSA algorithms + ed25519 = 0x0807, + ed448 = 0x0808, + + // RSASSA-PSS algorithms with public key OID RSASSA-PSS + rsa_pss_pss_sha256 = 0x0809, + rsa_pss_pss_sha384 = 0x080a, + rsa_pss_pss_sha512 = 0x080b, + + // Legacy algorithms + rsa_pkcs1_sha1 = 0x0201, + ecdsa_sha1 = 0x0203, + + _, +}; + +pub const NamedGroup = enum(u16) { + // Elliptic Curve Groups (ECDHE) + secp256r1 = 0x0017, + secp384r1 = 0x0018, + secp521r1 = 0x0019, + x25519 = 0x001D, + x448 = 0x001E, + + // Finite Field Groups (DHE) + ffdhe2048 = 0x0100, + ffdhe3072 = 0x0101, + ffdhe4096 = 0x0102, + ffdhe6144 = 0x0103, + ffdhe8192 = 0x0104, + + _, +}; + +pub const CipherSuite = enum(u16) { + AES_128_GCM_SHA256 = 0x1301, + AES_256_GCM_SHA384 = 0x1302, + CHACHA20_POLY1305_SHA256 = 0x1303, + AES_128_CCM_SHA256 = 0x1304, + AES_128_CCM_8_SHA256 = 0x1305, + AEGIS_256_SHA384 = 0x1306, + AEGIS_128L_SHA256 = 0x1307, + _, +}; + +pub const CertificateType = enum(u8) { + X509 = 0, + RawPublicKey = 2, + _, +}; + +pub const KeyUpdateRequest = enum(u8) { + update_not_requested = 0, + update_requested = 1, + _, +}; + +pub fn HandshakeCipherT(comptime AeadType: type, comptime HashType: type) type { + return struct { + pub const AEAD = AeadType; + pub const Hash = HashType; + pub const Hmac = crypto.auth.hmac.Hmac(Hash); + pub const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); + + handshake_secret: [Hkdf.prk_length]u8, + master_secret: [Hkdf.prk_length]u8, + client_handshake_key: [AEAD.key_length]u8, + server_handshake_key: [AEAD.key_length]u8, + client_finished_key: [Hmac.key_length]u8, + server_finished_key: [Hmac.key_length]u8, + client_handshake_iv: [AEAD.nonce_length]u8, + server_handshake_iv: [AEAD.nonce_length]u8, + transcript_hash: Hash, + }; +} + +pub const HandshakeCipher = union(enum) { + AES_128_GCM_SHA256: HandshakeCipherT(crypto.aead.aes_gcm.Aes128Gcm, crypto.hash.sha2.Sha256), + AES_256_GCM_SHA384: HandshakeCipherT(crypto.aead.aes_gcm.Aes256Gcm, crypto.hash.sha2.Sha384), + CHACHA20_POLY1305_SHA256: HandshakeCipherT(crypto.aead.chacha_poly.ChaCha20Poly1305, crypto.hash.sha2.Sha256), + AEGIS_256_SHA384: HandshakeCipherT(crypto.aead.aegis.Aegis256, crypto.hash.sha2.Sha384), + AEGIS_128L_SHA256: HandshakeCipherT(crypto.aead.aegis.Aegis128L, crypto.hash.sha2.Sha256), +}; + +pub fn ApplicationCipherT(comptime AeadType: type, comptime HashType: type) type { + return struct { + pub const AEAD = AeadType; + pub const Hash = HashType; + pub const Hmac = crypto.auth.hmac.Hmac(Hash); + pub const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); + + client_secret: [Hash.digest_length]u8, + server_secret: [Hash.digest_length]u8, + client_key: [AEAD.key_length]u8, + server_key: [AEAD.key_length]u8, + client_iv: [AEAD.nonce_length]u8, + server_iv: [AEAD.nonce_length]u8, + }; +} + +/// Encryption parameters for application traffic. +pub const ApplicationCipher = union(enum) { + AES_128_GCM_SHA256: ApplicationCipherT(crypto.aead.aes_gcm.Aes128Gcm, crypto.hash.sha2.Sha256), + AES_256_GCM_SHA384: ApplicationCipherT(crypto.aead.aes_gcm.Aes256Gcm, crypto.hash.sha2.Sha384), + CHACHA20_POLY1305_SHA256: ApplicationCipherT(crypto.aead.chacha_poly.ChaCha20Poly1305, crypto.hash.sha2.Sha256), + AEGIS_256_SHA384: ApplicationCipherT(crypto.aead.aegis.Aegis256, crypto.hash.sha2.Sha384), + AEGIS_128L_SHA256: ApplicationCipherT(crypto.aead.aegis.Aegis128L, crypto.hash.sha2.Sha256), +}; + +pub fn hkdfExpandLabel( + comptime Hkdf: type, + key: [Hkdf.prk_length]u8, + label: []const u8, + context: []const u8, + comptime len: usize, +) [len]u8 { + const max_label_len = 255; + const max_context_len = 255; + const tls13 = "tls13 "; + var buf: [2 + 1 + tls13.len + max_label_len + 1 + max_context_len]u8 = undefined; + mem.writeIntBig(u16, buf[0..2], len); + buf[2] = @intCast(u8, tls13.len + label.len); + buf[3..][0..tls13.len].* = tls13.*; + var i: usize = 3 + tls13.len; + mem.copy(u8, buf[i..], label); + i += label.len; + buf[i] = @intCast(u8, context.len); + i += 1; + mem.copy(u8, buf[i..], context); + i += context.len; + + var result: [len]u8 = undefined; + Hkdf.expand(&result, buf[0..i], key); + return result; +} + +pub fn emptyHash(comptime Hash: type) [Hash.digest_length]u8 { + var result: [Hash.digest_length]u8 = undefined; + Hash.hash(&.{}, &result, .{}); + return result; +} + +pub fn hmac(comptime Hmac: type, message: []const u8, key: [Hmac.key_length]u8) [Hmac.mac_length]u8 { + var result: [Hmac.mac_length]u8 = undefined; + Hmac.create(&result, message, &key); + return result; +} + +pub inline fn extension(comptime et: ExtensionType, bytes: anytype) [2 + 2 + bytes.len]u8 { + return int2(@enumToInt(et)) ++ array(1, bytes); +} + +pub inline fn array(comptime elem_size: comptime_int, bytes: anytype) [2 + bytes.len]u8 { + comptime assert(bytes.len % elem_size == 0); + return int2(bytes.len) ++ bytes; +} + +pub inline fn enum_array(comptime E: type, comptime tags: []const E) [2 + @sizeOf(E) * tags.len]u8 { + assert(@sizeOf(E) == 2); + var result: [tags.len * 2]u8 = undefined; + for (tags) |elem, i| { + result[i * 2] = @truncate(u8, @enumToInt(elem) >> 8); + result[i * 2 + 1] = @truncate(u8, @enumToInt(elem)); + } + return array(2, result); +} + +pub inline fn int2(x: u16) [2]u8 { + return .{ + @truncate(u8, x >> 8), + @truncate(u8, x), + }; +} + +pub inline fn int3(x: u24) [3]u8 { + return .{ + @truncate(u8, x >> 16), + @truncate(u8, x >> 8), + @truncate(u8, x), + }; +} + +/// An abstraction to ensure that protocol-parsing code does not perform an +/// out-of-bounds read. +pub const Decoder = struct { + buf: []u8, + /// Points to the next byte in buffer that will be decoded. + idx: usize = 0, + /// Up to this point in `buf` we have already checked that `cap` is greater than it. + our_end: usize = 0, + /// Beyond this point in `buf` is extra tag-along bytes beyond the amount we + /// requested with `readAtLeast`. + their_end: usize = 0, + /// Points to the end within buffer that has been filled. Beyond this point + /// in buf is undefined bytes. + cap: usize = 0, + /// Debug helper to prevent illegal calls to read functions. + disable_reads: bool = false, + + pub fn fromTheirSlice(buf: []u8) Decoder { + return .{ + .buf = buf, + .their_end = buf.len, + .cap = buf.len, + .disable_reads = true, + }; + } + + /// Use this function to increase `their_end`. + pub fn readAtLeast(d: *Decoder, stream: anytype, their_amt: usize) !void { + assert(!d.disable_reads); + const existing_amt = d.cap - d.idx; + d.their_end = d.idx + their_amt; + if (their_amt <= existing_amt) return; + const request_amt = their_amt - existing_amt; + const dest = d.buf[d.cap..]; + if (request_amt > dest.len) return error.TlsRecordOverflow; + const actual_amt = try stream.readAtLeast(dest, request_amt); + if (actual_amt < request_amt) return error.TlsConnectionTruncated; + d.cap += actual_amt; + } + + /// Same as `readAtLeast` but also increases `our_end` by exactly `our_amt`. + /// Use when `our_amt` is calculated by us, not by them. + pub fn readAtLeastOurAmt(d: *Decoder, stream: anytype, our_amt: usize) !void { + assert(!d.disable_reads); + try readAtLeast(d, stream, our_amt); + d.our_end = d.idx + our_amt; + } + + /// Use this function to increase `our_end`. + /// This should always be called with an amount provided by us, not them. + pub fn ensure(d: *Decoder, amt: usize) !void { + d.our_end = @max(d.idx + amt, d.our_end); + if (d.our_end > d.their_end) return error.TlsDecodeError; + } + + /// Use this function to increase `idx`. + pub fn decode(d: *Decoder, comptime T: type) T { + switch (@typeInfo(T)) { + .Int => |info| switch (info.bits) { + 8 => { + skip(d, 1); + return d.buf[d.idx - 1]; + }, + 16 => { + skip(d, 2); + const b0: u16 = d.buf[d.idx - 2]; + const b1: u16 = d.buf[d.idx - 1]; + return (b0 << 8) | b1; + }, + 24 => { + skip(d, 3); + const b0: u24 = d.buf[d.idx - 3]; + const b1: u24 = d.buf[d.idx - 2]; + const b2: u24 = d.buf[d.idx - 1]; + return (b0 << 16) | (b1 << 8) | b2; + }, + else => @compileError("unsupported int type: " ++ @typeName(T)), + }, + .Enum => |info| { + const int = d.decode(info.tag_type); + if (info.is_exhaustive) @compileError("exhaustive enum cannot be used"); + return @intToEnum(T, int); + }, + else => @compileError("unsupported type: " ++ @typeName(T)), + } + } + + /// Use this function to increase `idx`. + pub fn array(d: *Decoder, comptime len: usize) *[len]u8 { + skip(d, len); + return d.buf[d.idx - len ..][0..len]; + } + + /// Use this function to increase `idx`. + pub fn slice(d: *Decoder, len: usize) []u8 { + skip(d, len); + return d.buf[d.idx - len ..][0..len]; + } + + /// Use this function to increase `idx`. + pub fn skip(d: *Decoder, amt: usize) void { + d.idx += amt; + assert(d.idx <= d.our_end); // insufficient ensured bytes + } + + pub fn eof(d: Decoder) bool { + assert(d.our_end <= d.their_end); + assert(d.idx <= d.our_end); + return d.idx == d.their_end; + } + + /// Provide the length they claim, and receive a sub-decoder specific to that slice. + /// The parent decoder is advanced to the end. + pub fn sub(d: *Decoder, their_len: usize) !Decoder { + const end = d.idx + their_len; + if (end > d.their_end) return error.TlsDecodeError; + const sub_buf = d.buf[d.idx..end]; + d.idx = end; + d.our_end = end; + return fromTheirSlice(sub_buf); + } + + pub fn rest(d: Decoder) []u8 { + return d.buf[d.idx..d.cap]; + } +}; diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig new file mode 100644 index 000000000000..44891a197395 --- /dev/null +++ b/lib/std/crypto/tls/Client.zig @@ -0,0 +1,1308 @@ +const std = @import("../../std.zig"); +const tls = std.crypto.tls; +const Client = @This(); +const net = std.net; +const mem = std.mem; +const crypto = std.crypto; +const assert = std.debug.assert; +const Certificate = std.crypto.Certificate; + +const max_ciphertext_len = tls.max_ciphertext_len; +const hkdfExpandLabel = tls.hkdfExpandLabel; +const int2 = tls.int2; +const int3 = tls.int3; +const array = tls.array; +const enum_array = tls.enum_array; + +read_seq: u64, +write_seq: u64, +/// The starting index of cleartext bytes inside `partially_read_buffer`. +partial_cleartext_idx: u15, +/// The ending index of cleartext bytes inside `partially_read_buffer` as well +/// as the starting index of ciphertext bytes. +partial_ciphertext_idx: u15, +/// The ending index of ciphertext bytes inside `partially_read_buffer`. +partial_ciphertext_end: u15, +/// When this is true, the stream may still not be at the end because there +/// may be data in `partially_read_buffer`. +received_close_notify: bool, +/// By default, reaching the end-of-stream when reading from the server will +/// cause `error.TlsConnectionTruncated` to be returned, unless a close_notify +/// message has been received. By setting this flag to `true`, instead, the +/// end-of-stream will be forwarded to the application layer above TLS. +/// This makes the application vulnerable to truncation attacks unless the +/// application layer itself verifies that the amount of data received equals +/// the amount of data expected, such as HTTP with the Content-Length header. +allow_truncation_attacks: bool = false, +application_cipher: tls.ApplicationCipher, +/// The size is enough to contain exactly one TLSCiphertext record. +/// This buffer is segmented into four parts: +/// 0. unused +/// 1. cleartext +/// 2. ciphertext +/// 3. unused +/// The fields `partial_cleartext_idx`, `partial_ciphertext_idx`, and +/// `partial_ciphertext_end` describe the span of the segments. +partially_read_buffer: [tls.max_ciphertext_record_len]u8, + +/// This is an example of the type that is needed by the read and write +/// functions. It can have any fields but it must at least have these +/// functions. +/// +/// Note that `std.net.Stream` conforms to this interface. +/// +/// This declaration serves as documentation only. +pub const StreamInterface = struct { + /// Can be any error set. + pub const ReadError = error{}; + + /// Returns the number of bytes read. The number read may be less than the + /// buffer space provided. End-of-stream is indicated by a return value of 0. + /// + /// The `iovecs` parameter is mutable because so that function may to + /// mutate the fields in order to handle partial reads from the underlying + /// stream layer. + pub fn readv(this: @This(), iovecs: []std.os.iovec) ReadError!usize { + _ = .{ this, iovecs }; + @panic("unimplemented"); + } + + /// Can be any error set. + pub const WriteError = error{}; + + /// Returns the number of bytes read, which may be less than the buffer + /// space provided. A short read does not indicate end-of-stream. + pub fn writev(this: @This(), iovecs: []const std.os.iovec_const) WriteError!usize { + _ = .{ this, iovecs }; + @panic("unimplemented"); + } + + /// Returns the number of bytes read, which may be less than the buffer + /// space provided, indicating end-of-stream. + /// The `iovecs` parameter is mutable in case this function needs to mutate + /// the fields in order to handle partial writes from the underlying layer. + pub fn writevAll(this: @This(), iovecs: []std.os.iovec_const) WriteError!usize { + // This can be implemented in terms of writev, or specialized if desired. + _ = .{ this, iovecs }; + @panic("unimplemented"); + } +}; + +/// Initiates a TLS handshake and establishes a TLSv1.3 session with `stream`, which +/// must conform to `StreamInterface`. +/// +/// `host` is only borrowed during this function call. +pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) !Client { + const host_len = @intCast(u16, host.len); + + var random_buffer: [128]u8 = undefined; + crypto.random.bytes(&random_buffer); + const hello_rand = random_buffer[0..32].*; + const legacy_session_id = random_buffer[32..64].*; + const x25519_kp_seed = random_buffer[64..96].*; + const secp256r1_kp_seed = random_buffer[96..128].*; + + const x25519_kp = crypto.dh.X25519.KeyPair.create(x25519_kp_seed) catch |err| switch (err) { + // Only possible to happen if the private key is all zeroes. + error.IdentityElement => return error.InsufficientEntropy, + }; + const secp256r1_kp = crypto.sign.ecdsa.EcdsaP256Sha256.KeyPair.create(secp256r1_kp_seed) catch |err| switch (err) { + // Only possible to happen if the private key is all zeroes. + error.IdentityElement => return error.InsufficientEntropy, + }; + + const extensions_payload = + tls.extension(.supported_versions, [_]u8{ + 0x02, // byte length of supported versions + 0x03, 0x04, // TLS 1.3 + }) ++ tls.extension(.signature_algorithms, enum_array(tls.SignatureScheme, &.{ + .ecdsa_secp256r1_sha256, + .ecdsa_secp384r1_sha384, + .ecdsa_secp521r1_sha512, + .rsa_pss_rsae_sha256, + .rsa_pss_rsae_sha384, + .rsa_pss_rsae_sha512, + .rsa_pkcs1_sha256, + .rsa_pkcs1_sha384, + .rsa_pkcs1_sha512, + .ed25519, + })) ++ tls.extension(.supported_groups, enum_array(tls.NamedGroup, &.{ + .secp256r1, + .x25519, + })) ++ tls.extension( + .key_share, + array(1, int2(@enumToInt(tls.NamedGroup.x25519)) ++ + array(1, x25519_kp.public_key) ++ + int2(@enumToInt(tls.NamedGroup.secp256r1)) ++ + array(1, secp256r1_kp.public_key.toUncompressedSec1())), + ) ++ + int2(@enumToInt(tls.ExtensionType.server_name)) ++ + int2(host_len + 5) ++ // byte length of this extension payload + int2(host_len + 3) ++ // server_name_list byte count + [1]u8{0x00} ++ // name_type + int2(host_len); + + const extensions_header = + int2(@intCast(u16, extensions_payload.len + host_len)) ++ + extensions_payload; + + const legacy_compression_methods = 0x0100; + + const client_hello = + int2(@enumToInt(tls.ProtocolVersion.tls_1_2)) ++ + hello_rand ++ + [1]u8{32} ++ legacy_session_id ++ + cipher_suites ++ + int2(legacy_compression_methods) ++ + extensions_header; + + const out_handshake = + [_]u8{@enumToInt(tls.HandshakeType.client_hello)} ++ + int3(@intCast(u24, client_hello.len + host_len)) ++ + client_hello; + + const plaintext_header = [_]u8{ + @enumToInt(tls.ContentType.handshake), + 0x03, 0x01, // legacy_record_version + } ++ int2(@intCast(u16, out_handshake.len + host_len)) ++ out_handshake; + + { + var iovecs = [_]std.os.iovec_const{ + .{ + .iov_base = &plaintext_header, + .iov_len = plaintext_header.len, + }, + .{ + .iov_base = host.ptr, + .iov_len = host.len, + }, + }; + try stream.writevAll(&iovecs); + } + + const client_hello_bytes1 = plaintext_header[5..]; + + var handshake_cipher: tls.HandshakeCipher = undefined; + var handshake_buffer: [8000]u8 = undefined; + var d: tls.Decoder = .{ .buf = &handshake_buffer }; + { + try d.readAtLeastOurAmt(stream, tls.record_header_len); + const ct = d.decode(tls.ContentType); + d.skip(2); // legacy_record_version + const record_len = d.decode(u16); + try d.readAtLeast(stream, record_len); + const server_hello_fragment = d.buf[d.idx..][0..record_len]; + var ptd = try d.sub(record_len); + switch (ct) { + .alert => { + try ptd.ensure(2); + const level = ptd.decode(tls.AlertLevel); + const desc = ptd.decode(tls.AlertDescription); + _ = level; + _ = desc; + return error.TlsAlert; + }, + .handshake => { + try ptd.ensure(4); + const handshake_type = ptd.decode(tls.HandshakeType); + if (handshake_type != .server_hello) return error.TlsUnexpectedMessage; + const length = ptd.decode(u24); + var hsd = try ptd.sub(length); + try hsd.ensure(2 + 32 + 1 + 32 + 2 + 1 + 2); + const legacy_version = hsd.decode(u16); + const random = hsd.array(32); + if (mem.eql(u8, random, &tls.hello_retry_request_sequence)) { + // This is a HelloRetryRequest message. This client implementation + // does not expect to get one. + return error.TlsUnexpectedMessage; + } + const legacy_session_id_echo_len = hsd.decode(u8); + if (legacy_session_id_echo_len != 32) return error.TlsIllegalParameter; + const legacy_session_id_echo = hsd.array(32); + if (!mem.eql(u8, legacy_session_id_echo, &legacy_session_id)) + return error.TlsIllegalParameter; + const cipher_suite_tag = hsd.decode(tls.CipherSuite); + hsd.skip(1); // legacy_compression_method + const extensions_size = hsd.decode(u16); + var all_extd = try hsd.sub(extensions_size); + var supported_version: u16 = 0; + var shared_key: [32]u8 = undefined; + var have_shared_key = false; + while (!all_extd.eof()) { + try all_extd.ensure(2 + 2); + const et = all_extd.decode(tls.ExtensionType); + const ext_size = all_extd.decode(u16); + var extd = try all_extd.sub(ext_size); + switch (et) { + .supported_versions => { + if (supported_version != 0) return error.TlsIllegalParameter; + try extd.ensure(2); + supported_version = extd.decode(u16); + }, + .key_share => { + if (have_shared_key) return error.TlsIllegalParameter; + have_shared_key = true; + try extd.ensure(4); + const named_group = extd.decode(tls.NamedGroup); + const key_size = extd.decode(u16); + try extd.ensure(key_size); + switch (named_group) { + .x25519 => { + if (key_size != 32) return error.TlsIllegalParameter; + const server_pub_key = extd.array(32); + + shared_key = crypto.dh.X25519.scalarmult( + x25519_kp.secret_key, + server_pub_key.*, + ) catch return error.TlsDecryptFailure; + }, + .secp256r1 => { + const server_pub_key = extd.slice(key_size); + + const PublicKey = crypto.sign.ecdsa.EcdsaP256Sha256.PublicKey; + const pk = PublicKey.fromSec1(server_pub_key) catch { + return error.TlsDecryptFailure; + }; + const mul = pk.p.mulPublic(secp256r1_kp.secret_key.bytes, .Big) catch { + return error.TlsDecryptFailure; + }; + shared_key = mul.affineCoordinates().x.toBytes(.Big); + }, + else => { + return error.TlsIllegalParameter; + }, + } + }, + else => {}, + } + } + if (!have_shared_key) return error.TlsIllegalParameter; + + const tls_version = if (supported_version == 0) legacy_version else supported_version; + if (tls_version != @enumToInt(tls.ProtocolVersion.tls_1_3)) + return error.TlsIllegalParameter; + + switch (cipher_suite_tag) { + inline .AES_128_GCM_SHA256, + .AES_256_GCM_SHA384, + .CHACHA20_POLY1305_SHA256, + .AEGIS_256_SHA384, + .AEGIS_128L_SHA256, + => |tag| { + const P = std.meta.TagPayloadByName(tls.HandshakeCipher, @tagName(tag)); + handshake_cipher = @unionInit(tls.HandshakeCipher, @tagName(tag), .{ + .handshake_secret = undefined, + .master_secret = undefined, + .client_handshake_key = undefined, + .server_handshake_key = undefined, + .client_finished_key = undefined, + .server_finished_key = undefined, + .client_handshake_iv = undefined, + .server_handshake_iv = undefined, + .transcript_hash = P.Hash.init(.{}), + }); + const p = &@field(handshake_cipher, @tagName(tag)); + p.transcript_hash.update(client_hello_bytes1); // Client Hello part 1 + p.transcript_hash.update(host); // Client Hello part 2 + p.transcript_hash.update(server_hello_fragment); + const hello_hash = p.transcript_hash.peek(); + const zeroes = [1]u8{0} ** P.Hash.digest_length; + const early_secret = P.Hkdf.extract(&[1]u8{0}, &zeroes); + const empty_hash = tls.emptyHash(P.Hash); + const hs_derived_secret = hkdfExpandLabel(P.Hkdf, early_secret, "derived", &empty_hash, P.Hash.digest_length); + p.handshake_secret = P.Hkdf.extract(&hs_derived_secret, &shared_key); + const ap_derived_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "derived", &empty_hash, P.Hash.digest_length); + p.master_secret = P.Hkdf.extract(&ap_derived_secret, &zeroes); + const client_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "c hs traffic", &hello_hash, P.Hash.digest_length); + const server_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "s hs traffic", &hello_hash, P.Hash.digest_length); + p.client_finished_key = hkdfExpandLabel(P.Hkdf, client_secret, "finished", "", P.Hmac.key_length); + p.server_finished_key = hkdfExpandLabel(P.Hkdf, server_secret, "finished", "", P.Hmac.key_length); + p.client_handshake_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); + p.server_handshake_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); + p.client_handshake_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); + p.server_handshake_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); + }, + else => { + return error.TlsIllegalParameter; + }, + } + }, + else => return error.TlsUnexpectedMessage, + } + } + + // This is used for two purposes: + // * Detect whether a certificate is the first one presented, in which case + // we need to verify the host name. + // * Flip back and forth between the two cleartext buffers in order to keep + // the previous certificate in memory so that it can be verified by the + // next one. + var cert_index: usize = 0; + var read_seq: u64 = 0; + var prev_cert: Certificate.Parsed = undefined; + // Set to true once a trust chain has been established from the first + // certificate to a root CA. + const HandshakeState = enum { + /// In this state we expect only an encrypted_extensions message. + encrypted_extensions, + /// In this state we expect certificate messages. + certificate, + /// In this state we expect certificate or certificate_verify messages. + /// certificate messages are ignored since the trust chain is already + /// established. + trust_chain_established, + /// In this state, we expect only the finished message. + finished, + }; + var handshake_state: HandshakeState = .encrypted_extensions; + var cleartext_bufs: [2][8000]u8 = undefined; + var main_cert_pub_key_algo: Certificate.AlgorithmCategory = undefined; + var main_cert_pub_key_buf: [300]u8 = undefined; + var main_cert_pub_key_len: u16 = undefined; + const now_sec = std.time.timestamp(); + + while (true) { + try d.readAtLeastOurAmt(stream, tls.record_header_len); + const record_header = d.buf[d.idx..][0..5]; + const ct = d.decode(tls.ContentType); + d.skip(2); // legacy_version + const record_len = d.decode(u16); + try d.readAtLeast(stream, record_len); + var record_decoder = try d.sub(record_len); + switch (ct) { + .change_cipher_spec => { + try record_decoder.ensure(1); + if (record_decoder.decode(u8) != 0x01) return error.TlsIllegalParameter; + }, + .application_data => { + const cleartext_buf = &cleartext_bufs[cert_index % 2]; + + const cleartext = switch (handshake_cipher) { + inline else => |*p| c: { + const P = @TypeOf(p.*); + const ciphertext_len = record_len - P.AEAD.tag_length; + try record_decoder.ensure(ciphertext_len + P.AEAD.tag_length); + const ciphertext = record_decoder.slice(ciphertext_len); + if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow; + const cleartext = cleartext_buf[0..ciphertext.len]; + const auth_tag = record_decoder.array(P.AEAD.tag_length).*; + const V = @Vector(P.AEAD.nonce_length, u8); + const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); + const operand: V = pad ++ @bitCast([8]u8, big(read_seq)); + read_seq += 1; + const nonce = @as(V, p.server_handshake_iv) ^ operand; + P.AEAD.decrypt(cleartext, ciphertext, auth_tag, record_header, nonce, p.server_handshake_key) catch + return error.TlsBadRecordMac; + break :c cleartext; + }, + }; + + const inner_ct = @intToEnum(tls.ContentType, cleartext[cleartext.len - 1]); + if (inner_ct != .handshake) return error.TlsUnexpectedMessage; + + var ctd = tls.Decoder.fromTheirSlice(cleartext[0 .. cleartext.len - 1]); + while (true) { + try ctd.ensure(4); + const handshake_type = ctd.decode(tls.HandshakeType); + const handshake_len = ctd.decode(u24); + var hsd = try ctd.sub(handshake_len); + const wrapped_handshake = ctd.buf[ctd.idx - handshake_len - 4 .. ctd.idx]; + const handshake = ctd.buf[ctd.idx - handshake_len .. ctd.idx]; + switch (handshake_type) { + .encrypted_extensions => { + if (handshake_state != .encrypted_extensions) return error.TlsUnexpectedMessage; + handshake_state = .certificate; + switch (handshake_cipher) { + inline else => |*p| p.transcript_hash.update(wrapped_handshake), + } + try hsd.ensure(2); + const total_ext_size = hsd.decode(u16); + var all_extd = try hsd.sub(total_ext_size); + while (!all_extd.eof()) { + try all_extd.ensure(4); + const et = all_extd.decode(tls.ExtensionType); + const ext_size = all_extd.decode(u16); + var extd = try all_extd.sub(ext_size); + _ = extd; + switch (et) { + .server_name => {}, + else => {}, + } + } + }, + .certificate => cert: { + switch (handshake_cipher) { + inline else => |*p| p.transcript_hash.update(wrapped_handshake), + } + switch (handshake_state) { + .certificate => {}, + .trust_chain_established => break :cert, + else => return error.TlsUnexpectedMessage, + } + try hsd.ensure(1 + 4); + const cert_req_ctx_len = hsd.decode(u8); + if (cert_req_ctx_len != 0) return error.TlsIllegalParameter; + const certs_size = hsd.decode(u24); + var certs_decoder = try hsd.sub(certs_size); + while (!certs_decoder.eof()) { + try certs_decoder.ensure(3); + const cert_size = certs_decoder.decode(u24); + var certd = try certs_decoder.sub(cert_size); + + const subject_cert: Certificate = .{ + .buffer = certd.buf, + .index = @intCast(u32, certd.idx), + }; + const subject = try subject_cert.parse(); + if (cert_index == 0) { + // Verify the host on the first certificate. + try subject.verifyHostName(host); + + // Keep track of the public key for the + // certificate_verify message later. + main_cert_pub_key_algo = subject.pub_key_algo; + const pub_key = subject.pubKey(); + if (pub_key.len > main_cert_pub_key_buf.len) + return error.CertificatePublicKeyInvalid; + @memcpy(&main_cert_pub_key_buf, pub_key.ptr, pub_key.len); + main_cert_pub_key_len = @intCast(@TypeOf(main_cert_pub_key_len), pub_key.len); + } else { + try prev_cert.verify(subject, now_sec); + } + + if (ca_bundle.verify(subject, now_sec)) |_| { + handshake_state = .trust_chain_established; + break :cert; + } else |err| switch (err) { + error.CertificateIssuerNotFound => {}, + else => |e| return e, + } + + prev_cert = subject; + cert_index += 1; + + try certs_decoder.ensure(2); + const total_ext_size = certs_decoder.decode(u16); + var all_extd = try certs_decoder.sub(total_ext_size); + _ = all_extd; + } + }, + .certificate_verify => { + switch (handshake_state) { + .trust_chain_established => handshake_state = .finished, + .certificate => return error.TlsCertificateNotVerified, + else => return error.TlsUnexpectedMessage, + } + + try hsd.ensure(4); + const scheme = hsd.decode(tls.SignatureScheme); + const sig_len = hsd.decode(u16); + try hsd.ensure(sig_len); + const encoded_sig = hsd.slice(sig_len); + const max_digest_len = 64; + var verify_buffer = + ([1]u8{0x20} ** 64) ++ + "TLS 1.3, server CertificateVerify\x00".* ++ + @as([max_digest_len]u8, undefined); + + const verify_bytes = switch (handshake_cipher) { + inline else => |*p| v: { + const transcript_digest = p.transcript_hash.peek(); + verify_buffer[verify_buffer.len - max_digest_len ..][0..transcript_digest.len].* = transcript_digest; + p.transcript_hash.update(wrapped_handshake); + break :v verify_buffer[0 .. verify_buffer.len - max_digest_len + transcript_digest.len]; + }, + }; + const main_cert_pub_key = main_cert_pub_key_buf[0..main_cert_pub_key_len]; + + switch (scheme) { + inline .ecdsa_secp256r1_sha256, + .ecdsa_secp384r1_sha384, + => |comptime_scheme| { + if (main_cert_pub_key_algo != .X9_62_id_ecPublicKey) + return error.TlsBadSignatureScheme; + const Ecdsa = SchemeEcdsa(comptime_scheme); + const sig = try Ecdsa.Signature.fromDer(encoded_sig); + const key = try Ecdsa.PublicKey.fromSec1(main_cert_pub_key); + try sig.verify(verify_bytes, key); + }, + .rsa_pss_rsae_sha256 => { + if (main_cert_pub_key_algo != .rsaEncryption) + return error.TlsBadSignatureScheme; + + const Hash = crypto.hash.sha2.Sha256; + const rsa = Certificate.rsa; + const components = try rsa.PublicKey.parseDer(main_cert_pub_key); + const exponent = components.exponent; + const modulus = components.modulus; + var rsa_mem_buf: [512 * 32]u8 = undefined; + var fba = std.heap.FixedBufferAllocator.init(&rsa_mem_buf); + const ally = fba.allocator(); + switch (modulus.len) { + inline 128, 256, 512 => |modulus_len| { + const key = try rsa.PublicKey.fromBytes(exponent, modulus, ally); + const sig = rsa.PSSSignature.fromBytes(modulus_len, encoded_sig); + try rsa.PSSSignature.verify(modulus_len, sig, verify_bytes, key, Hash, ally); + }, + else => { + return error.TlsBadRsaSignatureBitCount; + }, + } + }, + else => { + return error.TlsBadSignatureScheme; + }, + } + }, + .finished => { + if (handshake_state != .finished) return error.TlsUnexpectedMessage; + // This message is to trick buggy proxies into behaving correctly. + const client_change_cipher_spec_msg = [_]u8{ + @enumToInt(tls.ContentType.change_cipher_spec), + 0x03, 0x03, // legacy protocol version + 0x00, 0x01, // length + 0x01, + }; + const app_cipher = switch (handshake_cipher) { + inline else => |*p, tag| c: { + const P = @TypeOf(p.*); + const finished_digest = p.transcript_hash.peek(); + p.transcript_hash.update(wrapped_handshake); + const expected_server_verify_data = tls.hmac(P.Hmac, &finished_digest, p.server_finished_key); + if (!mem.eql(u8, &expected_server_verify_data, handshake)) + return error.TlsDecryptError; + const handshake_hash = p.transcript_hash.finalResult(); + const verify_data = tls.hmac(P.Hmac, &handshake_hash, p.client_finished_key); + const out_cleartext = [_]u8{ + @enumToInt(tls.HandshakeType.finished), + 0, 0, verify_data.len, // length + } ++ verify_data ++ [1]u8{@enumToInt(tls.ContentType.handshake)}; + + const wrapped_len = out_cleartext.len + P.AEAD.tag_length; + + var finished_msg = [_]u8{ + @enumToInt(tls.ContentType.application_data), + 0x03, 0x03, // legacy protocol version + 0, wrapped_len, // byte length of encrypted record + } ++ @as([wrapped_len]u8, undefined); + + const ad = finished_msg[0..5]; + const ciphertext = finished_msg[5..][0..out_cleartext.len]; + const auth_tag = finished_msg[finished_msg.len - P.AEAD.tag_length ..]; + const nonce = p.client_handshake_iv; + P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, p.client_handshake_key); + + const both_msgs = client_change_cipher_spec_msg ++ finished_msg; + try stream.writeAll(&both_msgs); + + const client_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); + const server_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length); + break :c @unionInit(tls.ApplicationCipher, @tagName(tag), .{ + .client_secret = client_secret, + .server_secret = server_secret, + .client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length), + .server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length), + .client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length), + .server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length), + }); + }, + }; + const leftover = d.rest(); + var client: Client = .{ + .read_seq = 0, + .write_seq = 0, + .partial_cleartext_idx = 0, + .partial_ciphertext_idx = 0, + .partial_ciphertext_end = @intCast(u15, leftover.len), + .received_close_notify = false, + .application_cipher = app_cipher, + .partially_read_buffer = undefined, + }; + mem.copy(u8, &client.partially_read_buffer, leftover); + return client; + }, + else => { + return error.TlsUnexpectedMessage; + }, + } + if (ctd.eof()) break; + } + }, + else => { + return error.TlsUnexpectedMessage; + }, + } + } +} + +/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. +/// Returns the number of plaintext bytes sent, which may be fewer than `bytes.len`. +pub fn write(c: *Client, stream: anytype, bytes: []const u8) !usize { + return writeEnd(c, stream, bytes, false); +} + +/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. +pub fn writeAll(c: *Client, stream: anytype, bytes: []const u8) !void { + var index: usize = 0; + while (index < bytes.len) { + index += try c.write(stream, bytes[index..]); + } +} + +/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. +/// If `end` is true, then this function additionally sends a `close_notify` alert, +/// which is necessary for the server to distinguish between a properly finished +/// TLS session, or a truncation attack. +pub fn writeAllEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !void { + var index: usize = 0; + while (index < bytes.len) { + index += try c.writeEnd(stream, bytes[index..], end); + } +} + +/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. +/// Returns the number of plaintext bytes sent, which may be fewer than `bytes.len`. +/// If `end` is true, then this function additionally sends a `close_notify` alert, +/// which is necessary for the server to distinguish between a properly finished +/// TLS session, or a truncation attack. +pub fn writeEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !usize { + var ciphertext_buf: [tls.max_ciphertext_record_len * 4]u8 = undefined; + var iovecs_buf: [6]std.os.iovec_const = undefined; + var prepared = prepareCiphertextRecord(c, &iovecs_buf, &ciphertext_buf, bytes, .application_data); + if (end) { + prepared.iovec_end += prepareCiphertextRecord( + c, + iovecs_buf[prepared.iovec_end..], + ciphertext_buf[prepared.ciphertext_end..], + &tls.close_notify_alert, + .alert, + ).iovec_end; + } + + const iovec_end = prepared.iovec_end; + const overhead_len = prepared.overhead_len; + + // Ideally we would call writev exactly once here, however, we must ensure + // that we don't return with a record partially written. + var i: usize = 0; + var total_amt: usize = 0; + while (true) { + var amt = try stream.writev(iovecs_buf[i..iovec_end]); + while (amt >= iovecs_buf[i].iov_len) { + const encrypted_amt = iovecs_buf[i].iov_len; + total_amt += encrypted_amt - overhead_len; + amt -= encrypted_amt; + i += 1; + // Rely on the property that iovecs delineate records, meaning that + // if amt equals zero here, we have fortunately found ourselves + // with a short read that aligns at the record boundary. + if (i >= iovec_end) return total_amt; + // We also cannot return on a vector boundary if the final close_notify is + // not sent; otherwise the caller would not know to retry the call. + if (amt == 0 and (!end or i < iovec_end - 1)) return total_amt; + } + iovecs_buf[i].iov_base += amt; + iovecs_buf[i].iov_len -= amt; + } +} + +fn prepareCiphertextRecord( + c: *Client, + iovecs: []std.os.iovec_const, + ciphertext_buf: []u8, + bytes: []const u8, + inner_content_type: tls.ContentType, +) struct { + iovec_end: usize, + ciphertext_end: usize, + /// How many bytes are taken up by overhead per record. + overhead_len: usize, +} { + // Due to the trailing inner content type byte in the ciphertext, we need + // an additional buffer for storing the cleartext into before encrypting. + var cleartext_buf: [max_ciphertext_len]u8 = undefined; + var ciphertext_end: usize = 0; + var iovec_end: usize = 0; + var bytes_i: usize = 0; + switch (c.application_cipher) { + inline else => |*p| { + const P = @TypeOf(p.*); + const V = @Vector(P.AEAD.nonce_length, u8); + const overhead_len = tls.record_header_len + P.AEAD.tag_length + 1; + const close_notify_alert_reserved = tls.close_notify_alert.len + overhead_len; + while (true) { + const encrypted_content_len = @intCast(u16, @min( + @min(bytes.len - bytes_i, max_ciphertext_len - 1), + ciphertext_buf.len - close_notify_alert_reserved - + overhead_len - ciphertext_end, + )); + if (encrypted_content_len == 0) return .{ + .iovec_end = iovec_end, + .ciphertext_end = ciphertext_end, + .overhead_len = overhead_len, + }; + + mem.copy(u8, &cleartext_buf, bytes[bytes_i..][0..encrypted_content_len]); + cleartext_buf[encrypted_content_len] = @enumToInt(inner_content_type); + bytes_i += encrypted_content_len; + const ciphertext_len = encrypted_content_len + 1; + const cleartext = cleartext_buf[0..ciphertext_len]; + + const record_start = ciphertext_end; + const ad = ciphertext_buf[ciphertext_end..][0..5]; + ad.* = + [_]u8{@enumToInt(tls.ContentType.application_data)} ++ + int2(@enumToInt(tls.ProtocolVersion.tls_1_2)) ++ + int2(ciphertext_len + P.AEAD.tag_length); + ciphertext_end += ad.len; + const ciphertext = ciphertext_buf[ciphertext_end..][0..ciphertext_len]; + ciphertext_end += ciphertext_len; + const auth_tag = ciphertext_buf[ciphertext_end..][0..P.AEAD.tag_length]; + ciphertext_end += auth_tag.len; + const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); + const operand: V = pad ++ @bitCast([8]u8, big(c.write_seq)); + c.write_seq += 1; // TODO send key_update on overflow + const nonce = @as(V, p.client_iv) ^ operand; + P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, p.client_key); + + const record = ciphertext_buf[record_start..ciphertext_end]; + iovecs[iovec_end] = .{ + .iov_base = record.ptr, + .iov_len = record.len, + }; + iovec_end += 1; + } + }, + } +} + +pub fn eof(c: Client) bool { + return c.received_close_notify and + c.partial_cleartext_idx >= c.partial_ciphertext_idx and + c.partial_ciphertext_idx >= c.partial_ciphertext_end; +} + +/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. +/// Returns the number of bytes read, calling the underlying read function the +/// minimal number of times until the buffer has at least `len` bytes filled. +/// If the number read is less than `len` it means the stream reached the end. +/// Reaching the end of the stream is not an error condition. +pub fn readAtLeast(c: *Client, stream: anytype, buffer: []u8, len: usize) !usize { + var iovecs = [1]std.os.iovec{.{ .iov_base = buffer.ptr, .iov_len = buffer.len }}; + return readvAtLeast(c, stream, &iovecs, len); +} + +/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. +pub fn read(c: *Client, stream: anytype, buffer: []u8) !usize { + return readAtLeast(c, stream, buffer, 1); +} + +/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. +/// Returns the number of bytes read. If the number read is smaller than +/// `buffer.len`, it means the stream reached the end. Reaching the end of the +/// stream is not an error condition. +pub fn readAll(c: *Client, stream: anytype, buffer: []u8) !usize { + return readAtLeast(c, stream, buffer, buffer.len); +} + +/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. +/// Returns the number of bytes read. If the number read is less than the space +/// provided it means the stream reached the end. Reaching the end of the +/// stream is not an error condition. +/// The `iovecs` parameter is mutable because this function needs to mutate the fields in +/// order to handle partial reads from the underlying stream layer. +pub fn readv(c: *Client, stream: anytype, iovecs: []std.os.iovec) !usize { + return readvAtLeast(c, stream, iovecs); +} + +/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. +/// Returns the number of bytes read, calling the underlying read function the +/// minimal number of times until the iovecs have at least `len` bytes filled. +/// If the number read is less than `len` it means the stream reached the end. +/// Reaching the end of the stream is not an error condition. +/// The `iovecs` parameter is mutable because this function needs to mutate the fields in +/// order to handle partial reads from the underlying stream layer. +pub fn readvAtLeast(c: *Client, stream: anytype, iovecs: []std.os.iovec, len: usize) !usize { + if (c.eof()) return 0; + + var off_i: usize = 0; + var vec_i: usize = 0; + while (true) { + var amt = try c.readvAdvanced(stream, iovecs[vec_i..]); + off_i += amt; + if (c.eof() or off_i >= len) return off_i; + while (amt >= iovecs[vec_i].iov_len) { + amt -= iovecs[vec_i].iov_len; + vec_i += 1; + } + iovecs[vec_i].iov_base += amt; + iovecs[vec_i].iov_len -= amt; + } +} + +/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. +/// Returns number of bytes that have been read, populated inside `iovecs`. A +/// return value of zero bytes does not mean end of stream. Instead, check the `eof()` +/// for the end of stream. The `eof()` may be true after any call to +/// `read`, including when greater than zero bytes are returned, and this +/// function asserts that `eof()` is `false`. +/// See `readv` for a higher level function that has the same, familiar API as +/// other read functions, such as `std.fs.File.read`. +pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec) !usize { + var vp: VecPut = .{ .iovecs = iovecs }; + + // Give away the buffered cleartext we have, if any. + const partial_cleartext = c.partially_read_buffer[c.partial_cleartext_idx..c.partial_ciphertext_idx]; + if (partial_cleartext.len > 0) { + const amt = @intCast(u15, vp.put(partial_cleartext)); + c.partial_cleartext_idx += amt; + if (amt < partial_cleartext.len) { + // We still have cleartext left so we cannot issue another read() call yet. + assert(vp.total == amt); + return amt; + } + if (c.received_close_notify) { + c.partial_ciphertext_end = 0; + assert(vp.total == amt); + return amt; + } + if (c.partial_ciphertext_end == c.partial_ciphertext_idx) { + c.partial_cleartext_idx = 0; + c.partial_ciphertext_idx = 0; + c.partial_ciphertext_end = 0; + } + } + + assert(!c.received_close_notify); + + // Ideally, this buffer would never be used. It is needed when `iovecs` are + // too small to fit the cleartext, which may be as large as `max_ciphertext_len`. + var cleartext_stack_buffer: [max_ciphertext_len]u8 = undefined; + // Temporarily stores ciphertext before decrypting it and giving it to `iovecs`. + var in_stack_buffer: [max_ciphertext_len * 4]u8 = undefined; + // How many bytes left in the user's buffer. + const free_size = vp.freeSize(); + // The amount of the user's buffer that we need to repurpose for storing + // ciphertext. The end of the buffer will be used for such purposes. + const ciphertext_buf_len = (free_size / 2) -| in_stack_buffer.len; + // The amount of the user's buffer that will be used to give cleartext. The + // beginning of the buffer will be used for such purposes. + const cleartext_buf_len = free_size - ciphertext_buf_len; + const first_iov = c.partially_read_buffer[c.partial_ciphertext_end..]; + + var ask_iovecs_buf: [2]std.os.iovec = .{ + .{ + .iov_base = first_iov.ptr, + .iov_len = first_iov.len, + }, + .{ + .iov_base = &in_stack_buffer, + .iov_len = in_stack_buffer.len, + }, + }; + + // Cleartext capacity of output buffer, in records, rounded up. + const buf_cap = (cleartext_buf_len +| (max_ciphertext_len - 1)) / max_ciphertext_len; + const wanted_read_len = buf_cap * (max_ciphertext_len + tls.record_header_len); + const ask_len = @max(wanted_read_len, cleartext_stack_buffer.len); + const ask_iovecs = limitVecs(&ask_iovecs_buf, ask_len); + const actual_read_len = try stream.readv(ask_iovecs); + if (actual_read_len == 0) { + // This is either a truncation attack, a bug in the server, or an + // intentional omission of the close_notify message due to truncation + // detection handled above the TLS layer. + if (c.allow_truncation_attacks) { + c.received_close_notify = true; + } else { + return error.TlsConnectionTruncated; + } + } + + // There might be more bytes inside `in_stack_buffer` that need to be processed, + // but at least frag0 will have one complete ciphertext record. + const frag0_end = @min(c.partially_read_buffer.len, c.partial_ciphertext_end + actual_read_len); + const frag0 = c.partially_read_buffer[c.partial_ciphertext_idx..frag0_end]; + var frag1 = in_stack_buffer[0..actual_read_len -| first_iov.len]; + // We need to decipher frag0 and frag1 but there may be a ciphertext record + // straddling the boundary. We can handle this with two memcpy() calls to + // assemble the straddling record in between handling the two sides. + var frag = frag0; + var in: usize = 0; + while (true) { + if (in == frag.len) { + // Perfect split. + if (frag.ptr == frag1.ptr) { + c.partial_ciphertext_end = c.partial_ciphertext_idx; + return vp.total; + } + frag = frag1; + in = 0; + continue; + } + + if (in + tls.record_header_len > frag.len) { + if (frag.ptr == frag1.ptr) + return finishRead(c, frag, in, vp.total); + + const first = frag[in..]; + + if (frag1.len < tls.record_header_len) + return finishRead2(c, first, frag1, vp.total); + + // A record straddles the two fragments. Copy into the now-empty first fragment. + const record_len_byte_0: u16 = straddleByte(frag, frag1, in + 3); + const record_len_byte_1: u16 = straddleByte(frag, frag1, in + 4); + const record_len = (record_len_byte_0 << 8) | record_len_byte_1; + if (record_len > max_ciphertext_len) return error.TlsRecordOverflow; + + const full_record_len = record_len + tls.record_header_len; + const second_len = full_record_len - first.len; + if (frag1.len < second_len) + return finishRead2(c, first, frag1, vp.total); + + mem.copy(u8, frag[0..in], first); + mem.copy(u8, frag[first.len..], frag1[0..second_len]); + frag = frag[0..full_record_len]; + frag1 = frag1[second_len..]; + in = 0; + continue; + } + const ct = @intToEnum(tls.ContentType, frag[in]); + in += 1; + const legacy_version = mem.readIntBig(u16, frag[in..][0..2]); + in += 2; + _ = legacy_version; + const record_len = mem.readIntBig(u16, frag[in..][0..2]); + if (record_len > max_ciphertext_len) return error.TlsRecordOverflow; + in += 2; + const end = in + record_len; + if (end > frag.len) { + // We need the record header on the next iteration of the loop. + in -= tls.record_header_len; + + if (frag.ptr == frag1.ptr) + return finishRead(c, frag, in, vp.total); + + // A record straddles the two fragments. Copy into the now-empty first fragment. + const first = frag[in..]; + const full_record_len = record_len + tls.record_header_len; + const second_len = full_record_len - first.len; + if (frag1.len < second_len) + return finishRead2(c, first, frag1, vp.total); + + mem.copy(u8, frag[0..in], first); + mem.copy(u8, frag[first.len..], frag1[0..second_len]); + frag = frag[0..full_record_len]; + frag1 = frag1[second_len..]; + in = 0; + continue; + } + switch (ct) { + .alert => { + if (in + 2 > frag.len) return error.TlsDecodeError; + const level = @intToEnum(tls.AlertLevel, frag[in]); + const desc = @intToEnum(tls.AlertDescription, frag[in + 1]); + _ = level; + _ = desc; + return error.TlsAlert; + }, + .application_data => { + const cleartext = switch (c.application_cipher) { + inline else => |*p| c: { + const P = @TypeOf(p.*); + const V = @Vector(P.AEAD.nonce_length, u8); + const ad = frag[in - 5 ..][0..5]; + const ciphertext_len = record_len - P.AEAD.tag_length; + const ciphertext = frag[in..][0..ciphertext_len]; + in += ciphertext_len; + const auth_tag = frag[in..][0..P.AEAD.tag_length].*; + const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); + const operand: V = pad ++ @bitCast([8]u8, big(c.read_seq)); + const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.server_iv) ^ operand; + const out_buf = vp.peek(); + const cleartext_buf = if (ciphertext.len <= out_buf.len) + out_buf + else + &cleartext_stack_buffer; + const cleartext = cleartext_buf[0..ciphertext.len]; + P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_key) catch + return error.TlsBadRecordMac; + break :c cleartext; + }, + }; + + c.read_seq = try std.math.add(u64, c.read_seq, 1); + + const inner_ct = @intToEnum(tls.ContentType, cleartext[cleartext.len - 1]); + switch (inner_ct) { + .alert => { + const level = @intToEnum(tls.AlertLevel, cleartext[0]); + const desc = @intToEnum(tls.AlertDescription, cleartext[1]); + if (desc == .close_notify) { + c.received_close_notify = true; + c.partial_ciphertext_end = c.partial_ciphertext_idx; + return vp.total; + } + _ = level; + return error.TlsAlert; + }, + .handshake => { + var ct_i: usize = 0; + while (true) { + const handshake_type = @intToEnum(tls.HandshakeType, cleartext[ct_i]); + ct_i += 1; + const handshake_len = mem.readIntBig(u24, cleartext[ct_i..][0..3]); + ct_i += 3; + const next_handshake_i = ct_i + handshake_len; + if (next_handshake_i > cleartext.len - 1) + return error.TlsBadLength; + const handshake = cleartext[ct_i..next_handshake_i]; + switch (handshake_type) { + .new_session_ticket => { + // This client implementation ignores new session tickets. + }, + .key_update => { + switch (c.application_cipher) { + inline else => |*p| { + const P = @TypeOf(p.*); + const server_secret = hkdfExpandLabel(P.Hkdf, p.server_secret, "traffic upd", "", P.Hash.digest_length); + p.server_secret = server_secret; + p.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); + p.server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); + }, + } + c.read_seq = 0; + + switch (@intToEnum(tls.KeyUpdateRequest, handshake[0])) { + .update_requested => { + switch (c.application_cipher) { + inline else => |*p| { + const P = @TypeOf(p.*); + const client_secret = hkdfExpandLabel(P.Hkdf, p.client_secret, "traffic upd", "", P.Hash.digest_length); + p.client_secret = client_secret; + p.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); + p.client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); + }, + } + c.write_seq = 0; + }, + .update_not_requested => {}, + _ => return error.TlsIllegalParameter, + } + }, + else => { + return error.TlsUnexpectedMessage; + }, + } + ct_i = next_handshake_i; + if (ct_i >= cleartext.len - 1) break; + } + }, + .application_data => { + // Determine whether the output buffer or a stack + // buffer was used for storing the cleartext. + if (cleartext.ptr == &cleartext_stack_buffer) { + // Stack buffer was used, so we must copy to the output buffer. + const msg = cleartext[0 .. cleartext.len - 1]; + if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { + // We have already run out of room in iovecs. Continue + // appending to `partially_read_buffer`. + const dest = c.partially_read_buffer[c.partial_ciphertext_idx..]; + mem.copy(u8, dest, msg); + c.partial_ciphertext_idx = @intCast(@TypeOf(c.partial_ciphertext_idx), c.partial_ciphertext_idx + msg.len); + } else { + const amt = vp.put(msg); + if (amt < msg.len) { + const rest = msg[amt..]; + c.partial_cleartext_idx = 0; + c.partial_ciphertext_idx = @intCast(@TypeOf(c.partial_ciphertext_idx), rest.len); + mem.copy(u8, &c.partially_read_buffer, rest); + } + } + } else { + // Output buffer was used directly which means no + // memory copying needs to occur, and we can move + // on to the next ciphertext record. + vp.next(cleartext.len - 1); + } + }, + else => { + return error.TlsUnexpectedMessage; + }, + } + }, + else => { + return error.TlsUnexpectedMessage; + }, + } + in = end; + } +} + +fn finishRead(c: *Client, frag: []const u8, in: usize, out: usize) usize { + const saved_buf = frag[in..]; + if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { + // There is cleartext at the beginning already which we need to preserve. + c.partial_ciphertext_end = @intCast(@TypeOf(c.partial_ciphertext_end), c.partial_ciphertext_idx + saved_buf.len); + mem.copy(u8, c.partially_read_buffer[c.partial_ciphertext_idx..], saved_buf); + } else { + c.partial_cleartext_idx = 0; + c.partial_ciphertext_idx = 0; + c.partial_ciphertext_end = @intCast(@TypeOf(c.partial_ciphertext_end), saved_buf.len); + mem.copy(u8, &c.partially_read_buffer, saved_buf); + } + return out; +} + +fn finishRead2(c: *Client, first: []const u8, frag1: []const u8, out: usize) usize { + if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { + // There is cleartext at the beginning already which we need to preserve. + c.partial_ciphertext_end = @intCast(@TypeOf(c.partial_ciphertext_end), c.partial_ciphertext_idx + first.len + frag1.len); + mem.copy(u8, c.partially_read_buffer[c.partial_ciphertext_idx..], first); + mem.copy(u8, c.partially_read_buffer[c.partial_ciphertext_idx + first.len ..], frag1); + } else { + c.partial_cleartext_idx = 0; + c.partial_ciphertext_idx = 0; + c.partial_ciphertext_end = @intCast(@TypeOf(c.partial_ciphertext_end), first.len + frag1.len); + mem.copy(u8, &c.partially_read_buffer, first); + mem.copy(u8, c.partially_read_buffer[first.len..], frag1); + } + return out; +} + +fn straddleByte(s1: []const u8, s2: []const u8, index: usize) u8 { + if (index < s1.len) { + return s1[index]; + } else { + return s2[index - s1.len]; + } +} + +const builtin = @import("builtin"); +const native_endian = builtin.cpu.arch.endian(); + +inline fn big(x: anytype) @TypeOf(x) { + return switch (native_endian) { + .Big => x, + .Little => @byteSwap(x), + }; +} + +fn SchemeEcdsa(comptime scheme: tls.SignatureScheme) type { + return switch (scheme) { + .ecdsa_secp256r1_sha256 => crypto.sign.ecdsa.EcdsaP256Sha256, + .ecdsa_secp384r1_sha384 => crypto.sign.ecdsa.EcdsaP384Sha384, + .ecdsa_secp521r1_sha512 => crypto.sign.ecdsa.EcdsaP512Sha512, + else => @compileError("bad scheme"), + }; +} + +/// Abstraction for sending multiple byte buffers to a slice of iovecs. +const VecPut = struct { + iovecs: []const std.os.iovec, + idx: usize = 0, + off: usize = 0, + total: usize = 0, + + /// Returns the amount actually put which is always equal to bytes.len + /// unless the vectors ran out of space. + fn put(vp: *VecPut, bytes: []const u8) usize { + var bytes_i: usize = 0; + while (true) { + const v = vp.iovecs[vp.idx]; + const dest = v.iov_base[vp.off..v.iov_len]; + const src = bytes[bytes_i..][0..@min(dest.len, bytes.len - bytes_i)]; + mem.copy(u8, dest, src); + bytes_i += src.len; + vp.off += src.len; + if (vp.off >= v.iov_len) { + vp.off = 0; + vp.idx += 1; + if (vp.idx >= vp.iovecs.len) { + vp.total += bytes_i; + return bytes_i; + } + } + if (bytes_i >= bytes.len) { + vp.total += bytes_i; + return bytes_i; + } + } + } + + /// Returns the next buffer that consecutive bytes can go into. + fn peek(vp: VecPut) []u8 { + if (vp.idx >= vp.iovecs.len) return &.{}; + const v = vp.iovecs[vp.idx]; + return v.iov_base[vp.off..v.iov_len]; + } + + // After writing to the result of peek(), one can call next() to + // advance the cursor. + fn next(vp: *VecPut, len: usize) void { + vp.total += len; + vp.off += len; + if (vp.off >= vp.iovecs[vp.idx].iov_len) { + vp.off = 0; + vp.idx += 1; + } + } + + fn freeSize(vp: VecPut) usize { + if (vp.idx >= vp.iovecs.len) return 0; + var total: usize = 0; + total += vp.iovecs[vp.idx].iov_len - vp.off; + if (vp.idx + 1 >= vp.iovecs.len) return total; + for (vp.iovecs[vp.idx + 1 ..]) |v| total += v.iov_len; + return total; + } +}; + +/// Limit iovecs to a specific byte size. +fn limitVecs(iovecs: []std.os.iovec, len: usize) []std.os.iovec { + var vec_i: usize = 0; + var bytes_left: usize = len; + while (true) { + if (bytes_left >= iovecs[vec_i].iov_len) { + bytes_left -= iovecs[vec_i].iov_len; + vec_i += 1; + if (vec_i == iovecs.len or bytes_left == 0) return iovecs[0..vec_i]; + continue; + } + iovecs[vec_i].iov_len = bytes_left; + return iovecs[0..vec_i]; + } +} + +/// The priority order here is chosen based on what crypto algorithms Zig has +/// available in the standard library as well as what is faster. Following are +/// a few data points on the relative performance of these algorithms. +/// +/// Measurement taken with 0.11.0-dev.810+c2f5848fe +/// on x86_64-linux Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz: +/// zig run .lib/std/crypto/benchmark.zig -OReleaseFast +/// aegis-128l: 15382 MiB/s +/// aegis-256: 9553 MiB/s +/// aes128-gcm: 3721 MiB/s +/// aes256-gcm: 3010 MiB/s +/// chacha20Poly1305: 597 MiB/s +/// +/// Measurement taken with 0.11.0-dev.810+c2f5848fe +/// on x86_64-linux Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz: +/// zig run .lib/std/crypto/benchmark.zig -OReleaseFast -mcpu=baseline +/// aegis-128l: 629 MiB/s +/// chacha20Poly1305: 529 MiB/s +/// aegis-256: 461 MiB/s +/// aes128-gcm: 138 MiB/s +/// aes256-gcm: 120 MiB/s +const cipher_suites = enum_array(tls.CipherSuite, &.{ + .AEGIS_128L_SHA256, + .AEGIS_256_SHA384, + .AES_128_GCM_SHA256, + .AES_256_GCM_SHA384, + .CHACHA20_POLY1305_SHA256, +}); + +test { + _ = StreamInterface; +} diff --git a/lib/std/http.zig b/lib/std/http.zig index 8da696840374..944271df274c 100644 --- a/lib/std/http.zig +++ b/lib/std/http.zig @@ -1,8 +1,301 @@ -const std = @import("std.zig"); +pub const Client = @import("http/Client.zig"); + +/// https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods +/// https://datatracker.ietf.org/doc/html/rfc7231#section-4 Initial definiton +/// https://datatracker.ietf.org/doc/html/rfc5789#section-2 PATCH +pub const Method = enum { + GET, + HEAD, + POST, + PUT, + DELETE, + CONNECT, + OPTIONS, + TRACE, + PATCH, + + /// Returns true if a request of this method is allowed to have a body + /// Actual behavior from servers may vary and should still be checked + pub fn requestHasBody(self: Method) bool { + return switch (self) { + .POST, .PUT, .PATCH => true, + .GET, .HEAD, .DELETE, .CONNECT, .OPTIONS, .TRACE => false, + }; + } + + /// Returns true if a response to this method is allowed to have a body + /// Actual behavior from clients may vary and should still be checked + pub fn responseHasBody(self: Method) bool { + return switch (self) { + .GET, .POST, .DELETE, .CONNECT, .OPTIONS, .PATCH => true, + .HEAD, .PUT, .TRACE => false, + }; + } + + /// An HTTP method is safe if it doesn't alter the state of the server. + /// https://developer.mozilla.org/en-US/docs/Glossary/Safe/HTTP + /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.1 + pub fn safe(self: Method) bool { + return switch (self) { + .GET, .HEAD, .OPTIONS, .TRACE => true, + .POST, .PUT, .DELETE, .CONNECT, .PATCH => false, + }; + } + + /// An HTTP method is idempotent if an identical request can be made once or several times in a row with the same effect while leaving the server in the same state. + /// https://developer.mozilla.org/en-US/docs/Glossary/Idempotent + /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.2 + pub fn idempotent(self: Method) bool { + return switch (self) { + .GET, .HEAD, .PUT, .DELETE, .OPTIONS, .TRACE => true, + .CONNECT, .POST, .PATCH => false, + }; + } + + /// A cacheable response is an HTTP response that can be cached, that is stored to be retrieved and used later, saving a new request to the server. + /// https://developer.mozilla.org/en-US/docs/Glossary/cacheable + /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.3 + pub fn cacheable(self: Method) bool { + return switch (self) { + .GET, .HEAD => true, + .POST, .PUT, .DELETE, .CONNECT, .OPTIONS, .TRACE, .PATCH => false, + }; + } +}; + +/// https://developer.mozilla.org/en-US/docs/Web/HTTP/Status +pub const Status = enum(u10) { + @"continue" = 100, // RFC7231, Section 6.2.1 + switching_protocols = 101, // RFC7231, Section 6.2.2 + processing = 102, // RFC2518 + early_hints = 103, // RFC8297 + + ok = 200, // RFC7231, Section 6.3.1 + created = 201, // RFC7231, Section 6.3.2 + accepted = 202, // RFC7231, Section 6.3.3 + non_authoritative_info = 203, // RFC7231, Section 6.3.4 + no_content = 204, // RFC7231, Section 6.3.5 + reset_content = 205, // RFC7231, Section 6.3.6 + partial_content = 206, // RFC7233, Section 4.1 + multi_status = 207, // RFC4918 + already_reported = 208, // RFC5842 + im_used = 226, // RFC3229 + + multiple_choice = 300, // RFC7231, Section 6.4.1 + moved_permanently = 301, // RFC7231, Section 6.4.2 + found = 302, // RFC7231, Section 6.4.3 + see_other = 303, // RFC7231, Section 6.4.4 + not_modified = 304, // RFC7232, Section 4.1 + use_proxy = 305, // RFC7231, Section 6.4.5 + temporary_redirect = 307, // RFC7231, Section 6.4.7 + permanent_redirect = 308, // RFC7538 + + bad_request = 400, // RFC7231, Section 6.5.1 + unauthorized = 401, // RFC7235, Section 3.1 + payment_required = 402, // RFC7231, Section 6.5.2 + forbidden = 403, // RFC7231, Section 6.5.3 + not_found = 404, // RFC7231, Section 6.5.4 + method_not_allowed = 405, // RFC7231, Section 6.5.5 + not_acceptable = 406, // RFC7231, Section 6.5.6 + proxy_auth_required = 407, // RFC7235, Section 3.2 + request_timeout = 408, // RFC7231, Section 6.5.7 + conflict = 409, // RFC7231, Section 6.5.8 + gone = 410, // RFC7231, Section 6.5.9 + length_required = 411, // RFC7231, Section 6.5.10 + precondition_failed = 412, // RFC7232, Section 4.2][RFC8144, Section 3.2 + payload_too_large = 413, // RFC7231, Section 6.5.11 + uri_too_long = 414, // RFC7231, Section 6.5.12 + unsupported_media_type = 415, // RFC7231, Section 6.5.13][RFC7694, Section 3 + range_not_satisfiable = 416, // RFC7233, Section 4.4 + expectation_failed = 417, // RFC7231, Section 6.5.14 + teapot = 418, // RFC 7168, 2.3.3 + misdirected_request = 421, // RFC7540, Section 9.1.2 + unprocessable_entity = 422, // RFC4918 + locked = 423, // RFC4918 + failed_dependency = 424, // RFC4918 + too_early = 425, // RFC8470 + upgrade_required = 426, // RFC7231, Section 6.5.15 + precondition_required = 428, // RFC6585 + too_many_requests = 429, // RFC6585 + header_fields_too_large = 431, // RFC6585 + unavailable_for_legal_reasons = 451, // RFC7725 + + internal_server_error = 500, // RFC7231, Section 6.6.1 + not_implemented = 501, // RFC7231, Section 6.6.2 + bad_gateway = 502, // RFC7231, Section 6.6.3 + service_unavailable = 503, // RFC7231, Section 6.6.4 + gateway_timeout = 504, // RFC7231, Section 6.6.5 + http_version_not_supported = 505, // RFC7231, Section 6.6.6 + variant_also_negotiates = 506, // RFC2295 + insufficient_storage = 507, // RFC4918 + loop_detected = 508, // RFC5842 + not_extended = 510, // RFC2774 + network_authentication_required = 511, // RFC6585 + + _, + + pub fn phrase(self: Status) ?[]const u8 { + return switch (self) { + // 1xx statuses + .@"continue" => "Continue", + .switching_protocols => "Switching Protocols", + .processing => "Processing", + .early_hints => "Early Hints", -pub const Method = @import("http/method.zig").Method; -pub const Status = @import("http/status.zig").Status; + // 2xx statuses + .ok => "OK", + .created => "Created", + .accepted => "Accepted", + .non_authoritative_info => "Non-Authoritative Information", + .no_content => "No Content", + .reset_content => "Reset Content", + .partial_content => "Partial Content", + .multi_status => "Multi-Status", + .already_reported => "Already Reported", + .im_used => "IM Used", + + // 3xx statuses + .multiple_choice => "Multiple Choice", + .moved_permanently => "Moved Permanently", + .found => "Found", + .see_other => "See Other", + .not_modified => "Not Modified", + .use_proxy => "Use Proxy", + .temporary_redirect => "Temporary Redirect", + .permanent_redirect => "Permanent Redirect", + + // 4xx statuses + .bad_request => "Bad Request", + .unauthorized => "Unauthorized", + .payment_required => "Payment Required", + .forbidden => "Forbidden", + .not_found => "Not Found", + .method_not_allowed => "Method Not Allowed", + .not_acceptable => "Not Acceptable", + .proxy_auth_required => "Proxy Authentication Required", + .request_timeout => "Request Timeout", + .conflict => "Conflict", + .gone => "Gone", + .length_required => "Length Required", + .precondition_failed => "Precondition Failed", + .payload_too_large => "Payload Too Large", + .uri_too_long => "URI Too Long", + .unsupported_media_type => "Unsupported Media Type", + .range_not_satisfiable => "Range Not Satisfiable", + .expectation_failed => "Expectation Failed", + .teapot => "I'm a teapot", + .misdirected_request => "Misdirected Request", + .unprocessable_entity => "Unprocessable Entity", + .locked => "Locked", + .failed_dependency => "Failed Dependency", + .too_early => "Too Early", + .upgrade_required => "Upgrade Required", + .precondition_required => "Precondition Required", + .too_many_requests => "Too Many Requests", + .header_fields_too_large => "Request Header Fields Too Large", + .unavailable_for_legal_reasons => "Unavailable For Legal Reasons", + + // 5xx statuses + .internal_server_error => "Internal Server Error", + .not_implemented => "Not Implemented", + .bad_gateway => "Bad Gateway", + .service_unavailable => "Service Unavailable", + .gateway_timeout => "Gateway Timeout", + .http_version_not_supported => "HTTP Version Not Supported", + .variant_also_negotiates => "Variant Also Negotiates", + .insufficient_storage => "Insufficient Storage", + .loop_detected => "Loop Detected", + .not_extended => "Not Extended", + .network_authentication_required => "Network Authentication Required", + + else => return null, + }; + } + + pub const Class = enum { + informational, + success, + redirect, + client_error, + server_error, + }; + + pub fn class(self: Status) ?Class { + return switch (@enumToInt(self)) { + 100...199 => .informational, + 200...299 => .success, + 300...399 => .redirect, + 400...499 => .client_error, + 500...599 => .server_error, + else => null, + }; + } + + test { + try std.testing.expectEqualStrings("OK", Status.ok.phrase().?); + try std.testing.expectEqualStrings("Not Found", Status.not_found.phrase().?); + } + + test { + try std.testing.expectEqual(@as(?Status.Class, Status.Class.success), Status.ok.class()); + try std.testing.expectEqual(@as(?Status.Class, Status.Class.client_error), Status.not_found.class()); + } +}; + +pub const Headers = struct { + state: State = .start, + invalid_index: u32 = undefined, + + pub const State = enum { invalid, start, line, nl_r, nl_n, nl2_r, finished }; + + /// Returns how many bytes are processed into headers. Always less than or + /// equal to bytes.len. If the amount returned is less than bytes.len, it + /// means the headers ended and the first byte after the double \r\n\r\n is + /// located at `bytes[result]`. + pub fn feed(h: *Headers, bytes: []const u8) usize { + for (bytes) |b, i| { + switch (h.state) { + .start => switch (b) { + '\r' => h.state = .nl_r, + '\n' => return invalid(h, i), + else => {}, + }, + .nl_r => switch (b) { + '\n' => h.state = .nl_n, + else => return invalid(h, i), + }, + .nl_n => switch (b) { + '\r' => h.state = .nl2_r, + else => h.state = .line, + }, + .nl2_r => switch (b) { + '\n' => h.state = .finished, + else => return invalid(h, i), + }, + .line => switch (b) { + '\r' => h.state = .nl_r, + '\n' => return invalid(h, i), + else => {}, + }, + .invalid => return i, + .finished => return i, + } + } + return bytes.len; + } + + fn invalid(h: *Headers, i: usize) usize { + h.invalid_index = @intCast(u32, i); + h.state = .invalid; + return i; + } +}; + +const std = @import("std.zig"); test { - std.testing.refAllDecls(@This()); + _ = Client; + _ = Method; + _ = Status; + _ = Headers; } diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig new file mode 100644 index 000000000000..8a4a771416ae --- /dev/null +++ b/lib/std/http/Client.zig @@ -0,0 +1,181 @@ +//! This API is a barely-touched, barely-functional http client, just the +//! absolute minimum thing I needed in order to test `std.crypto.tls`. Bear +//! with me and I promise the API will become useful and streamlined. + +const std = @import("../std.zig"); +const assert = std.debug.assert; +const http = std.http; +const net = std.net; +const Client = @This(); +const Url = std.Url; + +allocator: std.mem.Allocator, +headers: std.ArrayListUnmanaged(u8) = .{}, +active_requests: usize = 0, +ca_bundle: std.crypto.Certificate.Bundle = .{}, + +/// TODO: emit error.UnexpectedEndOfStream or something like that when the read +/// data does not match the content length. This is necessary since HTTPS disables +/// close_notify protection on underlying TLS streams. +pub const Request = struct { + client: *Client, + stream: net.Stream, + headers: std.ArrayListUnmanaged(u8) = .{}, + tls_client: std.crypto.tls.Client, + protocol: Protocol, + response_headers: http.Headers = .{}, + + pub const Protocol = enum { http, https }; + + pub const Options = struct { + method: http.Method = .GET, + }; + + pub fn deinit(req: *Request) void { + req.client.active_requests -= 1; + req.headers.deinit(req.client.allocator); + req.* = undefined; + } + + pub fn addHeader(req: *Request, name: []const u8, value: []const u8) !void { + const gpa = req.client.allocator; + // Ensure an extra +2 for the \r\n in end() + try req.headers.ensureUnusedCapacity(gpa, name.len + value.len + 6); + req.headers.appendSliceAssumeCapacity(name); + req.headers.appendSliceAssumeCapacity(": "); + req.headers.appendSliceAssumeCapacity(value); + req.headers.appendSliceAssumeCapacity("\r\n"); + } + + pub fn end(req: *Request) !void { + req.headers.appendSliceAssumeCapacity("\r\n"); + switch (req.protocol) { + .http => { + try req.stream.writeAll(req.headers.items); + }, + .https => { + try req.tls_client.writeAll(req.stream, req.headers.items); + }, + } + } + + pub fn readAll(req: *Request, buffer: []u8) !usize { + return readAtLeast(req, buffer, buffer.len); + } + + pub fn read(req: *Request, buffer: []u8) !usize { + return readAtLeast(req, buffer, 1); + } + + pub fn readAtLeast(req: *Request, buffer: []u8, len: usize) !usize { + assert(len <= buffer.len); + var index: usize = 0; + while (index < len) { + const headers_finished = req.response_headers.state == .finished; + const amt = try readAdvanced(req, buffer[index..]); + if (amt == 0 and headers_finished) break; + index += amt; + } + return index; + } + + /// This one can return 0 without meaning EOF. + /// TODO change to readvAdvanced + pub fn readAdvanced(req: *Request, buffer: []u8) !usize { + if (req.response_headers.state == .finished) return readRaw(req, buffer); + + const amt = try readRaw(req, buffer); + const data = buffer[0..amt]; + const i = req.response_headers.feed(data); + if (req.response_headers.state == .invalid) return error.InvalidHttpHeaders; + if (i < data.len) { + const rest = data[i..]; + std.mem.copy(u8, buffer, rest); + return rest.len; + } + return 0; + } + + /// Only abstracts over http/https. + fn readRaw(req: *Request, buffer: []u8) !usize { + switch (req.protocol) { + .http => return req.stream.read(buffer), + .https => return req.tls_client.read(req.stream, buffer), + } + } + + /// Only abstracts over http/https. + fn readAtLeastRaw(req: *Request, buffer: []u8, len: usize) !usize { + switch (req.protocol) { + .http => return req.stream.readAtLeast(buffer, len), + .https => return req.tls_client.readAtLeast(req.stream, buffer, len), + } + } +}; + +pub fn deinit(client: *Client) void { + assert(client.active_requests == 0); + client.headers.deinit(client.allocator); + client.* = undefined; +} + +pub fn request(client: *Client, url: Url, options: Request.Options) !Request { + const protocol = std.meta.stringToEnum(Request.Protocol, url.scheme) orelse + return error.UnsupportedUrlScheme; + const port: u16 = url.port orelse switch (protocol) { + .http => 80, + .https => 443, + }; + + var req: Request = .{ + .client = client, + .stream = try net.tcpConnectToHost(client.allocator, url.host, port), + .protocol = protocol, + .tls_client = undefined, + }; + client.active_requests += 1; + errdefer req.deinit(); + + switch (protocol) { + .http => {}, + .https => { + req.tls_client = try std.crypto.tls.Client.init(req.stream, client.ca_bundle, url.host); + // This is appropriate for HTTPS because the HTTP headers contain + // the content length which is used to detect truncation attacks. + req.tls_client.allow_truncation_attacks = true; + }, + } + + try req.headers.ensureUnusedCapacity( + client.allocator, + @tagName(options.method).len + + 1 + + url.path.len + + " HTTP/1.1\r\nHost: ".len + + url.host.len + + "\r\nUpgrade-Insecure-Requests: 1\r\n".len + + client.headers.items.len + + 2, // for the \r\n at the end of headers + ); + req.headers.appendSliceAssumeCapacity(@tagName(options.method)); + req.headers.appendSliceAssumeCapacity(" "); + req.headers.appendSliceAssumeCapacity(url.path); + req.headers.appendSliceAssumeCapacity(" HTTP/1.1\r\nHost: "); + req.headers.appendSliceAssumeCapacity(url.host); + switch (protocol) { + .https => req.headers.appendSliceAssumeCapacity("\r\nUpgrade-Insecure-Requests: 1\r\n"), + .http => req.headers.appendSliceAssumeCapacity("\r\n"), + } + req.headers.appendSliceAssumeCapacity(client.headers.items); + + return req; +} + +pub fn addHeader(client: *Client, name: []const u8, value: []const u8) !void { + const gpa = client.allocator; + try client.headers.ensureUnusedCapacity(gpa, name.len + value.len + 4); + client.headers.appendSliceAssumeCapacity(name); + client.headers.appendSliceAssumeCapacity(": "); + client.headers.appendSliceAssumeCapacity(value); + client.headers.appendSliceAssumeCapacity("\r\n"); +} diff --git a/lib/std/http/method.zig b/lib/std/http/method.zig deleted file mode 100644 index c118ca9a475c..000000000000 --- a/lib/std/http/method.zig +++ /dev/null @@ -1,65 +0,0 @@ -//! HTTP Methods -//! https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods - -// Style guide is violated here so that @tagName can be used effectively -/// https://datatracker.ietf.org/doc/html/rfc7231#section-4 Initial definiton -/// https://datatracker.ietf.org/doc/html/rfc5789#section-2 PATCH -pub const Method = enum { - GET, - HEAD, - POST, - PUT, - DELETE, - CONNECT, - OPTIONS, - TRACE, - PATCH, - - /// Returns true if a request of this method is allowed to have a body - /// Actual behavior from servers may vary and should still be checked - pub fn requestHasBody(self: Method) bool { - return switch (self) { - .POST, .PUT, .PATCH => true, - .GET, .HEAD, .DELETE, .CONNECT, .OPTIONS, .TRACE => false, - }; - } - - /// Returns true if a response to this method is allowed to have a body - /// Actual behavior from clients may vary and should still be checked - pub fn responseHasBody(self: Method) bool { - return switch (self) { - .GET, .POST, .DELETE, .CONNECT, .OPTIONS, .PATCH => true, - .HEAD, .PUT, .TRACE => false, - }; - } - - /// An HTTP method is safe if it doesn't alter the state of the server. - /// https://developer.mozilla.org/en-US/docs/Glossary/Safe/HTTP - /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.1 - pub fn safe(self: Method) bool { - return switch (self) { - .GET, .HEAD, .OPTIONS, .TRACE => true, - .POST, .PUT, .DELETE, .CONNECT, .PATCH => false, - }; - } - - /// An HTTP method is idempotent if an identical request can be made once or several times in a row with the same effect while leaving the server in the same state. - /// https://developer.mozilla.org/en-US/docs/Glossary/Idempotent - /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.2 - pub fn idempotent(self: Method) bool { - return switch (self) { - .GET, .HEAD, .PUT, .DELETE, .OPTIONS, .TRACE => true, - .CONNECT, .POST, .PATCH => false, - }; - } - - /// A cacheable response is an HTTP response that can be cached, that is stored to be retrieved and used later, saving a new request to the server. - /// https://developer.mozilla.org/en-US/docs/Glossary/cacheable - /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.3 - pub fn cacheable(self: Method) bool { - return switch (self) { - .GET, .HEAD => true, - .POST, .PUT, .DELETE, .CONNECT, .OPTIONS, .TRACE, .PATCH => false, - }; - } -}; diff --git a/lib/std/http/status.zig b/lib/std/http/status.zig deleted file mode 100644 index 91738e0533c1..000000000000 --- a/lib/std/http/status.zig +++ /dev/null @@ -1,182 +0,0 @@ -//! HTTP Status -//! https://developer.mozilla.org/en-US/docs/Web/HTTP/Status - -const std = @import("../std.zig"); - -pub const Status = enum(u10) { - @"continue" = 100, // RFC7231, Section 6.2.1 - switching_protocols = 101, // RFC7231, Section 6.2.2 - processing = 102, // RFC2518 - early_hints = 103, // RFC8297 - - ok = 200, // RFC7231, Section 6.3.1 - created = 201, // RFC7231, Section 6.3.2 - accepted = 202, // RFC7231, Section 6.3.3 - non_authoritative_info = 203, // RFC7231, Section 6.3.4 - no_content = 204, // RFC7231, Section 6.3.5 - reset_content = 205, // RFC7231, Section 6.3.6 - partial_content = 206, // RFC7233, Section 4.1 - multi_status = 207, // RFC4918 - already_reported = 208, // RFC5842 - im_used = 226, // RFC3229 - - multiple_choice = 300, // RFC7231, Section 6.4.1 - moved_permanently = 301, // RFC7231, Section 6.4.2 - found = 302, // RFC7231, Section 6.4.3 - see_other = 303, // RFC7231, Section 6.4.4 - not_modified = 304, // RFC7232, Section 4.1 - use_proxy = 305, // RFC7231, Section 6.4.5 - temporary_redirect = 307, // RFC7231, Section 6.4.7 - permanent_redirect = 308, // RFC7538 - - bad_request = 400, // RFC7231, Section 6.5.1 - unauthorized = 401, // RFC7235, Section 3.1 - payment_required = 402, // RFC7231, Section 6.5.2 - forbidden = 403, // RFC7231, Section 6.5.3 - not_found = 404, // RFC7231, Section 6.5.4 - method_not_allowed = 405, // RFC7231, Section 6.5.5 - not_acceptable = 406, // RFC7231, Section 6.5.6 - proxy_auth_required = 407, // RFC7235, Section 3.2 - request_timeout = 408, // RFC7231, Section 6.5.7 - conflict = 409, // RFC7231, Section 6.5.8 - gone = 410, // RFC7231, Section 6.5.9 - length_required = 411, // RFC7231, Section 6.5.10 - precondition_failed = 412, // RFC7232, Section 4.2][RFC8144, Section 3.2 - payload_too_large = 413, // RFC7231, Section 6.5.11 - uri_too_long = 414, // RFC7231, Section 6.5.12 - unsupported_media_type = 415, // RFC7231, Section 6.5.13][RFC7694, Section 3 - range_not_satisfiable = 416, // RFC7233, Section 4.4 - expectation_failed = 417, // RFC7231, Section 6.5.14 - teapot = 418, // RFC 7168, 2.3.3 - misdirected_request = 421, // RFC7540, Section 9.1.2 - unprocessable_entity = 422, // RFC4918 - locked = 423, // RFC4918 - failed_dependency = 424, // RFC4918 - too_early = 425, // RFC8470 - upgrade_required = 426, // RFC7231, Section 6.5.15 - precondition_required = 428, // RFC6585 - too_many_requests = 429, // RFC6585 - header_fields_too_large = 431, // RFC6585 - unavailable_for_legal_reasons = 451, // RFC7725 - - internal_server_error = 500, // RFC7231, Section 6.6.1 - not_implemented = 501, // RFC7231, Section 6.6.2 - bad_gateway = 502, // RFC7231, Section 6.6.3 - service_unavailable = 503, // RFC7231, Section 6.6.4 - gateway_timeout = 504, // RFC7231, Section 6.6.5 - http_version_not_supported = 505, // RFC7231, Section 6.6.6 - variant_also_negotiates = 506, // RFC2295 - insufficient_storage = 507, // RFC4918 - loop_detected = 508, // RFC5842 - not_extended = 510, // RFC2774 - network_authentication_required = 511, // RFC6585 - - _, - - pub fn phrase(self: Status) ?[]const u8 { - return switch (self) { - // 1xx statuses - .@"continue" => "Continue", - .switching_protocols => "Switching Protocols", - .processing => "Processing", - .early_hints => "Early Hints", - - // 2xx statuses - .ok => "OK", - .created => "Created", - .accepted => "Accepted", - .non_authoritative_info => "Non-Authoritative Information", - .no_content => "No Content", - .reset_content => "Reset Content", - .partial_content => "Partial Content", - .multi_status => "Multi-Status", - .already_reported => "Already Reported", - .im_used => "IM Used", - - // 3xx statuses - .multiple_choice => "Multiple Choice", - .moved_permanently => "Moved Permanently", - .found => "Found", - .see_other => "See Other", - .not_modified => "Not Modified", - .use_proxy => "Use Proxy", - .temporary_redirect => "Temporary Redirect", - .permanent_redirect => "Permanent Redirect", - - // 4xx statuses - .bad_request => "Bad Request", - .unauthorized => "Unauthorized", - .payment_required => "Payment Required", - .forbidden => "Forbidden", - .not_found => "Not Found", - .method_not_allowed => "Method Not Allowed", - .not_acceptable => "Not Acceptable", - .proxy_auth_required => "Proxy Authentication Required", - .request_timeout => "Request Timeout", - .conflict => "Conflict", - .gone => "Gone", - .length_required => "Length Required", - .precondition_failed => "Precondition Failed", - .payload_too_large => "Payload Too Large", - .uri_too_long => "URI Too Long", - .unsupported_media_type => "Unsupported Media Type", - .range_not_satisfiable => "Range Not Satisfiable", - .expectation_failed => "Expectation Failed", - .teapot => "I'm a teapot", - .misdirected_request => "Misdirected Request", - .unprocessable_entity => "Unprocessable Entity", - .locked => "Locked", - .failed_dependency => "Failed Dependency", - .too_early => "Too Early", - .upgrade_required => "Upgrade Required", - .precondition_required => "Precondition Required", - .too_many_requests => "Too Many Requests", - .header_fields_too_large => "Request Header Fields Too Large", - .unavailable_for_legal_reasons => "Unavailable For Legal Reasons", - - // 5xx statuses - .internal_server_error => "Internal Server Error", - .not_implemented => "Not Implemented", - .bad_gateway => "Bad Gateway", - .service_unavailable => "Service Unavailable", - .gateway_timeout => "Gateway Timeout", - .http_version_not_supported => "HTTP Version Not Supported", - .variant_also_negotiates => "Variant Also Negotiates", - .insufficient_storage => "Insufficient Storage", - .loop_detected => "Loop Detected", - .not_extended => "Not Extended", - .network_authentication_required => "Network Authentication Required", - - else => return null, - }; - } - - pub const Class = enum { - informational, - success, - redirect, - client_error, - server_error, - }; - - pub fn class(self: Status) ?Class { - return switch (@enumToInt(self)) { - 100...199 => .informational, - 200...299 => .success, - 300...399 => .redirect, - 400...499 => .client_error, - 500...599 => .server_error, - else => null, - }; - } -}; - -test { - try std.testing.expectEqualStrings("OK", Status.ok.phrase().?); - try std.testing.expectEqualStrings("Not Found", Status.not_found.phrase().?); -} - -test { - try std.testing.expectEqual(@as(?Status.Class, Status.Class.success), Status.ok.class()); - try std.testing.expectEqual(@as(?Status.Class, Status.Class.client_error), Status.not_found.class()); -} diff --git a/lib/std/meta.zig b/lib/std/meta.zig index 39d561469f79..db284f8b61c4 100644 --- a/lib/std/meta.zig +++ b/lib/std/meta.zig @@ -810,21 +810,25 @@ test "std.meta.activeTag" { const TagPayloadType = TagPayload; -///Given a tagged union type, and an enum, return the type of the union -/// field corresponding to the enum tag. -pub fn TagPayload(comptime U: type, comptime tag: Tag(U)) type { +pub fn TagPayloadByName(comptime U: type, comptime tag_name: []const u8) type { comptime debug.assert(trait.is(.Union)(U)); const info = @typeInfo(U).Union; inline for (info.fields) |field_info| { - if (comptime mem.eql(u8, field_info.name, @tagName(tag))) + if (comptime mem.eql(u8, field_info.name, tag_name)) return field_info.type; } unreachable; } +/// Given a tagged union type, and an enum, return the type of the union field +/// corresponding to the enum tag. +pub fn TagPayload(comptime U: type, comptime tag: Tag(U)) type { + return TagPayloadByName(U, @tagName(tag)); +} + test "std.meta.TagPayload" { const Event = union(enum) { Moved: struct { diff --git a/lib/std/net.zig b/lib/std/net.zig index 4a0582e7f5ef..aa51176184e4 100644 --- a/lib/std/net.zig +++ b/lib/std/net.zig @@ -1672,6 +1672,40 @@ pub const Stream = struct { } } + pub fn readv(s: Stream, iovecs: []const os.iovec) ReadError!usize { + if (builtin.os.tag == .windows) { + // TODO improve this to use ReadFileScatter + if (iovecs.len == 0) return @as(usize, 0); + const first = iovecs[0]; + return os.windows.ReadFile(s.handle, first.iov_base[0..first.iov_len], null, io.default_mode); + } + + return os.readv(s.handle, iovecs); + } + + /// Returns the number of bytes read. If the number read is smaller than + /// `buffer.len`, it means the stream reached the end. Reaching the end of + /// a stream is not an error condition. + pub fn readAll(s: Stream, buffer: []u8) ReadError!usize { + return readAtLeast(s, buffer, buffer.len); + } + + /// Returns the number of bytes read, calling the underlying read function + /// the minimal number of times until the buffer has at least `len` bytes + /// filled. If the number read is less than `len` it means the stream + /// reached the end. Reaching the end of the stream is not an error + /// condition. + pub fn readAtLeast(s: Stream, buffer: []u8, len: usize) ReadError!usize { + assert(len <= buffer.len); + var index: usize = 0; + while (index < len) { + const amt = try s.read(buffer[index..]); + if (amt == 0) break; + index += amt; + } + return index; + } + /// TODO in evented I/O mode, this implementation incorrectly uses the event loop's /// file system thread instead of non-blocking. It needs to be reworked to properly /// use non-blocking I/O. @@ -1687,6 +1721,13 @@ pub const Stream = struct { } } + pub fn writeAll(self: Stream, bytes: []const u8) WriteError!void { + var index: usize = 0; + while (index < bytes.len) { + index += try self.write(bytes[index..]); + } + } + /// See https://github.com/ziglang/zig/issues/7699 /// See equivalent function: `std.fs.File.writev`. pub fn writev(self: Stream, iovecs: []const os.iovec_const) WriteError!usize { diff --git a/lib/std/os.zig b/lib/std/os.zig index 9f96c671514d..ffc294f0e6f6 100644 --- a/lib/std/os.zig +++ b/lib/std/os.zig @@ -767,6 +767,7 @@ pub fn readv(fd: fd_t, iov: []const iovec) ReadError!usize { .ISDIR => return error.IsDir, .NOBUFS => return error.SystemResources, .NOMEM => return error.SystemResources, + .CONNRESET => return error.ConnectionResetByPeer, else => |err| return unexpectedErrno(err), } } @@ -5685,11 +5686,11 @@ pub fn sendmsg( /// The file descriptor of the sending socket. sockfd: socket_t, /// Message header and iovecs - msg: msghdr_const, + msg: *const msghdr_const, flags: u32, ) SendMsgError!usize { while (true) { - const rc = system.sendmsg(sockfd, @ptrCast(*const std.x.os.Socket.Message, &msg), @intCast(c_int, flags)); + const rc = system.sendmsg(sockfd, msg, flags); if (builtin.os.tag == .windows) { if (rc == windows.ws2_32.SOCKET_ERROR) { switch (windows.ws2_32.WSAGetLastError()) { diff --git a/lib/std/os/linux.zig b/lib/std/os/linux.zig index ecb8a21d7a6b..d9d5fb32043d 100644 --- a/lib/std/os/linux.zig +++ b/lib/std/os/linux.zig @@ -1226,11 +1226,14 @@ pub fn getsockopt(fd: i32, level: u32, optname: u32, noalias optval: [*]u8, noal return syscall5(.getsockopt, @bitCast(usize, @as(isize, fd)), level, optname, @ptrToInt(optval), @ptrToInt(optlen)); } -pub fn sendmsg(fd: i32, msg: *const std.x.os.Socket.Message, flags: c_int) usize { +pub fn sendmsg(fd: i32, msg: *const msghdr_const, flags: u32) usize { + const fd_usize = @bitCast(usize, @as(isize, fd)); + const msg_usize = @ptrToInt(msg); if (native_arch == .x86) { - return socketcall(SC.sendmsg, &[3]usize{ @bitCast(usize, @as(isize, fd)), @ptrToInt(msg), @bitCast(usize, @as(isize, flags)) }); + return socketcall(SC.sendmsg, &[3]usize{ fd_usize, msg_usize, flags }); + } else { + return syscall3(.sendmsg, fd_usize, msg_usize, flags); } - return syscall3(.sendmsg, @bitCast(usize, @as(isize, fd)), @ptrToInt(msg), @bitCast(usize, @as(isize, flags))); } pub fn sendmmsg(fd: i32, msgvec: [*]mmsghdr_const, vlen: u32, flags: u32) usize { @@ -1274,24 +1277,42 @@ pub fn sendmmsg(fd: i32, msgvec: [*]mmsghdr_const, vlen: u32, flags: u32) usize } pub fn connect(fd: i32, addr: *const anyopaque, len: socklen_t) usize { + const fd_usize = @bitCast(usize, @as(isize, fd)); + const addr_usize = @ptrToInt(addr); if (native_arch == .x86) { - return socketcall(SC.connect, &[3]usize{ @bitCast(usize, @as(isize, fd)), @ptrToInt(addr), len }); + return socketcall(SC.connect, &[3]usize{ fd_usize, addr_usize, len }); + } else { + return syscall3(.connect, fd_usize, addr_usize, len); } - return syscall3(.connect, @bitCast(usize, @as(isize, fd)), @ptrToInt(addr), len); } -pub fn recvmsg(fd: i32, msg: *std.x.os.Socket.Message, flags: c_int) usize { +pub fn recvmsg(fd: i32, msg: *msghdr, flags: u32) usize { + const fd_usize = @bitCast(usize, @as(isize, fd)); + const msg_usize = @ptrToInt(msg); if (native_arch == .x86) { - return socketcall(SC.recvmsg, &[3]usize{ @bitCast(usize, @as(isize, fd)), @ptrToInt(msg), @bitCast(usize, @as(isize, flags)) }); + return socketcall(SC.recvmsg, &[3]usize{ fd_usize, msg_usize, flags }); + } else { + return syscall3(.recvmsg, fd_usize, msg_usize, flags); } - return syscall3(.recvmsg, @bitCast(usize, @as(isize, fd)), @ptrToInt(msg), @bitCast(usize, @as(isize, flags))); } -pub fn recvfrom(fd: i32, noalias buf: [*]u8, len: usize, flags: u32, noalias addr: ?*sockaddr, noalias alen: ?*socklen_t) usize { +pub fn recvfrom( + fd: i32, + noalias buf: [*]u8, + len: usize, + flags: u32, + noalias addr: ?*sockaddr, + noalias alen: ?*socklen_t, +) usize { + const fd_usize = @bitCast(usize, @as(isize, fd)); + const buf_usize = @ptrToInt(buf); + const addr_usize = @ptrToInt(addr); + const alen_usize = @ptrToInt(alen); if (native_arch == .x86) { - return socketcall(SC.recvfrom, &[6]usize{ @bitCast(usize, @as(isize, fd)), @ptrToInt(buf), len, flags, @ptrToInt(addr), @ptrToInt(alen) }); + return socketcall(SC.recvfrom, &[6]usize{ fd_usize, buf_usize, len, flags, addr_usize, alen_usize }); + } else { + return syscall6(.recvfrom, fd_usize, buf_usize, len, flags, addr_usize, alen_usize); } - return syscall6(.recvfrom, @bitCast(usize, @as(isize, fd)), @ptrToInt(buf), len, flags, @ptrToInt(addr), @ptrToInt(alen)); } pub fn shutdown(fd: i32, how: i32) usize { @@ -3219,7 +3240,15 @@ pub const sockaddr = extern struct { data: [14]u8, pub const SS_MAXSIZE = 128; - pub const storage = std.x.os.Socket.Address.Native.Storage; + pub const storage = extern struct { + family: sa_family_t align(8), + padding: [SS_MAXSIZE - @sizeOf(sa_family_t)]u8 = undefined, + + comptime { + assert(@sizeOf(storage) == SS_MAXSIZE); + assert(@alignOf(storage) == 8); + } + }; /// IPv4 socket address pub const in = extern struct { diff --git a/lib/std/os/linux/seccomp.zig b/lib/std/os/linux/seccomp.zig index fd002e741679..03a96633f895 100644 --- a/lib/std/os/linux/seccomp.zig +++ b/lib/std/os/linux/seccomp.zig @@ -6,16 +6,14 @@ //! isn't that useful for general-purpose applications, and so a mode that //! utilizes user-supplied filters mode was added. //! -//! Seccomp filters are classic BPF programs, which means that all the -//! information under `std.x.net.bpf` applies here as well. Conceptually, a -//! seccomp program is attached to the kernel and is executed on each syscall. -//! The "packet" being validated is the `data` structure, and the verdict is an -//! action that the kernel performs on the calling process. The actions are -//! variations on a "pass" or "fail" result, where a pass allows the syscall to -//! continue and a fail blocks the syscall and returns some sort of error value. -//! See the full list of actions under ::RET for more information. Finally, only -//! word-sized, absolute loads (`ld [k]`) are supported to read from the `data` -//! structure. +//! Seccomp filters are classic BPF programs. Conceptually, a seccomp program +//! is attached to the kernel and is executed on each syscall. The "packet" +//! being validated is the `data` structure, and the verdict is an action that +//! the kernel performs on the calling process. The actions are variations on a +//! "pass" or "fail" result, where a pass allows the syscall to continue and a +//! fail blocks the syscall and returns some sort of error value. See the full +//! list of actions under ::RET for more information. Finally, only word-sized, +//! absolute loads (`ld [k]`) are supported to read from the `data` structure. //! //! There are some issues with the filter API that have traditionally made //! writing them a pain: diff --git a/lib/std/os/windows/ws2_32.zig b/lib/std/os/windows/ws2_32.zig index 90e1422fd263..b4d18264f3c0 100644 --- a/lib/std/os/windows/ws2_32.zig +++ b/lib/std/os/windows/ws2_32.zig @@ -1,4 +1,5 @@ const std = @import("../../std.zig"); +const assert = std.debug.assert; const windows = std.os.windows; const WINAPI = windows.WINAPI; @@ -1106,7 +1107,15 @@ pub const sockaddr = extern struct { data: [14]u8, pub const SS_MAXSIZE = 128; - pub const storage = std.x.os.Socket.Address.Native.Storage; + pub const storage = extern struct { + family: ADDRESS_FAMILY align(8), + padding: [SS_MAXSIZE - @sizeOf(ADDRESS_FAMILY)]u8 = undefined, + + comptime { + assert(@sizeOf(storage) == SS_MAXSIZE); + assert(@alignOf(storage) == 8); + } + }; /// IPv4 socket address pub const in = extern struct { @@ -1207,7 +1216,7 @@ pub const LPFN_GETACCEPTEXSOCKADDRS = *const fn ( pub const LPFN_WSASENDMSG = *const fn ( s: SOCKET, - lpMsg: *const std.x.os.Socket.Message, + lpMsg: *const WSAMSG_const, dwFlags: u32, lpNumberOfBytesSent: ?*u32, lpOverlapped: ?*OVERLAPPED, @@ -1216,7 +1225,7 @@ pub const LPFN_WSASENDMSG = *const fn ( pub const LPFN_WSARECVMSG = *const fn ( s: SOCKET, - lpMsg: *std.x.os.Socket.Message, + lpMsg: *WSAMSG, lpdwNumberOfBytesRecv: ?*u32, lpOverlapped: ?*OVERLAPPED, lpCompletionRoutine: ?LPWSAOVERLAPPED_COMPLETION_ROUTINE, @@ -2090,7 +2099,7 @@ pub extern "ws2_32" fn WSASend( pub extern "ws2_32" fn WSASendMsg( s: SOCKET, - lpMsg: *const std.x.os.Socket.Message, + lpMsg: *WSAMSG_const, dwFlags: u32, lpNumberOfBytesSent: ?*u32, lpOverlapped: ?*OVERLAPPED, @@ -2099,7 +2108,7 @@ pub extern "ws2_32" fn WSASendMsg( pub extern "ws2_32" fn WSARecvMsg( s: SOCKET, - lpMsg: *std.x.os.Socket.Message, + lpMsg: *WSAMSG, lpdwNumberOfBytesRecv: ?*u32, lpOverlapped: ?*OVERLAPPED, lpCompletionRoutine: ?LPWSAOVERLAPPED_COMPLETION_ROUTINE, diff --git a/lib/std/std.zig b/lib/std/std.zig index 1b4217b50640..1cbcd6bad70e 100644 --- a/lib/std/std.zig +++ b/lib/std/std.zig @@ -42,6 +42,7 @@ pub const Target = @import("target.zig").Target; pub const Thread = @import("Thread.zig"); pub const Treap = @import("treap.zig").Treap; pub const Tz = tz.Tz; +pub const Url = @import("Url.zig"); pub const array_hash_map = @import("array_hash_map.zig"); pub const atomic = @import("atomic.zig"); @@ -90,7 +91,6 @@ pub const tz = @import("tz.zig"); pub const unicode = @import("unicode.zig"); pub const valgrind = @import("valgrind.zig"); pub const wasm = @import("wasm.zig"); -pub const x = @import("x.zig"); pub const zig = @import("zig.zig"); pub const start = @import("start.zig"); diff --git a/lib/std/x.zig b/lib/std/x.zig deleted file mode 100644 index 64caf324ed2f..000000000000 --- a/lib/std/x.zig +++ /dev/null @@ -1,19 +0,0 @@ -const std = @import("std.zig"); - -pub const os = struct { - pub const Socket = @import("x/os/socket.zig").Socket; - pub usingnamespace @import("x/os/io.zig"); - pub usingnamespace @import("x/os/net.zig"); -}; - -pub const net = struct { - pub const ip = @import("x/net/ip.zig"); - pub const tcp = @import("x/net/tcp.zig"); - pub const bpf = @import("x/net/bpf.zig"); -}; - -test { - inline for (.{ os, net }) |module| { - std.testing.refAllDecls(module); - } -} diff --git a/lib/std/x/net/bpf.zig b/lib/std/x/net/bpf.zig deleted file mode 100644 index bee930c332bf..000000000000 --- a/lib/std/x/net/bpf.zig +++ /dev/null @@ -1,1003 +0,0 @@ -//! This package provides instrumentation for creating Berkeley Packet Filter[1] -//! (BPF) programs, along with a simulator for running them. -//! -//! BPF is a mechanism for cheap, in-kernel packet filtering. Programs are -//! attached to a network device and executed for every packet that flows -//! through it. The program must then return a verdict: the amount of packet -//! bytes that the kernel should copy into userspace. Execution speed is -//! achieved by having programs run in a limited virtual machine, which has the -//! added benefit of graceful failure in the face of buggy programs. -//! -//! The BPF virtual machine has a 32-bit word length and a small number of -//! word-sized registers: -//! -//! - The accumulator, `a`: The source/destination of arithmetic and logic -//! operations. -//! - The index register, `x`: Used as an offset for indirect memory access and -//! as a comparison value for conditional jumps. -//! - The scratch memory store, `M[0]..M[15]`: Used for saving the value of a/x -//! for later use. -//! -//! The packet being examined is an array of bytes, and is addressed using plain -//! array subscript notation, e.g. [10] for the byte at offset 10. An implicit -//! program counter, `pc`, is intialized to zero and incremented for each instruction. -//! -//! The machine has a fixed instruction set with the following form, where the -//! numbers represent bit length: -//! -//! ``` -//! ┌───────────┬──────┬──────┐ -//! │ opcode:16 │ jt:8 │ jt:8 │ -//! ├───────────┴──────┴──────┤ -//! │ k:32 │ -//! └─────────────────────────┘ -//! ``` -//! -//! The `opcode` indicates the instruction class and its addressing mode. -//! Opcodes are generated by performing binary addition on the 8-bit class and -//! mode constants. For example, the opcode for loading a byte from the packet -//! at X + 2, (`ldb [x + 2]`), is: -//! -//! ``` -//! LD | IND | B = 0x00 | 0x40 | 0x20 -//! = 0x60 -//! ``` -//! -//! `jt` is an offset used for conditional jumps, and increments the program -//! counter by its amount if the comparison was true. Conversely, `jf` -//! increments the counter if it was false. These fields are ignored in all -//! other cases. `k` is a generic variable used for various purposes, most -//! commonly as some sort of constant. -//! -//! This package contains opcode extensions used by different implementations, -//! where "extension" is anything outside of the original that was imported into -//! 4.4BSD[2]. These are marked with "EXTENSION", along with a list of -//! implementations that use them. -//! -//! Most of the doc-comments use the BPF assembly syntax as described in the -//! original paper[1]. For the sake of completeness, here is the complete -//! instruction set, along with the extensions: -//! -//!``` -//! opcode addressing modes -//! ld #k #len M[k] [k] [x + k] -//! ldh [k] [x + k] -//! ldb [k] [x + k] -//! ldx #k #len M[k] 4 * ([k] & 0xf) arc4random() -//! st M[k] -//! stx M[k] -//! jmp L -//! jeq #k, Lt, Lf -//! jgt #k, Lt, Lf -//! jge #k, Lt, Lf -//! jset #k, Lt, Lf -//! add #k x -//! sub #k x -//! mul #k x -//! div #k x -//! or #k x -//! and #k x -//! lsh #k x -//! rsh #k x -//! neg #k x -//! mod #k x -//! xor #k x -//! ret #k a -//! tax -//! txa -//! ``` -//! -//! Finally, a note on program design. The lack of backwards jumps leads to a -//! "return early, return often" control flow. Take for example the program -//! generated from the tcpdump filter `ip`: -//! -//! ``` -//! (000) ldh [12] ; Ethernet Packet Type -//! (001) jeq #0x86dd, 2, 7 ; ETHERTYPE_IPV6 -//! (002) ldb [20] ; IPv6 Next Header -//! (003) jeq #0x6, 10, 4 ; TCP -//! (004) jeq #0x2c, 5, 11 ; IPv6 Fragment Header -//! (005) ldb [54] ; TCP Source Port -//! (006) jeq #0x6, 10, 11 ; IPPROTO_TCP -//! (007) jeq #0x800, 8, 11 ; ETHERTYPE_IP -//! (008) ldb [23] ; IPv4 Protocol -//! (009) jeq #0x6, 10, 11 ; IPPROTO_TCP -//! (010) ret #262144 ; copy 0x40000 -//! (011) ret #0 ; skip packet -//! ``` -//! -//! Here we can make a few observations: -//! -//! - The problem "filter only tcp packets" has essentially been transformed -//! into a series of layer checks. -//! - There are two distinct branches in the code, one for validating IPv4 -//! headers and one for IPv6 headers. -//! - Most conditional jumps in these branches lead directly to the last two -//! instructions, a pass or fail. Thus the goal of a program is to find the -//! fastest route to a pass/fail comparison. -//! -//! [1]: S. McCanne and V. Jacobson, "The BSD Packet Filter: A New Architecture -//! for User-level Packet Capture", Proceedings of the 1993 Winter USENIX. -//! [2]: https://minnie.tuhs.org/cgi-bin/utree.pl?file=4.4BSD/usr/src/sys/net/bpf.h -const std = @import("std"); -const builtin = @import("builtin"); -const native_endian = builtin.target.cpu.arch.endian(); -const mem = std.mem; -const math = std.math; -const random = std.crypto.random; -const assert = std.debug.assert; -const expectEqual = std.testing.expectEqual; -const expectError = std.testing.expectError; -const expect = std.testing.expect; - -// instruction classes -/// ld, ldh, ldb: Load data into a. -pub const LD = 0x00; -/// ldx: Load data into x. -pub const LDX = 0x01; -/// st: Store into scratch memory the value of a. -pub const ST = 0x02; -/// st: Store into scratch memory the value of x. -pub const STX = 0x03; -/// alu: Wrapping arithmetic/bitwise operations on a using the value of k/x. -pub const ALU = 0x04; -/// jmp, jeq, jgt, je, jset: Increment the program counter based on a comparison -/// between k/x and the accumulator. -pub const JMP = 0x05; -/// ret: Return a verdict using the value of k/the accumulator. -pub const RET = 0x06; -/// tax, txa: Register value copying between X and a. -pub const MISC = 0x07; - -// Size of data to be loaded from the packet. -/// ld: 32-bit full word. -pub const W = 0x00; -/// ldh: 16-bit half word. -pub const H = 0x08; -/// ldb: Single byte. -pub const B = 0x10; - -// Addressing modes used for loads to a/x. -/// #k: The immediate value stored in k. -pub const IMM = 0x00; -/// [k]: The value at offset k in the packet. -pub const ABS = 0x20; -/// [x + k]: The value at offset x + k in the packet. -pub const IND = 0x40; -/// M[k]: The value of the k'th scratch memory register. -pub const MEM = 0x60; -/// #len: The size of the packet. -pub const LEN = 0x80; -/// 4 * ([k] & 0xf): Four times the low four bits of the byte at offset k in the -/// packet. This is used for efficiently loading the header length of an IP -/// packet. -pub const MSH = 0xa0; -/// arc4random: 32-bit integer generated from a CPRNG (see arc4random(3)) loaded into a. -/// EXTENSION. Defined for: -/// - OpenBSD. -pub const RND = 0xc0; - -// Modifiers for different instruction classes. -/// Use the value of k for alu operations (add #k). -/// Compare against the value of k for jumps (jeq #k, Lt, Lf). -/// Return the value of k for returns (ret #k). -pub const K = 0x00; -/// Use the value of x for alu operations (add x). -/// Compare against the value of X for jumps (jeq x, Lt, Lf). -pub const X = 0x08; -/// Return the value of a for returns (ret a). -pub const A = 0x10; - -// ALU Operations on a using the value of k/x. -// All arithmetic operations are defined to overflow the value of a. -/// add: a = a + k -/// a = a + x. -pub const ADD = 0x00; -/// sub: a = a - k -/// a = a - x. -pub const SUB = 0x10; -/// mul: a = a * k -/// a = a * x. -pub const MUL = 0x20; -/// div: a = a / k -/// a = a / x. -/// Truncated division. -pub const DIV = 0x30; -/// or: a = a | k -/// a = a | x. -pub const OR = 0x40; -/// and: a = a & k -/// a = a & x. -pub const AND = 0x50; -/// lsh: a = a << k -/// a = a << x. -/// a = a << k, a = a << x. -pub const LSH = 0x60; -/// rsh: a = a >> k -/// a = a >> x. -pub const RSH = 0x70; -/// neg: a = -a. -/// Note that this isn't a binary negation, rather the value of `~a + 1`. -pub const NEG = 0x80; -/// mod: a = a % k -/// a = a % x. -/// EXTENSION. Defined for: -/// - Linux. -/// - NetBSD + Minix 3. -/// - FreeBSD and derivitives. -pub const MOD = 0x90; -/// xor: a = a ^ k -/// a = a ^ x. -/// EXTENSION. Defined for: -/// - Linux. -/// - NetBSD + Minix 3. -/// - FreeBSD and derivitives. -pub const XOR = 0xa0; - -// Jump operations using a comparison between a and x/k. -/// jmp L: pc += k. -/// No comparison done here. -pub const JA = 0x00; -/// jeq #k, Lt, Lf: pc += (a == k) ? jt : jf. -/// jeq x, Lt, Lf: pc += (a == x) ? jt : jf. -pub const JEQ = 0x10; -/// jgt #k, Lt, Lf: pc += (a > k) ? jt : jf. -/// jgt x, Lt, Lf: pc += (a > x) ? jt : jf. -pub const JGT = 0x20; -/// jge #k, Lt, Lf: pc += (a >= k) ? jt : jf. -/// jge x, Lt, Lf: pc += (a >= x) ? jt : jf. -pub const JGE = 0x30; -/// jset #k, Lt, Lf: pc += (a & k > 0) ? jt : jf. -/// jset x, Lt, Lf: pc += (a & x > 0) ? jt : jf. -pub const JSET = 0x40; - -// Miscellaneous operations/register copy. -/// tax: x = a. -pub const TAX = 0x00; -/// txa: a = x. -pub const TXA = 0x80; - -/// The 16 registers in the scratch memory store as named enums. -pub const Scratch = enum(u4) { m0, m1, m2, m3, m4, m5, m6, m7, m8, m9, m10, m11, m12, m13, m14, m15 }; -pub const MEMWORDS = 16; -pub const MAXINSNS = switch (builtin.os.tag) { - .linux => 4096, - else => 512, -}; -pub const MINBUFSIZE = 32; -pub const MAXBUFSIZE = 1 << 21; - -pub const Insn = extern struct { - opcode: u16, - jt: u8, - jf: u8, - k: u32, - - /// Implements the `std.fmt.format` API. - /// The formatting is similar to the output of tcpdump -dd. - pub fn format( - self: Insn, - comptime layout: []const u8, - opts: std.fmt.FormatOptions, - writer: anytype, - ) !void { - _ = opts; - if (layout.len != 0) std.fmt.invalidFmtError(layout, self); - - try std.fmt.format( - writer, - "Insn{{ 0x{X:0<2}, {d}, {d}, 0x{X:0<8} }}", - .{ self.opcode, self.jt, self.jf, self.k }, - ); - } - - const Size = enum(u8) { - word = W, - half_word = H, - byte = B, - }; - - fn stmt(opcode: u16, k: u32) Insn { - return .{ - .opcode = opcode, - .jt = 0, - .jf = 0, - .k = k, - }; - } - - pub fn ld_imm(value: u32) Insn { - return stmt(LD | IMM, value); - } - - pub fn ld_abs(size: Size, offset: u32) Insn { - return stmt(LD | ABS | @enumToInt(size), offset); - } - - pub fn ld_ind(size: Size, offset: u32) Insn { - return stmt(LD | IND | @enumToInt(size), offset); - } - - pub fn ld_mem(reg: Scratch) Insn { - return stmt(LD | MEM, @enumToInt(reg)); - } - - pub fn ld_len() Insn { - return stmt(LD | LEN | W, 0); - } - - pub fn ld_rnd() Insn { - return stmt(LD | RND | W, 0); - } - - pub fn ldx_imm(value: u32) Insn { - return stmt(LDX | IMM, value); - } - - pub fn ldx_mem(reg: Scratch) Insn { - return stmt(LDX | MEM, @enumToInt(reg)); - } - - pub fn ldx_len() Insn { - return stmt(LDX | LEN | W, 0); - } - - pub fn ldx_msh(offset: u32) Insn { - return stmt(LDX | MSH | B, offset); - } - - pub fn st(reg: Scratch) Insn { - return stmt(ST, @enumToInt(reg)); - } - pub fn stx(reg: Scratch) Insn { - return stmt(STX, @enumToInt(reg)); - } - - const AluOp = enum(u16) { - add = ADD, - sub = SUB, - mul = MUL, - div = DIV, - @"or" = OR, - @"and" = AND, - lsh = LSH, - rsh = RSH, - mod = MOD, - xor = XOR, - }; - - const Source = enum(u16) { - k = K, - x = X, - }; - const KOrX = union(Source) { - k: u32, - x: void, - }; - - pub fn alu_neg() Insn { - return stmt(ALU | NEG, 0); - } - - pub fn alu(op: AluOp, source: KOrX) Insn { - return stmt( - ALU | @enumToInt(op) | @enumToInt(source), - if (source == .k) source.k else 0, - ); - } - - const JmpOp = enum(u16) { - jeq = JEQ, - jgt = JGT, - jge = JGE, - jset = JSET, - }; - - pub fn jmp_ja(location: u32) Insn { - return stmt(JMP | JA, location); - } - - pub fn jmp(op: JmpOp, source: KOrX, jt: u8, jf: u8) Insn { - return Insn{ - .opcode = JMP | @enumToInt(op) | @enumToInt(source), - .jt = jt, - .jf = jf, - .k = if (source == .k) source.k else 0, - }; - } - - const Verdict = enum(u16) { - k = K, - a = A, - }; - const KOrA = union(Verdict) { - k: u32, - a: void, - }; - - pub fn ret(verdict: KOrA) Insn { - return stmt( - RET | @enumToInt(verdict), - if (verdict == .k) verdict.k else 0, - ); - } - - pub fn tax() Insn { - return stmt(MISC | TAX, 0); - } - - pub fn txa() Insn { - return stmt(MISC | TXA, 0); - } -}; - -fn opcodeEqual(opcode: u16, insn: Insn) !void { - try expectEqual(opcode, insn.opcode); -} - -test "opcodes" { - try opcodeEqual(0x00, Insn.ld_imm(0)); - try opcodeEqual(0x20, Insn.ld_abs(.word, 0)); - try opcodeEqual(0x28, Insn.ld_abs(.half_word, 0)); - try opcodeEqual(0x30, Insn.ld_abs(.byte, 0)); - try opcodeEqual(0x40, Insn.ld_ind(.word, 0)); - try opcodeEqual(0x48, Insn.ld_ind(.half_word, 0)); - try opcodeEqual(0x50, Insn.ld_ind(.byte, 0)); - try opcodeEqual(0x60, Insn.ld_mem(.m0)); - try opcodeEqual(0x80, Insn.ld_len()); - try opcodeEqual(0xc0, Insn.ld_rnd()); - - try opcodeEqual(0x01, Insn.ldx_imm(0)); - try opcodeEqual(0x61, Insn.ldx_mem(.m0)); - try opcodeEqual(0x81, Insn.ldx_len()); - try opcodeEqual(0xb1, Insn.ldx_msh(0)); - - try opcodeEqual(0x02, Insn.st(.m0)); - try opcodeEqual(0x03, Insn.stx(.m0)); - - try opcodeEqual(0x04, Insn.alu(.add, .{ .k = 0 })); - try opcodeEqual(0x14, Insn.alu(.sub, .{ .k = 0 })); - try opcodeEqual(0x24, Insn.alu(.mul, .{ .k = 0 })); - try opcodeEqual(0x34, Insn.alu(.div, .{ .k = 0 })); - try opcodeEqual(0x44, Insn.alu(.@"or", .{ .k = 0 })); - try opcodeEqual(0x54, Insn.alu(.@"and", .{ .k = 0 })); - try opcodeEqual(0x64, Insn.alu(.lsh, .{ .k = 0 })); - try opcodeEqual(0x74, Insn.alu(.rsh, .{ .k = 0 })); - try opcodeEqual(0x94, Insn.alu(.mod, .{ .k = 0 })); - try opcodeEqual(0xa4, Insn.alu(.xor, .{ .k = 0 })); - try opcodeEqual(0x84, Insn.alu_neg()); - try opcodeEqual(0x0c, Insn.alu(.add, .x)); - try opcodeEqual(0x1c, Insn.alu(.sub, .x)); - try opcodeEqual(0x2c, Insn.alu(.mul, .x)); - try opcodeEqual(0x3c, Insn.alu(.div, .x)); - try opcodeEqual(0x4c, Insn.alu(.@"or", .x)); - try opcodeEqual(0x5c, Insn.alu(.@"and", .x)); - try opcodeEqual(0x6c, Insn.alu(.lsh, .x)); - try opcodeEqual(0x7c, Insn.alu(.rsh, .x)); - try opcodeEqual(0x9c, Insn.alu(.mod, .x)); - try opcodeEqual(0xac, Insn.alu(.xor, .x)); - - try opcodeEqual(0x05, Insn.jmp_ja(0)); - try opcodeEqual(0x15, Insn.jmp(.jeq, .{ .k = 0 }, 0, 0)); - try opcodeEqual(0x25, Insn.jmp(.jgt, .{ .k = 0 }, 0, 0)); - try opcodeEqual(0x35, Insn.jmp(.jge, .{ .k = 0 }, 0, 0)); - try opcodeEqual(0x45, Insn.jmp(.jset, .{ .k = 0 }, 0, 0)); - try opcodeEqual(0x1d, Insn.jmp(.jeq, .x, 0, 0)); - try opcodeEqual(0x2d, Insn.jmp(.jgt, .x, 0, 0)); - try opcodeEqual(0x3d, Insn.jmp(.jge, .x, 0, 0)); - try opcodeEqual(0x4d, Insn.jmp(.jset, .x, 0, 0)); - - try opcodeEqual(0x06, Insn.ret(.{ .k = 0 })); - try opcodeEqual(0x16, Insn.ret(.a)); - - try opcodeEqual(0x07, Insn.tax()); - try opcodeEqual(0x87, Insn.txa()); -} - -pub const Error = error{ - InvalidOpcode, - InvalidOffset, - InvalidLocation, - DivisionByZero, - NoReturn, -}; - -/// A simple implementation of the BPF virtual-machine. -/// Use this to run/debug programs. -pub fn simulate( - packet: []const u8, - filter: []const Insn, - byte_order: std.builtin.Endian, -) Error!u32 { - assert(filter.len > 0 and filter.len < MAXINSNS); - assert(packet.len < MAXBUFSIZE); - const len = @intCast(u32, packet.len); - - var a: u32 = 0; - var x: u32 = 0; - var m = mem.zeroes([MEMWORDS]u32); - var pc: usize = 0; - - while (pc < filter.len) : (pc += 1) { - const i = filter[pc]; - // Cast to a wider type to protect against overflow. - const k = @as(u64, i.k); - const remaining = filter.len - (pc + 1); - - // Do validation/error checking here to compress the second switch. - switch (i.opcode) { - LD | ABS | W => if (k + @sizeOf(u32) - 1 >= packet.len) return error.InvalidOffset, - LD | ABS | H => if (k + @sizeOf(u16) - 1 >= packet.len) return error.InvalidOffset, - LD | ABS | B => if (k >= packet.len) return error.InvalidOffset, - LD | IND | W => if (k + x + @sizeOf(u32) - 1 >= packet.len) return error.InvalidOffset, - LD | IND | H => if (k + x + @sizeOf(u16) - 1 >= packet.len) return error.InvalidOffset, - LD | IND | B => if (k + x >= packet.len) return error.InvalidOffset, - - LDX | MSH | B => if (k >= packet.len) return error.InvalidOffset, - ST, STX, LD | MEM, LDX | MEM => if (i.k >= MEMWORDS) return error.InvalidOffset, - - JMP | JA => if (remaining <= i.k) return error.InvalidOffset, - JMP | JEQ | K, - JMP | JGT | K, - JMP | JGE | K, - JMP | JSET | K, - JMP | JEQ | X, - JMP | JGT | X, - JMP | JGE | X, - JMP | JSET | X, - => if (remaining <= i.jt or remaining <= i.jf) return error.InvalidLocation, - else => {}, - } - switch (i.opcode) { - LD | IMM => a = i.k, - LD | MEM => a = m[i.k], - LD | LEN | W => a = len, - LD | RND | W => a = random.int(u32), - LD | ABS | W => a = mem.readInt(u32, packet[i.k..][0..@sizeOf(u32)], byte_order), - LD | ABS | H => a = mem.readInt(u16, packet[i.k..][0..@sizeOf(u16)], byte_order), - LD | ABS | B => a = packet[i.k], - LD | IND | W => a = mem.readInt(u32, packet[i.k + x ..][0..@sizeOf(u32)], byte_order), - LD | IND | H => a = mem.readInt(u16, packet[i.k + x ..][0..@sizeOf(u16)], byte_order), - LD | IND | B => a = packet[i.k + x], - - LDX | IMM => x = i.k, - LDX | MEM => x = m[i.k], - LDX | LEN | W => x = len, - LDX | MSH | B => x = @as(u32, @truncate(u4, packet[i.k])) << 2, - - ST => m[i.k] = a, - STX => m[i.k] = x, - - ALU | ADD | K => a +%= i.k, - ALU | SUB | K => a -%= i.k, - ALU | MUL | K => a *%= i.k, - ALU | DIV | K => a = try math.divTrunc(u32, a, i.k), - ALU | OR | K => a |= i.k, - ALU | AND | K => a &= i.k, - ALU | LSH | K => a = math.shl(u32, a, i.k), - ALU | RSH | K => a = math.shr(u32, a, i.k), - ALU | MOD | K => a = try math.mod(u32, a, i.k), - ALU | XOR | K => a ^= i.k, - ALU | ADD | X => a +%= x, - ALU | SUB | X => a -%= x, - ALU | MUL | X => a *%= x, - ALU | DIV | X => a = try math.divTrunc(u32, a, x), - ALU | OR | X => a |= x, - ALU | AND | X => a &= x, - ALU | LSH | X => a = math.shl(u32, a, x), - ALU | RSH | X => a = math.shr(u32, a, x), - ALU | MOD | X => a = try math.mod(u32, a, x), - ALU | XOR | X => a ^= x, - ALU | NEG => a = @bitCast(u32, -%@bitCast(i32, a)), - - JMP | JA => pc += i.k, - JMP | JEQ | K => pc += if (a == i.k) i.jt else i.jf, - JMP | JGT | K => pc += if (a > i.k) i.jt else i.jf, - JMP | JGE | K => pc += if (a >= i.k) i.jt else i.jf, - JMP | JSET | K => pc += if (a & i.k > 0) i.jt else i.jf, - JMP | JEQ | X => pc += if (a == x) i.jt else i.jf, - JMP | JGT | X => pc += if (a > x) i.jt else i.jf, - JMP | JGE | X => pc += if (a >= x) i.jt else i.jf, - JMP | JSET | X => pc += if (a & x > 0) i.jt else i.jf, - - RET | K => return i.k, - RET | A => return a, - - MISC | TAX => x = a, - MISC | TXA => a = x, - else => return error.InvalidOpcode, - } - } - - return error.NoReturn; -} - -// This program is the BPF form of the tcpdump filter: -// -// tcpdump -dd 'ip host mirror.internode.on.net and tcp port ftp-data' -// -// As of January 2022, mirror.internode.on.net resolves to 150.101.135.3 -// -// For reference, here's what it looks like in BPF assembler. -// Note that the jumps are used for TCP/IP layer checks. -// -// ``` -// ldh [12] (#proto) -// jeq #0x0800 (ETHERTYPE_IP), L1, fail -// L1: ld [26] -// jeq #150.101.135.3, L2, dest -// dest: ld [30] -// jeq #150.101.135.3, L2, fail -// L2: ldb [23] -// jeq #0x6 (IPPROTO_TCP), L3, fail -// L3: ldh [20] -// jset #0x1fff, fail, plen -// plen: ldx 4 * ([14] & 0xf) -// ldh [x + 14] -// jeq #0x14 (FTP), pass, dstp -// dstp: ldh [x + 16] -// jeq #0x14 (FTP), pass, fail -// pass: ret #0x40000 -// fail: ret #0 -// ``` -const tcpdump_filter = [_]Insn{ - Insn.ld_abs(.half_word, 12), - Insn.jmp(.jeq, .{ .k = 0x800 }, 0, 14), - Insn.ld_abs(.word, 26), - Insn.jmp(.jeq, .{ .k = 0x96658703 }, 2, 0), - Insn.ld_abs(.word, 30), - Insn.jmp(.jeq, .{ .k = 0x96658703 }, 0, 10), - Insn.ld_abs(.byte, 23), - Insn.jmp(.jeq, .{ .k = 0x6 }, 0, 8), - Insn.ld_abs(.half_word, 20), - Insn.jmp(.jset, .{ .k = 0x1fff }, 6, 0), - Insn.ldx_msh(14), - Insn.ld_ind(.half_word, 14), - Insn.jmp(.jeq, .{ .k = 0x14 }, 2, 0), - Insn.ld_ind(.half_word, 16), - Insn.jmp(.jeq, .{ .k = 0x14 }, 0, 1), - Insn.ret(.{ .k = 0x40000 }), - Insn.ret(.{ .k = 0 }), -}; - -// This packet is the output of `ls` on mirror.internode.on.net:/, captured -// using the filter above. -// -// zig fmt: off -const ftp_data = [_]u8{ - // ethernet - 14 bytes: IPv4(0x0800) from a4:71:74:ad:4b:f0 -> de:ad:be:ef:f0:0f - 0xde, 0xad, 0xbe, 0xef, 0xf0, 0x0f, 0xa4, 0x71, 0x74, 0xad, 0x4b, 0xf0, 0x08, 0x00, - // IPv4 - 20 bytes: TCP data from 150.101.135.3 -> 192.168.1.3 - 0x45, 0x00, 0x01, 0xf2, 0x70, 0x3b, 0x40, 0x00, 0x37, 0x06, 0xf2, 0xb6, - 0x96, 0x65, 0x87, 0x03, 0xc0, 0xa8, 0x01, 0x03, - // TCP - 32 bytes: Source port: 20 (FTP). Payload = 446 bytes - 0x00, 0x14, 0x80, 0x6d, 0x35, 0x81, 0x2d, 0x40, 0x4f, 0x8a, 0x29, 0x9e, 0x80, 0x18, 0x00, 0x2e, - 0x88, 0x8d, 0x00, 0x00, 0x01, 0x01, 0x08, 0x0a, 0x0b, 0x59, 0x5d, 0x09, 0x32, 0x8b, 0x51, 0xa0 -} ++ - // Raw line-based FTP data - 446 bytes - "lrwxrwxrwx 1 root root 12 Feb 14 2012 debian -> .pub2/debian\r\n" ++ - "lrwxrwxrwx 1 root root 15 Feb 14 2012 debian-cd -> .pub2/debian-cd\r\n" ++ - "lrwxrwxrwx 1 root root 9 Mar 9 2018 linux -> pub/linux\r\n" ++ - "drwxr-xr-X 3 mirror mirror 4096 Sep 20 08:10 pub\r\n" ++ - "lrwxrwxrwx 1 root root 12 Feb 14 2012 ubuntu -> .pub2/ubuntu\r\n" ++ - "-rw-r--r-- 1 root root 1044 Jan 20 2015 welcome.msg\r\n"; -// zig fmt: on - -test "tcpdump filter" { - try expectEqual( - @as(u32, 0x40000), - try simulate(ftp_data, &tcpdump_filter, .Big), - ); -} - -fn expectPass(data: anytype, filter: []const Insn) !void { - try expectEqual( - @as(u32, 0), - try simulate(mem.asBytes(data), filter, .Big), - ); -} - -fn expectFail(expected_error: anyerror, data: anytype, filter: []const Insn) !void { - try expectError( - expected_error, - simulate(mem.asBytes(data), filter, native_endian), - ); -} - -test "simulator coverage" { - const some_data = [_]u8{ - 0xaa, 0xbb, 0xcc, 0xdd, 0x7f, - }; - - try expectPass(&some_data, &.{ - // ld #10 - // ldx #1 - // st M[0] - // stx M[1] - // fail if A != 10 - Insn.ld_imm(10), - Insn.ldx_imm(1), - Insn.st(.m0), - Insn.stx(.m1), - Insn.jmp(.jeq, .{ .k = 10 }, 1, 0), - Insn.ret(.{ .k = 1 }), - // ld [0] - // fail if A != 0xaabbccdd - Insn.ld_abs(.word, 0), - Insn.jmp(.jeq, .{ .k = 0xaabbccdd }, 1, 0), - Insn.ret(.{ .k = 2 }), - // ldh [0] - // fail if A != 0xaabb - Insn.ld_abs(.half_word, 0), - Insn.jmp(.jeq, .{ .k = 0xaabb }, 1, 0), - Insn.ret(.{ .k = 3 }), - // ldb [0] - // fail if A != 0xaa - Insn.ld_abs(.byte, 0), - Insn.jmp(.jeq, .{ .k = 0xaa }, 1, 0), - Insn.ret(.{ .k = 4 }), - // ld [x + 0] - // fail if A != 0xbbccdd7f - Insn.ld_ind(.word, 0), - Insn.jmp(.jeq, .{ .k = 0xbbccdd7f }, 1, 0), - Insn.ret(.{ .k = 5 }), - // ldh [x + 0] - // fail if A != 0xbbcc - Insn.ld_ind(.half_word, 0), - Insn.jmp(.jeq, .{ .k = 0xbbcc }, 1, 0), - Insn.ret(.{ .k = 6 }), - // ldb [x + 0] - // fail if A != 0xbb - Insn.ld_ind(.byte, 0), - Insn.jmp(.jeq, .{ .k = 0xbb }, 1, 0), - Insn.ret(.{ .k = 7 }), - // ld M[0] - // fail if A != 10 - Insn.ld_mem(.m0), - Insn.jmp(.jeq, .{ .k = 10 }, 1, 0), - Insn.ret(.{ .k = 8 }), - // ld #len - // fail if A != 5 - Insn.ld_len(), - Insn.jmp(.jeq, .{ .k = some_data.len }, 1, 0), - Insn.ret(.{ .k = 9 }), - // ld #0 - // ld arc4random() - // fail if A == 0 - Insn.ld_imm(0), - Insn.ld_rnd(), - Insn.jmp(.jgt, .{ .k = 0 }, 1, 0), - Insn.ret(.{ .k = 10 }), - // ld #3 - // ldx #10 - // st M[2] - // txa - // fail if a != x - Insn.ld_imm(3), - Insn.ldx_imm(10), - Insn.st(.m2), - Insn.txa(), - Insn.jmp(.jeq, .x, 1, 0), - Insn.ret(.{ .k = 11 }), - // ldx M[2] - // fail if A <= X - Insn.ldx_mem(.m2), - Insn.jmp(.jgt, .x, 1, 0), - Insn.ret(.{ .k = 12 }), - // ldx #len - // fail if a <= x - Insn.ldx_len(), - Insn.jmp(.jgt, .x, 1, 0), - Insn.ret(.{ .k = 13 }), - // a = 4 * (0x7f & 0xf) - // x = 4 * ([4] & 0xf) - // fail if a != x - Insn.ld_imm(4 * (0x7f & 0xf)), - Insn.ldx_msh(4), - Insn.jmp(.jeq, .x, 1, 0), - Insn.ret(.{ .k = 14 }), - // ld #(u32)-1 - // ldx #2 - // add #1 - // fail if a != 0 - Insn.ld_imm(0xffffffff), - Insn.ldx_imm(2), - Insn.alu(.add, .{ .k = 1 }), - Insn.jmp(.jeq, .{ .k = 0 }, 1, 0), - Insn.ret(.{ .k = 15 }), - // sub #1 - // fail if a != (u32)-1 - Insn.alu(.sub, .{ .k = 1 }), - Insn.jmp(.jeq, .{ .k = 0xffffffff }, 1, 0), - Insn.ret(.{ .k = 16 }), - // add x - // fail if a != 1 - Insn.alu(.add, .x), - Insn.jmp(.jeq, .{ .k = 1 }, 1, 0), - Insn.ret(.{ .k = 17 }), - // sub x - // fail if a != (u32)-1 - Insn.alu(.sub, .x), - Insn.jmp(.jeq, .{ .k = 0xffffffff }, 1, 0), - Insn.ret(.{ .k = 18 }), - // ld #16 - // mul #2 - // fail if a != 32 - Insn.ld_imm(16), - Insn.alu(.mul, .{ .k = 2 }), - Insn.jmp(.jeq, .{ .k = 32 }, 1, 0), - Insn.ret(.{ .k = 19 }), - // mul x - // fail if a != 64 - Insn.alu(.mul, .x), - Insn.jmp(.jeq, .{ .k = 64 }, 1, 0), - Insn.ret(.{ .k = 20 }), - // div #2 - // fail if a != 32 - Insn.alu(.div, .{ .k = 2 }), - Insn.jmp(.jeq, .{ .k = 32 }, 1, 0), - Insn.ret(.{ .k = 21 }), - // div x - // fail if a != 16 - Insn.alu(.div, .x), - Insn.jmp(.jeq, .{ .k = 16 }, 1, 0), - Insn.ret(.{ .k = 22 }), - // or #4 - // fail if a != 20 - Insn.alu(.@"or", .{ .k = 4 }), - Insn.jmp(.jeq, .{ .k = 20 }, 1, 0), - Insn.ret(.{ .k = 23 }), - // or x - // fail if a != 22 - Insn.alu(.@"or", .x), - Insn.jmp(.jeq, .{ .k = 22 }, 1, 0), - Insn.ret(.{ .k = 24 }), - // and #6 - // fail if a != 6 - Insn.alu(.@"and", .{ .k = 0b110 }), - Insn.jmp(.jeq, .{ .k = 6 }, 1, 0), - Insn.ret(.{ .k = 25 }), - // and x - // fail if a != 2 - Insn.alu(.@"and", .x), - Insn.jmp(.jeq, .x, 1, 0), - Insn.ret(.{ .k = 26 }), - // xor #15 - // fail if a != 13 - Insn.alu(.xor, .{ .k = 0b1111 }), - Insn.jmp(.jeq, .{ .k = 0b1101 }, 1, 0), - Insn.ret(.{ .k = 27 }), - // xor x - // fail if a != 15 - Insn.alu(.xor, .x), - Insn.jmp(.jeq, .{ .k = 0b1111 }, 1, 0), - Insn.ret(.{ .k = 28 }), - // rsh #1 - // fail if a != 7 - Insn.alu(.rsh, .{ .k = 1 }), - Insn.jmp(.jeq, .{ .k = 0b0111 }, 1, 0), - Insn.ret(.{ .k = 29 }), - // rsh x - // fail if a != 1 - Insn.alu(.rsh, .x), - Insn.jmp(.jeq, .{ .k = 0b0001 }, 1, 0), - Insn.ret(.{ .k = 30 }), - // lsh #1 - // fail if a != 2 - Insn.alu(.lsh, .{ .k = 1 }), - Insn.jmp(.jeq, .{ .k = 0b0010 }, 1, 0), - Insn.ret(.{ .k = 31 }), - // lsh x - // fail if a != 8 - Insn.alu(.lsh, .x), - Insn.jmp(.jeq, .{ .k = 0b1000 }, 1, 0), - Insn.ret(.{ .k = 32 }), - // mod 6 - // fail if a != 2 - Insn.alu(.mod, .{ .k = 6 }), - Insn.jmp(.jeq, .{ .k = 2 }, 1, 0), - Insn.ret(.{ .k = 33 }), - // mod x - // fail if a != 0 - Insn.alu(.mod, .x), - Insn.jmp(.jeq, .{ .k = 0 }, 1, 0), - Insn.ret(.{ .k = 34 }), - // tax - // neg - // fail if a != (u32)-2 - Insn.txa(), - Insn.alu_neg(), - Insn.jmp(.jeq, .{ .k = ~@as(u32, 2) + 1 }, 1, 0), - Insn.ret(.{ .k = 35 }), - // ja #1 (skip the next instruction) - Insn.jmp_ja(1), - Insn.ret(.{ .k = 36 }), - // ld #20 - // tax - // fail if a != 20 - // fail if a != x - Insn.ld_imm(20), - Insn.tax(), - Insn.jmp(.jeq, .{ .k = 20 }, 1, 0), - Insn.ret(.{ .k = 37 }), - Insn.jmp(.jeq, .x, 1, 0), - Insn.ret(.{ .k = 38 }), - // ld #19 - // fail if a == 20 - // fail if a == x - // fail if a >= 20 - // fail if a >= X - Insn.ld_imm(19), - Insn.jmp(.jeq, .{ .k = 20 }, 0, 1), - Insn.ret(.{ .k = 39 }), - Insn.jmp(.jeq, .x, 0, 1), - Insn.ret(.{ .k = 40 }), - Insn.jmp(.jgt, .{ .k = 20 }, 0, 1), - Insn.ret(.{ .k = 41 }), - Insn.jmp(.jgt, .x, 0, 1), - Insn.ret(.{ .k = 42 }), - // ld #21 - // fail if a < 20 - // fail if a < x - Insn.ld_imm(21), - Insn.jmp(.jgt, .{ .k = 20 }, 1, 0), - Insn.ret(.{ .k = 43 }), - Insn.jmp(.jgt, .x, 1, 0), - Insn.ret(.{ .k = 44 }), - // ldx #22 - // fail if a < 22 - // fail if a < x - Insn.ldx_imm(22), - Insn.jmp(.jge, .{ .k = 22 }, 0, 1), - Insn.ret(.{ .k = 45 }), - Insn.jmp(.jge, .x, 0, 1), - Insn.ret(.{ .k = 46 }), - // ld #23 - // fail if a >= 22 - // fail if a >= x - Insn.ld_imm(23), - Insn.jmp(.jge, .{ .k = 22 }, 1, 0), - Insn.ret(.{ .k = 47 }), - Insn.jmp(.jge, .x, 1, 0), - Insn.ret(.{ .k = 48 }), - // ldx #0b10100 - // fail if a & 0b10100 == 0 - // fail if a & x == 0 - Insn.ldx_imm(0b10100), - Insn.jmp(.jset, .{ .k = 0b10100 }, 1, 0), - Insn.ret(.{ .k = 47 }), - Insn.jmp(.jset, .x, 1, 0), - Insn.ret(.{ .k = 48 }), - // ldx #0 - // fail if a & 0 > 0 - // fail if a & x > 0 - Insn.ldx_imm(0), - Insn.jmp(.jset, .{ .k = 0 }, 0, 1), - Insn.ret(.{ .k = 49 }), - Insn.jmp(.jset, .x, 0, 1), - Insn.ret(.{ .k = 50 }), - Insn.ret(.{ .k = 0 }), - }); - try expectPass(&some_data, &.{ - Insn.ld_imm(35), - Insn.ld_imm(0), - Insn.ret(.a), - }); - - // Errors - try expectFail(error.NoReturn, &some_data, &.{ - Insn.ld_imm(10), - }); - try expectFail(error.InvalidOpcode, &some_data, &.{ - Insn.stmt(0x7f, 0xdeadbeef), - }); - try expectFail(error.InvalidOffset, &some_data, &.{ - Insn.stmt(LD | ABS | W, 10), - }); - try expectFail(error.InvalidLocation, &some_data, &.{ - Insn.jmp(.jeq, .{ .k = 0 }, 10, 0), - }); - try expectFail(error.InvalidLocation, &some_data, &.{ - Insn.jmp(.jeq, .{ .k = 0 }, 0, 10), - }); -} diff --git a/lib/std/x/net/ip.zig b/lib/std/x/net/ip.zig deleted file mode 100644 index b3da9725d8cf..000000000000 --- a/lib/std/x/net/ip.zig +++ /dev/null @@ -1,57 +0,0 @@ -const std = @import("../../std.zig"); - -const fmt = std.fmt; - -const IPv4 = std.x.os.IPv4; -const IPv6 = std.x.os.IPv6; -const Socket = std.x.os.Socket; - -/// A generic IP abstraction. -const ip = @This(); - -/// A union of all eligible types of IP addresses. -pub const Address = union(enum) { - ipv4: IPv4.Address, - ipv6: IPv6.Address, - - /// Instantiate a new address with a IPv4 host and port. - pub fn initIPv4(host: IPv4, port: u16) Address { - return .{ .ipv4 = .{ .host = host, .port = port } }; - } - - /// Instantiate a new address with a IPv6 host and port. - pub fn initIPv6(host: IPv6, port: u16) Address { - return .{ .ipv6 = .{ .host = host, .port = port } }; - } - - /// Re-interpret a generic socket address into an IP address. - pub fn from(address: Socket.Address) ip.Address { - return switch (address) { - .ipv4 => |ipv4_address| .{ .ipv4 = ipv4_address }, - .ipv6 => |ipv6_address| .{ .ipv6 = ipv6_address }, - }; - } - - /// Re-interpret an IP address into a generic socket address. - pub fn into(self: ip.Address) Socket.Address { - return switch (self) { - .ipv4 => |ipv4_address| .{ .ipv4 = ipv4_address }, - .ipv6 => |ipv6_address| .{ .ipv6 = ipv6_address }, - }; - } - - /// Implements the `std.fmt.format` API. - pub fn format( - self: ip.Address, - comptime layout: []const u8, - opts: fmt.FormatOptions, - writer: anytype, - ) !void { - if (layout.len != 0) std.fmt.invalidFmtError(layout, self); - _ = opts; - switch (self) { - .ipv4 => |address| try fmt.format(writer, "{}:{}", .{ address.host, address.port }), - .ipv6 => |address| try fmt.format(writer, "{}:{}", .{ address.host, address.port }), - } - } -}; diff --git a/lib/std/x/net/tcp.zig b/lib/std/x/net/tcp.zig deleted file mode 100644 index 0293deb9db32..000000000000 --- a/lib/std/x/net/tcp.zig +++ /dev/null @@ -1,447 +0,0 @@ -const std = @import("../../std.zig"); -const builtin = @import("builtin"); - -const io = std.io; -const os = std.os; -const ip = std.x.net.ip; - -const fmt = std.fmt; -const mem = std.mem; -const testing = std.testing; -const native_os = builtin.os; - -const IPv4 = std.x.os.IPv4; -const IPv6 = std.x.os.IPv6; -const Socket = std.x.os.Socket; -const Buffer = std.x.os.Buffer; - -/// A generic TCP socket abstraction. -const tcp = @This(); - -/// A TCP client-address pair. -pub const Connection = struct { - client: tcp.Client, - address: ip.Address, - - /// Enclose a TCP client and address into a client-address pair. - pub fn from(conn: Socket.Connection) tcp.Connection { - return .{ - .client = tcp.Client.from(conn.socket), - .address = ip.Address.from(conn.address), - }; - } - - /// Unravel a TCP client-address pair into a socket-address pair. - pub fn into(self: tcp.Connection) Socket.Connection { - return .{ - .socket = self.client.socket, - .address = self.address.into(), - }; - } - - /// Closes the underlying client of the connection. - pub fn deinit(self: tcp.Connection) void { - self.client.deinit(); - } -}; - -/// Possible domains that a TCP client/listener may operate over. -pub const Domain = enum(u16) { - ip = os.AF.INET, - ipv6 = os.AF.INET6, -}; - -/// A TCP client. -pub const Client = struct { - socket: Socket, - - /// Implements `std.io.Reader`. - pub const Reader = struct { - client: Client, - flags: u32, - - /// Implements `readFn` for `std.io.Reader`. - pub fn read(self: Client.Reader, buffer: []u8) !usize { - return self.client.read(buffer, self.flags); - } - }; - - /// Implements `std.io.Writer`. - pub const Writer = struct { - client: Client, - flags: u32, - - /// Implements `writeFn` for `std.io.Writer`. - pub fn write(self: Client.Writer, buffer: []const u8) !usize { - return self.client.write(buffer, self.flags); - } - }; - - /// Opens a new client. - pub fn init(domain: tcp.Domain, flags: std.enums.EnumFieldStruct(Socket.InitFlags, bool, false)) !Client { - return Client{ - .socket = try Socket.init( - @enumToInt(domain), - os.SOCK.STREAM, - os.IPPROTO.TCP, - flags, - ), - }; - } - - /// Enclose a TCP client over an existing socket. - pub fn from(socket: Socket) Client { - return Client{ .socket = socket }; - } - - /// Closes the client. - pub fn deinit(self: Client) void { - self.socket.deinit(); - } - - /// Shutdown either the read side, write side, or all sides of the client's underlying socket. - pub fn shutdown(self: Client, how: os.ShutdownHow) !void { - return self.socket.shutdown(how); - } - - /// Have the client attempt to the connect to an address. - pub fn connect(self: Client, address: ip.Address) !void { - return self.socket.connect(address.into()); - } - - /// Extracts the error set of a function. - /// TODO: remove after Socket.{read, write} error unions are well-defined across different platforms - fn ErrorSetOf(comptime Function: anytype) type { - return @typeInfo(@typeInfo(@TypeOf(Function)).Fn.return_type.?).ErrorUnion.error_set; - } - - /// Wrap `tcp.Client` into `std.io.Reader`. - pub fn reader(self: Client, flags: u32) io.Reader(Client.Reader, ErrorSetOf(Client.Reader.read), Client.Reader.read) { - return .{ .context = .{ .client = self, .flags = flags } }; - } - - /// Wrap `tcp.Client` into `std.io.Writer`. - pub fn writer(self: Client, flags: u32) io.Writer(Client.Writer, ErrorSetOf(Client.Writer.write), Client.Writer.write) { - return .{ .context = .{ .client = self, .flags = flags } }; - } - - /// Read data from the socket into the buffer provided with a set of flags - /// specified. It returns the number of bytes read into the buffer provided. - pub fn read(self: Client, buf: []u8, flags: u32) !usize { - return self.socket.read(buf, flags); - } - - /// Write a buffer of data provided to the socket with a set of flags specified. - /// It returns the number of bytes that are written to the socket. - pub fn write(self: Client, buf: []const u8, flags: u32) !usize { - return self.socket.write(buf, flags); - } - - /// Writes multiple I/O vectors with a prepended message header to the socket - /// with a set of flags specified. It returns the number of bytes that are - /// written to the socket. - pub fn writeMessage(self: Client, msg: Socket.Message, flags: u32) !usize { - return self.socket.writeMessage(msg, flags); - } - - /// Read multiple I/O vectors with a prepended message header from the socket - /// with a set of flags specified. It returns the number of bytes that were - /// read into the buffer provided. - pub fn readMessage(self: Client, msg: *Socket.Message, flags: u32) !usize { - return self.socket.readMessage(msg, flags); - } - - /// Query and return the latest cached error on the client's underlying socket. - pub fn getError(self: Client) !void { - return self.socket.getError(); - } - - /// Query the read buffer size of the client's underlying socket. - pub fn getReadBufferSize(self: Client) !u32 { - return self.socket.getReadBufferSize(); - } - - /// Query the write buffer size of the client's underlying socket. - pub fn getWriteBufferSize(self: Client) !u32 { - return self.socket.getWriteBufferSize(); - } - - /// Query the address that the client's socket is locally bounded to. - pub fn getLocalAddress(self: Client) !ip.Address { - return ip.Address.from(try self.socket.getLocalAddress()); - } - - /// Query the address that the socket is connected to. - pub fn getRemoteAddress(self: Client) !ip.Address { - return ip.Address.from(try self.socket.getRemoteAddress()); - } - - /// Have close() or shutdown() syscalls block until all queued messages in the client have been successfully - /// sent, or if the timeout specified in seconds has been reached. It returns `error.UnsupportedSocketOption` - /// if the host does not support the option for a socket to linger around up until a timeout specified in - /// seconds. - pub fn setLinger(self: Client, timeout_seconds: ?u16) !void { - return self.socket.setLinger(timeout_seconds); - } - - /// Have keep-alive messages be sent periodically. The timing in which keep-alive messages are sent are - /// dependant on operating system settings. It returns `error.UnsupportedSocketOption` if the host does - /// not support periodically sending keep-alive messages on connection-oriented sockets. - pub fn setKeepAlive(self: Client, enabled: bool) !void { - return self.socket.setKeepAlive(enabled); - } - - /// Disable Nagle's algorithm on a TCP socket. It returns `error.UnsupportedSocketOption` if - /// the host does not support sockets disabling Nagle's algorithm. - pub fn setNoDelay(self: Client, enabled: bool) !void { - if (@hasDecl(os.TCP, "NODELAY")) { - const bytes = mem.asBytes(&@as(usize, @boolToInt(enabled))); - return self.socket.setOption(os.IPPROTO.TCP, os.TCP.NODELAY, bytes); - } - return error.UnsupportedSocketOption; - } - - /// Enables TCP Quick ACK on a TCP socket to immediately send rather than delay ACKs when necessary. It returns - /// `error.UnsupportedSocketOption` if the host does not support TCP Quick ACK. - pub fn setQuickACK(self: Client, enabled: bool) !void { - if (@hasDecl(os.TCP, "QUICKACK")) { - return self.socket.setOption(os.IPPROTO.TCP, os.TCP.QUICKACK, mem.asBytes(&@as(u32, @boolToInt(enabled)))); - } - return error.UnsupportedSocketOption; - } - - /// Set the write buffer size of the socket. - pub fn setWriteBufferSize(self: Client, size: u32) !void { - return self.socket.setWriteBufferSize(size); - } - - /// Set the read buffer size of the socket. - pub fn setReadBufferSize(self: Client, size: u32) !void { - return self.socket.setReadBufferSize(size); - } - - /// Set a timeout on the socket that is to occur if no messages are successfully written - /// to its bound destination after a specified number of milliseconds. A subsequent write - /// to the socket will thereafter return `error.WouldBlock` should the timeout be exceeded. - pub fn setWriteTimeout(self: Client, milliseconds: u32) !void { - return self.socket.setWriteTimeout(milliseconds); - } - - /// Set a timeout on the socket that is to occur if no messages are successfully read - /// from its bound destination after a specified number of milliseconds. A subsequent - /// read from the socket will thereafter return `error.WouldBlock` should the timeout be - /// exceeded. - pub fn setReadTimeout(self: Client, milliseconds: u32) !void { - return self.socket.setReadTimeout(milliseconds); - } -}; - -/// A TCP listener. -pub const Listener = struct { - socket: Socket, - - /// Opens a new listener. - pub fn init(domain: tcp.Domain, flags: std.enums.EnumFieldStruct(Socket.InitFlags, bool, false)) !Listener { - return Listener{ - .socket = try Socket.init( - @enumToInt(domain), - os.SOCK.STREAM, - os.IPPROTO.TCP, - flags, - ), - }; - } - - /// Closes the listener. - pub fn deinit(self: Listener) void { - self.socket.deinit(); - } - - /// Shuts down the underlying listener's socket. The next subsequent call, or - /// a current pending call to accept() after shutdown is called will return - /// an error. - pub fn shutdown(self: Listener) !void { - return self.socket.shutdown(.recv); - } - - /// Binds the listener's socket to an address. - pub fn bind(self: Listener, address: ip.Address) !void { - return self.socket.bind(address.into()); - } - - /// Start listening for incoming connections. - pub fn listen(self: Listener, max_backlog_size: u31) !void { - return self.socket.listen(max_backlog_size); - } - - /// Accept a pending incoming connection queued to the kernel backlog - /// of the listener's socket. - pub fn accept(self: Listener, flags: std.enums.EnumFieldStruct(Socket.InitFlags, bool, false)) !tcp.Connection { - return tcp.Connection.from(try self.socket.accept(flags)); - } - - /// Query and return the latest cached error on the listener's underlying socket. - pub fn getError(self: Client) !void { - return self.socket.getError(); - } - - /// Query the address that the listener's socket is locally bounded to. - pub fn getLocalAddress(self: Listener) !ip.Address { - return ip.Address.from(try self.socket.getLocalAddress()); - } - - /// Allow multiple sockets on the same host to listen on the same address. It returns `error.UnsupportedSocketOption` if - /// the host does not support sockets listening the same address. - pub fn setReuseAddress(self: Listener, enabled: bool) !void { - return self.socket.setReuseAddress(enabled); - } - - /// Allow multiple sockets on the same host to listen on the same port. It returns `error.UnsupportedSocketOption` if - /// the host does not supports sockets listening on the same port. - pub fn setReusePort(self: Listener, enabled: bool) !void { - return self.socket.setReusePort(enabled); - } - - /// Enables TCP Fast Open (RFC 7413) on a TCP socket. It returns `error.UnsupportedSocketOption` if the host does not - /// support TCP Fast Open. - pub fn setFastOpen(self: Listener, enabled: bool) !void { - if (@hasDecl(os.TCP, "FASTOPEN")) { - return self.socket.setOption(os.IPPROTO.TCP, os.TCP.FASTOPEN, mem.asBytes(&@as(u32, @boolToInt(enabled)))); - } - return error.UnsupportedSocketOption; - } - - /// Set a timeout on the listener that is to occur if no new incoming connections come in - /// after a specified number of milliseconds. A subsequent accept call to the listener - /// will thereafter return `error.WouldBlock` should the timeout be exceeded. - pub fn setAcceptTimeout(self: Listener, milliseconds: usize) !void { - return self.socket.setReadTimeout(milliseconds); - } -}; - -test "tcp: create client/listener pair" { - if (native_os.tag == .wasi) return error.SkipZigTest; - - const listener = try tcp.Listener.init(.ip, .{ .close_on_exec = true }); - defer listener.deinit(); - - try listener.bind(ip.Address.initIPv4(IPv4.unspecified, 0)); - try listener.listen(128); - - var binded_address = try listener.getLocalAddress(); - switch (binded_address) { - .ipv4 => |*ipv4| ipv4.host = IPv4.localhost, - .ipv6 => |*ipv6| ipv6.host = IPv6.localhost, - } - - const client = try tcp.Client.init(.ip, .{ .close_on_exec = true }); - defer client.deinit(); - - try client.connect(binded_address); - - const conn = try listener.accept(.{ .close_on_exec = true }); - defer conn.deinit(); -} - -test "tcp/client: 1ms read timeout" { - if (native_os.tag == .wasi) return error.SkipZigTest; - - const listener = try tcp.Listener.init(.ip, .{ .close_on_exec = true }); - defer listener.deinit(); - - try listener.bind(ip.Address.initIPv4(IPv4.unspecified, 0)); - try listener.listen(128); - - var binded_address = try listener.getLocalAddress(); - switch (binded_address) { - .ipv4 => |*ipv4| ipv4.host = IPv4.localhost, - .ipv6 => |*ipv6| ipv6.host = IPv6.localhost, - } - - const client = try tcp.Client.init(.ip, .{ .close_on_exec = true }); - defer client.deinit(); - - try client.connect(binded_address); - try client.setReadTimeout(1); - - const conn = try listener.accept(.{ .close_on_exec = true }); - defer conn.deinit(); - - var buf: [1]u8 = undefined; - try testing.expectError(error.WouldBlock, client.reader(0).read(&buf)); -} - -test "tcp/client: read and write multiple vectors" { - if (native_os.tag == .wasi) return error.SkipZigTest; - - if (builtin.os.tag == .windows) { - // https://github.com/ziglang/zig/issues/13893 - return error.SkipZigTest; - } - - const listener = try tcp.Listener.init(.ip, .{ .close_on_exec = true }); - defer listener.deinit(); - - try listener.bind(ip.Address.initIPv4(IPv4.unspecified, 0)); - try listener.listen(128); - - var binded_address = try listener.getLocalAddress(); - switch (binded_address) { - .ipv4 => |*ipv4| ipv4.host = IPv4.localhost, - .ipv6 => |*ipv6| ipv6.host = IPv6.localhost, - } - - const client = try tcp.Client.init(.ip, .{ .close_on_exec = true }); - defer client.deinit(); - - try client.connect(binded_address); - - const conn = try listener.accept(.{ .close_on_exec = true }); - defer conn.deinit(); - - const message = "hello world"; - _ = try conn.client.writeMessage(Socket.Message.fromBuffers(&[_]Buffer{ - Buffer.from(message[0 .. message.len / 2]), - Buffer.from(message[message.len / 2 ..]), - }), 0); - - var buf: [message.len + 1]u8 = undefined; - var msg = Socket.Message.fromBuffers(&[_]Buffer{ - Buffer.from(buf[0 .. message.len / 2]), - Buffer.from(buf[message.len / 2 ..]), - }); - _ = try client.readMessage(&msg, 0); - - try testing.expectEqualStrings(message, buf[0..message.len]); -} - -test "tcp/listener: bind to unspecified ipv4 address" { - if (native_os.tag == .wasi) return error.SkipZigTest; - - const listener = try tcp.Listener.init(.ip, .{ .close_on_exec = true }); - defer listener.deinit(); - - try listener.bind(ip.Address.initIPv4(IPv4.unspecified, 0)); - try listener.listen(128); - - const address = try listener.getLocalAddress(); - try testing.expect(address == .ipv4); -} - -test "tcp/listener: bind to unspecified ipv6 address" { - if (native_os.tag == .wasi) return error.SkipZigTest; - - if (builtin.os.tag == .windows) { - // https://github.com/ziglang/zig/issues/13893 - return error.SkipZigTest; - } - - const listener = try tcp.Listener.init(.ipv6, .{ .close_on_exec = true }); - defer listener.deinit(); - - try listener.bind(ip.Address.initIPv6(IPv6.unspecified, 0)); - try listener.listen(128); - - const address = try listener.getLocalAddress(); - try testing.expect(address == .ipv6); -} diff --git a/lib/std/x/os/io.zig b/lib/std/x/os/io.zig deleted file mode 100644 index 6c4763df659f..000000000000 --- a/lib/std/x/os/io.zig +++ /dev/null @@ -1,224 +0,0 @@ -const std = @import("../../std.zig"); -const builtin = @import("builtin"); - -const os = std.os; -const mem = std.mem; -const testing = std.testing; -const native_os = builtin.os; -const linux = std.os.linux; - -/// POSIX `iovec`, or Windows `WSABUF`. The difference between the two are the ordering -/// of fields, alongside the length being represented as either a ULONG or a size_t. -pub const Buffer = if (native_os.tag == .windows) - extern struct { - len: c_ulong, - ptr: usize, - - pub fn from(slice: []const u8) Buffer { - return .{ .len = @intCast(c_ulong, slice.len), .ptr = @ptrToInt(slice.ptr) }; - } - - pub fn into(self: Buffer) []const u8 { - return @intToPtr([*]const u8, self.ptr)[0..self.len]; - } - - pub fn intoMutable(self: Buffer) []u8 { - return @intToPtr([*]u8, self.ptr)[0..self.len]; - } - } -else - extern struct { - ptr: usize, - len: usize, - - pub fn from(slice: []const u8) Buffer { - return .{ .ptr = @ptrToInt(slice.ptr), .len = slice.len }; - } - - pub fn into(self: Buffer) []const u8 { - return @intToPtr([*]const u8, self.ptr)[0..self.len]; - } - - pub fn intoMutable(self: Buffer) []u8 { - return @intToPtr([*]u8, self.ptr)[0..self.len]; - } - }; - -pub const Reactor = struct { - pub const InitFlags = enum { - close_on_exec, - }; - - pub const Event = struct { - data: usize, - is_error: bool, - is_hup: bool, - is_readable: bool, - is_writable: bool, - }; - - pub const Interest = struct { - hup: bool = false, - oneshot: bool = false, - readable: bool = false, - writable: bool = false, - }; - - fd: os.fd_t, - - pub fn init(flags: std.enums.EnumFieldStruct(Reactor.InitFlags, bool, false)) !Reactor { - var raw_flags: u32 = 0; - const set = std.EnumSet(Reactor.InitFlags).init(flags); - if (set.contains(.close_on_exec)) raw_flags |= linux.EPOLL.CLOEXEC; - return Reactor{ .fd = try os.epoll_create1(raw_flags) }; - } - - pub fn deinit(self: Reactor) void { - os.close(self.fd); - } - - pub fn update(self: Reactor, fd: os.fd_t, identifier: usize, interest: Reactor.Interest) !void { - var flags: u32 = 0; - flags |= if (interest.oneshot) linux.EPOLL.ONESHOT else linux.EPOLL.ET; - if (interest.hup) flags |= linux.EPOLL.RDHUP; - if (interest.readable) flags |= linux.EPOLL.IN; - if (interest.writable) flags |= linux.EPOLL.OUT; - - const event = &linux.epoll_event{ - .events = flags, - .data = .{ .ptr = identifier }, - }; - - os.epoll_ctl(self.fd, linux.EPOLL.CTL_MOD, fd, event) catch |err| switch (err) { - error.FileDescriptorNotRegistered => try os.epoll_ctl(self.fd, linux.EPOLL.CTL_ADD, fd, event), - else => return err, - }; - } - - pub fn remove(self: Reactor, fd: os.fd_t) !void { - // directly from man epoll_ctl BUGS section - // In kernel versions before 2.6.9, the EPOLL_CTL_DEL operation re‐ - // quired a non-null pointer in event, even though this argument is - // ignored. Since Linux 2.6.9, event can be specified as NULL when - // using EPOLL_CTL_DEL. Applications that need to be portable to - // kernels before 2.6.9 should specify a non-null pointer in event. - var event = linux.epoll_event{ - .events = 0, - .data = .{ .ptr = 0 }, - }; - - return os.epoll_ctl(self.fd, linux.EPOLL.CTL_DEL, fd, &event); - } - - pub fn poll(self: Reactor, comptime max_num_events: comptime_int, closure: anytype, timeout_milliseconds: ?u64) !void { - var events: [max_num_events]linux.epoll_event = undefined; - - const num_events = os.epoll_wait(self.fd, &events, if (timeout_milliseconds) |ms| @intCast(i32, ms) else -1); - for (events[0..num_events]) |ev| { - const is_error = ev.events & linux.EPOLL.ERR != 0; - const is_hup = ev.events & (linux.EPOLL.HUP | linux.EPOLL.RDHUP) != 0; - const is_readable = ev.events & linux.EPOLL.IN != 0; - const is_writable = ev.events & linux.EPOLL.OUT != 0; - - try closure.call(Reactor.Event{ - .data = ev.data.ptr, - .is_error = is_error, - .is_hup = is_hup, - .is_readable = is_readable, - .is_writable = is_writable, - }); - } - } -}; - -test "reactor/linux: drive async tcp client/listener pair" { - if (native_os.tag != .linux) return error.SkipZigTest; - - const ip = std.x.net.ip; - const tcp = std.x.net.tcp; - - const IPv4 = std.x.os.IPv4; - const IPv6 = std.x.os.IPv6; - - const reactor = try Reactor.init(.{ .close_on_exec = true }); - defer reactor.deinit(); - - const listener = try tcp.Listener.init(.ip, .{ - .close_on_exec = true, - .nonblocking = true, - }); - defer listener.deinit(); - - try reactor.update(listener.socket.fd, 0, .{ .readable = true }); - try reactor.poll(1, struct { - fn call(event: Reactor.Event) !void { - try testing.expectEqual(Reactor.Event{ - .data = 0, - .is_error = false, - .is_hup = true, - .is_readable = false, - .is_writable = false, - }, event); - } - }, null); - - try listener.bind(ip.Address.initIPv4(IPv4.unspecified, 0)); - try listener.listen(128); - - var binded_address = try listener.getLocalAddress(); - switch (binded_address) { - .ipv4 => |*ipv4| ipv4.host = IPv4.localhost, - .ipv6 => |*ipv6| ipv6.host = IPv6.localhost, - } - - const client = try tcp.Client.init(.ip, .{ - .close_on_exec = true, - .nonblocking = true, - }); - defer client.deinit(); - - try reactor.update(client.socket.fd, 1, .{ .readable = true, .writable = true }); - try reactor.poll(1, struct { - fn call(event: Reactor.Event) !void { - try testing.expectEqual(Reactor.Event{ - .data = 1, - .is_error = false, - .is_hup = true, - .is_readable = false, - .is_writable = true, - }, event); - } - }, null); - - client.connect(binded_address) catch |err| switch (err) { - error.WouldBlock => {}, - else => return err, - }; - - try reactor.poll(1, struct { - fn call(event: Reactor.Event) !void { - try testing.expectEqual(Reactor.Event{ - .data = 1, - .is_error = false, - .is_hup = false, - .is_readable = false, - .is_writable = true, - }, event); - } - }, null); - - try reactor.poll(1, struct { - fn call(event: Reactor.Event) !void { - try testing.expectEqual(Reactor.Event{ - .data = 0, - .is_error = false, - .is_hup = false, - .is_readable = true, - .is_writable = false, - }, event); - } - }, null); - - try reactor.remove(client.socket.fd); - try reactor.remove(listener.socket.fd); -} diff --git a/lib/std/x/os/net.zig b/lib/std/x/os/net.zig deleted file mode 100644 index e00299e24301..000000000000 --- a/lib/std/x/os/net.zig +++ /dev/null @@ -1,605 +0,0 @@ -const std = @import("../../std.zig"); -const builtin = @import("builtin"); - -const os = std.os; -const fmt = std.fmt; -const mem = std.mem; -const math = std.math; -const testing = std.testing; -const native_os = builtin.os; -const have_ifnamesize = @hasDecl(os.system, "IFNAMESIZE"); - -pub const ResolveScopeIdError = error{ - NameTooLong, - PermissionDenied, - AddressFamilyNotSupported, - ProtocolFamilyNotAvailable, - ProcessFdQuotaExceeded, - SystemFdQuotaExceeded, - SystemResources, - ProtocolNotSupported, - SocketTypeNotSupported, - InterfaceNotFound, - FileSystem, - Unexpected, -}; - -/// Resolves a network interface name into a scope/zone ID. It returns -/// an error if either resolution fails, or if the interface name is -/// too long. -pub fn resolveScopeId(name: []const u8) ResolveScopeIdError!u32 { - if (have_ifnamesize) { - if (name.len >= os.IFNAMESIZE) return error.NameTooLong; - - if (native_os.tag == .windows or comptime native_os.tag.isDarwin()) { - var interface_name: [os.IFNAMESIZE:0]u8 = undefined; - mem.copy(u8, &interface_name, name); - interface_name[name.len] = 0; - - const rc = blk: { - if (native_os.tag == .windows) { - break :blk os.windows.ws2_32.if_nametoindex(@ptrCast([*:0]const u8, &interface_name)); - } else { - const index = os.system.if_nametoindex(@ptrCast([*:0]const u8, &interface_name)); - break :blk @bitCast(u32, index); - } - }; - if (rc == 0) { - return error.InterfaceNotFound; - } - return rc; - } - - if (native_os.tag == .linux) { - const fd = try os.socket(os.AF.INET, os.SOCK.DGRAM, 0); - defer os.closeSocket(fd); - - var f: os.ifreq = undefined; - mem.copy(u8, &f.ifrn.name, name); - f.ifrn.name[name.len] = 0; - - try os.ioctl_SIOCGIFINDEX(fd, &f); - - return @bitCast(u32, f.ifru.ivalue); - } - } - - return error.InterfaceNotFound; -} - -/// An IPv4 address comprised of 4 bytes. -pub const IPv4 = extern struct { - /// A IPv4 host-port pair. - pub const Address = extern struct { - host: IPv4, - port: u16, - }; - - /// Octets of a IPv4 address designating the local host. - pub const localhost_octets = [_]u8{ 127, 0, 0, 1 }; - - /// The IPv4 address of the local host. - pub const localhost: IPv4 = .{ .octets = localhost_octets }; - - /// Octets of an unspecified IPv4 address. - pub const unspecified_octets = [_]u8{0} ** 4; - - /// An unspecified IPv4 address. - pub const unspecified: IPv4 = .{ .octets = unspecified_octets }; - - /// Octets of a broadcast IPv4 address. - pub const broadcast_octets = [_]u8{255} ** 4; - - /// An IPv4 broadcast address. - pub const broadcast: IPv4 = .{ .octets = broadcast_octets }; - - /// The prefix octet pattern of a link-local IPv4 address. - pub const link_local_prefix = [_]u8{ 169, 254 }; - - /// The prefix octet patterns of IPv4 addresses intended for - /// documentation. - pub const documentation_prefixes = [_][]const u8{ - &[_]u8{ 192, 0, 2 }, - &[_]u8{ 198, 51, 100 }, - &[_]u8{ 203, 0, 113 }, - }; - - octets: [4]u8, - - /// Returns whether or not the two addresses are equal to, less than, or - /// greater than each other. - pub fn cmp(self: IPv4, other: IPv4) math.Order { - return mem.order(u8, &self.octets, &other.octets); - } - - /// Returns true if both addresses are semantically equivalent. - pub fn eql(self: IPv4, other: IPv4) bool { - return mem.eql(u8, &self.octets, &other.octets); - } - - /// Returns true if the address is a loopback address. - pub fn isLoopback(self: IPv4) bool { - return self.octets[0] == 127; - } - - /// Returns true if the address is an unspecified IPv4 address. - pub fn isUnspecified(self: IPv4) bool { - return mem.eql(u8, &self.octets, &unspecified_octets); - } - - /// Returns true if the address is a private IPv4 address. - pub fn isPrivate(self: IPv4) bool { - return self.octets[0] == 10 or - (self.octets[0] == 172 and self.octets[1] >= 16 and self.octets[1] <= 31) or - (self.octets[0] == 192 and self.octets[1] == 168); - } - - /// Returns true if the address is a link-local IPv4 address. - pub fn isLinkLocal(self: IPv4) bool { - return mem.startsWith(u8, &self.octets, &link_local_prefix); - } - - /// Returns true if the address is a multicast IPv4 address. - pub fn isMulticast(self: IPv4) bool { - return self.octets[0] >= 224 and self.octets[0] <= 239; - } - - /// Returns true if the address is a IPv4 broadcast address. - pub fn isBroadcast(self: IPv4) bool { - return mem.eql(u8, &self.octets, &broadcast_octets); - } - - /// Returns true if the address is in a range designated for documentation. Refer - /// to IETF RFC 5737 for more details. - pub fn isDocumentation(self: IPv4) bool { - inline for (documentation_prefixes) |prefix| { - if (mem.startsWith(u8, &self.octets, prefix)) { - return true; - } - } - return false; - } - - /// Implements the `std.fmt.format` API. - pub fn format( - self: IPv4, - comptime layout: []const u8, - opts: fmt.FormatOptions, - writer: anytype, - ) !void { - _ = opts; - if (layout.len != 0) std.fmt.invalidFmtError(layout, self); - - try fmt.format(writer, "{}.{}.{}.{}", .{ - self.octets[0], - self.octets[1], - self.octets[2], - self.octets[3], - }); - } - - /// Set of possible errors that may encountered when parsing an IPv4 - /// address. - pub const ParseError = error{ - UnexpectedEndOfOctet, - TooManyOctets, - OctetOverflow, - UnexpectedToken, - IncompleteAddress, - }; - - /// Parses an arbitrary IPv4 address. - pub fn parse(buf: []const u8) ParseError!IPv4 { - var octets: [4]u8 = undefined; - var octet: u8 = 0; - - var index: u8 = 0; - var saw_any_digits: bool = false; - - for (buf) |c| { - switch (c) { - '.' => { - if (!saw_any_digits) return error.UnexpectedEndOfOctet; - if (index == 3) return error.TooManyOctets; - octets[index] = octet; - index += 1; - octet = 0; - saw_any_digits = false; - }, - '0'...'9' => { - saw_any_digits = true; - octet = math.mul(u8, octet, 10) catch return error.OctetOverflow; - octet = math.add(u8, octet, c - '0') catch return error.OctetOverflow; - }, - else => return error.UnexpectedToken, - } - } - - if (index == 3 and saw_any_digits) { - octets[index] = octet; - return IPv4{ .octets = octets }; - } - - return error.IncompleteAddress; - } - - /// Maps the address to its IPv6 equivalent. In most cases, you would - /// want to map the address to its IPv6 equivalent rather than directly - /// re-interpreting the address. - pub fn mapToIPv6(self: IPv4) IPv6 { - var octets: [16]u8 = undefined; - mem.copy(u8, octets[0..12], &IPv6.v4_mapped_prefix); - mem.copy(u8, octets[12..], &self.octets); - return IPv6{ .octets = octets, .scope_id = IPv6.no_scope_id }; - } - - /// Directly re-interprets the address to its IPv6 equivalent. In most - /// cases, you would want to map the address to its IPv6 equivalent rather - /// than directly re-interpreting the address. - pub fn toIPv6(self: IPv4) IPv6 { - var octets: [16]u8 = undefined; - mem.set(u8, octets[0..12], 0); - mem.copy(u8, octets[12..], &self.octets); - return IPv6{ .octets = octets, .scope_id = IPv6.no_scope_id }; - } -}; - -/// An IPv6 address comprised of 16 bytes for an address, and 4 bytes -/// for a scope ID; cumulatively summing to 20 bytes in total. -pub const IPv6 = extern struct { - /// A IPv6 host-port pair. - pub const Address = extern struct { - host: IPv6, - port: u16, - }; - - /// Octets of a IPv6 address designating the local host. - pub const localhost_octets = [_]u8{0} ** 15 ++ [_]u8{0x01}; - - /// The IPv6 address of the local host. - pub const localhost: IPv6 = .{ - .octets = localhost_octets, - .scope_id = no_scope_id, - }; - - /// Octets of an unspecified IPv6 address. - pub const unspecified_octets = [_]u8{0} ** 16; - - /// An unspecified IPv6 address. - pub const unspecified: IPv6 = .{ - .octets = unspecified_octets, - .scope_id = no_scope_id, - }; - - /// The prefix of a IPv6 address that is mapped to a IPv4 address. - pub const v4_mapped_prefix = [_]u8{0} ** 10 ++ [_]u8{0xFF} ** 2; - - /// A marker value used to designate an IPv6 address with no - /// associated scope ID. - pub const no_scope_id = math.maxInt(u32); - - octets: [16]u8, - scope_id: u32, - - /// Returns whether or not the two addresses are equal to, less than, or - /// greater than each other. - pub fn cmp(self: IPv6, other: IPv6) math.Order { - return switch (mem.order(u8, self.octets, other.octets)) { - .eq => math.order(self.scope_id, other.scope_id), - else => |order| order, - }; - } - - /// Returns true if both addresses are semantically equivalent. - pub fn eql(self: IPv6, other: IPv6) bool { - return self.scope_id == other.scope_id and mem.eql(u8, &self.octets, &other.octets); - } - - /// Returns true if the address is an unspecified IPv6 address. - pub fn isUnspecified(self: IPv6) bool { - return mem.eql(u8, &self.octets, &unspecified_octets); - } - - /// Returns true if the address is a loopback address. - pub fn isLoopback(self: IPv6) bool { - return mem.eql(u8, self.octets[0..3], &[_]u8{ 0, 0, 0 }) and - mem.eql(u8, self.octets[12..], &[_]u8{ 0, 0, 0, 1 }); - } - - /// Returns true if the address maps to an IPv4 address. - pub fn mapsToIPv4(self: IPv6) bool { - return mem.startsWith(u8, &self.octets, &v4_mapped_prefix); - } - - /// Returns an IPv4 address representative of the address should - /// it the address be mapped to an IPv4 address. It returns null - /// otherwise. - pub fn toIPv4(self: IPv6) ?IPv4 { - if (!self.mapsToIPv4()) return null; - return IPv4{ .octets = self.octets[12..][0..4].* }; - } - - /// Returns true if the address is a multicast IPv6 address. - pub fn isMulticast(self: IPv6) bool { - return self.octets[0] == 0xFF; - } - - /// Returns true if the address is a unicast link local IPv6 address. - pub fn isLinkLocal(self: IPv6) bool { - return self.octets[0] == 0xFE and self.octets[1] & 0xC0 == 0x80; - } - - /// Returns true if the address is a deprecated unicast site local - /// IPv6 address. Refer to IETF RFC 3879 for more details as to - /// why they are deprecated. - pub fn isSiteLocal(self: IPv6) bool { - return self.octets[0] == 0xFE and self.octets[1] & 0xC0 == 0xC0; - } - - /// IPv6 multicast address scopes. - pub const Scope = enum(u8) { - interface = 1, - link = 2, - realm = 3, - admin = 4, - site = 5, - organization = 8, - global = 14, - unknown = 0xFF, - }; - - /// Returns the multicast scope of the address. - pub fn scope(self: IPv6) Scope { - if (!self.isMulticast()) return .unknown; - - return switch (self.octets[0] & 0x0F) { - 1 => .interface, - 2 => .link, - 3 => .realm, - 4 => .admin, - 5 => .site, - 8 => .organization, - 14 => .global, - else => .unknown, - }; - } - - /// Implements the `std.fmt.format` API. Specifying 'x' or 's' formats the - /// address lower-cased octets, while specifying 'X' or 'S' formats the - /// address using upper-cased ASCII octets. - /// - /// The default specifier is 'x'. - pub fn format( - self: IPv6, - comptime layout: []const u8, - opts: fmt.FormatOptions, - writer: anytype, - ) !void { - _ = opts; - const specifier = comptime &[_]u8{if (layout.len == 0) 'x' else switch (layout[0]) { - 'x', 'X' => |specifier| specifier, - 's' => 'x', - 'S' => 'X', - else => std.fmt.invalidFmtError(layout, self), - }}; - - if (mem.startsWith(u8, &self.octets, &v4_mapped_prefix)) { - return fmt.format(writer, "::{" ++ specifier ++ "}{" ++ specifier ++ "}:{}.{}.{}.{}", .{ - 0xFF, - 0xFF, - self.octets[12], - self.octets[13], - self.octets[14], - self.octets[15], - }); - } - - const zero_span: struct { from: usize, to: usize } = span: { - var i: usize = 0; - while (i < self.octets.len) : (i += 2) { - if (self.octets[i] == 0 and self.octets[i + 1] == 0) break; - } else break :span .{ .from = 0, .to = 0 }; - - const from = i; - - while (i < self.octets.len) : (i += 2) { - if (self.octets[i] != 0 or self.octets[i + 1] != 0) break; - } - - break :span .{ .from = from, .to = i }; - }; - - var i: usize = 0; - while (i != 16) : (i += 2) { - if (zero_span.from != zero_span.to and i == zero_span.from) { - try writer.writeAll("::"); - } else if (i >= zero_span.from and i < zero_span.to) {} else { - if (i != 0 and i != zero_span.to) try writer.writeAll(":"); - - const val = @as(u16, self.octets[i]) << 8 | self.octets[i + 1]; - try fmt.formatIntValue(val, specifier, .{}, writer); - } - } - - if (self.scope_id != no_scope_id and self.scope_id != 0) { - try fmt.format(writer, "%{d}", .{self.scope_id}); - } - } - - /// Set of possible errors that may encountered when parsing an IPv6 - /// address. - pub const ParseError = error{ - MalformedV4Mapping, - InterfaceNotFound, - UnknownScopeId, - } || IPv4.ParseError; - - /// Parses an arbitrary IPv6 address, including link-local addresses. - pub fn parse(buf: []const u8) ParseError!IPv6 { - if (mem.lastIndexOfScalar(u8, buf, '%')) |index| { - const ip_slice = buf[0..index]; - const scope_id_slice = buf[index + 1 ..]; - - if (scope_id_slice.len == 0) return error.UnknownScopeId; - - const scope_id: u32 = switch (scope_id_slice[0]) { - '0'...'9' => fmt.parseInt(u32, scope_id_slice, 10), - else => resolveScopeId(scope_id_slice) catch |err| switch (err) { - error.InterfaceNotFound => return error.InterfaceNotFound, - else => err, - }, - } catch return error.UnknownScopeId; - - return parseWithScopeID(ip_slice, scope_id); - } - - return parseWithScopeID(buf, no_scope_id); - } - - /// Parses an IPv6 address with a pre-specified scope ID. Presumes - /// that the address is not a link-local address. - pub fn parseWithScopeID(buf: []const u8, scope_id: u32) ParseError!IPv6 { - var octets: [16]u8 = undefined; - var octet: u16 = 0; - var tail: [16]u8 = undefined; - - var out: []u8 = &octets; - var index: u8 = 0; - - var saw_any_digits: bool = false; - var abbrv: bool = false; - - for (buf) |c, i| { - switch (c) { - ':' => { - if (!saw_any_digits) { - if (abbrv) return error.UnexpectedToken; - if (i != 0) abbrv = true; - mem.set(u8, out[index..], 0); - out = &tail; - index = 0; - continue; - } - if (index == 14) return error.TooManyOctets; - - out[index] = @truncate(u8, octet >> 8); - index += 1; - out[index] = @truncate(u8, octet); - index += 1; - - octet = 0; - saw_any_digits = false; - }, - '.' => { - if (!abbrv or out[0] != 0xFF and out[1] != 0xFF) { - return error.MalformedV4Mapping; - } - const start_index = mem.lastIndexOfScalar(u8, buf[0..i], ':').? + 1; - const v4 = try IPv4.parse(buf[start_index..]); - octets[10] = 0xFF; - octets[11] = 0xFF; - mem.copy(u8, octets[12..], &v4.octets); - - return IPv6{ .octets = octets, .scope_id = scope_id }; - }, - else => { - saw_any_digits = true; - const digit = fmt.charToDigit(c, 16) catch return error.UnexpectedToken; - octet = math.mul(u16, octet, 16) catch return error.OctetOverflow; - octet = math.add(u16, octet, digit) catch return error.OctetOverflow; - }, - } - } - - if (!saw_any_digits and !abbrv) { - return error.IncompleteAddress; - } - - if (index == 14) { - out[14] = @truncate(u8, octet >> 8); - out[15] = @truncate(u8, octet); - } else { - out[index] = @truncate(u8, octet >> 8); - index += 1; - out[index] = @truncate(u8, octet); - index += 1; - mem.copy(u8, octets[16 - index ..], out[0..index]); - } - - return IPv6{ .octets = octets, .scope_id = scope_id }; - } -}; - -test { - testing.refAllDecls(@This()); -} - -test "ip: convert to and from ipv6" { - try testing.expectFmt("::7f00:1", "{}", .{IPv4.localhost.toIPv6()}); - try testing.expect(!IPv4.localhost.toIPv6().mapsToIPv4()); - - try testing.expectFmt("::ffff:127.0.0.1", "{}", .{IPv4.localhost.mapToIPv6()}); - try testing.expect(IPv4.localhost.mapToIPv6().mapsToIPv4()); - - try testing.expect(IPv4.localhost.toIPv6().toIPv4() == null); - try testing.expectFmt("127.0.0.1", "{?}", .{IPv4.localhost.mapToIPv6().toIPv4()}); -} - -test "ipv4: parse & format" { - const cases = [_][]const u8{ - "0.0.0.0", - "255.255.255.255", - "1.2.3.4", - "123.255.0.91", - "127.0.0.1", - }; - - for (cases) |case| { - try testing.expectFmt(case, "{}", .{try IPv4.parse(case)}); - } -} - -test "ipv6: parse & format" { - const inputs = [_][]const u8{ - "FF01:0:0:0:0:0:0:FB", - "FF01::Fb", - "::1", - "::", - "2001:db8::", - "::1234:5678", - "2001:db8::1234:5678", - "::ffff:123.5.123.5", - }; - - const outputs = [_][]const u8{ - "ff01::fb", - "ff01::fb", - "::1", - "::", - "2001:db8::", - "::1234:5678", - "2001:db8::1234:5678", - "::ffff:123.5.123.5", - }; - - for (inputs) |input, i| { - try testing.expectFmt(outputs[i], "{}", .{try IPv6.parse(input)}); - } -} - -test "ipv6: parse & format addresses with scope ids" { - if (!have_ifnamesize) return error.SkipZigTest; - const iface = if (native_os.tag == .linux) - "lo" - else - "lo0"; - const input = "FF01::FB%" ++ iface; - const output = "ff01::fb%1"; - - const parsed = IPv6.parse(input) catch |err| switch (err) { - error.InterfaceNotFound => return, - else => return err, - }; - - try testing.expectFmt(output, "{}", .{parsed}); -} diff --git a/lib/std/x/os/socket.zig b/lib/std/x/os/socket.zig deleted file mode 100644 index 99782710cbb3..000000000000 --- a/lib/std/x/os/socket.zig +++ /dev/null @@ -1,320 +0,0 @@ -const std = @import("../../std.zig"); -const builtin = @import("builtin"); -const net = @import("net.zig"); - -const os = std.os; -const fmt = std.fmt; -const mem = std.mem; -const time = std.time; -const meta = std.meta; -const native_os = builtin.os; -const native_endian = builtin.cpu.arch.endian(); - -const Buffer = std.x.os.Buffer; - -const assert = std.debug.assert; - -/// A generic, cross-platform socket abstraction. -pub const Socket = struct { - /// A socket-address pair. - pub const Connection = struct { - socket: Socket, - address: Socket.Address, - - /// Enclose a socket and address into a socket-address pair. - pub fn from(socket: Socket, address: Socket.Address) Socket.Connection { - return .{ .socket = socket, .address = address }; - } - }; - - /// A generic socket address abstraction. It is safe to directly access and modify - /// the fields of a `Socket.Address`. - pub const Address = union(enum) { - pub const Native = struct { - pub const requires_prepended_length = native_os.getVersionRange() == .semver; - pub const Length = if (requires_prepended_length) u8 else [0]u8; - - pub const Family = if (requires_prepended_length) u8 else c_ushort; - - /// POSIX `sockaddr.storage`. The expected size and alignment is specified in IETF RFC 2553. - pub const Storage = extern struct { - pub const expected_size = os.sockaddr.SS_MAXSIZE; - pub const expected_alignment = 8; - - pub const padding_size = expected_size - - mem.alignForward(@sizeOf(Address.Native.Length), expected_alignment) - - mem.alignForward(@sizeOf(Address.Native.Family), expected_alignment); - - len: Address.Native.Length align(expected_alignment) = undefined, - family: Address.Native.Family align(expected_alignment) = undefined, - padding: [padding_size]u8 align(expected_alignment) = undefined, - - comptime { - assert(@sizeOf(Storage) == Storage.expected_size); - assert(@alignOf(Storage) == Storage.expected_alignment); - } - }; - }; - - ipv4: net.IPv4.Address, - ipv6: net.IPv6.Address, - - /// Instantiate a new address with a IPv4 host and port. - pub fn initIPv4(host: net.IPv4, port: u16) Socket.Address { - return .{ .ipv4 = .{ .host = host, .port = port } }; - } - - /// Instantiate a new address with a IPv6 host and port. - pub fn initIPv6(host: net.IPv6, port: u16) Socket.Address { - return .{ .ipv6 = .{ .host = host, .port = port } }; - } - - /// Parses a `sockaddr` into a generic socket address. - pub fn fromNative(address: *align(4) const os.sockaddr) Socket.Address { - switch (address.family) { - os.AF.INET => { - const info = @ptrCast(*const os.sockaddr.in, address); - const host = net.IPv4{ .octets = @bitCast([4]u8, info.addr) }; - const port = mem.bigToNative(u16, info.port); - return Socket.Address.initIPv4(host, port); - }, - os.AF.INET6 => { - const info = @ptrCast(*const os.sockaddr.in6, address); - const host = net.IPv6{ .octets = info.addr, .scope_id = info.scope_id }; - const port = mem.bigToNative(u16, info.port); - return Socket.Address.initIPv6(host, port); - }, - else => unreachable, - } - } - - /// Encodes a generic socket address into an extern union that may be reliably - /// casted into a `sockaddr` which may be passed into socket syscalls. - pub fn toNative(self: Socket.Address) extern union { - ipv4: os.sockaddr.in, - ipv6: os.sockaddr.in6, - } { - return switch (self) { - .ipv4 => |address| .{ - .ipv4 = .{ - .addr = @bitCast(u32, address.host.octets), - .port = mem.nativeToBig(u16, address.port), - }, - }, - .ipv6 => |address| .{ - .ipv6 = .{ - .addr = address.host.octets, - .port = mem.nativeToBig(u16, address.port), - .scope_id = address.host.scope_id, - .flowinfo = 0, - }, - }, - }; - } - - /// Returns the number of bytes that make up the `sockaddr` equivalent to the address. - pub fn getNativeSize(self: Socket.Address) u32 { - return switch (self) { - .ipv4 => @sizeOf(os.sockaddr.in), - .ipv6 => @sizeOf(os.sockaddr.in6), - }; - } - - /// Implements the `std.fmt.format` API. - pub fn format( - self: Socket.Address, - comptime layout: []const u8, - opts: fmt.FormatOptions, - writer: anytype, - ) !void { - if (layout.len != 0) std.fmt.invalidFmtError(layout, self); - _ = opts; - switch (self) { - .ipv4 => |address| try fmt.format(writer, "{}:{}", .{ address.host, address.port }), - .ipv6 => |address| try fmt.format(writer, "{}:{}", .{ address.host, address.port }), - } - } - }; - - /// POSIX `msghdr`. Denotes a destination address, set of buffers, control data, and flags. Ported - /// directly from musl. - pub const Message = if (native_os.isAtLeast(.windows, .vista) != null and native_os.isAtLeast(.windows, .vista).?) - extern struct { - name: usize = @ptrToInt(@as(?[*]u8, null)), - name_len: c_int = 0, - - buffers: usize = undefined, - buffers_len: c_ulong = undefined, - - control: Buffer = .{ - .ptr = @ptrToInt(@as(?[*]u8, null)), - .len = 0, - }, - flags: c_ulong = 0, - - pub usingnamespace MessageMixin(Message); - } - else if (native_os.tag == .windows) - extern struct { - name: usize = @ptrToInt(@as(?[*]u8, null)), - name_len: c_int = 0, - - buffers: usize = undefined, - buffers_len: u32 = undefined, - - control: Buffer = .{ - .ptr = @ptrToInt(@as(?[*]u8, null)), - .len = 0, - }, - flags: u32 = 0, - - pub usingnamespace MessageMixin(Message); - } - else if (@sizeOf(usize) > 4 and native_endian == .Big) - extern struct { - name: usize = @ptrToInt(@as(?[*]u8, null)), - name_len: c_uint = 0, - - buffers: usize = undefined, - _pad_1: c_int = 0, - buffers_len: c_int = undefined, - - control: usize = @ptrToInt(@as(?[*]u8, null)), - _pad_2: c_int = 0, - control_len: c_uint = 0, - - flags: c_int = 0, - - pub usingnamespace MessageMixin(Message); - } - else if (@sizeOf(usize) > 4 and native_endian == .Little) - extern struct { - name: usize = @ptrToInt(@as(?[*]u8, null)), - name_len: c_uint = 0, - - buffers: usize = undefined, - buffers_len: c_int = undefined, - _pad_1: c_int = 0, - - control: usize = @ptrToInt(@as(?[*]u8, null)), - control_len: c_uint = 0, - _pad_2: c_int = 0, - - flags: c_int = 0, - - pub usingnamespace MessageMixin(Message); - } - else - extern struct { - name: usize = @ptrToInt(@as(?[*]u8, null)), - name_len: c_uint = 0, - - buffers: usize = undefined, - buffers_len: c_int = undefined, - - control: usize = @ptrToInt(@as(?[*]u8, null)), - control_len: c_uint = 0, - - flags: c_int = 0, - - pub usingnamespace MessageMixin(Message); - }; - - fn MessageMixin(comptime Self: type) type { - return struct { - pub fn fromBuffers(buffers: []const Buffer) Self { - var self: Self = .{}; - self.setBuffers(buffers); - return self; - } - - pub fn setName(self: *Self, name: []const u8) void { - self.name = @ptrToInt(name.ptr); - self.name_len = @intCast(meta.fieldInfo(Self, .name_len).type, name.len); - } - - pub fn setBuffers(self: *Self, buffers: []const Buffer) void { - self.buffers = @ptrToInt(buffers.ptr); - self.buffers_len = @intCast(meta.fieldInfo(Self, .buffers_len).type, buffers.len); - } - - pub fn setControl(self: *Self, control: []const u8) void { - if (native_os.tag == .windows) { - self.control = Buffer.from(control); - } else { - self.control = @ptrToInt(control.ptr); - self.control_len = @intCast(meta.fieldInfo(Self, .control_len).type, control.len); - } - } - - pub fn setFlags(self: *Self, flags: u32) void { - self.flags = @intCast(meta.fieldInfo(Self, .flags).type, flags); - } - - pub fn getName(self: Self) []const u8 { - return @intToPtr([*]const u8, self.name)[0..@intCast(usize, self.name_len)]; - } - - pub fn getBuffers(self: Self) []const Buffer { - return @intToPtr([*]const Buffer, self.buffers)[0..@intCast(usize, self.buffers_len)]; - } - - pub fn getControl(self: Self) []const u8 { - if (native_os.tag == .windows) { - return self.control.into(); - } else { - return @intToPtr([*]const u8, self.control)[0..@intCast(usize, self.control_len)]; - } - } - - pub fn getFlags(self: Self) u32 { - return @intCast(u32, self.flags); - } - }; - } - - /// POSIX `linger`, denoting the linger settings of a socket. - /// - /// Microsoft's documentation and glibc denote the fields to be unsigned - /// short's on Windows, whereas glibc and musl denote the fields to be - /// int's on every other platform. - pub const Linger = extern struct { - pub const Field = switch (native_os.tag) { - .windows => c_ushort, - else => c_int, - }; - - enabled: Field, - timeout_seconds: Field, - - pub fn init(timeout_seconds: ?u16) Socket.Linger { - return .{ - .enabled = @intCast(Socket.Linger.Field, @boolToInt(timeout_seconds != null)), - .timeout_seconds = if (timeout_seconds) |seconds| @intCast(Socket.Linger.Field, seconds) else 0, - }; - } - }; - - /// Possible set of flags to initialize a socket with. - pub const InitFlags = enum { - // Initialize a socket to be non-blocking. - nonblocking, - - // Have a socket close itself on exec syscalls. - close_on_exec, - }; - - /// The underlying handle of a socket. - fd: os.socket_t, - - /// Enclose a socket abstraction over an existing socket file descriptor. - pub fn from(fd: os.socket_t) Socket { - return Socket{ .fd = fd }; - } - - /// Mix in socket syscalls depending on the platform we are compiling against. - pub usingnamespace switch (native_os.tag) { - .windows => @import("socket_windows.zig"), - else => @import("socket_posix.zig"), - }.Mixin(Socket); -}; diff --git a/lib/std/x/os/socket_posix.zig b/lib/std/x/os/socket_posix.zig deleted file mode 100644 index 859075aa20bd..000000000000 --- a/lib/std/x/os/socket_posix.zig +++ /dev/null @@ -1,275 +0,0 @@ -const std = @import("../../std.zig"); - -const os = std.os; -const mem = std.mem; -const time = std.time; - -pub fn Mixin(comptime Socket: type) type { - return struct { - /// Open a new socket. - pub fn init(domain: u32, socket_type: u32, protocol: u32, flags: std.enums.EnumFieldStruct(Socket.InitFlags, bool, false)) !Socket { - var raw_flags: u32 = socket_type; - const set = std.EnumSet(Socket.InitFlags).init(flags); - if (set.contains(.close_on_exec)) raw_flags |= os.SOCK.CLOEXEC; - if (set.contains(.nonblocking)) raw_flags |= os.SOCK.NONBLOCK; - return Socket{ .fd = try os.socket(domain, raw_flags, protocol) }; - } - - /// Closes the socket. - pub fn deinit(self: Socket) void { - os.closeSocket(self.fd); - } - - /// Shutdown either the read side, write side, or all side of the socket. - pub fn shutdown(self: Socket, how: os.ShutdownHow) !void { - return os.shutdown(self.fd, how); - } - - /// Binds the socket to an address. - pub fn bind(self: Socket, address: Socket.Address) !void { - return os.bind(self.fd, @ptrCast(*const os.sockaddr, &address.toNative()), address.getNativeSize()); - } - - /// Start listening for incoming connections on the socket. - pub fn listen(self: Socket, max_backlog_size: u31) !void { - return os.listen(self.fd, max_backlog_size); - } - - /// Have the socket attempt to the connect to an address. - pub fn connect(self: Socket, address: Socket.Address) !void { - return os.connect(self.fd, @ptrCast(*const os.sockaddr, &address.toNative()), address.getNativeSize()); - } - - /// Accept a pending incoming connection queued to the kernel backlog - /// of the socket. - pub fn accept(self: Socket, flags: std.enums.EnumFieldStruct(Socket.InitFlags, bool, false)) !Socket.Connection { - var address: Socket.Address.Native.Storage = undefined; - var address_len: u32 = @sizeOf(Socket.Address.Native.Storage); - - var raw_flags: u32 = 0; - const set = std.EnumSet(Socket.InitFlags).init(flags); - if (set.contains(.close_on_exec)) raw_flags |= os.SOCK.CLOEXEC; - if (set.contains(.nonblocking)) raw_flags |= os.SOCK.NONBLOCK; - - const socket = Socket{ .fd = try os.accept(self.fd, @ptrCast(*os.sockaddr, &address), &address_len, raw_flags) }; - const socket_address = Socket.Address.fromNative(@ptrCast(*os.sockaddr, &address)); - - return Socket.Connection.from(socket, socket_address); - } - - /// Read data from the socket into the buffer provided with a set of flags - /// specified. It returns the number of bytes read into the buffer provided. - pub fn read(self: Socket, buf: []u8, flags: u32) !usize { - return os.recv(self.fd, buf, flags); - } - - /// Write a buffer of data provided to the socket with a set of flags specified. - /// It returns the number of bytes that are written to the socket. - pub fn write(self: Socket, buf: []const u8, flags: u32) !usize { - return os.send(self.fd, buf, flags); - } - - /// Writes multiple I/O vectors with a prepended message header to the socket - /// with a set of flags specified. It returns the number of bytes that are - /// written to the socket. - pub fn writeMessage(self: Socket, msg: Socket.Message, flags: u32) !usize { - while (true) { - const rc = os.system.sendmsg(self.fd, &msg, @intCast(c_int, flags)); - return switch (os.errno(rc)) { - .SUCCESS => return @intCast(usize, rc), - .ACCES => error.AccessDenied, - .AGAIN => error.WouldBlock, - .ALREADY => error.FastOpenAlreadyInProgress, - .BADF => unreachable, // always a race condition - .CONNRESET => error.ConnectionResetByPeer, - .DESTADDRREQ => unreachable, // The socket is not connection-mode, and no peer address is set. - .FAULT => unreachable, // An invalid user space address was specified for an argument. - .INTR => continue, - .INVAL => unreachable, // Invalid argument passed. - .ISCONN => unreachable, // connection-mode socket was connected already but a recipient was specified - .MSGSIZE => error.MessageTooBig, - .NOBUFS => error.SystemResources, - .NOMEM => error.SystemResources, - .NOTSOCK => unreachable, // The file descriptor sockfd does not refer to a socket. - .OPNOTSUPP => unreachable, // Some bit in the flags argument is inappropriate for the socket type. - .PIPE => error.BrokenPipe, - .AFNOSUPPORT => error.AddressFamilyNotSupported, - .LOOP => error.SymLinkLoop, - .NAMETOOLONG => error.NameTooLong, - .NOENT => error.FileNotFound, - .NOTDIR => error.NotDir, - .HOSTUNREACH => error.NetworkUnreachable, - .NETUNREACH => error.NetworkUnreachable, - .NOTCONN => error.SocketNotConnected, - .NETDOWN => error.NetworkSubsystemFailed, - else => |err| os.unexpectedErrno(err), - }; - } - } - - /// Read multiple I/O vectors with a prepended message header from the socket - /// with a set of flags specified. It returns the number of bytes that were - /// read into the buffer provided. - pub fn readMessage(self: Socket, msg: *Socket.Message, flags: u32) !usize { - while (true) { - const rc = os.system.recvmsg(self.fd, msg, @intCast(c_int, flags)); - return switch (os.errno(rc)) { - .SUCCESS => @intCast(usize, rc), - .BADF => unreachable, // always a race condition - .FAULT => unreachable, - .INVAL => unreachable, - .NOTCONN => unreachable, - .NOTSOCK => unreachable, - .INTR => continue, - .AGAIN => error.WouldBlock, - .NOMEM => error.SystemResources, - .CONNREFUSED => error.ConnectionRefused, - .CONNRESET => error.ConnectionResetByPeer, - else => |err| os.unexpectedErrno(err), - }; - } - } - - /// Query the address that the socket is locally bounded to. - pub fn getLocalAddress(self: Socket) !Socket.Address { - var address: Socket.Address.Native.Storage = undefined; - var address_len: u32 = @sizeOf(Socket.Address.Native.Storage); - try os.getsockname(self.fd, @ptrCast(*os.sockaddr, &address), &address_len); - return Socket.Address.fromNative(@ptrCast(*os.sockaddr, &address)); - } - - /// Query the address that the socket is connected to. - pub fn getRemoteAddress(self: Socket) !Socket.Address { - var address: Socket.Address.Native.Storage = undefined; - var address_len: u32 = @sizeOf(Socket.Address.Native.Storage); - try os.getpeername(self.fd, @ptrCast(*os.sockaddr, &address), &address_len); - return Socket.Address.fromNative(@ptrCast(*os.sockaddr, &address)); - } - - /// Query and return the latest cached error on the socket. - pub fn getError(self: Socket) !void { - return os.getsockoptError(self.fd); - } - - /// Query the read buffer size of the socket. - pub fn getReadBufferSize(self: Socket) !u32 { - var value: u32 = undefined; - var value_len: u32 = @sizeOf(u32); - - const rc = os.system.getsockopt(self.fd, os.SOL.SOCKET, os.SO.RCVBUF, mem.asBytes(&value), &value_len); - return switch (os.errno(rc)) { - .SUCCESS => value, - .BADF => error.BadFileDescriptor, - .FAULT => error.InvalidAddressSpace, - .INVAL => error.InvalidSocketOption, - .NOPROTOOPT => error.UnknownSocketOption, - .NOTSOCK => error.NotASocket, - else => |err| os.unexpectedErrno(err), - }; - } - - /// Query the write buffer size of the socket. - pub fn getWriteBufferSize(self: Socket) !u32 { - var value: u32 = undefined; - var value_len: u32 = @sizeOf(u32); - - const rc = os.system.getsockopt(self.fd, os.SOL.SOCKET, os.SO.SNDBUF, mem.asBytes(&value), &value_len); - return switch (os.errno(rc)) { - .SUCCESS => value, - .BADF => error.BadFileDescriptor, - .FAULT => error.InvalidAddressSpace, - .INVAL => error.InvalidSocketOption, - .NOPROTOOPT => error.UnknownSocketOption, - .NOTSOCK => error.NotASocket, - else => |err| os.unexpectedErrno(err), - }; - } - - /// Set a socket option. - pub fn setOption(self: Socket, level: u32, code: u32, value: []const u8) !void { - return os.setsockopt(self.fd, level, code, value); - } - - /// Have close() or shutdown() syscalls block until all queued messages in the socket have been successfully - /// sent, or if the timeout specified in seconds has been reached. It returns `error.UnsupportedSocketOption` - /// if the host does not support the option for a socket to linger around up until a timeout specified in - /// seconds. - pub fn setLinger(self: Socket, timeout_seconds: ?u16) !void { - if (@hasDecl(os.SO, "LINGER")) { - const settings = Socket.Linger.init(timeout_seconds); - return self.setOption(os.SOL.SOCKET, os.SO.LINGER, mem.asBytes(&settings)); - } - - return error.UnsupportedSocketOption; - } - - /// On connection-oriented sockets, have keep-alive messages be sent periodically. The timing in which keep-alive - /// messages are sent are dependant on operating system settings. It returns `error.UnsupportedSocketOption` if - /// the host does not support periodically sending keep-alive messages on connection-oriented sockets. - pub fn setKeepAlive(self: Socket, enabled: bool) !void { - if (@hasDecl(os.SO, "KEEPALIVE")) { - return self.setOption(os.SOL.SOCKET, os.SO.KEEPALIVE, mem.asBytes(&@as(u32, @boolToInt(enabled)))); - } - return error.UnsupportedSocketOption; - } - - /// Allow multiple sockets on the same host to listen on the same address. It returns `error.UnsupportedSocketOption` if - /// the host does not support sockets listening the same address. - pub fn setReuseAddress(self: Socket, enabled: bool) !void { - if (@hasDecl(os.SO, "REUSEADDR")) { - return self.setOption(os.SOL.SOCKET, os.SO.REUSEADDR, mem.asBytes(&@as(u32, @boolToInt(enabled)))); - } - return error.UnsupportedSocketOption; - } - - /// Allow multiple sockets on the same host to listen on the same port. It returns `error.UnsupportedSocketOption` if - /// the host does not supports sockets listening on the same port. - pub fn setReusePort(self: Socket, enabled: bool) !void { - if (@hasDecl(os.SO, "REUSEPORT")) { - return self.setOption(os.SOL.SOCKET, os.SO.REUSEPORT, mem.asBytes(&@as(u32, @boolToInt(enabled)))); - } - return error.UnsupportedSocketOption; - } - - /// Set the write buffer size of the socket. - pub fn setWriteBufferSize(self: Socket, size: u32) !void { - return self.setOption(os.SOL.SOCKET, os.SO.SNDBUF, mem.asBytes(&size)); - } - - /// Set the read buffer size of the socket. - pub fn setReadBufferSize(self: Socket, size: u32) !void { - return self.setOption(os.SOL.SOCKET, os.SO.RCVBUF, mem.asBytes(&size)); - } - - /// WARNING: Timeouts only affect blocking sockets. It is undefined behavior if a timeout is - /// set on a non-blocking socket. - /// - /// Set a timeout on the socket that is to occur if no messages are successfully written - /// to its bound destination after a specified number of milliseconds. A subsequent write - /// to the socket will thereafter return `error.WouldBlock` should the timeout be exceeded. - pub fn setWriteTimeout(self: Socket, milliseconds: usize) !void { - const timeout = os.timeval{ - .tv_sec = @intCast(i32, milliseconds / time.ms_per_s), - .tv_usec = @intCast(i32, (milliseconds % time.ms_per_s) * time.us_per_ms), - }; - - return self.setOption(os.SOL.SOCKET, os.SO.SNDTIMEO, mem.asBytes(&timeout)); - } - - /// WARNING: Timeouts only affect blocking sockets. It is undefined behavior if a timeout is - /// set on a non-blocking socket. - /// - /// Set a timeout on the socket that is to occur if no messages are successfully read - /// from its bound destination after a specified number of milliseconds. A subsequent - /// read from the socket will thereafter return `error.WouldBlock` should the timeout be - /// exceeded. - pub fn setReadTimeout(self: Socket, milliseconds: usize) !void { - const timeout = os.timeval{ - .tv_sec = @intCast(i32, milliseconds / time.ms_per_s), - .tv_usec = @intCast(i32, (milliseconds % time.ms_per_s) * time.us_per_ms), - }; - - return self.setOption(os.SOL.SOCKET, os.SO.RCVTIMEO, mem.asBytes(&timeout)); - } - }; -} diff --git a/lib/std/x/os/socket_windows.zig b/lib/std/x/os/socket_windows.zig deleted file mode 100644 index 43b047dd109d..000000000000 --- a/lib/std/x/os/socket_windows.zig +++ /dev/null @@ -1,458 +0,0 @@ -const std = @import("../../std.zig"); -const net = @import("net.zig"); - -const os = std.os; -const mem = std.mem; - -const windows = std.os.windows; -const ws2_32 = windows.ws2_32; - -pub fn Mixin(comptime Socket: type) type { - return struct { - /// Open a new socket. - pub fn init(domain: u32, socket_type: u32, protocol: u32, flags: std.enums.EnumFieldStruct(Socket.InitFlags, bool, false)) !Socket { - var raw_flags: u32 = ws2_32.WSA_FLAG_OVERLAPPED; - const set = std.EnumSet(Socket.InitFlags).init(flags); - if (set.contains(.close_on_exec)) raw_flags |= ws2_32.WSA_FLAG_NO_HANDLE_INHERIT; - - const fd = ws2_32.WSASocketW( - @intCast(i32, domain), - @intCast(i32, socket_type), - @intCast(i32, protocol), - null, - 0, - raw_flags, - ); - if (fd == ws2_32.INVALID_SOCKET) { - return switch (ws2_32.WSAGetLastError()) { - .WSANOTINITIALISED => { - _ = try windows.WSAStartup(2, 2); - return init(domain, socket_type, protocol, flags); - }, - .WSAEAFNOSUPPORT => error.AddressFamilyNotSupported, - .WSAEMFILE => error.ProcessFdQuotaExceeded, - .WSAENOBUFS => error.SystemResources, - .WSAEPROTONOSUPPORT => error.ProtocolNotSupported, - else => |err| windows.unexpectedWSAError(err), - }; - } - - if (set.contains(.nonblocking)) { - var enabled: c_ulong = 1; - const rc = ws2_32.ioctlsocket(fd, ws2_32.FIONBIO, &enabled); - if (rc == ws2_32.SOCKET_ERROR) { - return windows.unexpectedWSAError(ws2_32.WSAGetLastError()); - } - } - - return Socket{ .fd = fd }; - } - - /// Closes the socket. - pub fn deinit(self: Socket) void { - _ = ws2_32.closesocket(self.fd); - } - - /// Shutdown either the read side, write side, or all side of the socket. - pub fn shutdown(self: Socket, how: os.ShutdownHow) !void { - const rc = ws2_32.shutdown(self.fd, switch (how) { - .recv => ws2_32.SD_RECEIVE, - .send => ws2_32.SD_SEND, - .both => ws2_32.SD_BOTH, - }); - if (rc == ws2_32.SOCKET_ERROR) { - return switch (ws2_32.WSAGetLastError()) { - .WSAECONNABORTED => return error.ConnectionAborted, - .WSAECONNRESET => return error.ConnectionResetByPeer, - .WSAEINPROGRESS => return error.BlockingOperationInProgress, - .WSAEINVAL => unreachable, - .WSAENETDOWN => return error.NetworkSubsystemFailed, - .WSAENOTCONN => return error.SocketNotConnected, - .WSAENOTSOCK => unreachable, - .WSANOTINITIALISED => unreachable, - else => |err| return windows.unexpectedWSAError(err), - }; - } - } - - /// Binds the socket to an address. - pub fn bind(self: Socket, address: Socket.Address) !void { - const rc = ws2_32.bind(self.fd, @ptrCast(*const ws2_32.sockaddr, &address.toNative()), @intCast(c_int, address.getNativeSize())); - if (rc == ws2_32.SOCKET_ERROR) { - return switch (ws2_32.WSAGetLastError()) { - .WSAENETDOWN => error.NetworkSubsystemFailed, - .WSAEACCES => error.AccessDenied, - .WSAEADDRINUSE => error.AddressInUse, - .WSAEADDRNOTAVAIL => error.AddressNotAvailable, - .WSAEFAULT => error.BadAddress, - .WSAEINPROGRESS => error.WouldBlock, - .WSAEINVAL => error.AlreadyBound, - .WSAENOBUFS => error.NoEphemeralPortsAvailable, - .WSAENOTSOCK => error.NotASocket, - else => |err| windows.unexpectedWSAError(err), - }; - } - } - - /// Start listening for incoming connections on the socket. - pub fn listen(self: Socket, max_backlog_size: u31) !void { - const rc = ws2_32.listen(self.fd, max_backlog_size); - if (rc == ws2_32.SOCKET_ERROR) { - return switch (ws2_32.WSAGetLastError()) { - .WSAENETDOWN => error.NetworkSubsystemFailed, - .WSAEADDRINUSE => error.AddressInUse, - .WSAEISCONN => error.AlreadyConnected, - .WSAEINVAL => error.SocketNotBound, - .WSAEMFILE, .WSAENOBUFS => error.SystemResources, - .WSAENOTSOCK => error.FileDescriptorNotASocket, - .WSAEOPNOTSUPP => error.OperationNotSupported, - .WSAEINPROGRESS => error.WouldBlock, - else => |err| windows.unexpectedWSAError(err), - }; - } - } - - /// Have the socket attempt to the connect to an address. - pub fn connect(self: Socket, address: Socket.Address) !void { - const rc = ws2_32.connect(self.fd, @ptrCast(*const ws2_32.sockaddr, &address.toNative()), @intCast(c_int, address.getNativeSize())); - if (rc == ws2_32.SOCKET_ERROR) { - return switch (ws2_32.WSAGetLastError()) { - .WSAEADDRINUSE => error.AddressInUse, - .WSAEADDRNOTAVAIL => error.AddressNotAvailable, - .WSAECONNREFUSED => error.ConnectionRefused, - .WSAETIMEDOUT => error.ConnectionTimedOut, - .WSAEFAULT => error.BadAddress, - .WSAEINVAL => error.ListeningSocket, - .WSAEISCONN => error.AlreadyConnected, - .WSAENOTSOCK => error.NotASocket, - .WSAEACCES => error.BroadcastNotEnabled, - .WSAENOBUFS => error.SystemResources, - .WSAEAFNOSUPPORT => error.AddressFamilyNotSupported, - .WSAEINPROGRESS, .WSAEWOULDBLOCK => error.WouldBlock, - .WSAEHOSTUNREACH, .WSAENETUNREACH => error.NetworkUnreachable, - else => |err| windows.unexpectedWSAError(err), - }; - } - } - - /// Accept a pending incoming connection queued to the kernel backlog - /// of the socket. - pub fn accept(self: Socket, flags: std.enums.EnumFieldStruct(Socket.InitFlags, bool, false)) !Socket.Connection { - var address: Socket.Address.Native.Storage = undefined; - var address_len: c_int = @sizeOf(Socket.Address.Native.Storage); - - const fd = ws2_32.accept(self.fd, @ptrCast(*ws2_32.sockaddr, &address), &address_len); - if (fd == ws2_32.INVALID_SOCKET) { - return switch (ws2_32.WSAGetLastError()) { - .WSANOTINITIALISED => unreachable, - .WSAECONNRESET => error.ConnectionResetByPeer, - .WSAEFAULT => unreachable, - .WSAEINVAL => error.SocketNotListening, - .WSAEMFILE => error.ProcessFdQuotaExceeded, - .WSAENETDOWN => error.NetworkSubsystemFailed, - .WSAENOBUFS => error.FileDescriptorNotASocket, - .WSAEOPNOTSUPP => error.OperationNotSupported, - .WSAEWOULDBLOCK => error.WouldBlock, - else => |err| windows.unexpectedWSAError(err), - }; - } - - const socket = Socket.from(fd); - errdefer socket.deinit(); - - const socket_address = Socket.Address.fromNative(@ptrCast(*ws2_32.sockaddr, &address)); - - const set = std.EnumSet(Socket.InitFlags).init(flags); - if (set.contains(.nonblocking)) { - var enabled: c_ulong = 1; - const rc = ws2_32.ioctlsocket(fd, ws2_32.FIONBIO, &enabled); - if (rc == ws2_32.SOCKET_ERROR) { - return windows.unexpectedWSAError(ws2_32.WSAGetLastError()); - } - } - - return Socket.Connection.from(socket, socket_address); - } - - /// Read data from the socket into the buffer provided with a set of flags - /// specified. It returns the number of bytes read into the buffer provided. - pub fn read(self: Socket, buf: []u8, flags: u32) !usize { - var bufs = &[_]ws2_32.WSABUF{.{ .len = @intCast(u32, buf.len), .buf = buf.ptr }}; - var num_bytes: u32 = undefined; - var flags_ = flags; - - const rc = ws2_32.WSARecv(self.fd, bufs, 1, &num_bytes, &flags_, null, null); - if (rc == ws2_32.SOCKET_ERROR) { - return switch (ws2_32.WSAGetLastError()) { - .WSAECONNABORTED => error.ConnectionAborted, - .WSAECONNRESET => error.ConnectionResetByPeer, - .WSAEDISCON => error.ConnectionClosedByPeer, - .WSAEFAULT => error.BadBuffer, - .WSAEINPROGRESS, - .WSAEWOULDBLOCK, - .WSA_IO_PENDING, - .WSAETIMEDOUT, - => error.WouldBlock, - .WSAEINTR => error.Cancelled, - .WSAEINVAL => error.SocketNotBound, - .WSAEMSGSIZE => error.MessageTooLarge, - .WSAENETDOWN => error.NetworkSubsystemFailed, - .WSAENETRESET => error.NetworkReset, - .WSAENOTCONN => error.SocketNotConnected, - .WSAENOTSOCK => error.FileDescriptorNotASocket, - .WSAEOPNOTSUPP => error.OperationNotSupported, - .WSAESHUTDOWN => error.AlreadyShutdown, - .WSA_OPERATION_ABORTED => error.OperationAborted, - else => |err| windows.unexpectedWSAError(err), - }; - } - - return @intCast(usize, num_bytes); - } - - /// Write a buffer of data provided to the socket with a set of flags specified. - /// It returns the number of bytes that are written to the socket. - pub fn write(self: Socket, buf: []const u8, flags: u32) !usize { - var bufs = &[_]ws2_32.WSABUF{.{ .len = @intCast(u32, buf.len), .buf = @intToPtr([*]u8, @ptrToInt(buf.ptr)) }}; - var num_bytes: u32 = undefined; - - const rc = ws2_32.WSASend(self.fd, bufs, 1, &num_bytes, flags, null, null); - if (rc == ws2_32.SOCKET_ERROR) { - return switch (ws2_32.WSAGetLastError()) { - .WSAECONNABORTED => error.ConnectionAborted, - .WSAECONNRESET => error.ConnectionResetByPeer, - .WSAEFAULT => error.BadBuffer, - .WSAEINPROGRESS, - .WSAEWOULDBLOCK, - .WSA_IO_PENDING, - .WSAETIMEDOUT, - => error.WouldBlock, - .WSAEINTR => error.Cancelled, - .WSAEINVAL => error.SocketNotBound, - .WSAEMSGSIZE => error.MessageTooLarge, - .WSAENETDOWN => error.NetworkSubsystemFailed, - .WSAENETRESET => error.NetworkReset, - .WSAENOBUFS => error.BufferDeadlock, - .WSAENOTCONN => error.SocketNotConnected, - .WSAENOTSOCK => error.FileDescriptorNotASocket, - .WSAEOPNOTSUPP => error.OperationNotSupported, - .WSAESHUTDOWN => error.AlreadyShutdown, - .WSA_OPERATION_ABORTED => error.OperationAborted, - else => |err| windows.unexpectedWSAError(err), - }; - } - - return @intCast(usize, num_bytes); - } - - /// Writes multiple I/O vectors with a prepended message header to the socket - /// with a set of flags specified. It returns the number of bytes that are - /// written to the socket. - pub fn writeMessage(self: Socket, msg: Socket.Message, flags: u32) !usize { - const call = try windows.loadWinsockExtensionFunction(ws2_32.LPFN_WSASENDMSG, self.fd, ws2_32.WSAID_WSASENDMSG); - - var num_bytes: u32 = undefined; - - const rc = call(self.fd, &msg, flags, &num_bytes, null, null); - if (rc == ws2_32.SOCKET_ERROR) { - return switch (ws2_32.WSAGetLastError()) { - .WSAECONNABORTED => error.ConnectionAborted, - .WSAECONNRESET => error.ConnectionResetByPeer, - .WSAEFAULT => error.BadBuffer, - .WSAEINPROGRESS, - .WSAEWOULDBLOCK, - .WSA_IO_PENDING, - .WSAETIMEDOUT, - => error.WouldBlock, - .WSAEINTR => error.Cancelled, - .WSAEINVAL => error.SocketNotBound, - .WSAEMSGSIZE => error.MessageTooLarge, - .WSAENETDOWN => error.NetworkSubsystemFailed, - .WSAENETRESET => error.NetworkReset, - .WSAENOBUFS => error.BufferDeadlock, - .WSAENOTCONN => error.SocketNotConnected, - .WSAENOTSOCK => error.FileDescriptorNotASocket, - .WSAEOPNOTSUPP => error.OperationNotSupported, - .WSAESHUTDOWN => error.AlreadyShutdown, - .WSA_OPERATION_ABORTED => error.OperationAborted, - else => |err| windows.unexpectedWSAError(err), - }; - } - - return @intCast(usize, num_bytes); - } - - /// Read multiple I/O vectors with a prepended message header from the socket - /// with a set of flags specified. It returns the number of bytes that were - /// read into the buffer provided. - pub fn readMessage(self: Socket, msg: *Socket.Message, flags: u32) !usize { - _ = flags; - const call = try windows.loadWinsockExtensionFunction(ws2_32.LPFN_WSARECVMSG, self.fd, ws2_32.WSAID_WSARECVMSG); - - var num_bytes: u32 = undefined; - - const rc = call(self.fd, msg, &num_bytes, null, null); - if (rc == ws2_32.SOCKET_ERROR) { - return switch (ws2_32.WSAGetLastError()) { - .WSAECONNABORTED => error.ConnectionAborted, - .WSAECONNRESET => error.ConnectionResetByPeer, - .WSAEDISCON => error.ConnectionClosedByPeer, - .WSAEFAULT => error.BadBuffer, - .WSAEINPROGRESS, - .WSAEWOULDBLOCK, - .WSA_IO_PENDING, - .WSAETIMEDOUT, - => error.WouldBlock, - .WSAEINTR => error.Cancelled, - .WSAEINVAL => error.SocketNotBound, - .WSAEMSGSIZE => error.MessageTooLarge, - .WSAENETDOWN => error.NetworkSubsystemFailed, - .WSAENETRESET => error.NetworkReset, - .WSAENOTCONN => error.SocketNotConnected, - .WSAENOTSOCK => error.FileDescriptorNotASocket, - .WSAEOPNOTSUPP => error.OperationNotSupported, - .WSAESHUTDOWN => error.AlreadyShutdown, - .WSA_OPERATION_ABORTED => error.OperationAborted, - else => |err| windows.unexpectedWSAError(err), - }; - } - - return @intCast(usize, num_bytes); - } - - /// Query the address that the socket is locally bounded to. - pub fn getLocalAddress(self: Socket) !Socket.Address { - var address: Socket.Address.Native.Storage = undefined; - var address_len: c_int = @sizeOf(Socket.Address.Native.Storage); - - const rc = ws2_32.getsockname(self.fd, @ptrCast(*ws2_32.sockaddr, &address), &address_len); - if (rc == ws2_32.SOCKET_ERROR) { - return switch (ws2_32.WSAGetLastError()) { - .WSANOTINITIALISED => unreachable, - .WSAEFAULT => unreachable, - .WSAENETDOWN => error.NetworkSubsystemFailed, - .WSAENOTSOCK => error.FileDescriptorNotASocket, - .WSAEINVAL => error.SocketNotBound, - else => |err| windows.unexpectedWSAError(err), - }; - } - - return Socket.Address.fromNative(@ptrCast(*ws2_32.sockaddr, &address)); - } - - /// Query the address that the socket is connected to. - pub fn getRemoteAddress(self: Socket) !Socket.Address { - var address: Socket.Address.Native.Storage = undefined; - var address_len: c_int = @sizeOf(Socket.Address.Native.Storage); - - const rc = ws2_32.getpeername(self.fd, @ptrCast(*ws2_32.sockaddr, &address), &address_len); - if (rc == ws2_32.SOCKET_ERROR) { - return switch (ws2_32.WSAGetLastError()) { - .WSANOTINITIALISED => unreachable, - .WSAEFAULT => unreachable, - .WSAENETDOWN => error.NetworkSubsystemFailed, - .WSAENOTSOCK => error.FileDescriptorNotASocket, - .WSAEINVAL => error.SocketNotBound, - else => |err| windows.unexpectedWSAError(err), - }; - } - - return Socket.Address.fromNative(@ptrCast(*ws2_32.sockaddr, &address)); - } - - /// Query and return the latest cached error on the socket. - pub fn getError(self: Socket) !void { - _ = self; - return {}; - } - - /// Query the read buffer size of the socket. - pub fn getReadBufferSize(self: Socket) !u32 { - _ = self; - return 0; - } - - /// Query the write buffer size of the socket. - pub fn getWriteBufferSize(self: Socket) !u32 { - _ = self; - return 0; - } - - /// Set a socket option. - pub fn setOption(self: Socket, level: u32, code: u32, value: []const u8) !void { - const rc = ws2_32.setsockopt(self.fd, @intCast(i32, level), @intCast(i32, code), value.ptr, @intCast(i32, value.len)); - if (rc == ws2_32.SOCKET_ERROR) { - return switch (ws2_32.WSAGetLastError()) { - .WSANOTINITIALISED => unreachable, - .WSAENETDOWN => return error.NetworkSubsystemFailed, - .WSAEFAULT => unreachable, - .WSAENOTSOCK => return error.FileDescriptorNotASocket, - .WSAEINVAL => return error.SocketNotBound, - else => |err| windows.unexpectedWSAError(err), - }; - } - } - - /// Have close() or shutdown() syscalls block until all queued messages in the socket have been successfully - /// sent, or if the timeout specified in seconds has been reached. It returns `error.UnsupportedSocketOption` - /// if the host does not support the option for a socket to linger around up until a timeout specified in - /// seconds. - pub fn setLinger(self: Socket, timeout_seconds: ?u16) !void { - const settings = Socket.Linger.init(timeout_seconds); - return self.setOption(ws2_32.SOL.SOCKET, ws2_32.SO.LINGER, mem.asBytes(&settings)); - } - - /// On connection-oriented sockets, have keep-alive messages be sent periodically. The timing in which keep-alive - /// messages are sent are dependant on operating system settings. It returns `error.UnsupportedSocketOption` if - /// the host does not support periodically sending keep-alive messages on connection-oriented sockets. - pub fn setKeepAlive(self: Socket, enabled: bool) !void { - return self.setOption(ws2_32.SOL.SOCKET, ws2_32.SO.KEEPALIVE, mem.asBytes(&@as(u32, @boolToInt(enabled)))); - } - - /// Allow multiple sockets on the same host to listen on the same address. It returns `error.UnsupportedSocketOption` if - /// the host does not support sockets listening the same address. - pub fn setReuseAddress(self: Socket, enabled: bool) !void { - return self.setOption(ws2_32.SOL.SOCKET, ws2_32.SO.REUSEADDR, mem.asBytes(&@as(u32, @boolToInt(enabled)))); - } - - /// Allow multiple sockets on the same host to listen on the same port. It returns `error.UnsupportedSocketOption` if - /// the host does not supports sockets listening on the same port. - /// - /// TODO: verify if this truly mimicks SO.REUSEPORT behavior, or if SO.REUSE_UNICASTPORT provides the correct behavior - pub fn setReusePort(self: Socket, enabled: bool) !void { - try self.setOption(ws2_32.SOL.SOCKET, ws2_32.SO.BROADCAST, mem.asBytes(&@as(u32, @boolToInt(enabled)))); - try self.setReuseAddress(enabled); - } - - /// Set the write buffer size of the socket. - pub fn setWriteBufferSize(self: Socket, size: u32) !void { - return self.setOption(ws2_32.SOL.SOCKET, ws2_32.SO.SNDBUF, mem.asBytes(&size)); - } - - /// Set the read buffer size of the socket. - pub fn setReadBufferSize(self: Socket, size: u32) !void { - return self.setOption(ws2_32.SOL.SOCKET, ws2_32.SO.RCVBUF, mem.asBytes(&size)); - } - - /// WARNING: Timeouts only affect blocking sockets. It is undefined behavior if a timeout is - /// set on a non-blocking socket. - /// - /// Set a timeout on the socket that is to occur if no messages are successfully written - /// to its bound destination after a specified number of milliseconds. A subsequent write - /// to the socket will thereafter return `error.WouldBlock` should the timeout be exceeded. - pub fn setWriteTimeout(self: Socket, milliseconds: u32) !void { - return self.setOption(ws2_32.SOL.SOCKET, ws2_32.SO.SNDTIMEO, mem.asBytes(&milliseconds)); - } - - /// WARNING: Timeouts only affect blocking sockets. It is undefined behavior if a timeout is - /// set on a non-blocking socket. - /// - /// Set a timeout on the socket that is to occur if no messages are successfully read - /// from its bound destination after a specified number of milliseconds. A subsequent - /// read from the socket will thereafter return `error.WouldBlock` should the timeout be - /// exceeded. - pub fn setReadTimeout(self: Socket, milliseconds: u32) !void { - return self.setOption(ws2_32.SOL.SOCKET, ws2_32.SO.RCVTIMEO, mem.asBytes(&milliseconds)); - } - }; -} diff --git a/src/Compilation.zig b/src/Compilation.zig index a18b05a93969..4c7489c0c846 100644 --- a/src/Compilation.zig +++ b/src/Compilation.zig @@ -584,7 +584,17 @@ pub const AllErrors = struct { Message.HashContext, std.hash_map.default_max_load_percentage, ).init(allocator); - const err_source = try module_err_msg.src_loc.file_scope.getSource(module.gpa); + const err_source = module_err_msg.src_loc.file_scope.getSource(module.gpa) catch |err| { + const file_path = try module_err_msg.src_loc.file_scope.fullPath(allocator); + try errors.append(.{ + .plain = .{ + .msg = try std.fmt.allocPrint(allocator, "unable to load '{s}': {s}", .{ + file_path, @errorName(err), + }), + }, + }); + return; + }; const err_span = try module_err_msg.src_loc.span(module.gpa); const err_loc = std.zig.findLineColumn(err_source.bytes, err_span.main);