Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow vector types to be scalar aligned #1158

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 59 additions & 75 deletions crates/rustc_codegen_spirv/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use rspirv::spirv::{StorageClass, Word};
use rustc_data_structures::fx::FxHashMap;
use rustc_errors::ErrorGuaranteed;
use rustc_index::Idx;
use rustc_middle::query::Providers;
use rustc_middle::query::{Key, Providers};
use rustc_middle::ty::layout::{FnAbiOf, LayoutOf, TyAndLayout};
use rustc_middle::ty::GenericArgsRef;
use rustc_middle::ty::{
Expand All @@ -21,7 +21,7 @@ use rustc_span::DUMMY_SP;
use rustc_span::{Span, Symbol};
use rustc_target::abi::call::{ArgAbi, ArgAttributes, FnAbi, PassMode};
use rustc_target::abi::{
Abi, Align, FieldsShape, LayoutS, Primitive, Scalar, Size, TagEncoding, VariantIdx, Variants,
Abi, Align, FieldsShape, LayoutS, Primitive, Scalar, Size, VariantIdx, Variants,
};
use rustc_target::spec::abi::Abi as SpecAbi;
use std::cell::RefCell;
Expand Down Expand Up @@ -94,86 +94,66 @@ pub(crate) fn provide(providers: &mut Providers) {
Ok(readjust_fn_abi(tcx, result?))
};

// FIXME(eddyb) remove this by deriving `Clone` for `LayoutS` upstream.
// FIXME(eddyb) the `S` suffix is a naming antipattern, rename upstream.
fn clone_layout(layout: &LayoutS) -> LayoutS {
let LayoutS {
ref fields,
ref variants,
abi,
largest_niche,
align,
size,
max_repr_align,
unadjusted_abi_align,
} = *layout;
LayoutS {
fields: match *fields {
FieldsShape::Primitive => FieldsShape::Primitive,
FieldsShape::Union(count) => FieldsShape::Union(count),
FieldsShape::Array { stride, count } => FieldsShape::Array { stride, count },
FieldsShape::Arbitrary {
ref offsets,
ref memory_index,
} => FieldsShape::Arbitrary {
offsets: offsets.clone(),
memory_index: memory_index.clone(),
},
},
variants: match *variants {
Variants::Single { index } => Variants::Single { index },
Variants::Multiple {
tag,
ref tag_encoding,
tag_field,
ref variants,
} => Variants::Multiple {
tag,
tag_encoding: match *tag_encoding {
TagEncoding::Direct => TagEncoding::Direct,
TagEncoding::Niche {
untagged_variant,
ref niche_variants,
niche_start,
} => TagEncoding::Niche {
untagged_variant,
niche_variants: niche_variants.clone(),
niche_start,
},
},
tag_field,
variants: variants.clone(),
},
},
abi,
largest_niche,
align,
size,
max_repr_align,
unadjusted_abi_align,
}
}
providers.layout_of = |tcx, key| {
let TyAndLayout { ty, mut layout } =
(rustc_interface::DEFAULT_QUERY_PROVIDERS.layout_of)(tcx, key)?;
let orig = (rustc_interface::DEFAULT_QUERY_PROVIDERS.layout_of)(tcx, key)?;

#[allow(clippy::match_like_matches_macro)]
let hide_niche = match ty.kind() {
ty::Bool => true,
_ => false,
};
let mut modified: Option<LayoutS> = None;

if hide_niche {
layout = tcx.mk_layout(LayoutS {
largest_niche: None,
..clone_layout(layout.0.0)
});
}
hide_bool_niche(&tcx, orig, &mut modified);
adjust_vector_layout(&tcx, orig, &mut modified);

let ty = orig.ty;

let layout = if let Some(modified) = modified {
tcx.mk_layout(modified)
} else {
orig.layout
};
Ok(TyAndLayout { ty, layout })
};
}

fn hide_bool_niche<'tcx>(_cx: &TyCtxt<'tcx>, orig: TyAndLayout<'tcx>, modified: &mut Option<LayoutS>) {
#[allow(clippy::match_like_matches_macro)]
let hide_niche = match orig.ty.kind() {
ty::Bool => true,
_ => false,
};

if hide_niche {
let layout = modified.get_or_insert_with(|| orig.layout.0.0.clone());
layout.largest_niche = None;
}
}

fn adjust_vector_layout<'tcx>(cx: &TyCtxt<'tcx>, orig: TyAndLayout<'tcx>, modified: &mut Option<LayoutS>) {
// in spirv, in most cases vectors have align equal to their element type. in block resource
// contexts (storage, uniform, etc) the rules are sometimes more restrictive than that, but
// it's best to use the least common denominator here since if a layout get somputed that is
// not valid under the given vulkan environment then spirv-val will catch it and the user
// can manually adjust the block layout, whereas if we use the most restrictive rules then
// everywhere else suffers.
//
// it may be possible for us to figure out whether we're in a block context and what the
// currently enabled restrictions allow and compute this more intelligently, but I'm not sure
// whether that can break rustc assumptions about abi (what if the same type is used in
// multiple places with competing requirements?)
if let Abi::Vector { element, count } = orig.abi {
let layout = modified.get_or_insert_with(|| orig.layout.0.0.clone());
let element_align = element.align(cx).abi;
let element_size = element.size(cx);
let repr_align_align = orig.ty.ty_adt_id()
.and_then(|did| cx.repr_options_of_def(did).align);
if let Some(align) = repr_align_align {
layout.align.abi = align.max(element_align);
} else {
layout.align.abi = element_align;
}
layout.align.pref = layout.align.abi;
layout.size = (element_size * count).align_to(layout.align.abi);
}
}

/// If a struct contains a pointer to itself, even indirectly, then doing a naiive recursive walk
/// of the fields will result in an infinite loop. Because pointers are the only thing that are
/// allowed to be recursive, keep track of what pointers we've translated, or are currently in the
Expand Down Expand Up @@ -634,7 +614,8 @@ fn trans_aggregate<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>
}
}
FieldsShape::Array { stride, count } => {
let element_type = ty.field(cx, 0).spirv_type(span, cx);
let element_type_hl = ty.field(cx, 0);
let element_type = element_type_hl.spirv_type(span, cx);
if ty.is_unsized() {
// There's a potential for this array to be sized, but the element to be unsized, e.g. `[[u8]; 5]`.
// However, I think rust disallows all these cases, so assert this here.
Expand All @@ -653,6 +634,9 @@ fn trans_aggregate<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>
.sizeof(cx)
.expect("Unexpected unsized type in sized FieldsShape::Array")
.align_to(element_spv.alignof(cx));
if stride_spv != stride {
eprintln!("element type: {:?}\nelement spv type: {:?}", &element_type_hl, &element_spv);
}
assert_eq!(stride_spv, stride);
SpirvType::Array {
element: element_type,
Expand Down
12 changes: 3 additions & 9 deletions crates/rustc_codegen_spirv/src/spirv_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ impl SpirvType<'_> {
Self::Integer(width, _) | Self::Float(width) => Size::from_bits(width),
Self::Adt { size, .. } => size?,
Self::Vector { element, count } => {
cx.lookup_type(element).sizeof(cx)? * count.next_power_of_two() as u64
cx.lookup_type(element).sizeof(cx)? * count as u64
}
Self::Matrix { element, count } => cx.lookup_type(element).sizeof(cx)? * count as u64,
Self::Array { element, count } => {
Expand All @@ -357,14 +357,8 @@ impl SpirvType<'_> {
Self::Bool => Align::from_bytes(1).unwrap(),
Self::Integer(width, _) | Self::Float(width) => Align::from_bits(width as u64).unwrap(),
Self::Adt { align, .. } => align,
// Vectors have size==align
Self::Vector { .. } => Align::from_bytes(
self.sizeof(cx)
.expect("alignof: Vectors must be sized")
.bytes(),
)
.expect("alignof: Vectors must have power-of-2 size"),
Self::Array { element, .. }
Self::Vector { element, .. }
| Self::Array { element, .. }
| Self::RuntimeArray { element }
| Self::Matrix { element, .. } => cx.lookup_type(element).alignof(cx),
Self::Pointer { .. } => cx.tcx.data_layout.pointer_align.abi,
Expand Down
1 change: 1 addition & 0 deletions crates/rustc_codegen_spirv/src/target.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ impl SpirvTarget {
Target {
llvm_target: self.to_string().into(),
pointer_width: 32,
// note: vector size and align gets patched in abi.rs too
data_layout: "e-m:e-p:32:32:32-i64:64-n8:16:32:64".into(),
arch: ARCH.into(),
options: self.init_target_opts(),
Expand Down
Loading