Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arm Fused Multiply-Add fixes #1219

Merged
merged 2 commits into from
Sep 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions crates/core_arch/src/arm_shared/neon/generated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8721,7 +8721,7 @@ pub unsafe fn vmull_laneq_u32<const LANE: i32>(a: uint32x2_t, b: uint32x4_t) ->
/// Floating-point fused Multiply-Add to accumulator(vector)
#[inline]
#[target_feature(enable = "neon")]
#[cfg_attr(target_arch = "arm", target_feature(enable = "fp-armv8,v8"))]
#[cfg_attr(target_arch = "arm", target_feature(enable = "vfp4"))]
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vfma))]
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(fmla))]
pub unsafe fn vfma_f32(a: float32x2_t, b: float32x2_t, c: float32x2_t) -> float32x2_t {
Expand All @@ -8737,7 +8737,7 @@ vfma_f32_(b, c, a)
/// Floating-point fused Multiply-Add to accumulator(vector)
#[inline]
#[target_feature(enable = "neon")]
#[cfg_attr(target_arch = "arm", target_feature(enable = "fp-armv8,v8"))]
#[cfg_attr(target_arch = "arm", target_feature(enable = "vfp4"))]
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vfma))]
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(fmla))]
pub unsafe fn vfmaq_f32(a: float32x4_t, b: float32x4_t, c: float32x4_t) -> float32x4_t {
Expand All @@ -8753,27 +8753,27 @@ vfmaq_f32_(b, c, a)
/// Floating-point fused Multiply-Add to accumulator(vector)
#[inline]
#[target_feature(enable = "neon")]
#[cfg_attr(target_arch = "arm", target_feature(enable = "fp-armv8,v8"))]
#[cfg_attr(target_arch = "arm", target_feature(enable = "vfp4"))]
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vfma))]
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(fmla))]
pub unsafe fn vfma_n_f32(a: float32x2_t, b: float32x2_t, c: f32) -> float32x2_t {
vfma_f32(a, b, vdup_n_f32(c))
vfma_f32(a, b, vdup_n_f32_vfp4(c))
}

/// Floating-point fused Multiply-Add to accumulator(vector)
#[inline]
#[target_feature(enable = "neon")]
#[cfg_attr(target_arch = "arm", target_feature(enable = "fp-armv8,v8"))]
#[cfg_attr(target_arch = "arm", target_feature(enable = "vfp4"))]
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vfma))]
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(fmla))]
pub unsafe fn vfmaq_n_f32(a: float32x4_t, b: float32x4_t, c: f32) -> float32x4_t {
vfmaq_f32(a, b, vdupq_n_f32(c))
vfmaq_f32(a, b, vdupq_n_f32_vfp4(c))
}

/// Floating-point fused multiply-subtract from accumulator
#[inline]
#[target_feature(enable = "neon")]
#[cfg_attr(target_arch = "arm", target_feature(enable = "fp-armv8,v8"))]
#[cfg_attr(target_arch = "arm", target_feature(enable = "vfp4"))]
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vfms))]
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(fmls))]
pub unsafe fn vfms_f32(a: float32x2_t, b: float32x2_t, c: float32x2_t) -> float32x2_t {
Expand All @@ -8784,7 +8784,7 @@ pub unsafe fn vfms_f32(a: float32x2_t, b: float32x2_t, c: float32x2_t) -> float3
/// Floating-point fused multiply-subtract from accumulator
#[inline]
#[target_feature(enable = "neon")]
#[cfg_attr(target_arch = "arm", target_feature(enable = "fp-armv8,v8"))]
#[cfg_attr(target_arch = "arm", target_feature(enable = "vfp4"))]
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vfms))]
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(fmls))]
pub unsafe fn vfmsq_f32(a: float32x4_t, b: float32x4_t, c: float32x4_t) -> float32x4_t {
Expand All @@ -8795,21 +8795,21 @@ pub unsafe fn vfmsq_f32(a: float32x4_t, b: float32x4_t, c: float32x4_t) -> float
/// Floating-point fused Multiply-subtract to accumulator(vector)
#[inline]
#[target_feature(enable = "neon")]
#[cfg_attr(target_arch = "arm", target_feature(enable = "fp-armv8,v8"))]
#[cfg_attr(target_arch = "arm", target_feature(enable = "vfp4"))]
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vfms))]
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(fmls))]
pub unsafe fn vfms_n_f32(a: float32x2_t, b: float32x2_t, c: f32) -> float32x2_t {
vfms_f32(a, b, vdup_n_f32(c))
vfms_f32(a, b, vdup_n_f32_vfp4(c))
}

/// Floating-point fused Multiply-subtract to accumulator(vector)
#[inline]
#[target_feature(enable = "neon")]
#[cfg_attr(target_arch = "arm", target_feature(enable = "fp-armv8,v8"))]
#[cfg_attr(target_arch = "arm", target_feature(enable = "vfp4"))]
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vfms))]
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(fmls))]
pub unsafe fn vfmsq_n_f32(a: float32x4_t, b: float32x4_t, c: f32) -> float32x4_t {
vfmsq_f32(a, b, vdupq_n_f32(c))
vfmsq_f32(a, b, vdupq_n_f32_vfp4(c))
}

/// Subtract
Expand Down
26 changes: 26 additions & 0 deletions crates/core_arch/src/arm_shared/neon/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3786,6 +3786,19 @@ pub unsafe fn vdupq_n_f32(value: f32) -> float32x4_t {
float32x4_t(value, value, value, value)
}

/// Duplicate vector element to vector or scalar
///
/// Private vfp4 version used by FMA intriniscs because LLVM does
/// not inline the non-vfp4 version in vfp4 functions.
#[inline]
#[target_feature(enable = "neon")]
#[cfg_attr(target_arch = "arm", target_feature(enable = "vfp4"))]
#[cfg_attr(all(test, target_arch = "arm"), assert_instr("vdup.32"))]
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(dup))]
unsafe fn vdupq_n_f32_vfp4(value: f32) -> float32x4_t {
float32x4_t(value, value, value, value)
}

/// Duplicate vector element to vector or scalar
#[inline]
#[target_feature(enable = "neon")]
Expand Down Expand Up @@ -3896,6 +3909,19 @@ pub unsafe fn vdup_n_f32(value: f32) -> float32x2_t {
float32x2_t(value, value)
}

/// Duplicate vector element to vector or scalar
///
/// Private vfp4 version used by FMA intriniscs because LLVM does
/// not inline the non-vfp4 version in vfp4 functions.
#[inline]
#[target_feature(enable = "neon")]
#[cfg_attr(target_arch = "arm", target_feature(enable = "vfp4"))]
#[cfg_attr(all(test, target_arch = "arm"), assert_instr("vdup.32"))]
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(dup))]
unsafe fn vdup_n_f32_vfp4(value: f32) -> float32x2_t {
float32x2_t(value, value)
}

/// Duplicate vector element to vector or scalar
#[inline]
#[target_feature(enable = "neon")]
Expand Down
12 changes: 6 additions & 6 deletions crates/stdarch-gen/neon.spec
Original file line number Diff line number Diff line change
Expand Up @@ -2733,15 +2733,15 @@ generate float64x1_t
aarch64 = fmla
generate float64x2_t

target = fp-armv8
target = vfp4
arm = vfma
link-arm = llvm.fma._EXT_
generate float*_t

/// Floating-point fused Multiply-Add to accumulator(vector)
name = vfma
n-suffix
multi_fn = vfma-self-noext, a, b, {vdup-nself-noext, c}
multi_fn = vfma-self-noext, a, b, {vdup-nselfvfp4-noext, c}
a = 2.0, 3.0, 4.0, 5.0
b = 6.0, 4.0, 7.0, 8.0
c = 8.0
Expand All @@ -2752,7 +2752,7 @@ generate float64x1_t:float64x1_t:f64:float64x1_t
aarch64 = fmla
generate float64x2_t:float64x2_t:f64:float64x2_t

target = fp-armv8
target = vfp4
arm = vfma
generate float32x2_t:float32x2_t:f32:float32x2_t, float32x4_t:float32x4_t:f32:float32x4_t

Expand Down Expand Up @@ -2811,14 +2811,14 @@ generate float64x1_t
aarch64 = fmls
generate float64x2_t

target = fp-armv8
target = vfp4
arm = vfms
generate float*_t

/// Floating-point fused Multiply-subtract to accumulator(vector)
name = vfms
n-suffix
multi_fn = vfms-self-noext, a, b, {vdup-nself-noext, c}
multi_fn = vfms-self-noext, a, b, {vdup-nselfvfp4-noext, c}
a = 50.0, 35.0, 60.0, 69.0
b = 6.0, 4.0, 7.0, 8.0
c = 8.0
Expand All @@ -2829,7 +2829,7 @@ generate float64x1_t:float64x1_t:f64:float64x1_t
aarch64 = fmls
generate float64x2_t:float64x2_t:f64:float64x2_t

target = fp-armv8
target = vfp4
arm = vfms
generate float32x2_t:float32x2_t:f32:float32x2_t, float32x4_t:float32x4_t:f32:float32x4_t

Expand Down
17 changes: 16 additions & 1 deletion crates/stdarch-gen/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ enum Suffix {
enum TargetFeature {
Default,
ArmV7,
Vfp4,
FPArmV8,
AES,
}
Expand Down Expand Up @@ -980,6 +981,7 @@ fn gen_aarch64(
let current_target = match target {
Default => "neon",
ArmV7 => "v7",
Vfp4 => "vfp4",
FPArmV8 => "fp-armv8,v8",
AES => "neon,aes",
};
Expand Down Expand Up @@ -1120,6 +1122,7 @@ fn gen_aarch64(
out_t,
fixed,
None,
true,
));
}
calls
Expand Down Expand Up @@ -1630,12 +1633,14 @@ fn gen_arm(
let current_target_aarch64 = match target {
Default => "neon",
ArmV7 => "neon",
Vfp4 => "neon",
FPArmV8 => "neon",
AES => "neon,aes",
};
let current_target_arm = match target {
Default => "v7",
ArmV7 => "v7",
Vfp4 => "vfp4",
FPArmV8 => "fp-armv8,v8",
AES => "aes,v8",
};
Expand Down Expand Up @@ -1916,6 +1921,7 @@ fn gen_arm(
out_t,
fixed,
None,
false,
));
}
calls
Expand Down Expand Up @@ -2283,6 +2289,7 @@ fn get_call(
out_t: &str,
fixed: &Vec<String>,
n: Option<i32>,
aarch64: bool,
) -> String {
let params: Vec<_> = in_str.split(',').map(|v| v.trim().to_string()).collect();
assert!(params.len() > 0);
Expand Down Expand Up @@ -2450,7 +2457,8 @@ fn get_call(
in_t,
out_t,
fixed,
Some(i as i32)
Some(i as i32),
aarch64
)
);
call.push_str(&sub_match);
Expand Down Expand Up @@ -2499,6 +2507,7 @@ fn get_call(
out_t,
fixed,
n.clone(),
aarch64,
);
if !param_str.is_empty() {
param_str.push_str(", ");
Expand Down Expand Up @@ -2569,6 +2578,11 @@ fn get_call(
fn_name.push_str(type_to_suffix(in_t[1]));
} else if fn_format[1] == "nself" {
fn_name.push_str(type_to_n_suffix(in_t[1]));
} else if fn_format[1] == "nselfvfp4" {
fn_name.push_str(type_to_n_suffix(in_t[1]));
if !aarch64 {
fn_name.push_str("_vfp4");
}
} else if fn_format[1] == "out" {
fn_name.push_str(type_to_suffix(out_t));
} else if fn_format[1] == "in0" {
Expand Down Expand Up @@ -2854,6 +2868,7 @@ mod test {
target = match Some(String::from(&line[9..])) {
Some(input) => match input.as_str() {
"v7" => ArmV7,
"vfp4" => Vfp4,
"fp-armv8" => FPArmV8,
"aes" => AES,
_ => Default,
Expand Down