Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor TLS, add TLS server #19308

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 180 additions & 23 deletions lib/std/crypto/Certificate.zig
Original file line number Diff line number Diff line change
Expand Up @@ -969,7 +969,28 @@ pub const rsa = struct {
try EMSA_PSS_VERIFY(msg, &em_dec, mod_bits - 1, Hash.digest_length, Hash);
}

fn EMSA_PSS_VERIFY(msg: []const u8, em: []const u8, emBit: usize, sLen: usize, comptime Hash: type) !void {
pub fn sign(
comptime modulus_len: usize,
msg: []const u8,
comptime Hash: type,
private_key: PrivateKey,
salt: [Hash.digest_length]u8,
) ![modulus_len]u8 {
const mod_bits = modulus_len * 8;

var out: [modulus_len]u8 = undefined;
const em = try EMSA_PSS_ENCODE(&out, msg, mod_bits - 1, Hash, salt);

return try private_key.decrypt(modulus_len, em[0..modulus_len].*);
}

fn EMSA_PSS_VERIFY(
msg: []const u8,
em: []const u8,
emBit: usize,
sLen: usize,
comptime Hash: type,
) !void {
// 1. If the length of M is greater than the input limitation for
// the hash function (2^61 - 1 octets for SHA-1), output
// "inconsistent" and stop.
Expand Down Expand Up @@ -1103,6 +1124,99 @@ pub const rsa = struct {

return out[0..len];
}

fn EMSA_PSS_ENCODE(
out: []u8,
msg: []const u8,
emBit: usize,
comptime Hash: type,
salt: [Hash.digest_length]u8,
) ![]u8 {
// 1. If the length of M is greater than the input limitation for
// the hash function (2^61 - 1 octets for SHA-1), output
// "inconsistent" and stop.
// All the cryptographic hash functions in the standard library have a limit of >= 2^61 - 1.
// Even then, this check is only there for paranoia. In the context of TLS certifcates, emBit cannot exceed 4096.
if (emBit >= 1 << 61) return error.InvalidSignature;
// emLen = \c2yyeil(emBits/8)
const emLen = ((emBit - 1) / 8) + 1;

// 2. Let mHash = Hash(M), an octet string of length hLen.
var mHash: [Hash.digest_length]u8 = undefined;
Hash.hash(msg, &mHash, .{});

// 3. If emLen < hLen + sLen + 2, output "encoding error" and stop.
if (emLen < Hash.digest_length + salt.len + 2) {
return error.EncodingError;
}

// 4. Generate a random octet string salt of length sLen; if sLen =
// 0, then salt is the empty string.
// 5. Let
// M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt;
// M' is an octet string of length 8 + hLen + sLen with eight
// initial zero octets.
var m_p: [8 + Hash.digest_length + salt.len]u8 = undefined;

@memcpy(m_p[0..8], &([_]u8{0} ** 8));
@memcpy(m_p[8 .. 8 + mHash.len], &mHash);
@memcpy(m_p[(8 + Hash.digest_length)..], &salt);

// 6. Let H = Hash(M'), an octet string of length hLen.
var hash: [Hash.digest_length]u8 = undefined;
Hash.hash(&m_p, &hash, .{});

// 7. Generate an octet string PS consisting of emLen - sLen - hLen
// - 2 zero octets. The length of PS may be 0.
const ps_len = emLen - salt.len - Hash.digest_length - 2;

// 8. Let DB = PS || 0x01 || salt; DB is an octet string of length
// emLen - hLen - 1.
const mgf_len = emLen - Hash.digest_length - 1;
var db_buf: [512]u8 = undefined;
var db = db_buf[0 .. emLen - Hash.digest_length - 1];
var i: usize = 0;
while (i < ps_len) : (i += 1) {
db[i] = 0x00;
}
db[i] = 0x01;
i += 1;
@memcpy(db[i..], &salt);

// 9. Let dbMask = MGF(H, emLen - hLen - 1).
var mgf_out_buf: [512]u8 = undefined;
const mgf_out = mgf_out_buf[0 .. ((mgf_len - 1) / Hash.digest_length + 1) * Hash.digest_length];
const dbMask = try MGF1(Hash, mgf_out, &hash, mgf_len);

// 10. Let maskedDB = DB \xor dbMask.
i = 0;
while (i < db.len) : (i += 1) {
db[i] = db[i] ^ dbMask[i];
}

// 11. Set the leftmost 8emLen - emBits bits of the leftmost octet
// in maskedDB to zero.
const zero_bits = emLen * 8 - emBit;
var mask: u8 = 0;
i = 0;
while (i < 8 - zero_bits) : (i += 1) {
mask = mask << 1;
mask += 1;
}
db[0] = db[0] & mask;

// 12. Let EM = maskedDB || H || 0xbc.
i = 0;
@memcpy(out[0..db.len], db);
i += db.len;
@memcpy(out[i .. i + hash.len], &hash);
i += hash.len;
out[i] = 0xbc;
i += 1;

// 13. Output EM.
return out[0..i];
}
};

pub const PublicKey = struct {
Expand All @@ -1111,11 +1225,11 @@ pub const rsa = struct {
/// public exponent
e: Fe,

pub fn fromBytes(pub_bytes: []const u8, modulus_bytes: []const u8) !PublicKey {
pub fn fromBytes(mod: []const u8, exp: []const u8) !PublicKey {
// Reject modulus below 512 bits.
// 512-bit RSA was factored in 1999, so this limit barely means anything,
// but establish some limit now to ratchet in what we can.
const _n = Modulus.fromBytes(modulus_bytes, .big) catch return error.CertificatePublicKeyInvalid;
const _n = Modulus.fromBytes(mod, .big) catch return error.CertificatePublicKeyInvalid;
if (_n.bits() < 512) return error.CertificatePublicKeyInvalid;

// Exponent must be odd and greater than 2.
Expand All @@ -1124,45 +1238,88 @@ pub const rsa = struct {
// unlikely that exponents larger than 32 bits are being used for anything
// Windows commonly does.
// [1] https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-rsapubkey
if (pub_bytes.len > 4) return error.CertificatePublicKeyInvalid;
const _e = Fe.fromBytes(_n, pub_bytes, .big) catch return error.CertificatePublicKeyInvalid;
if (exp.len > 4) return error.CertificatePublicKeyInvalid;
const _e = Fe.fromBytes(_n, exp, .big) catch return error.CertificatePublicKeyInvalid;
if (!_e.isOdd()) return error.CertificatePublicKeyInvalid;
const e_v = _e.toPrimitive(u32) catch return error.CertificatePublicKeyInvalid;
if (e_v < 2) return error.CertificatePublicKeyInvalid;

return .{
.n = _n,
.e = _e,
};
return .{ .n = _n, .e = _e };
}

// RFC8017 Appendix A.1.1
pub fn fromDer(pub_key: []const u8) !PublicKey {
const pub_key_seq = try der.Element.parse(pub_key, 0);
if (pub_key_seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType;
const modulus_elem = try der.Element.parse(pub_key, pub_key_seq.slice.start);
pub fn fromDer(bytes: []const u8) !PublicKey {
const seq = try der.Element.parse(bytes, 0);
if (seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType;
const modulus_elem = try der.Element.parse(bytes, seq.slice.start);
if (modulus_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType;
const exponent_elem = try der.Element.parse(pub_key, modulus_elem.slice.end);
const exponent_elem = try der.Element.parse(bytes, modulus_elem.slice.end);
if (exponent_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType;
// Skip over meaningless zeroes in the modulus.
const modulus_raw = pub_key[modulus_elem.slice.start..modulus_elem.slice.end];
const modulus_offset = for (modulus_raw, 0..) |byte, i| {
if (byte != 0) break i;
} else modulus_raw.len;
const exponent = pub_key[exponent_elem.slice.start..exponent_elem.slice.end];
const modulus_raw = bytes[modulus_elem.slice.start..modulus_elem.slice.end];
const modulus_offset = std.mem.indexOfNone(u8, modulus_raw, &[_]u8{0}) orelse modulus_raw.len;
const modulus = modulus_raw[modulus_offset..];
const exponent = bytes[exponent_elem.slice.start..exponent_elem.slice.end];

return try fromBytes(exponent, modulus);
return try fromBytes(modulus, exponent);
clickingbuttons marked this conversation as resolved.
Show resolved Hide resolved
}

fn encrypt(public_key: PublicKey, comptime modulus_len: usize, msg: [modulus_len]u8) ![modulus_len]u8 {
const m = Fe.fromBytes(public_key.n, &msg, .big) catch return error.MessageTooLong;
const e = public_key.n.powPublic(m, public_key.e) catch unreachable;
fn encrypt(self: PublicKey, comptime modulus_len: usize, msg: [modulus_len]u8) ![modulus_len]u8 {
const m = Fe.fromBytes(self.n, &msg, .big) catch return error.MessageTooLong;
const e = self.n.powPublic(m, self.e) catch unreachable;
var res: [modulus_len]u8 = undefined;
e.toBytes(&res, .big) catch unreachable;
return res;
}
};

pub const PrivateKey = struct {
public: PublicKey,
/// private exponent
d: Fe,

pub fn fromBytes(mod: []const u8, public: []const u8, private: []const u8) !PrivateKey {
const _public = try PublicKey.fromBytes(mod, public);

const _d = Fe.fromBytes(_public.n, private, .big) catch return error.CertificatePrivateKeyInvalid;
if (!_d.isOdd()) return error.CertificatePrivateKeyInvalid;

return .{ .public = _public, .d = _d };
}

// RFC8017 Appendix A.1.2
pub fn fromDer(bytes: []const u8) !PrivateKey {
const seq = try der.Element.parse(bytes, 0);
if (seq.identifier.tag != .sequence) return error.PrivateKeyWrongDataType;

// We're just interested in the first 3 fields which don't vary by version
const version_elem = try der.Element.parse(bytes, seq.slice.start);
if (version_elem.identifier.tag != .integer) return error.PrivateKeyFieldWrongDataType;

const modulus_elem = try der.Element.parse(bytes, version_elem.slice.end);
if (modulus_elem.identifier.tag != .integer) return error.PrivateKeyFieldWrongDataType;
const pub_exponent_elem = try der.Element.parse(bytes, modulus_elem.slice.end);
if (pub_exponent_elem.identifier.tag != .integer) return error.PrivateKeyFieldWrongDataType;
const priv_exponent_elem = try der.Element.parse(bytes, pub_exponent_elem.slice.end);
if (priv_exponent_elem.identifier.tag != .integer) return error.PrivateKeyFieldWrongDataType;
// Skip over meaningless zeroes in the modulus.
const modulus_raw = bytes[modulus_elem.slice.start..modulus_elem.slice.end];
const modulus_offset = std.mem.indexOfNone(u8, modulus_raw, &[_]u8{0}) orelse modulus_raw.len;
const modulus = modulus_raw[modulus_offset..];
const pub_exponent = bytes[pub_exponent_elem.slice.start..pub_exponent_elem.slice.end];
const priv_exponent = bytes[priv_exponent_elem.slice.start..priv_exponent_elem.slice.end];

return try fromBytes(modulus, pub_exponent, priv_exponent);
}

fn decrypt(self: PrivateKey, comptime modulus_len: usize, msg: [modulus_len]u8) ![modulus_len]u8 {
const m = Fe.fromBytes(self.public.n, &msg, .big) catch return error.MessageTooLong;
const e = self.public.n.pow(m, self.d) catch unreachable;
var res: [modulus_len]u8 = undefined;
try e.toBytes(&res, .big);
return res;
}
};
};

const use_vectors = @import("builtin").zig_backend != .stage2_x86_64;
Binary file modified lib/std/crypto/testdata/key.der
Binary file not shown.
Loading