Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

p521: Improve field arithmetic #948

Merged
merged 2 commits into from
Nov 2, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 13 additions & 17 deletions p521/src/arithmetic/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,8 @@ impl FieldElement {

/// Returns self^(2^n) mod p
const fn sqn(&self, n: usize) -> Self {
let mut x = *self;
let mut i = 0;
let mut x = self.square();
let mut i = 1;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt this will have any effect on the generated code. LLVM is usually quite good at these sorts of optimizations.

Copy link
Contributor Author

@MasterAwesome MasterAwesome Nov 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LLVM actually doesn't optimize this even in release, here's how it looks on x86-64:

After

p521::arithmetic::field::FieldElement::sqn:
 push    r15
 push    r14
 push    r13
 push    r12
 push    rbx
 sub     rsp, 160
 mov     r14, rdx
 mov     rbx, rdi
 mov     rax, qword, ptr, [rsi, +, 64]
 movups  xmm0, xmmword, ptr, [rsi]
 movaps  xmmword, ptr, [rsp, +, 80], xmm0
 movups  xmm0, xmmword, ptr, [rsi, +, 16]
 movaps  xmmword, ptr, [rsp, +, 96], xmm0
 movups  xmm0, xmmword, ptr, [rsi, +, 32]
 movaps  xmmword, ptr, [rsp, +, 112], xmm0
 movups  xmm0, xmmword, ptr, [rsi, +, 48]
 movaps  xmmword, ptr, [rsp, +, 128], xmm0
 mov     qword, ptr, [rsp, +, 144], rax
 mov     rdi, rsp
 lea     rsi, [rsp, +, 80]
 call    qword, ptr, [rip, +, _ZN4p52110arithmetic5field5loose17LooseFieldElement6square17hbbdaf448d86087daE@GOTPCREL]
 cmp     r14, 2
 jb      .LBB27_3
 dec     r14
 mov     r15, rsp
 lea     r12, [rsp, +, 80]
 mov     r13, qword, ptr, [rip, +, _ZN4p52110arithmetic5field5loose17LooseFieldElement6square17hbbdaf448d86087daE@GOTPCREL]
.LBB27_2:
 mov     rax, qword, ptr, [rsp, +, 64]
 movaps  xmm0, xmmword, ptr, [rsp]
 movaps  xmm1, xmmword, ptr, [rsp, +, 16]
 movaps  xmm2, xmmword, ptr, [rsp, +, 32]
 movaps  xmm3, xmmword, ptr, [rsp, +, 48]
 movaps  xmmword, ptr, [rsp, +, 80], xmm0
 movaps  xmmword, ptr, [rsp, +, 96], xmm1
 movaps  xmmword, ptr, [rsp, +, 112], xmm2
 movaps  xmmword, ptr, [rsp, +, 128], xmm3
 mov     qword, ptr, [rsp, +, 144], rax
 mov     rdi, r15
 mov     rsi, r12
 call    r13
 dec     r14
 jne     .LBB27_2
.LBB27_3:
 mov     rax, qword, ptr, [rsp, +, 64]
 mov     qword, ptr, [rbx, +, 64], rax
 movaps  xmm0, xmmword, ptr, [rsp]
 movaps  xmm1, xmmword, ptr, [rsp, +, 16]
 movaps  xmm2, xmmword, ptr, [rsp, +, 32]
 movaps  xmm3, xmmword, ptr, [rsp, +, 48]
 movups  xmmword, ptr, [rbx, +, 48], xmm3
 movups  xmmword, ptr, [rbx, +, 32], xmm2
 movups  xmmword, ptr, [rbx, +, 16], xmm1
 movups  xmmword, ptr, [rbx], xmm0
 mov     rax, rbx
 add     rsp, 160
 pop     rbx
 pop     r12
 pop     r13
 pop     r14
 pop     r15
 ret

Before

p521::arithmetic::field::FieldElement::sqn:
 push    rbp
 push    r15
 push    r14
 push    r13
 push    r12
 push    rbx
 sub     rsp, 264
 mov     r14, rsi
 mov     qword, ptr, [rsp, +, 112], rdi
 mov     r13, qword, ptr, [rsi]
 mov     rbp, qword, ptr, [rsi, +, 8]
 mov     rax, qword, ptr, [rsi, +, 16]
 mov     rcx, qword, ptr, [rsi, +, 24]
 mov     r15, qword, ptr, [rsi, +, 32]
 mov     r12, qword, ptr, [rsi, +, 40]
 mov     rdx, qword, ptr, [rsi, +, 48]
 mov     rsi, qword, ptr, [rsi, +, 56]
 mov     rdi, qword, ptr, [r14, +, 64]
 mov     ebx, 519
.LBB31_1:
 mov     qword, ptr, [rsp], r13
 mov     qword, ptr, [rsp, +, 8], rbp
 mov     qword, ptr, [rsp, +, 16], rax
 mov     qword, ptr, [rsp, +, 24], rcx
 mov     qword, ptr, [rsp, +, 32], r15
 mov     qword, ptr, [rsp, +, 40], r12
 mov     qword, ptr, [rsp, +, 48], rdx
 mov     qword, ptr, [rsp, +, 56], rsi
 mov     qword, ptr, [rsp, +, 64], rdi
 lea     rdi, [rsp, +, 120]
 mov     rsi, rsp
 call    qword, ptr, [rip, +, _ZN4p52110arithmetic5field5loose17LooseFieldElement6square17hbbdaf448d86087daE@GOTPCREL]
 mov     r13, qword, ptr, [rsp, +, 120]
 mov     rbp, qword, ptr, [rsp, +, 128]
 mov     rax, qword, ptr, [rsp, +, 136]
 mov     rcx, qword, ptr, [rsp, +, 144]
 mov     r15, qword, ptr, [rsp, +, 152]
 mov     r12, qword, ptr, [rsp, +, 160]
 mov     rdx, qword, ptr, [rsp, +, 168]
 mov     rsi, qword, ptr, [rsp, +, 176]
 mov     rdi, qword, ptr, [rsp, +, 184]
 dec     rbx
 jne     .LBB31_1
 mov     qword, ptr, [rsp], r13
 mov     qword, ptr, [rsp, +, 8], rbp
 mov     qword, ptr, [rsp, +, 104], rax
 mov     qword, ptr, [rsp, +, 16], rax
 mov     qword, ptr, [rsp, +, 96], rcx
 mov     qword, ptr, [rsp, +, 24], rcx
 mov     qword, ptr, [rsp, +, 32], r15
 mov     qword, ptr, [rsp, +, 40], r12
 mov     qword, ptr, [rsp, +, 88], rdx
 mov     qword, ptr, [rsp, +, 48], rdx
 mov     qword, ptr, [rsp, +, 80], rsi
 mov     qword, ptr, [rsp, +, 56], rsi
 mov     qword, ptr, [rsp, +, 72], rdi
 mov     qword, ptr, [rsp, +, 64], rdi
 lea     rbx, [rsp, +, 120]
 mov     rsi, rsp
 mov     rdi, rbx
 call    qword, ptr, [rip, +, _ZN4p52110arithmetic5field5loose17LooseFieldElement6square17hbbdaf448d86087daE@GOTPCREL]
 lea     rdi, [rsp, +, 198]
 mov     rsi, rbx
 call    p521::arithmetic::field::field_impl::fiat_p521_to_bytes
 mov     rdi, rsp
 mov     rsi, r14
 call    p521::arithmetic::field::field_impl::fiat_p521_to_bytes
 mov     bl, 1
 xor     eax, eax
.LBB31_3:
 lea     r14, [rax, +, 1]
 movzx   ecx, byte, ptr, [rsp, +, rax, +, 198]
 xor     edi, edi
 cmp     cl, byte, ptr, [rsp, +, rax]
 sete    dil
 call    qword, ptr, [rip, +, _ZN6subtle9black_box17h67d940d0400f0e9dE@GOTPCREL]
 and     bl, al
 mov     rax, r14
 cmp     r14, 66
 jne     .LBB31_3
 movzx   edi, bl
 call    qword, ptr, [rip, +, _ZN6subtle9black_box17h67d940d0400f0e9dE@GOTPCREL]
 mov     rcx, qword, ptr, [rsp, +, 112]
 mov     qword, ptr, [rcx], r13
 mov     qword, ptr, [rcx, +, 8], rbp
 mov     rdx, qword, ptr, [rsp, +, 104]
 mov     qword, ptr, [rcx, +, 16], rdx
 mov     rdx, qword, ptr, [rsp, +, 96]
 mov     qword, ptr, [rcx, +, 24], rdx
 mov     qword, ptr, [rcx, +, 32], r15
 mov     qword, ptr, [rcx, +, 40], r12
 mov     rdx, qword, ptr, [rsp, +, 88]
 mov     qword, ptr, [rcx, +, 48], rdx
 mov     rdx, qword, ptr, [rsp, +, 80]
 mov     qword, ptr, [rcx, +, 56], rdx
 mov     rdx, qword, ptr, [rsp, +, 72]
 mov     qword, ptr, [rcx, +, 64], rdx
 mov     byte, ptr, [rcx, +, 72], al
 mov     rax, rcx
 add     rsp, 264
 pop     rbx
 pop     r12
 pop     r13
 pop     r14
 pop     r15
 pop     rbp
 ret

Copy link
Contributor Author

@MasterAwesome MasterAwesome Nov 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous solution also takes up ~100bytes more of stack space :O

Tested on:

rustc 1.73.0 (cc66ad468 2023-10-03)
binary: rustc
commit-hash: cc66ad468955717ab92600c770da8c1601a4ff33
commit-date: 2023-10-03
host: x86_64-unknown-linux-gnu
release: 1.73.0
LLVM version: 17.0.2

while i < n {
x = x.square();
i += 1;
Expand Down Expand Up @@ -276,22 +276,18 @@ impl FieldElement {

/// Returns the square root of self mod p, or `None` if no square root
/// exists.
///
/// # Implementation details
/// If _x_ has a sqrt, then due to Euler's criterion this implies x<sup>(p - 1)/2</sup> = 1.
/// 1. x<sup>(p + 1)/2</sup> = x.
/// 2. There's a special property due to _p ≡ 3 (mod 4)_ which implies _(p + 1)/4_ is an integer.
/// 3. We can rewrite `1.` as x<sup>((p+1)/4)<sup>2</sup></sup>
/// 4. x<sup>(p+1)/4</sup> is the square root.
/// 5. This is simplified as (2<sup>251</sup> - 1 + 1) /4 = 2<sup>519</sup>
/// 6. Hence, x<sup>2<sup>519</sup></sup> is the square root iff _result.square() == self_
pub fn sqrt(&self) -> CtOption<Self> {
// Tonelli-Shank's algorithm for q mod 4 = 3 (i.e. Shank's algorithm)
// https://eprint.iacr.org/2012/685.pdf
let w = self.pow_vartime(&[
0x0000000000000000,
0x0000000000000000,
0x0000000000000000,
0x0000000000000000,
0x0000000000000000,
0x0000000000000000,
0x0000000000000000,
0x0000000000000000,
0x0000000000000080,
]);

CtOption::new(w, w.square().ct_eq(self))
let sqrt = self.sqn(519);
CtOption::new(sqrt, sqrt.square().ct_eq(self))
}

/// Relax a tight field element into a loose one.
Expand Down