Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add API to enable ext/capabilities, and remove default capabilities #630

Merged
merged 3 commits into from
May 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion crates/rustc_codegen_spirv/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use rustc_middle::ty::{
};
use rustc_span::def_id::DefId;
use rustc_span::Span;
use rustc_span::DUMMY_SP;
use rustc_target::abi::call::{CastTarget, FnAbi, PassMode, Reg, RegKind};
use rustc_target::abi::{
Abi, Align, FieldsShape, LayoutOf, Primitive, Scalar, Size, TagEncoding, VariantIdx, Variants,
Expand Down Expand Up @@ -323,11 +324,15 @@ impl<'tcx> ConvSpirvType<'tcx> for TyAndLayout<'tcx> {

fn trans_type_impl<'tcx>(
cx: &CodegenCx<'tcx>,
span: Span,
mut span: Span,
ty: TyAndLayout<'tcx>,
is_immediate: bool,
) -> Word {
if let TyKind::Adt(adt, substs) = *ty.ty.kind() {
if span == DUMMY_SP {
span = cx.tcx.def_span(adt.did);
}

let attrs = AggregatedSpirvAttributes::parse(cx, cx.tcx.get_attrs(adt.did));

if let Some(intrinsic_type_attr) = attrs.intrinsic_type.map(|attr| attr.value) {
Expand Down
25 changes: 18 additions & 7 deletions crates/rustc_codegen_spirv/src/builder/ext_inst.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use super::Builder;
use crate::builder_spirv::{SpirvValue, SpirvValueExt};
use crate::codegen_cx::CodegenCx;
use rspirv::spirv::{CLOp, GLOp, Word};
use rspirv::{dr::Operand, spirv::Capability};

Expand Down Expand Up @@ -40,14 +39,26 @@ impl ExtInst {
}
}

pub fn import_integer_functions_2_intel<'tcx>(&mut self, cx: &CodegenCx<'tcx>) {
pub fn require_integer_functions_2_intel<'a, 'tcx>(
&mut self,
bx: &Builder<'a, 'tcx>,
to_zombie: Word,
) {
if !self.integer_functions_2_intel {
assert!(!cx.target.is_kernel());
assert!(!bx.target.is_kernel());
self.integer_functions_2_intel = true;
cx.emit_global()
.extension("SPV_INTEL_shader_integer_functions2");
cx.emit_global()
.capability(Capability::IntegerFunctions2INTEL);
if !bx
.builder
.has_capability(Capability::IntegerFunctions2INTEL)
{
bx.zombie(to_zombie, "capability IntegerFunctions2INTEL is required");
}
if !bx
.builder
.has_extension("SPV_INTEL_shader_integer_functions2")
{
bx.zombie(to_zombie, "extension IntegerFunctions2INTEL is required");
}
}
}
}
Expand Down
26 changes: 14 additions & 12 deletions crates/rustc_codegen_spirv/src/builder/intrinsics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,34 +362,36 @@ impl<'a, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'a, 'tcx> {
if self.target.is_kernel() {
self.cl_op(CLOp::clz, ret_ty, [args[0].immediate()])
} else {
self.ext_inst
.borrow_mut()
.import_integer_functions_2_intel(self);
self.emit()
let result = self
.emit()
.u_count_leading_zeros_intel(
args[0].immediate().ty,
None,
args[0].immediate().def(self),
)
.unwrap()
.with_type(args[0].immediate().ty)
.unwrap();
self.ext_inst
.borrow_mut()
.require_integer_functions_2_intel(self, result);
result.with_type(args[0].immediate().ty)
}
}
sym::cttz | sym::cttz_nonzero => {
if self.target.is_kernel() {
self.cl_op(CLOp::ctz, ret_ty, [args[0].immediate()])
} else {
self.ext_inst
.borrow_mut()
.import_integer_functions_2_intel(self);
self.emit()
let result = self
.emit()
.u_count_trailing_zeros_intel(
args[0].immediate().ty,
None,
args[0].immediate().def(self),
)
.unwrap()
.with_type(args[0].immediate().ty)
.unwrap();
self.ext_inst
.borrow_mut()
.require_integer_functions_2_intel(self, result);
result.with_type(args[0].immediate().ty)
}
}

Expand Down
16 changes: 12 additions & 4 deletions crates/rustc_codegen_spirv/src/builder_spirv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -350,10 +350,6 @@ impl BuilderSpirv {

// The linker will always be ran on this module
builder.capability(Capability::Linkage);
builder.capability(Capability::Int8);
builder.capability(Capability::Int16);
builder.capability(Capability::Int64);
builder.capability(Capability::Float64);

let addressing_model = if target.is_kernel() {
builder.capability(Capability::Addresses);
Expand Down Expand Up @@ -425,6 +421,18 @@ impl BuilderSpirv {
})
}

pub fn has_extension(&self, extension: &str) -> bool {
self.builder
.borrow()
.module_ref()
.extensions
.iter()
.any(|inst| {
inst.class.opcode == Op::Extension
&& inst.operands[0].unwrap_literal_string() == extension
})
Comment on lines +429 to +433
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be O(1) as a set of interned strings, maybe open an issue about simple algorithmic inefficiencies like this and leave // FIXME(#1234) comments? At least that's what I would should do.

}

pub fn select_function_by_id(&self, id: Word) -> BuilderCursor {
let mut builder = self.builder.borrow_mut();
for (index, func) in builder.module_ref().functions.iter().enumerate() {
Expand Down
15 changes: 13 additions & 2 deletions crates/rustc_codegen_spirv/src/codegen_cx/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,22 @@ impl<'tcx> CodegenCx<'tcx> {
pub fn new(tcx: TyCtxt<'tcx>, codegen_unit: &'tcx CodegenUnit<'tcx>) -> Self {
let sym = Symbols::get();

let features = tcx
let mut feature_names = tcx
.sess
.target_features
.iter()
.filter(|s| *s != &sym.bindless)
.map(|s| s.as_str().parse())
.map(|s| s.as_str())
.collect::<Vec<_>>();

// target_features is a HashSet, not a Vec, so we need to sort to have deterministic
// compilation - otherwise, the order of capabilities in binaries depends on the iteration
// order of the hashset. Sort by the string, since that's easy.
feature_names.sort();
Comment on lines +107 to +110
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a FxHashSet, right? Its order should already be deterministic. But yes this is probably a good idea just to keep the output clean, if nothing else.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be deterministic! In practice, somehow, it isn't - I hit a different order when upgrading rust-toolchain, not sure if the actual cause of the different order was the different rustc binary, or if the different order was more common/subtle than that. (Random guess, do symbols like, use their pointer as a hash key, instead of the string contents? So if the symbols are allocated in a different order, the FxHashSet<Symbol> order is different?)


let features = feature_names
.into_iter()
.map(|s| s.parse())
.collect::<Result<_, String>>()
.unwrap_or_else(|error| {
tcx.sess.err(&error);
Expand Down Expand Up @@ -221,6 +231,7 @@ impl<'tcx> CodegenCx<'tcx> {
|| self.tcx.crate_name(LOCAL_CRATE) == self.sym.spirv_std
|| self.tcx.crate_name(LOCAL_CRATE) == self.sym.libm
|| self.tcx.crate_name(LOCAL_CRATE) == self.sym.num_traits
|| self.tcx.crate_name(LOCAL_CRATE) == self.sym.glam
}

// FIXME(eddyb) should this just be looking at `kernel_mode`?
Expand Down
105 changes: 0 additions & 105 deletions crates/rustc_codegen_spirv/src/linker/capability_computation.rs

This file was deleted.

51 changes: 48 additions & 3 deletions crates/rustc_codegen_spirv/src/linker/import_export_link.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ fn find_import_export_pairs_and_killed_params(
};
let import_type = *type_map.get(&import_id).expect("Unexpected op");
// Make sure the import/export pair has the same type.
check_tys_equal(sess, name, import_type, export_type)?;
check_tys_equal(sess, module, name, import_type, export_type)?;
rewrite_rules.insert(import_id, export_id);
if let Some(params) = fn_parameters.get(&import_id) {
for &param in params {
Expand Down Expand Up @@ -108,11 +108,56 @@ fn fn_parameters(module: &Module) -> FxHashMap<Word, Vec<Word>> {
.collect()
}

fn check_tys_equal(sess: &Session, name: &str, import_type: Word, export_type: Word) -> Result<()> {
fn check_tys_equal(
sess: &Session,
module: &Module,
name: &str,
import_type: Word,
export_type: Word,
) -> Result<()> {
if import_type == export_type {
Ok(())
} else {
sess.err(&format!("Types mismatch for {:?}", name));
// We have an error. It's okay to do something really slow now to report the error.
use std::fmt::Write;
let ty_defs = module
.types_global_values
.iter()
.filter_map(|inst| Some((inst.result_id?, inst)))
.collect();
fn format_ty(ty_defs: &FxHashMap<Word, &Instruction>, ty: Word, buf: &mut String) {
match ty_defs.get(&ty) {
Some(def) => {
write!(buf, "({}", def.class.opname).unwrap();
if let Some(result_type) = def.result_type {
write!(buf, " {}", result_type).unwrap();
}
for op in &def.operands {
if let Some(id) = op.id_ref_any() {
write!(buf, " ").unwrap();
format_ty(ty_defs, id, buf);
}
}
write!(buf, ")").unwrap();
}
None => write!(buf, "{}", ty).unwrap(),
}
}
fn format_ty_(ty_defs: &FxHashMap<Word, &Instruction>, ty: Word) -> String {
let mut result = String::new();
format_ty(ty_defs, ty, &mut result);
result
}
sess.struct_err(&format!("Types mismatch for {:?}", name))
.note(&format!(
"import type: {}",
format_ty_(&ty_defs, import_type)
))
.note(&format!(
"export type: {}",
format_ty_(&ty_defs, export_type)
))
.emit();
Err(ErrorReported)
}
}
Expand Down
6 changes: 0 additions & 6 deletions crates/rustc_codegen_spirv/src/linker/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#[cfg(test)]
mod test;

mod capability_computation;
mod dce;
mod duplicates;
mod import_export_link;
Expand Down Expand Up @@ -292,11 +291,6 @@ pub fn link(sess: &Session, mut inputs: Vec<Module>, opts: &Options) -> Result<L
let _timer = sess.timer("link_dce_2");
dce::dce(output);
}
{
let _timer = sess.timer("link_remove_extra_capabilities");
capability_computation::remove_extra_capabilities(output);
capability_computation::remove_extra_extensions(output);
}

if opts.compact_ids {
let _timer = sess.timer("link_compact_ids");
Expand Down
2 changes: 1 addition & 1 deletion crates/rustc_codegen_spirv/src/linker/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ fn type_mismatch() {
let result = assemble_and_link(&[&a, &b]);
assert_eq!(
result.err().as_deref(),
Some("error: Types mismatch for \"foo\"")
Some("error: Types mismatch for \"foo\"\n |\n = note: import type: (TypeFloat)\n = note: export type: (TypeInt)")
);
}

Expand Down
2 changes: 2 additions & 0 deletions crates/rustc_codegen_spirv/src/symbols.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub struct Symbols {
pub spirv_std: Symbol,
pub libm: Symbol,
pub num_traits: Symbol,
pub glam: Symbol,
pub entry_point_name: Symbol,
descriptor_set: Symbol,
binding: Symbol,
Expand Down Expand Up @@ -374,6 +375,7 @@ impl Symbols {
spirv_std: Symbol::intern("spirv_std"),
libm: Symbol::intern("libm"),
num_traits: Symbol::intern("num_traits"),
glam: Symbol::intern("glam"),
descriptor_set: Symbol::intern("descriptor_set"),
binding: Symbol::intern("binding"),
image_type: Symbol::intern("image_type"),
Expand Down
Loading