Skip to content

Commit

Permalink
feat(fft128): use twice as many lanes for avx512
Browse files Browse the repository at this point in the history
  • Loading branch information
sarah el kazdadi committed Feb 21, 2024
1 parent 9cd97a4 commit c95cc56
Show file tree
Hide file tree
Showing 2 changed files with 445 additions and 38 deletions.
233 changes: 223 additions & 10 deletions src/fft128/f128_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -621,9 +621,9 @@ impl f128 {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg_attr(docsrs, doc(cfg(any(target_arch = "x86", target_arch = "x86_64"))))]
pub mod x86 {
use pulp::{f64x4, x86::V3, Simd};
#[cfg(feature = "nightly")]
use pulp::{f64x8, x86::V4};
use pulp::{b8, f64x8, x86::V4};
use pulp::{f64x4, x86::V3, Simd};

#[inline(always)]
pub(crate) fn quick_two_sum_f64x4(simd: V3, a: f64x4, b: f64x4) -> (f64x4, f64x4) {
Expand Down Expand Up @@ -687,6 +687,67 @@ pub mod x86 {
(p, simd.mul_sub_f64x8(a, b, p))
}

#[cfg(feature = "nightly")]
#[inline(always)]
pub(crate) fn quick_two_sum_f64x16(simd: V4, a: f64x16, b: f64x16) -> (f64x16, f64x16) {
let s = simd.add_f64x16(a, b);
(s, simd.sub_f64x16(b, simd.sub_f64x16(s, a)))
}

#[inline(always)]
#[cfg(feature = "nightly")]
pub(crate) fn two_sum_f64x16(simd: V4, a: f64x16, b: f64x16) -> (f64x16, f64x16) {
let sign_bit = simd.splat_f64x16(-0.0);
let cmp = simd.cmp_gt_f64x16(
simd.andnot_f64x16(sign_bit, a),
simd.andnot_f64x16(sign_bit, b),
);
let (a, b) = (simd.select_f64x16(cmp, a, b), simd.select_f64x16(cmp, b, a));

quick_two_sum_f64x16(simd, a, b)
}

#[inline(always)]
#[cfg(feature = "nightly")]
pub(crate) fn two_diff_f64x16(simd: V4, a: f64x16, b: f64x16) -> (f64x16, f64x16) {
two_sum_f64x16(
simd,
a,
f64x16 {
lo: simd.f64s_neg(b.lo),
hi: simd.f64s_neg(b.hi),
},
)
}

#[cfg(feature = "nightly")]
#[inline(always)]
pub(crate) fn two_prod_f64x16(simd: V4, a: f64x16, b: f64x16) -> (f64x16, f64x16) {
let p = simd.mul_f64x16(a, b);
(p, simd.mul_sub_f64x16(a, b, p))
}

#[cfg(feature = "nightly")]
#[derive(Copy, Clone, Debug)]
#[repr(C)]
pub struct f64x16 {
pub lo: f64x8,
pub hi: f64x8,
}

#[cfg(feature = "nightly")]
#[derive(Copy, Clone, Debug)]
#[repr(C)]
pub struct b16 {
pub lo: b8,
pub hi: b8,
}

#[cfg(feature = "nightly")]
unsafe impl bytemuck::Zeroable for f64x16 {}
#[cfg(feature = "nightly")]
unsafe impl bytemuck::Pod for f64x16 {}

pub trait V3F128Ext {
fn add_estimate_f128x4(self, a0: f64x4, a1: f64x4, b0: f64x4, b1: f64x4) -> (f64x4, f64x4);
fn sub_estimate_f128x4(self, a0: f64x4, a1: f64x4, b0: f64x4, b1: f64x4) -> (f64x4, f64x4);
Expand All @@ -702,6 +763,34 @@ pub mod x86 {
fn add_f128x8(self, a0: f64x8, a1: f64x8, b0: f64x8, b1: f64x8) -> (f64x8, f64x8);
fn sub_f128x8(self, a0: f64x8, a1: f64x8, b0: f64x8, b1: f64x8) -> (f64x8, f64x8);
fn mul_f128x8(self, a0: f64x8, a1: f64x8, b0: f64x8, b1: f64x8) -> (f64x8, f64x8);

fn add_estimate_f128x16(
self,
a0: f64x16,
a1: f64x16,
b0: f64x16,
b1: f64x16,
) -> (f64x16, f64x16);
fn sub_estimate_f128x16(
self,
a0: f64x16,
a1: f64x16,
b0: f64x16,
b1: f64x16,
) -> (f64x16, f64x16);
fn add_f128x16(self, a0: f64x16, a1: f64x16, b0: f64x16, b1: f64x16) -> (f64x16, f64x16);
fn sub_f128x16(self, a0: f64x16, a1: f64x16, b0: f64x16, b1: f64x16) -> (f64x16, f64x16);
fn mul_f128x16(self, a0: f64x16, a1: f64x16, b0: f64x16, b1: f64x16) -> (f64x16, f64x16);

fn splat_f64x16(self, value: f64) -> f64x16;
fn add_f64x16(self, a: f64x16, b: f64x16) -> f64x16;
fn sub_f64x16(self, a: f64x16, b: f64x16) -> f64x16;
fn mul_f64x16(self, a: f64x16, b: f64x16) -> f64x16;
fn mul_add_f64x16(self, a: f64x16, b: f64x16, c: f64x16) -> f64x16;
fn mul_sub_f64x16(self, a: f64x16, b: f64x16, c: f64x16) -> f64x16;
fn andnot_f64x16(self, a: f64x16, b: f64x16) -> f64x16;
fn cmp_gt_f64x16(self, a: f64x16, b: f64x16) -> b16;
fn select_f64x16(self, mask: b16, if_true: f64x16, if_false: f64x16) -> f64x16;
}

impl V3F128Ext for V3 {
Expand Down Expand Up @@ -747,10 +836,7 @@ pub mod x86 {
#[inline(always)]
fn mul_f128x4(self, a0: f64x4, a1: f64x4, b0: f64x4, b1: f64x4) -> (f64x4, f64x4) {
let (p1, p2) = two_prod_f64x4(self, a0, b0);
let p2 = self.add_f64x4(
p2,
self.add_f64x4(self.mul_f64x4(a0, b1), self.mul_f64x4(a1, b0)),
);
let p2 = self.add_f64x4(p2, self.mul_add_f64x4(a0, b1, self.mul_f64x4(a1, b0)));
quick_two_sum_f64x4(self, p1, p2)
}
}
Expand Down Expand Up @@ -799,12 +885,139 @@ pub mod x86 {
#[inline(always)]
fn mul_f128x8(self, a0: f64x8, a1: f64x8, b0: f64x8, b1: f64x8) -> (f64x8, f64x8) {
let (p1, p2) = two_prod_f64x8(self, a0, b0);
let p2 = self.add_f64x8(
p2,
self.add_f64x8(self.mul_f64x8(a0, b1), self.mul_f64x8(a1, b0)),
);
let p2 = self.add_f64x8(p2, self.mul_add_f64x8(a0, b1, self.mul_f64x8(a1, b0)));
quick_two_sum_f64x8(self, p1, p2)
}

#[inline(always)]
fn add_estimate_f128x16(
self,
a0: f64x16,
a1: f64x16,
b0: f64x16,
b1: f64x16,
) -> (f64x16, f64x16) {
let (s, e) = two_sum_f64x16(self, a0, b0);
let e = self.add_f64x16(e, self.add_f64x16(a1, b1));
quick_two_sum_f64x16(self, s, e)
}

#[inline(always)]
fn sub_estimate_f128x16(
self,
a0: f64x16,
a1: f64x16,
b0: f64x16,
b1: f64x16,
) -> (f64x16, f64x16) {
let (s, e) = two_diff_f64x16(self, a0, b0);
let e = self.add_f64x16(e, a1);
let e = self.sub_f64x16(e, b1);
quick_two_sum_f64x16(self, s, e)
}

#[inline(always)]
fn add_f128x16(self, a0: f64x16, a1: f64x16, b0: f64x16, b1: f64x16) -> (f64x16, f64x16) {
let (s1, s2) = two_sum_f64x16(self, a0, b0);
let (t1, t2) = two_sum_f64x16(self, a1, b1);

let s2 = self.add_f64x16(s2, t1);
let (s1, s2) = quick_two_sum_f64x16(self, s1, s2);
let s2 = self.add_f64x16(s2, t2);
let (s1, s2) = quick_two_sum_f64x16(self, s1, s2);
(s1, s2)
}

#[inline(always)]
fn sub_f128x16(self, a0: f64x16, a1: f64x16, b0: f64x16, b1: f64x16) -> (f64x16, f64x16) {
let (s1, s2) = two_diff_f64x16(self, a0, b0);
let (t1, t2) = two_diff_f64x16(self, a1, b1);

let s2 = self.add_f64x16(s2, t1);
let (s1, s2) = quick_two_sum_f64x16(self, s1, s2);
let s2 = self.add_f64x16(s2, t2);
let (s1, s2) = quick_two_sum_f64x16(self, s1, s2);
(s1, s2)
}

#[inline(always)]
fn mul_f128x16(self, a0: f64x16, a1: f64x16, b0: f64x16, b1: f64x16) -> (f64x16, f64x16) {
let (p1, p2) = two_prod_f64x16(self, a0, b0);
let p2 = self.add_f64x16(p2, self.mul_add_f64x16(a0, b1, self.mul_f64x16(a1, b0)));
quick_two_sum_f64x16(self, p1, p2)
}

#[inline(always)]
fn add_f64x16(self, a: f64x16, b: f64x16) -> f64x16 {
f64x16 {
lo: self.add_f64x8(a.lo, b.lo),
hi: self.add_f64x8(a.hi, b.hi),
}
}

#[inline(always)]
fn sub_f64x16(self, a: f64x16, b: f64x16) -> f64x16 {
f64x16 {
lo: self.sub_f64x8(a.lo, b.lo),
hi: self.sub_f64x8(a.hi, b.hi),
}
}

#[inline(always)]
fn mul_f64x16(self, a: f64x16, b: f64x16) -> f64x16 {
f64x16 {
lo: self.mul_f64x8(a.lo, b.lo),
hi: self.mul_f64x8(a.hi, b.hi),
}
}

#[inline(always)]
fn mul_add_f64x16(self, a: f64x16, b: f64x16, c: f64x16) -> f64x16 {
f64x16 {
lo: self.mul_add_f64x8(a.lo, b.lo, c.lo),
hi: self.mul_add_f64x8(a.hi, b.hi, c.hi),
}
}

#[inline(always)]
fn mul_sub_f64x16(self, a: f64x16, b: f64x16, c: f64x16) -> f64x16 {
f64x16 {
lo: self.mul_sub_f64x8(a.lo, b.lo, c.lo),
hi: self.mul_sub_f64x8(a.hi, b.hi, c.hi),
}
}

#[inline(always)]
fn andnot_f64x16(self, a: f64x16, b: f64x16) -> f64x16 {
f64x16 {
lo: self.andnot_f64x8(a.lo, b.lo),
hi: self.andnot_f64x8(a.hi, b.hi),
}
}

#[inline(always)]
fn cmp_gt_f64x16(self, a: f64x16, b: f64x16) -> b16 {
b16 {
lo: self.cmp_gt_f64x8(a.lo, b.lo),
hi: self.cmp_gt_f64x8(a.hi, b.hi),
}
}

#[inline(always)]
fn select_f64x16(self, mask: b16, if_true: f64x16, if_false: f64x16) -> f64x16 {
f64x16 {
lo: self.select_f64x8(mask.lo, if_true.lo, if_false.lo),
hi: self.select_f64x8(mask.hi, if_true.hi, if_false.hi),
}
}

#[inline(always)]
fn splat_f64x16(self, value: f64) -> f64x16 {
f64x16 {
lo: self.splat_f64x8(value),
hi: self.splat_f64x8(value),
}
}
}
}

Expand Down
Loading

0 comments on commit c95cc56

Please sign in to comment.