Skip to content

Commit

Permalink
Add f16 inline ASM support for RISC-V
Browse files Browse the repository at this point in the history
  • Loading branch information
beetrees committed Jun 18, 2024
1 parent 92af831 commit a86b38e
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 29 deletions.
111 changes: 87 additions & 24 deletions compiler/rustc_codegen_llvm/src/asm.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::slice;

use crate::attributes;
use crate::builder::Builder;
use crate::common::Funclet;
Expand Down Expand Up @@ -64,7 +66,7 @@ impl<'ll, 'tcx> AsmBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
let mut layout = None;
let ty = if let Some(ref place) = place {
layout = Some(&place.layout);
llvm_fixup_output_type(self.cx, reg.reg_class(), &place.layout)
llvm_fixup_output_type(self, reg.reg_class(), &place.layout)
} else if matches!(
reg.reg_class(),
InlineAsmRegClass::X86(
Expand Down Expand Up @@ -112,7 +114,7 @@ impl<'ll, 'tcx> AsmBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
// so we just use the type of the input.
&in_value.layout
};
let ty = llvm_fixup_output_type(self.cx, reg.reg_class(), layout);
let ty = llvm_fixup_output_type(self, reg.reg_class(), layout);
output_types.push(ty);
op_idx.insert(idx, constraints.len());
let prefix = if late { "=" } else { "=&" };
Expand Down Expand Up @@ -913,6 +915,46 @@ fn llvm_asm_scalar_type<'ll>(cx: &CodegenCx<'ll, '_>, scalar: Scalar) -> &'ll Ty
}
}

fn function_target_features<'ll>(builder: &Builder<'_, 'll, '_>) -> impl Iterator<Item = &'ll str> {
let llfn = builder.llfn();
let key = "target-features";
let attr = unsafe {
llvm::LLVMGetStringAttributeAtIndex(
llfn,
llvm::AttributePlace::Function.as_uint(),
key.as_ptr().cast(),
key.len().try_into().unwrap(),
)
};
let Some(attr) = attr else {
return "".split(',');
};
let value = unsafe {
let mut length = 0;
let ptr = llvm::LLVMGetStringAttributeValue(attr, &mut length);
slice::from_raw_parts(ptr.cast(), length.try_into().unwrap())
};
let Ok(value) = std::str::from_utf8(value) else {
return "".split(',');
};
value.split(',')
}

fn is_zfhmin_enabled(builder: &Builder<'_, '_, '_>) -> bool {
let mut zfhmin_enabled = false;
let mut zfh_enabled = false;
for feature in function_target_features(builder) {
match feature {
"+zfhmin" => zfhmin_enabled = true,
"-zfhmin" => zfhmin_enabled = false,
"+zfh" => zfh_enabled = true,
"-zfh" => zfh_enabled = false,
_ => {}
}
}
zfhmin_enabled || zfh_enabled
}

/// Fix up an input value to work around LLVM bugs.
fn llvm_fixup_input<'ll, 'tcx>(
bx: &mut Builder<'_, 'll, 'tcx>,
Expand Down Expand Up @@ -1029,6 +1071,15 @@ fn llvm_fixup_input<'ll, 'tcx>(
_ => value,
}
}
(InlineAsmRegClass::RiscV(RiscVInlineAsmRegClass::freg), Abi::Scalar(s))
if s.primitive() == Primitive::Float(Float::F16) && !is_zfhmin_enabled(bx) =>
{
// Smaller floats are always "NaN-boxed" inside larger floats on RISC-V.
let value = bx.bitcast(value, bx.type_i16());
let value = bx.zext(value, bx.type_i32());
let value = bx.or(value, bx.const_u32(0xFFFF_0000));
bx.bitcast(value, bx.type_f32())
}
_ => value,
}
}
Expand Down Expand Up @@ -1140,56 +1191,63 @@ fn llvm_fixup_output<'ll, 'tcx>(
_ => value,
}
}
(InlineAsmRegClass::RiscV(RiscVInlineAsmRegClass::freg), Abi::Scalar(s))
if s.primitive() == Primitive::Float(Float::F16) && !is_zfhmin_enabled(bx) =>
{
let value = bx.bitcast(value, bx.type_i32());
let value = bx.trunc(value, bx.type_i16());
bx.bitcast(value, bx.type_f16())
}
_ => value,
}
}

/// Output type to use for llvm_fixup_output.
fn llvm_fixup_output_type<'ll, 'tcx>(
cx: &CodegenCx<'ll, 'tcx>,
bx: &Builder<'_, 'll, 'tcx>,
reg: InlineAsmRegClass,
layout: &TyAndLayout<'tcx>,
) -> &'ll Type {
match (reg, layout.abi) {
(InlineAsmRegClass::AArch64(AArch64InlineAsmRegClass::vreg), Abi::Scalar(s)) => {
if let Primitive::Int(Integer::I8, _) = s.primitive() {
cx.type_vector(cx.type_i8(), 8)
bx.type_vector(bx.type_i8(), 8)
} else {
layout.llvm_type(cx)
layout.llvm_type(bx)
}
}
(InlineAsmRegClass::AArch64(AArch64InlineAsmRegClass::vreg_low16), Abi::Scalar(s)) => {
let elem_ty = llvm_asm_scalar_type(cx, s);
let elem_ty = llvm_asm_scalar_type(bx, s);
let count = 16 / layout.size.bytes();
cx.type_vector(elem_ty, count)
bx.type_vector(elem_ty, count)
}
(
InlineAsmRegClass::AArch64(AArch64InlineAsmRegClass::vreg_low16),
Abi::Vector { element, count },
) if layout.size.bytes() == 8 => {
let elem_ty = llvm_asm_scalar_type(cx, element);
cx.type_vector(elem_ty, count * 2)
let elem_ty = llvm_asm_scalar_type(bx, element);
bx.type_vector(elem_ty, count * 2)
}
(InlineAsmRegClass::X86(X86InlineAsmRegClass::reg_abcd), Abi::Scalar(s))
if s.primitive() == Primitive::Float(Float::F64) =>
{
cx.type_i64()
bx.type_i64()
}
(
InlineAsmRegClass::X86(X86InlineAsmRegClass::xmm_reg | X86InlineAsmRegClass::zmm_reg),
Abi::Vector { .. },
) if layout.size.bytes() == 64 => cx.type_vector(cx.type_f64(), 8),
) if layout.size.bytes() == 64 => bx.type_vector(bx.type_f64(), 8),
(
InlineAsmRegClass::X86(
X86InlineAsmRegClass::xmm_reg
| X86InlineAsmRegClass::ymm_reg
| X86InlineAsmRegClass::zmm_reg,
),
Abi::Scalar(s),
) if cx.sess().asm_arch == Some(InlineAsmArch::X86)
) if bx.sess().asm_arch == Some(InlineAsmArch::X86)
&& s.primitive() == Primitive::Float(Float::F128) =>
{
cx.type_vector(cx.type_i32(), 4)
bx.type_vector(bx.type_i32(), 4)
}
(
InlineAsmRegClass::X86(
Expand All @@ -1198,7 +1256,7 @@ fn llvm_fixup_output_type<'ll, 'tcx>(
| X86InlineAsmRegClass::zmm_reg,
),
Abi::Scalar(s),
) if s.primitive() == Primitive::Float(Float::F16) => cx.type_vector(cx.type_i16(), 8),
) if s.primitive() == Primitive::Float(Float::F16) => bx.type_vector(bx.type_i16(), 8),
(
InlineAsmRegClass::X86(
X86InlineAsmRegClass::xmm_reg
Expand All @@ -1207,16 +1265,16 @@ fn llvm_fixup_output_type<'ll, 'tcx>(
),
Abi::Vector { element, count: count @ (8 | 16) },
) if element.primitive() == Primitive::Float(Float::F16) => {
cx.type_vector(cx.type_i16(), count)
bx.type_vector(bx.type_i16(), count)
}
(
InlineAsmRegClass::Arm(ArmInlineAsmRegClass::sreg | ArmInlineAsmRegClass::sreg_low16),
Abi::Scalar(s),
) => {
if let Primitive::Int(Integer::I32, _) = s.primitive() {
cx.type_f32()
bx.type_f32()
} else {
layout.llvm_type(cx)
layout.llvm_type(bx)
}
}
(
Expand All @@ -1228,20 +1286,25 @@ fn llvm_fixup_output_type<'ll, 'tcx>(
Abi::Scalar(s),
) => {
if let Primitive::Int(Integer::I64, _) = s.primitive() {
cx.type_f64()
bx.type_f64()
} else {
layout.llvm_type(cx)
layout.llvm_type(bx)
}
}
(InlineAsmRegClass::Mips(MipsInlineAsmRegClass::reg), Abi::Scalar(s)) => {
match s.primitive() {
// MIPS only supports register-length arithmetics.
Primitive::Int(Integer::I8 | Integer::I16, _) => cx.type_i32(),
Primitive::Float(Float::F32) => cx.type_i32(),
Primitive::Float(Float::F64) => cx.type_i64(),
_ => layout.llvm_type(cx),
Primitive::Int(Integer::I8 | Integer::I16, _) => bx.type_i32(),
Primitive::Float(Float::F32) => bx.type_i32(),
Primitive::Float(Float::F64) => bx.type_i64(),
_ => layout.llvm_type(bx),
}
}
_ => layout.llvm_type(cx),
(InlineAsmRegClass::RiscV(RiscVInlineAsmRegClass::freg), Abi::Scalar(s))
if s.primitive() == Primitive::Float(Float::F16) && !is_zfhmin_enabled(bx) =>
{
bx.type_f32()
}
_ => layout.llvm_type(bx),
}
}
7 changes: 7 additions & 0 deletions compiler/rustc_codegen_llvm/src/llvm/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1000,6 +1000,13 @@ extern "C" {
Value: *const c_char,
ValueLen: c_uint,
) -> &Attribute;
pub fn LLVMGetStringAttributeAtIndex(
F: &Value,
Idx: c_uint,
K: *const c_char,
KLen: c_uint,
) -> Option<&Attribute>;
pub fn LLVMGetStringAttributeValue(A: &Attribute, Length: &mut c_uint) -> *const c_char;

// Operations on functions
pub fn LLVMSetFunctionCallConv(Fn: &Value, CC: c_uint);
Expand Down
7 changes: 4 additions & 3 deletions compiler/rustc_target/src/asm/riscv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,13 @@ impl RiscVInlineAsmRegClass {
match self {
Self::reg => {
if arch == InlineAsmArch::RiscV64 {
types! { _: I8, I16, I32, I64, F32, F64; }
types! { _: I8, I16, I32, I64, F16, F32, F64; }
} else {
types! { _: I8, I16, I32, F32; }
types! { _: I8, I16, I32, F16, F32; }
}
}
Self::freg => types! { f: F32; d: F64; },
// FIXME(f16_f128): Add `q: F128;` once LLVM support the `Q` extension.
Self::freg => types! { f: F16, F32; d: F64; },
Self::vreg => &[],
}
}
Expand Down
50 changes: 48 additions & 2 deletions tests/assembly/asm/riscv-types.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,33 @@
//@ revisions: riscv64 riscv32
//@ revisions: riscv64 riscv32 riscv64-zfhmin riscv32-zfhmin riscv64-zfh riscv32-zfh
//@ assembly-output: emit-asm

//@[riscv64] compile-flags: --target riscv64imac-unknown-none-elf
//@[riscv64] needs-llvm-components: riscv

//@[riscv32] compile-flags: --target riscv32imac-unknown-none-elf
//@[riscv32] needs-llvm-components: riscv

//@[riscv64-zfhmin] compile-flags: --target riscv64imac-unknown-none-elf --cfg riscv64
//@[riscv64-zfhmin] needs-llvm-components: riscv
//@[riscv64-zfhmin] compile-flags: -C target-feature=+zfhmin
//@[riscv64-zfhmin] filecheck-flags: --check-prefix riscv64

//@[riscv32-zfhmin] compile-flags: --target riscv32imac-unknown-none-elf
//@[riscv32-zfhmin] needs-llvm-components: riscv
//@[riscv32-zfhmin] compile-flags: -C target-feature=+zfhmin

//@[riscv64-zfh] compile-flags: --target riscv64imac-unknown-none-elf --cfg riscv64
//@[riscv64-zfh] needs-llvm-components: riscv
//@[riscv64-zfh] compile-flags: -C target-feature=+zfh
//@[riscv64-zfh] filecheck-flags: --check-prefix riscv64

//@[riscv32-zfh] compile-flags: --target riscv32imac-unknown-none-elf
//@[riscv32-zfh] needs-llvm-components: riscv
//@[riscv32-zfh] compile-flags: -C target-feature=+zfh

//@ compile-flags: -C target-feature=+d

#![feature(no_core, lang_items, rustc_attrs)]
#![feature(no_core, lang_items, rustc_attrs, f16)]
#![crate_type = "rlib"]
#![no_core]
#![allow(asm_sub_register)]
Expand All @@ -33,6 +54,7 @@ type ptr = *mut u8;

impl Copy for i8 {}
impl Copy for i16 {}
impl Copy for f16 {}
impl Copy for i32 {}
impl Copy for f32 {}
impl Copy for i64 {}
Expand Down Expand Up @@ -103,6 +125,12 @@ macro_rules! check_reg {
// CHECK: #NO_APP
check!(reg_i8 i8 reg "mv");

// CHECK-LABEL: reg_f16:
// CHECK: #APP
// CHECK: mv {{[a-z0-9]+}}, {{[a-z0-9]+}}
// CHECK: #NO_APP
check!(reg_f16 f16 reg "mv");

// CHECK-LABEL: reg_i16:
// CHECK: #APP
// CHECK: mv {{[a-z0-9]+}}, {{[a-z0-9]+}}
Expand Down Expand Up @@ -141,6 +169,12 @@ check!(reg_f64 f64 reg "mv");
// CHECK: #NO_APP
check!(reg_ptr ptr reg "mv");

// CHECK-LABEL: freg_f16:
// CHECK: #APP
// CHECK: fmv.s f{{[a-z0-9]+}}, f{{[a-z0-9]+}}
// CHECK: #NO_APP
check!(freg_f16 f16 freg "fmv.s");

// CHECK-LABEL: freg_f32:
// CHECK: #APP
// CHECK: fmv.s f{{[a-z0-9]+}}, f{{[a-z0-9]+}}
Expand All @@ -165,6 +199,12 @@ check_reg!(a0_i8 i8 "a0" "mv");
// CHECK: #NO_APP
check_reg!(a0_i16 i16 "a0" "mv");

// CHECK-LABEL: a0_f16:
// CHECK: #APP
// CHECK: mv a0, a0
// CHECK: #NO_APP
check_reg!(a0_f16 f16 "a0" "mv");

// CHECK-LABEL: a0_i32:
// CHECK: #APP
// CHECK: mv a0, a0
Expand Down Expand Up @@ -197,6 +237,12 @@ check_reg!(a0_f64 f64 "a0" "mv");
// CHECK: #NO_APP
check_reg!(a0_ptr ptr "a0" "mv");

// CHECK-LABEL: fa0_f16:
// CHECK: #APP
// CHECK: fmv.s fa0, fa0
// CHECK: #NO_APP
check_reg!(fa0_f16 f16 "fa0" "fmv.s");

// CHECK-LABEL: fa0_f32:
// CHECK: #APP
// CHECK: fmv.s fa0, fa0
Expand Down

0 comments on commit a86b38e

Please sign in to comment.