Skip to content

Commit

Permalink
implement SIMD sqrt and fma
Browse files Browse the repository at this point in the history
  • Loading branch information
RalfJung committed Mar 17, 2022
1 parent a9a0d0e commit 4fd5dca
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 1 deletion.
37 changes: 36 additions & 1 deletion src/shims/intrinsics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,8 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
| "simd_ceil"
| "simd_floor"
| "simd_round"
| "simd_trunc" => {
| "simd_trunc"
| "simd_fsqrt" => {
let &[ref op] = check_arg_count(args)?;
let (op, op_len) = this.operand_to_simd(op)?;
let (dest, dest_len) = this.place_to_simd(dest)?;
Expand All @@ -342,6 +343,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
Floor,
Round,
Trunc,
Sqrt,
}
#[derive(Copy, Clone)]
enum Op {
Expand All @@ -356,6 +358,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
"simd_floor" => Op::HostOp(HostFloatOp::Floor),
"simd_round" => Op::HostOp(HostFloatOp::Round),
"simd_trunc" => Op::HostOp(HostFloatOp::Trunc),
"simd_fsqrt" => Op::HostOp(HostFloatOp::Sqrt),
_ => unreachable!(),
};

Expand Down Expand Up @@ -388,6 +391,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
HostFloatOp::Floor => f.floor(),
HostFloatOp::Round => f.round(),
HostFloatOp::Trunc => f.trunc(),
HostFloatOp::Sqrt => f.sqrt(),
};
Scalar::from_u32(res.to_bits())
}
Expand All @@ -398,6 +402,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
HostFloatOp::Floor => f.floor(),
HostFloatOp::Round => f.round(),
HostFloatOp::Trunc => f.trunc(),
HostFloatOp::Sqrt => f.sqrt(),
};
Scalar::from_u64(res.to_bits())
}
Expand Down Expand Up @@ -508,6 +513,36 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
this.write_scalar(val, &dest.into())?;
}
}
"simd_fma" => {
let &[ref a, ref b, ref c] = check_arg_count(args)?;
let (a, a_len) = this.operand_to_simd(a)?;
let (b, b_len) = this.operand_to_simd(b)?;
let (c, c_len) = this.operand_to_simd(c)?;
let (dest, dest_len) = this.place_to_simd(dest)?;

assert_eq!(dest_len, a_len);
assert_eq!(dest_len, b_len);
assert_eq!(dest_len, c_len);

for i in 0..dest_len {
let a = this.read_immediate(&this.mplace_index(&a, i)?.into())?.to_scalar()?;
let b = this.read_immediate(&this.mplace_index(&b, i)?.into())?.to_scalar()?;
let c = this.read_immediate(&this.mplace_index(&c, i)?.into())?.to_scalar()?;
let dest = this.mplace_index(&dest, i)?;

// Works for f32 and f64.
let ty::Float(float_ty) = dest.layout.ty.kind() else {
bug!("{} operand is not a float", intrinsic_name)
};
let val = match float_ty {
FloatTy::F32 =>
Scalar::from_f32(a.to_f32()?.mul_add(b.to_f32()?, c.to_f32()?).value),
FloatTy::F64 =>
Scalar::from_f64(a.to_f64()?.mul_add(b.to_f64()?, c.to_f64()?).value),
};
this.write_scalar(val, &dest.into())?;
}
}
#[rustfmt::skip]
| "simd_reduce_and"
| "simd_reduce_or"
Expand Down
10 changes: 10 additions & 0 deletions tests/run-pass/portable-simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ fn simd_ops_f32() {
assert_eq!(a.max(b * f32x4::splat(4.0)), f32x4::from_array([10.0, 10.0, 12.0, 10.0]));
assert_eq!(a.min(b * f32x4::splat(4.0)), f32x4::from_array([4.0, 8.0, 10.0, -16.0]));

assert_eq!(a.mul_add(b, a), (a*b)+a);
assert_eq!(b.mul_add(b, a), (b*b)+a);
assert_eq!((a*a).sqrt(), a);
assert_eq!((b*b).sqrt(), b.abs());

assert_eq!(a.lanes_eq(f32x4::splat(5.0) * b), Mask::from_array([false, true, false, false]));
assert_eq!(a.lanes_ne(f32x4::splat(5.0) * b), Mask::from_array([true, false, true, true]));
assert_eq!(a.lanes_le(f32x4::splat(5.0) * b), Mask::from_array([false, true, true, false]));
Expand Down Expand Up @@ -59,6 +64,11 @@ fn simd_ops_f64() {
assert_eq!(a.max(b * f64x4::splat(4.0)), f64x4::from_array([10.0, 10.0, 12.0, 10.0]));
assert_eq!(a.min(b * f64x4::splat(4.0)), f64x4::from_array([4.0, 8.0, 10.0, -16.0]));

assert_eq!(a.mul_add(b, a), (a*b)+a);
assert_eq!(b.mul_add(b, a), (b*b)+a);
assert_eq!((a*a).sqrt(), a);
assert_eq!((b*b).sqrt(), b.abs());

assert_eq!(a.lanes_eq(f64x4::splat(5.0) * b), Mask::from_array([false, true, false, false]));
assert_eq!(a.lanes_ne(f64x4::splat(5.0) * b), Mask::from_array([true, false, true, true]));
assert_eq!(a.lanes_le(f64x4::splat(5.0) * b), Mask::from_array([false, true, true, false]));
Expand Down

0 comments on commit 4fd5dca

Please sign in to comment.