Skip to content

Commit

Permalink
crypto.sha2: Use intrinsics for SHA-256 on x86-64 and AArch64
Browse files Browse the repository at this point in the history
There's probably plenty of room to optimize these further in the
future, but for the moment this gives ~3x improvement on Intel
x86-64 processors, ~5x on AMD, and ~10x on M1 Macs.

These extensions are very new - Most processors prior to 2020 do
not support them.

AVX-512 is a slightly older alternative that we could use on Intel
for a much bigger performance bump, but it's been fused off on
Intel's latest hybrid architectures and it relies on computing
independent SHA hashes in parallel. In contrast, these SHA intrinsics
provide the usual single-threaded, single-stream interface, and should
continue working on new processors.

AArch64 also has SHA-512 intrinsics that we could take advantage
of in the future
  • Loading branch information
topolarity committed Oct 23, 2022
1 parent 0ae60f7 commit 3ed22de
Showing 1 changed file with 164 additions and 70 deletions.
234 changes: 164 additions & 70 deletions lib/std/crypto/sha2.zig
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
const std = @import("../std.zig");
const builtin = @import("builtin");
const mem = std.mem;
const math = std.math;
const htest = @import("test.zig");
Expand All @@ -16,10 +17,9 @@ const RoundParam256 = struct {
g: usize,
h: usize,
i: usize,
k: u32,
};

fn roundParam256(a: usize, b: usize, c: usize, d: usize, e: usize, f: usize, g: usize, h: usize, i: usize, k: u32) RoundParam256 {
fn roundParam256(a: usize, b: usize, c: usize, d: usize, e: usize, f: usize, g: usize, h: usize, i: usize) RoundParam256 {
return RoundParam256{
.a = a,
.b = b,
Expand All @@ -30,7 +30,6 @@ fn roundParam256(a: usize, b: usize, c: usize, d: usize, e: usize, f: usize, g:
.g = g,
.h = h,
.i = i,
.k = k,
};
}

Expand Down Expand Up @@ -70,6 +69,8 @@ const Sha256Params = Sha2Params32{
.digest_bits = 256,
};

const v4u32 = @Vector(4, u32);

/// SHA-224
pub const Sha224 = Sha2x32(Sha224Params);

Expand All @@ -83,7 +84,7 @@ fn Sha2x32(comptime params: Sha2Params32) type {
pub const digest_length = params.digest_bits / 8;
pub const Options = struct {};

s: [8]u32,
s: [8]u32 align(16),
// Streaming Cache
buf: [64]u8 = undefined,
buf_len: u8 = 0,
Expand Down Expand Up @@ -168,8 +169,19 @@ fn Sha2x32(comptime params: Sha2Params32) type {
}
}

const W = [64]u32{
0x428A2F98, 0x71374491, 0xB5C0FBCF, 0xE9B5DBA5, 0x3956C25B, 0x59F111F1, 0x923F82A4, 0xAB1C5ED5,
0xD807AA98, 0x12835B01, 0x243185BE, 0x550C7DC3, 0x72BE5D74, 0x80DEB1FE, 0x9BDC06A7, 0xC19BF174,
0xE49B69C1, 0xEFBE4786, 0x0FC19DC6, 0x240CA1CC, 0x2DE92C6F, 0x4A7484AA, 0x5CB0A9DC, 0x76F988DA,
0x983E5152, 0xA831C66D, 0xB00327C8, 0xBF597FC7, 0xC6E00BF3, 0xD5A79147, 0x06CA6351, 0x14292967,
0x27B70A85, 0x2E1B2138, 0x4D2C6DFC, 0x53380D13, 0x650A7354, 0x766A0ABB, 0x81C2C92E, 0x92722C85,
0xA2BFE8A1, 0xA81A664B, 0xC24B8B70, 0xC76C51A3, 0xD192E819, 0xD6990624, 0xF40E3585, 0x106AA070,
0x19A4C116, 0x1E376C08, 0x2748774C, 0x34B0BCB5, 0x391C0CB3, 0x4ED8AA4A, 0x5B9CCA4F, 0x682E6FF3,
0x748F82EE, 0x78A5636F, 0x84C87814, 0x8CC70208, 0x90BEFFFA, 0xA4506CEB, 0xBEF9A3F7, 0xC67178F2,
};

fn round(d: *Self, b: *const [64]u8) void {
var s: [64]u32 = undefined;
var s: [64]u32 align(16) = undefined;

var i: usize = 0;
while (i < 16) : (i += 1) {
Expand All @@ -179,6 +191,88 @@ fn Sha2x32(comptime params: Sha2Params32) type {
s[i] |= @as(u32, b[i * 4 + 2]) << 8;
s[i] |= @as(u32, b[i * 4 + 3]) << 0;
}

if (builtin.cpu.arch == .aarch64 and builtin.cpu.features.isEnabled(@enumToInt(std.Target.aarch64.Feature.sha2))) {
var x: v4u32 = d.s[0..4].*;
var y: v4u32 = d.s[4..8].*;
const s_v = @ptrCast(*[16]v4u32, &s);

comptime var k: u8 = 0;
inline while (k < 16) : (k += 1) {
if (k > 3) {
s_v[k] = asm (
\\sha256su0.4s %[w0_3], %[w4_7]
\\sha256su1.4s %[w0_3], %[w8_11], %[w12_15]
: [w0_3] "=w" (-> v4u32),
: [_] "0" (s_v[k - 4]),
[w4_7] "w" (s_v[k - 3]),
[w8_11] "w" (s_v[k - 2]),
[w12_15] "w" (s_v[k - 1]),
);
}

const w: v4u32 = s_v[k] +% @as(v4u32, W[4 * k ..][0..4].*);
asm volatile (
\\mov.4s v0, %[x]
\\sha256h.4s %[x], %[y], %[w]
\\sha256h2.4s %[y], v0, %[w]
: [x] "=w" (x),
[y] "=w" (y),
: [_] "0" (x),
[_] "1" (y),
[w] "w" (w),
: "v0"
);
}

d.s[0..4].* = x +% @as(v4u32, d.s[0..4].*);
d.s[4..8].* = y +% @as(v4u32, d.s[4..8].*);
return;
} else if (builtin.cpu.arch == .x86_64 and builtin.cpu.features.isEnabled(@enumToInt(std.Target.x86.Feature.sha))) {
var x: v4u32 = [_]u32{ d.s[5], d.s[4], d.s[1], d.s[0] };
var y: v4u32 = [_]u32{ d.s[7], d.s[6], d.s[3], d.s[2] };
const s_v = @ptrCast(*[16]v4u32, &s);

comptime var k: u8 = 0;
inline while (k < 16) : (k += 1) {
if (k < 12) {
const r = asm ("sha256msg1 %[w4_7], %[w0_3]"
: [w0_3] "=x" (-> v4u32),
: [_] "0" (s_v[k]),
[w4_7] "x" (s_v[k + 1]),
);
const t = @shuffle(u32, s_v[k + 2], s_v[k + 3], [_]i32{ 1, 2, 3, -1 });
s_v[k + 4] = asm ("sha256msg2 %[w12_15], %[t]"
: [t] "=x" (-> v4u32),
: [_] "0" (r +% t),
[w12_15] "x" (s_v[k + 3]),
);
}

const w: v4u32 = s_v[k] +% @as(v4u32, W[4 * k ..][0..4].*);
asm volatile (
\\sha256rnds2 %[x], %[y]
\\pshufd $0xe, %%xmm0, %%xmm0
\\sha256rnds2 %[y], %[x]
: [y] "=x" (y),
[x] "=x" (x),
: [_] "0" (y),
[_] "1" (x),
[_] "{xmm0}" (w),
);
}

d.s[0] +%= x[3];
d.s[1] +%= x[2];
d.s[4] +%= x[1];
d.s[5] +%= x[0];
d.s[2] +%= y[3];
d.s[3] +%= y[2];
d.s[6] +%= y[1];
d.s[7] +%= y[0];
return;
}

while (i < 64) : (i += 1) {
s[i] = s[i - 16] +% s[i - 7] +% (math.rotr(u32, s[i - 15], @as(u32, 7)) ^ math.rotr(u32, s[i - 15], @as(u32, 18)) ^ (s[i - 15] >> 3)) +% (math.rotr(u32, s[i - 2], @as(u32, 17)) ^ math.rotr(u32, s[i - 2], @as(u32, 19)) ^ (s[i - 2] >> 10));
}
Expand All @@ -195,73 +289,73 @@ fn Sha2x32(comptime params: Sha2Params32) type {
};

const round0 = comptime [_]RoundParam256{
roundParam256(0, 1, 2, 3, 4, 5, 6, 7, 0, 0x428A2F98),
roundParam256(7, 0, 1, 2, 3, 4, 5, 6, 1, 0x71374491),
roundParam256(6, 7, 0, 1, 2, 3, 4, 5, 2, 0xB5C0FBCF),
roundParam256(5, 6, 7, 0, 1, 2, 3, 4, 3, 0xE9B5DBA5),
roundParam256(4, 5, 6, 7, 0, 1, 2, 3, 4, 0x3956C25B),
roundParam256(3, 4, 5, 6, 7, 0, 1, 2, 5, 0x59F111F1),
roundParam256(2, 3, 4, 5, 6, 7, 0, 1, 6, 0x923F82A4),
roundParam256(1, 2, 3, 4, 5, 6, 7, 0, 7, 0xAB1C5ED5),
roundParam256(0, 1, 2, 3, 4, 5, 6, 7, 8, 0xD807AA98),
roundParam256(7, 0, 1, 2, 3, 4, 5, 6, 9, 0x12835B01),
roundParam256(6, 7, 0, 1, 2, 3, 4, 5, 10, 0x243185BE),
roundParam256(5, 6, 7, 0, 1, 2, 3, 4, 11, 0x550C7DC3),
roundParam256(4, 5, 6, 7, 0, 1, 2, 3, 12, 0x72BE5D74),
roundParam256(3, 4, 5, 6, 7, 0, 1, 2, 13, 0x80DEB1FE),
roundParam256(2, 3, 4, 5, 6, 7, 0, 1, 14, 0x9BDC06A7),
roundParam256(1, 2, 3, 4, 5, 6, 7, 0, 15, 0xC19BF174),
roundParam256(0, 1, 2, 3, 4, 5, 6, 7, 16, 0xE49B69C1),
roundParam256(7, 0, 1, 2, 3, 4, 5, 6, 17, 0xEFBE4786),
roundParam256(6, 7, 0, 1, 2, 3, 4, 5, 18, 0x0FC19DC6),
roundParam256(5, 6, 7, 0, 1, 2, 3, 4, 19, 0x240CA1CC),
roundParam256(4, 5, 6, 7, 0, 1, 2, 3, 20, 0x2DE92C6F),
roundParam256(3, 4, 5, 6, 7, 0, 1, 2, 21, 0x4A7484AA),
roundParam256(2, 3, 4, 5, 6, 7, 0, 1, 22, 0x5CB0A9DC),
roundParam256(1, 2, 3, 4, 5, 6, 7, 0, 23, 0x76F988DA),
roundParam256(0, 1, 2, 3, 4, 5, 6, 7, 24, 0x983E5152),
roundParam256(7, 0, 1, 2, 3, 4, 5, 6, 25, 0xA831C66D),
roundParam256(6, 7, 0, 1, 2, 3, 4, 5, 26, 0xB00327C8),
roundParam256(5, 6, 7, 0, 1, 2, 3, 4, 27, 0xBF597FC7),
roundParam256(4, 5, 6, 7, 0, 1, 2, 3, 28, 0xC6E00BF3),
roundParam256(3, 4, 5, 6, 7, 0, 1, 2, 29, 0xD5A79147),
roundParam256(2, 3, 4, 5, 6, 7, 0, 1, 30, 0x06CA6351),
roundParam256(1, 2, 3, 4, 5, 6, 7, 0, 31, 0x14292967),
roundParam256(0, 1, 2, 3, 4, 5, 6, 7, 32, 0x27B70A85),
roundParam256(7, 0, 1, 2, 3, 4, 5, 6, 33, 0x2E1B2138),
roundParam256(6, 7, 0, 1, 2, 3, 4, 5, 34, 0x4D2C6DFC),
roundParam256(5, 6, 7, 0, 1, 2, 3, 4, 35, 0x53380D13),
roundParam256(4, 5, 6, 7, 0, 1, 2, 3, 36, 0x650A7354),
roundParam256(3, 4, 5, 6, 7, 0, 1, 2, 37, 0x766A0ABB),
roundParam256(2, 3, 4, 5, 6, 7, 0, 1, 38, 0x81C2C92E),
roundParam256(1, 2, 3, 4, 5, 6, 7, 0, 39, 0x92722C85),
roundParam256(0, 1, 2, 3, 4, 5, 6, 7, 40, 0xA2BFE8A1),
roundParam256(7, 0, 1, 2, 3, 4, 5, 6, 41, 0xA81A664B),
roundParam256(6, 7, 0, 1, 2, 3, 4, 5, 42, 0xC24B8B70),
roundParam256(5, 6, 7, 0, 1, 2, 3, 4, 43, 0xC76C51A3),
roundParam256(4, 5, 6, 7, 0, 1, 2, 3, 44, 0xD192E819),
roundParam256(3, 4, 5, 6, 7, 0, 1, 2, 45, 0xD6990624),
roundParam256(2, 3, 4, 5, 6, 7, 0, 1, 46, 0xF40E3585),
roundParam256(1, 2, 3, 4, 5, 6, 7, 0, 47, 0x106AA070),
roundParam256(0, 1, 2, 3, 4, 5, 6, 7, 48, 0x19A4C116),
roundParam256(7, 0, 1, 2, 3, 4, 5, 6, 49, 0x1E376C08),
roundParam256(6, 7, 0, 1, 2, 3, 4, 5, 50, 0x2748774C),
roundParam256(5, 6, 7, 0, 1, 2, 3, 4, 51, 0x34B0BCB5),
roundParam256(4, 5, 6, 7, 0, 1, 2, 3, 52, 0x391C0CB3),
roundParam256(3, 4, 5, 6, 7, 0, 1, 2, 53, 0x4ED8AA4A),
roundParam256(2, 3, 4, 5, 6, 7, 0, 1, 54, 0x5B9CCA4F),
roundParam256(1, 2, 3, 4, 5, 6, 7, 0, 55, 0x682E6FF3),
roundParam256(0, 1, 2, 3, 4, 5, 6, 7, 56, 0x748F82EE),
roundParam256(7, 0, 1, 2, 3, 4, 5, 6, 57, 0x78A5636F),
roundParam256(6, 7, 0, 1, 2, 3, 4, 5, 58, 0x84C87814),
roundParam256(5, 6, 7, 0, 1, 2, 3, 4, 59, 0x8CC70208),
roundParam256(4, 5, 6, 7, 0, 1, 2, 3, 60, 0x90BEFFFA),
roundParam256(3, 4, 5, 6, 7, 0, 1, 2, 61, 0xA4506CEB),
roundParam256(2, 3, 4, 5, 6, 7, 0, 1, 62, 0xBEF9A3F7),
roundParam256(1, 2, 3, 4, 5, 6, 7, 0, 63, 0xC67178F2),
roundParam256(0, 1, 2, 3, 4, 5, 6, 7, 0),
roundParam256(7, 0, 1, 2, 3, 4, 5, 6, 1),
roundParam256(6, 7, 0, 1, 2, 3, 4, 5, 2),
roundParam256(5, 6, 7, 0, 1, 2, 3, 4, 3),
roundParam256(4, 5, 6, 7, 0, 1, 2, 3, 4),
roundParam256(3, 4, 5, 6, 7, 0, 1, 2, 5),
roundParam256(2, 3, 4, 5, 6, 7, 0, 1, 6),
roundParam256(1, 2, 3, 4, 5, 6, 7, 0, 7),
roundParam256(0, 1, 2, 3, 4, 5, 6, 7, 8),
roundParam256(7, 0, 1, 2, 3, 4, 5, 6, 9),
roundParam256(6, 7, 0, 1, 2, 3, 4, 5, 10),
roundParam256(5, 6, 7, 0, 1, 2, 3, 4, 11),
roundParam256(4, 5, 6, 7, 0, 1, 2, 3, 12),
roundParam256(3, 4, 5, 6, 7, 0, 1, 2, 13),
roundParam256(2, 3, 4, 5, 6, 7, 0, 1, 14),
roundParam256(1, 2, 3, 4, 5, 6, 7, 0, 15),
roundParam256(0, 1, 2, 3, 4, 5, 6, 7, 16),
roundParam256(7, 0, 1, 2, 3, 4, 5, 6, 17),
roundParam256(6, 7, 0, 1, 2, 3, 4, 5, 18),
roundParam256(5, 6, 7, 0, 1, 2, 3, 4, 19),
roundParam256(4, 5, 6, 7, 0, 1, 2, 3, 20),
roundParam256(3, 4, 5, 6, 7, 0, 1, 2, 21),
roundParam256(2, 3, 4, 5, 6, 7, 0, 1, 22),
roundParam256(1, 2, 3, 4, 5, 6, 7, 0, 23),
roundParam256(0, 1, 2, 3, 4, 5, 6, 7, 24),
roundParam256(7, 0, 1, 2, 3, 4, 5, 6, 25),
roundParam256(6, 7, 0, 1, 2, 3, 4, 5, 26),
roundParam256(5, 6, 7, 0, 1, 2, 3, 4, 27),
roundParam256(4, 5, 6, 7, 0, 1, 2, 3, 28),
roundParam256(3, 4, 5, 6, 7, 0, 1, 2, 29),
roundParam256(2, 3, 4, 5, 6, 7, 0, 1, 30),
roundParam256(1, 2, 3, 4, 5, 6, 7, 0, 31),
roundParam256(0, 1, 2, 3, 4, 5, 6, 7, 32),
roundParam256(7, 0, 1, 2, 3, 4, 5, 6, 33),
roundParam256(6, 7, 0, 1, 2, 3, 4, 5, 34),
roundParam256(5, 6, 7, 0, 1, 2, 3, 4, 35),
roundParam256(4, 5, 6, 7, 0, 1, 2, 3, 36),
roundParam256(3, 4, 5, 6, 7, 0, 1, 2, 37),
roundParam256(2, 3, 4, 5, 6, 7, 0, 1, 38),
roundParam256(1, 2, 3, 4, 5, 6, 7, 0, 39),
roundParam256(0, 1, 2, 3, 4, 5, 6, 7, 40),
roundParam256(7, 0, 1, 2, 3, 4, 5, 6, 41),
roundParam256(6, 7, 0, 1, 2, 3, 4, 5, 42),
roundParam256(5, 6, 7, 0, 1, 2, 3, 4, 43),
roundParam256(4, 5, 6, 7, 0, 1, 2, 3, 44),
roundParam256(3, 4, 5, 6, 7, 0, 1, 2, 45),
roundParam256(2, 3, 4, 5, 6, 7, 0, 1, 46),
roundParam256(1, 2, 3, 4, 5, 6, 7, 0, 47),
roundParam256(0, 1, 2, 3, 4, 5, 6, 7, 48),
roundParam256(7, 0, 1, 2, 3, 4, 5, 6, 49),
roundParam256(6, 7, 0, 1, 2, 3, 4, 5, 50),
roundParam256(5, 6, 7, 0, 1, 2, 3, 4, 51),
roundParam256(4, 5, 6, 7, 0, 1, 2, 3, 52),
roundParam256(3, 4, 5, 6, 7, 0, 1, 2, 53),
roundParam256(2, 3, 4, 5, 6, 7, 0, 1, 54),
roundParam256(1, 2, 3, 4, 5, 6, 7, 0, 55),
roundParam256(0, 1, 2, 3, 4, 5, 6, 7, 56),
roundParam256(7, 0, 1, 2, 3, 4, 5, 6, 57),
roundParam256(6, 7, 0, 1, 2, 3, 4, 5, 58),
roundParam256(5, 6, 7, 0, 1, 2, 3, 4, 59),
roundParam256(4, 5, 6, 7, 0, 1, 2, 3, 60),
roundParam256(3, 4, 5, 6, 7, 0, 1, 2, 61),
roundParam256(2, 3, 4, 5, 6, 7, 0, 1, 62),
roundParam256(1, 2, 3, 4, 5, 6, 7, 0, 63),
};
inline for (round0) |r| {
v[r.h] = v[r.h] +% (math.rotr(u32, v[r.e], @as(u32, 6)) ^ math.rotr(u32, v[r.e], @as(u32, 11)) ^ math.rotr(u32, v[r.e], @as(u32, 25))) +% (v[r.g] ^ (v[r.e] & (v[r.f] ^ v[r.g]))) +% r.k +% s[r.i];
v[r.h] = v[r.h] +% (math.rotr(u32, v[r.e], @as(u32, 6)) ^ math.rotr(u32, v[r.e], @as(u32, 11)) ^ math.rotr(u32, v[r.e], @as(u32, 25))) +% (v[r.g] ^ (v[r.e] & (v[r.f] ^ v[r.g]))) +% W[r.i] +% s[r.i];

v[r.d] = v[r.d] +% v[r.h];

Expand Down

0 comments on commit 3ed22de

Please sign in to comment.