Skip to content

Commit

Permalink
add support for zero-initializing workgroup memory
Browse files Browse the repository at this point in the history
  • Loading branch information
teoxoy committed Jan 9, 2023
1 parent 78f4fef commit 5f2f940
Show file tree
Hide file tree
Showing 45 changed files with 1,476 additions and 85 deletions.
1 change: 1 addition & 0 deletions benches/criterion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ fn backends(c: &mut Criterion) {
version: naga::back::glsl::Version::new_gles(320),
writer_flags: naga::back::glsl::WriterFlags::empty(),
binding_map: Default::default(),
zero_initialize_workgroup_memory: true,
};
for &(ref module, ref info) in inputs.iter() {
for ep in module.entry_points.iter() {
Expand Down
49 changes: 47 additions & 2 deletions src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ pub struct Options {
pub writer_flags: WriterFlags,
/// Map of resources association to binding locations.
pub binding_map: BindingMap,
/// Should workgroup variables be zero initialized (by polyfilling)?
pub zero_initialize_workgroup_memory: bool,
}

impl Default for Options {
Expand All @@ -236,6 +238,7 @@ impl Default for Options {
version: Version::new_gles(310),
writer_flags: WriterFlags::ADJUST_COORDINATE_SPACE,
binding_map: BindingMap::default(),
zero_initialize_workgroup_memory: true,
}
}
}
Expand Down Expand Up @@ -1399,6 +1402,12 @@ impl<'a, W: Write> Writer<'a, W> {
// Close the parentheses and open braces to start the function body
writeln!(self.out, ") {{")?;

if self.options.zero_initialize_workgroup_memory
&& ctx.ty.is_compute_entry_point(self.module)
{
self.write_workgroup_variables_initialization(&ctx)?;
}

// Compose the function arguments from globals, in case of an entry point.
if let back::FunctionType::EntryPoint(ep_index) = ctx.ty {
let stage = self.module.entry_points[ep_index as usize].stage;
Expand Down Expand Up @@ -1487,6 +1496,42 @@ impl<'a, W: Write> Writer<'a, W> {
Ok(())
}

fn write_workgroup_variables_initialization(
&mut self,
ctx: &back::FunctionCtx,
) -> BackendResult {
let mut vars = self
.module
.global_variables
.iter()
.filter(|&(handle, var)| {
!ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
})
.peekable();

if vars.peek().is_some() {
let level = back::Level(1);

writeln!(
self.out,
"{}if (gl_GlobalInvocationID == uvec3(0u)) {{",
level
)?;

for (handle, var) in vars {
let name = &self.names[&NameKey::GlobalVariable(handle)];
write!(self.out, "{}{} = ", level.next(), name)?;
self.write_zero_init_value(var.ty)?;
writeln!(self.out, ";")?;
}

writeln!(self.out, "{}}}", level)?;
self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
}

Ok(())
}

/// Helper method that writes a list of comma separated `T` with a writer function `F`
///
/// The writer function `F` receives a mutable reference to `self` that if needed won't cause
Expand Down Expand Up @@ -3515,7 +3560,7 @@ impl<'a, W: Write> Writer<'a, W> {
fn write_zero_init_value(&mut self, ty: Handle<crate::Type>) -> BackendResult {
let inner = &self.module.types[ty].inner;
match *inner {
TypeInner::Scalar { kind, .. } => {
TypeInner::Scalar { kind, .. } | TypeInner::Atomic { kind, .. } => {
self.write_zero_init_scalar(kind)?;
}
TypeInner::Vector { kind, .. } => {
Expand Down Expand Up @@ -3560,7 +3605,7 @@ impl<'a, W: Write> Writer<'a, W> {
}
write!(self.out, ")")?;
}
_ => {} // TODO:
_ => unreachable!(),
}

Ok(())
Expand Down
3 changes: 3 additions & 0 deletions src/back/hlsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ pub struct Options {
pub special_constants_binding: Option<BindTarget>,
/// Bind target of the push constant buffer
pub push_constants_target: Option<BindTarget>,
/// Should workgroup variables be zero initialized (by polyfilling)?
pub zero_initialize_workgroup_memory: bool,
}

impl Default for Options {
Expand All @@ -201,6 +203,7 @@ impl Default for Options {
fake_missing_bindings: true,
special_constants_binding: None,
push_constants_target: None,
zero_initialize_workgroup_memory: true,
}
}
}
Expand Down
57 changes: 57 additions & 0 deletions src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
// Write function name
write!(self.out, " {}(", name)?;

let need_workgroup_variables_initialization =
self.need_workgroup_variables_initialization(func_ctx, module);

// Write function arguments for non entry point functions
match func_ctx.ty {
back::FunctionType::Function(handle) => {
Expand Down Expand Up @@ -1129,6 +1132,16 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.write_semantic(binding, Some((stage, Io::Input)))?;
}
}

if need_workgroup_variables_initialization && stage == ShaderStage::Compute {
if !func.arguments.is_empty() {
write!(self.out, ", ")?;
}
write!(
self.out,
"uint3 __global_invocation_id : SV_DispatchThreadID"
)?;
}
}
}
}
Expand All @@ -1151,6 +1164,10 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
writeln!(self.out)?;
writeln!(self.out, "{{")?;

if need_workgroup_variables_initialization {
self.write_workgroup_variables_initialization(func_ctx, module)?;
}

if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
self.write_ep_arguments_initialization(module, func, index)?;
}
Expand Down Expand Up @@ -1204,6 +1221,46 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Ok(())
}

fn need_workgroup_variables_initialization(
&mut self,
func_ctx: &back::FunctionCtx,
module: &Module,
) -> bool {
self.options.zero_initialize_workgroup_memory
&& func_ctx.ty.is_compute_entry_point(module)
&& module.global_variables.iter().any(|(handle, var)| {
!func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
})
}

fn write_workgroup_variables_initialization(
&mut self,
func_ctx: &back::FunctionCtx,
module: &Module,
) -> BackendResult {
let level = back::Level(1);

writeln!(
self.out,
"{}if (all(__global_invocation_id == uint3(0u, 0u, 0u))) {{",
level
)?;

let vars = module.global_variables.iter().filter(|&(handle, var)| {
!func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
});

for (handle, var) in vars {
let name = &self.names[&NameKey::GlobalVariable(handle)];
write!(self.out, "{}{} = ", level.next(), name)?;
self.write_default_init(module, var.ty)?;
writeln!(self.out, ";")?;
}

writeln!(self.out, "{}}}", level)?;
self.write_barrier(crate::Barrier::WORK_GROUP, level)
}

/// Helper method used to write statements
///
/// # Notes
Expand Down
11 changes: 11 additions & 0 deletions src/back/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@ enum FunctionType {
EntryPoint(crate::proc::EntryPointIndex),
}

impl FunctionType {
fn is_compute_entry_point(&self, module: &crate::Module) -> bool {
match *self {
FunctionType::EntryPoint(index) => {
module.entry_points[index as usize].stage == crate::ShaderStage::Compute
}
_ => false,
}
}
}

/// Helper structure that stores data needed when writing the function
struct FunctionCtx<'a> {
/// The current function being written
Expand Down
3 changes: 3 additions & 0 deletions src/back/msl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ pub struct Options {
/// Bounds checking policies.
#[cfg_attr(feature = "deserialize", serde(default))]
pub bounds_check_policies: index::BoundsCheckPolicies,
/// Should workgroup variables be zero initialized (by polyfilling)?
pub zero_initialize_workgroup_memory: bool,
}

impl Default for Options {
Expand All @@ -220,6 +222,7 @@ impl Default for Options {
spirv_cross_compatibility: false,
fake_missing_bindings: true,
bounds_check_policies: index::BoundsCheckPolicies::default(),
zero_initialize_workgroup_memory: true,
}
}
}
Expand Down
Loading

0 comments on commit 5f2f940

Please sign in to comment.