From 93c1a645ec358a9b15765d4c09d43e4ee3de3fda Mon Sep 17 00:00:00 2001 From: Elabajaba Date: Tue, 28 May 2024 06:42:11 -0400 Subject: [PATCH] Naga 0.20 (#87) Pipeline override constants will be fully implemented in a followup PR. They're theoretically working, but they need tests and also need a function to map the unmangled names to the mangled names. Otherwise this appears to work. I've tested it with bevy and everything seemed to work. --------- Co-authored-by: robtfm <50659922+robtfm@users.noreply.github.com> --- Cargo.toml | 4 +- examples/pbr_compose_test.rs | 4 +- src/compose/error.rs | 5 +- src/compose/mod.rs | 83 ++++++++-- src/compose/test.rs | 155 +++--------------- src/compose/tests/expected/glsl_call_wgsl.txt | 2 +- src/derive.rs | 124 ++++++++++++-- src/prune/mod.rs | 60 ++++++- src/redirect.rs | 5 +- 9 files changed, 266 insertions(+), 176 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ca8cf7f..4f747ea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ prune = [] allow_deprecated = [] [dependencies] -naga = { version = "0.19.0", features = ["wgsl-in", "wgsl-out", "clone"] } +naga = { version = "0.20", features = ["wgsl-in", "wgsl-out"] } tracing = "0.1" regex = "1.8" regex-syntax = "0.8" @@ -31,6 +31,6 @@ once_cell = "1.17.0" indexmap = "2" [dev-dependencies] -wgpu = { version = "0.19.0", features = ["naga-ir"] } +wgpu = { version = "0.20", features = ["naga-ir"] } futures-lite = "1" tracing-subscriber = { version = "0.3", features = ["std", "fmt"] } diff --git a/examples/pbr_compose_test.rs b/examples/pbr_compose_test.rs index 3adac17..a8aed5c 100644 --- a/examples/pbr_compose_test.rs +++ b/examples/pbr_compose_test.rs @@ -96,7 +96,7 @@ fn test_compose_full() -> Result { }) { Ok(module) => { // println!("shader: {:#?}", module); - // let info = naga::valid::Validator::new(naga::valid::ValidationFlags::all(), naga::valid::Capabilities::default()).validate(&module).unwrap(); + // let info = composer.create_validator().validate(&module).unwrap(); // let _wgsl = naga::back::wgsl::write_string(&module, &info, naga::back::wgsl::WriterFlags::EXPLICIT_TYPES).unwrap(); // println!("wgsl: \n\n{}", wgsl); Ok(module) @@ -120,7 +120,7 @@ fn test_compose_final_module(n: usize, composer: &mut Composer) { }) { Ok(module) => { // println!("shader: {:#?}", module); - // let info = naga::valid::Validator::new(naga::valid::ValidationFlags::all(), naga::valid::Capabilities::default()).validate(&module).unwrap(); + // let info = composer.create_validator().validate(&module).unwrap(); // let _wgsl = naga::back::wgsl::write_string(&module, &info, naga::back::wgsl::WriterFlags::EXPLICIT_TYPES).unwrap(); // println!("wgsl: \n\n{}", wgsl); Ok(module) diff --git a/src/compose/error.rs b/src/compose/error.rs index 2d1fc6e..171e3f3 100644 --- a/src/compose/error.rs +++ b/src/compose/error.rs @@ -78,7 +78,7 @@ pub enum ComposerErrorInner { WgslParseError(naga::front::wgsl::ParseError), #[cfg(feature = "glsl")] #[error("{0:?}")] - GlslParseError(Vec), + GlslParseError(naga::front::glsl::ParseError), #[error("naga_oil bug, please file a report: failed to convert imported module IR back into WGSL for use with WGSL shaders: {0}")] WgslBackError(naga::back::wgsl::Error), #[cfg(feature = "glsl")] @@ -226,7 +226,8 @@ impl ComposerError { ), #[cfg(feature = "glsl")] ComposerErrorInner::GlslParseError(e) => ( - e.iter() + e.errors + .iter() .map(|naga::front::glsl::Error { kind, meta }| { Label::primary((), map_span(meta.to_range().unwrap_or(0..0))) .with_message(kind.to_string()) diff --git a/src/compose/mod.rs b/src/compose/mod.rs index 4e0024a..3046131 100644 --- a/src/compose/mod.rs +++ b/src/compose/mod.rs @@ -126,7 +126,10 @@ use indexmap::IndexMap; /// /// codespan reporting for errors is available using the error `emit_to_string` method. this requires validation to be enabled, which is true by default. `Composer::non_validating()` produces a non-validating composer that is not able to give accurate error reporting. /// -use naga::EntryPoint; +use naga::{ + valid::{Capabilities, ShaderStages}, + EntryPoint, +}; use regex::Regex; use std::collections::{hash_map::Entry, BTreeMap, HashMap, HashSet}; use tracing::{debug, trace}; @@ -318,6 +321,11 @@ pub struct Composer { pub module_sets: HashMap, pub module_index: HashMap, pub capabilities: naga::valid::Capabilities, + /// The shader stages that the subgroup operations are valid for. + /// Used when creating a validator for the module. + /// See https://github.com/gfx-rs/wgpu/blob/d9c054c645af0ea9ef81617c3e762fbf0f3fecda/wgpu-core/src/device/mod.rs#L515 + /// for how to set this for proper subgroup ops support. + pub subgroup_stages: ShaderStages, preprocessor: Preprocessor, check_decoration_regex: Regex, undecorate_regex: Regex, @@ -339,6 +347,7 @@ impl Default for Composer { Self { validate: true, capabilities: Default::default(), + subgroup_stages: ShaderStages::empty(), module_sets: Default::default(), module_index: Default::default(), preprocessor: Preprocessor::default(), @@ -417,6 +426,21 @@ impl Composer { String::from_utf8(data_encoding::BASE32_NOPAD.decode(from.as_bytes()).unwrap()).unwrap() } + /// This creates a validator that properly detects subgroup support. + fn create_validator(&self) -> naga::valid::Validator { + let subgroup_operations = if self.capabilities.contains(Capabilities::SUBGROUP) { + use naga::valid::SubgroupOperationSet as S; + S::BASIC | S::VOTE | S::ARITHMETIC | S::BALLOT | S::SHUFFLE | S::SHUFFLE_RELATIVE + } else { + naga::valid::SubgroupOperationSet::empty() + }; + let mut validator = + naga::valid::Validator::new(naga::valid::ValidationFlags::all(), self.capabilities); + validator.subgroup_stages(self.subgroup_stages); + validator.subgroup_operations(subgroup_operations); + validator + } + fn undecorate(&self, string: &str) -> String { let undecor = self .undecorate_regex @@ -476,10 +500,10 @@ impl Composer { #[allow(unused)] header_for: &str, // Only used when GLSL is enabled ) -> Result { // TODO: cache headers again - let info = - naga::valid::Validator::new(naga::valid::ValidationFlags::all(), self.capabilities) - .validate(naga_module) - .map_err(ComposerErrorInner::HeaderValidationError)?; + let info = self + .create_validator() + .validate(naga_module) + .map_err(ComposerErrorInner::HeaderValidationError)?; match language { ShaderLanguage::Wgsl => naga::back::wgsl::write_string( @@ -526,12 +550,10 @@ impl Composer { naga_module.entry_points.push(ep); - let info = naga::valid::Validator::new( - naga::valid::ValidationFlags::all(), - self.capabilities, - ) - .validate(naga_module) - .map_err(ComposerErrorInner::HeaderValidationError)?; + let info = self + .create_validator() + .validate(naga_module) + .map_err(ComposerErrorInner::HeaderValidationError)?; let mut string = String::new(); let options = naga::back::glsl::Options { @@ -1002,6 +1024,17 @@ impl Composer { } } + // These are naga/wgpu's pipeline override constants, not naga_oil's overrides + let mut owned_pipeline_overrides = IndexMap::new(); + for (h, po) in source_ir.overrides.iter_mut() { + if let Some(name) = po.name.as_mut() { + if !name.contains(DECORATION_PRE) { + *name = format!("{name}{module_decoration}"); + owned_pipeline_overrides.insert(name.clone(), h); + } + } + } + let mut owned_vars = IndexMap::new(); for (h, gv) in source_ir.global_variables.iter_mut() { if let Some(name) = gv.name.as_mut() { @@ -1101,6 +1134,11 @@ impl Composer { module_builder.import_const(h); } + for h in owned_pipeline_overrides.values() { + header_builder.import_pipeline_override(h); + module_builder.import_pipeline_override(h); + } + for h in owned_vars.values() { header_builder.import_global(h); module_builder.import_global(h); @@ -1226,6 +1264,16 @@ impl Composer { } } + for (h, po) in source_ir.overrides.iter() { + if let Some(name) = &po.name { + if composable.owned_functions.contains(name) + && items.map_or(true, |items| items.contains(name)) + { + derived.import_pipeline_override(&h); + } + } + } + for (h, v) in source_ir.global_variables.iter() { if let Some(name) = &v.name { if composable.owned_vars.contains(name) @@ -1385,10 +1433,17 @@ impl Composer { /// specify capabilities to be used for naga module generation. /// purges any existing modules - pub fn with_capabilities(self, capabilities: naga::valid::Capabilities) -> Self { + /// See https://github.com/gfx-rs/wgpu/blob/d9c054c645af0ea9ef81617c3e762fbf0f3fecda/wgpu-core/src/device/mod.rs#L515 + /// for how to set the subgroup_stages value. + pub fn with_capabilities( + self, + capabilities: naga::valid::Capabilities, + subgroup_stages: naga::valid::ShaderStages, + ) -> Self { Self { capabilities, validate: self.validate, + subgroup_stages, ..Default::default() } } @@ -1748,9 +1803,7 @@ impl Composer { // validation if self.validate { - let info = - naga::valid::Validator::new(naga::valid::ValidationFlags::all(), self.capabilities) - .validate(&naga_module); + let info = self.create_validator().validate(&naga_module); match info { Ok(_) => Ok(naga_module), Err(e) => { diff --git a/src/compose/test.rs b/src/compose/test.rs index 846ac71..572c7ea 100644 --- a/src/compose/test.rs +++ b/src/compose/test.rs @@ -44,12 +44,7 @@ mod test { }) .unwrap(); - let info = naga::valid::Validator::new( - naga::valid::ValidationFlags::all(), - naga::valid::Capabilities::default(), - ) - .validate(&module) - .unwrap(); + let info = composer.create_validator().validate(&module).unwrap(); let wgsl = naga::back::wgsl::write_string( &module, &info, @@ -90,12 +85,7 @@ mod test { }) .unwrap(); - let info = naga::valid::Validator::new( - naga::valid::ValidationFlags::all(), - naga::valid::Capabilities::default(), - ) - .validate(&module) - .unwrap(); + let info = composer.create_validator().validate(&module).unwrap(); let wgsl = naga::back::wgsl::write_string( &module, &info, @@ -144,12 +134,7 @@ mod test { }) .unwrap(); - let info = naga::valid::Validator::new( - naga::valid::ValidationFlags::all(), - naga::valid::Capabilities::default(), - ) - .validate(&module) - .unwrap(); + let info = composer.create_validator().validate(&module).unwrap(); let wgsl = naga::back::wgsl::write_string( &module, &info, @@ -322,12 +307,7 @@ mod test { }) .unwrap(); - let info = naga::valid::Validator::new( - naga::valid::ValidationFlags::all(), - naga::valid::Capabilities::default(), - ) - .validate(&module) - .unwrap(); + let info = composer.create_validator().validate(&module).unwrap(); let wgsl = naga::back::wgsl::write_string( &module, &info, @@ -364,12 +344,7 @@ mod test { }) .unwrap(); - let info = naga::valid::Validator::new( - naga::valid::ValidationFlags::all(), - naga::valid::Capabilities::default(), - ) - .validate(&module) - .unwrap(); + let info = composer.create_validator().validate(&module).unwrap(); let wgsl = naga::back::wgsl::write_string( &module, &info, @@ -418,12 +393,7 @@ mod test { }) .unwrap(); - let info = naga::valid::Validator::new( - naga::valid::ValidationFlags::all(), - naga::valid::Capabilities::default(), - ) - .validate(&module) - .unwrap(); + let info = composer.create_validator().validate(&module).unwrap(); let wgsl = naga::back::wgsl::write_string( &module, &info, @@ -457,12 +427,7 @@ mod test { }) .unwrap(); - let info = naga::valid::Validator::new( - naga::valid::ValidationFlags::all(), - naga::valid::Capabilities::default(), - ) - .validate(&module) - .unwrap(); + let info = composer.create_validator().validate(&module).unwrap(); let wgsl = naga::back::wgsl::write_string( &module, &info, @@ -550,12 +515,7 @@ mod test { }) .unwrap(); - let info = naga::valid::Validator::new( - naga::valid::ValidationFlags::all(), - naga::valid::Capabilities::default(), - ) - .validate(&module) - .unwrap(); + let info = composer.create_validator().validate(&module).unwrap(); let wgsl = naga::back::wgsl::write_string( &module, &info, @@ -607,12 +567,7 @@ mod test { #[cfg(feature = "override_any")] { let module = module.unwrap(); - let info = naga::valid::Validator::new( - naga::valid::ValidationFlags::all(), - naga::valid::Capabilities::default(), - ) - .validate(&module) - .unwrap(); + let info = composer.create_validator().validate(&module).unwrap(); let wgsl = naga::back::wgsl::write_string( &module, &info, @@ -662,12 +617,7 @@ mod test { // println!("{:#?}", module); - let info = naga::valid::Validator::new( - naga::valid::ValidationFlags::all(), - naga::valid::Capabilities::default(), - ) - .validate(&module) - .unwrap(); + let info = composer.create_validator().validate(&module).unwrap(); let wgsl = naga::back::wgsl::write_string( &module, &info, @@ -704,12 +654,7 @@ mod test { }) .unwrap(); - let info = naga::valid::Validator::new( - naga::valid::ValidationFlags::all(), - naga::valid::Capabilities::default(), - ) - .validate(&module) - .unwrap(); + let info = composer.create_validator().validate(&module).unwrap(); let wgsl = naga::back::wgsl::write_string( &module, &info, @@ -745,12 +690,7 @@ mod test { }) .unwrap(); - let info = naga::valid::Validator::new( - naga::valid::ValidationFlags::all(), - naga::valid::Capabilities::default(), - ) - .validate(&module) - .unwrap(); + let info = composer.create_validator().validate(&module).unwrap(); let wgsl = naga::back::wgsl::write_string( &module, &info, @@ -785,12 +725,7 @@ mod test { }) .unwrap(); - let info = naga::valid::Validator::new( - naga::valid::ValidationFlags::all(), - naga::valid::Capabilities::default(), - ) - .validate(&module) - .unwrap(); + let info = composer.create_validator().validate(&module).unwrap(); let wgsl = naga::back::wgsl::write_string( &module, &info, @@ -826,12 +761,7 @@ mod test { }) .unwrap(); - let info = naga::valid::Validator::new( - naga::valid::ValidationFlags::all(), - naga::valid::Capabilities::default(), - ) - .validate(&module) - .unwrap(); + let info = composer.create_validator().validate(&module).unwrap(); let wgsl = naga::back::wgsl::write_string( &module, &info, @@ -893,12 +823,7 @@ mod test { }) .unwrap(); - let info = naga::valid::Validator::new( - naga::valid::ValidationFlags::all(), - naga::valid::Capabilities::default(), - ) - .validate(&module) - .unwrap(); + let info = composer.create_validator().validate(&module).unwrap(); let wgsl = naga::back::wgsl::write_string( &module, &info, @@ -959,12 +884,7 @@ mod test { // println!("{}", module.emit_to_string(&composer)); // assert!(false); - let info = naga::valid::Validator::new( - naga::valid::ValidationFlags::all(), - naga::valid::Capabilities::default(), - ) - .validate(&module) - .unwrap(); + let info = composer.create_validator().validate(&module).unwrap(); let wgsl = naga::back::wgsl::write_string( &module, &info, @@ -999,12 +919,7 @@ mod test { }) .unwrap(); - let info = naga::valid::Validator::new( - naga::valid::ValidationFlags::all(), - naga::valid::Capabilities::default(), - ) - .validate(&module) - .unwrap(); + let info = composer.create_validator().validate(&module).unwrap(); let wgsl = naga::back::wgsl::write_string( &module, &info, @@ -1047,12 +962,7 @@ mod test { }) .unwrap(); - let info = naga::valid::Validator::new( - naga::valid::ValidationFlags::all(), - naga::valid::Capabilities::default(), - ) - .validate(&module_a) - .unwrap(); + let info = composer.create_validator().validate(&module_a).unwrap(); let wgsl = naga::back::wgsl::write_string( &module_a, &info, @@ -1074,12 +984,7 @@ mod test { }) .unwrap(); - let info = naga::valid::Validator::new( - naga::valid::ValidationFlags::all(), - naga::valid::Capabilities::default(), - ) - .validate(&module_b) - .unwrap(); + let info = composer.create_validator().validate(&module_b).unwrap(); let wgsl = naga::back::wgsl::write_string( &module_b, &info, @@ -1201,12 +1106,7 @@ mod test { }) .unwrap(); - let info = naga::valid::Validator::new( - naga::valid::ValidationFlags::all(), - naga::valid::Capabilities::default(), - ) - .validate(&module) - .unwrap(); + let info = composer.create_validator().validate(&module).unwrap(); let wgsl = naga::back::wgsl::write_string( &module, &info, @@ -1240,12 +1140,7 @@ mod test { }) .unwrap(); - let info = naga::valid::Validator::new( - naga::valid::ValidationFlags::all(), - naga::valid::Capabilities::default(), - ) - .validate(&module) - .unwrap(); + let info = composer.create_validator().validate(&module).unwrap(); let wgsl = naga::back::wgsl::write_string( &module, &info, @@ -1285,12 +1180,7 @@ mod test { }) .unwrap(); - let info = naga::valid::Validator::new( - naga::valid::ValidationFlags::all(), - naga::valid::Capabilities::default(), - ) - .validate(&module) - .unwrap(); + let info = composer.create_validator().validate(&module).unwrap(); let wgsl = naga::back::wgsl::write_string( &module, &info, @@ -1448,6 +1338,7 @@ mod test { ), module: &shader_module, entry_point: "run_test", + compilation_options: Default::default(), }); let bindgroup = device.create_bind_group(&BindGroupDescriptor { diff --git a/src/compose/tests/expected/glsl_call_wgsl.txt b/src/compose/tests/expected/glsl_call_wgsl.txt index 1df7ccf..d339ec2 100644 --- a/src/compose/tests/expected/glsl_call_wgsl.txt +++ b/src/compose/tests/expected/glsl_call_wgsl.txt @@ -1,5 +1,5 @@ struct VertexOutput { - @builtin(position) member: vec4, + @builtin(position) gl_Position: vec4, } var gl_Position: vec4; diff --git a/src/derive.rs b/src/derive.rs index f4481c6..058f276 100644 --- a/src/derive.rs +++ b/src/derive.rs @@ -1,8 +1,8 @@ use indexmap::IndexMap; use naga::{ Arena, AtomicFunction, Block, Constant, EntryPoint, Expression, Function, FunctionArgument, - FunctionResult, GlobalVariable, Handle, ImageQuery, LocalVariable, Module, SampleLevel, Span, - Statement, StructMember, SwitchCase, Type, TypeInner, UniqueArena, + FunctionResult, GatherMode, GlobalVariable, Handle, ImageQuery, LocalVariable, Module, + Override, SampleLevel, Span, Statement, StructMember, SwitchCase, Type, TypeInner, UniqueArena, }; use std::{cell::RefCell, rc::Rc}; @@ -11,17 +11,25 @@ pub struct DerivedModule<'a> { shader: Option<&'a Module>, span_offset: usize, + /// Maps the original type handle to the the mangled type handle. type_map: IndexMap, Handle>, + /// Maps the original const handle to the the mangled const handle. const_map: IndexMap, Handle>, - const_expression_map: Rc, Handle>>>, + /// Maps the original pipeline override handle to the the mangled pipeline override handle. + pipeline_override_map: IndexMap, Handle>, + /// Contains both const expressions and pipeline override constant expressions. + /// The expressions are stored together because that's what Naga expects. + global_expressions: Rc>>, + /// Maps the original expression handle to the new expression handle for const expressions and pipeline override expressions. + /// The expressions are stored together because that's what Naga expects. + global_expression_map: Rc, Handle>>>, global_map: IndexMap, Handle>, function_map: IndexMap>, - types: UniqueArena, constants: Arena, - const_expressions: Rc>>, globals: Arena, functions: Arena, + pipeline_overrides: Arena, } impl<'a> DerivedModule<'a> { @@ -38,7 +46,8 @@ impl<'a> DerivedModule<'a> { self.type_map.clear(); self.const_map.clear(); self.global_map.clear(); - self.const_expression_map.borrow_mut().clear(); + self.global_expression_map.borrow_mut().clear(); + self.pipeline_override_map.clear(); } pub fn map_span(&self, span: Span) -> Span { @@ -136,9 +145,8 @@ impl<'a> DerivedModule<'a> { let new_const = Constant { name: c.name.clone(), - r#override: c.r#override.clone(), ty: self.import_type(&c.ty), - init: self.import_const_expression(c.init), + init: self.import_global_expression(c.init), }; let span = self.shader.as_ref().unwrap().constants.get_span(*h_const); @@ -166,7 +174,7 @@ impl<'a> DerivedModule<'a> { space: gv.space, binding: gv.binding.clone(), ty: self.import_type(&gv.ty), - init: gv.init.map(|c| self.import_const_expression(c)), + init: gv.init.map(|c| self.import_global_expression(c)), }; let span = self @@ -182,18 +190,56 @@ impl<'a> DerivedModule<'a> { new_h }) } - // remap a const expression from source context into our derived context - pub fn import_const_expression(&mut self, h_cexpr: Handle) -> Handle { + + // remap either a const or pipeline override expression from source context into our derived context + pub fn import_global_expression(&mut self, h_expr: Handle) -> Handle { self.import_expression( - h_cexpr, - &self.shader.as_ref().unwrap().const_expressions, - self.const_expression_map.clone(), - self.const_expressions.clone(), + h_expr, + &self.shader.as_ref().unwrap().global_expressions, + self.global_expression_map.clone(), + self.global_expressions.clone(), false, true, ) } + // remap a pipeline override from source context into our derived context + pub fn import_pipeline_override(&mut self, h_override: &Handle) -> Handle { + self.pipeline_override_map + .get(h_override) + .copied() + .unwrap_or_else(|| { + let pipeline_override = self + .shader + .as_ref() + .unwrap() + .overrides + .try_get(*h_override) + .unwrap(); + + let new_override = Override { + name: pipeline_override.name.clone(), + id: pipeline_override.id, + ty: self.import_type(&pipeline_override.ty), + init: pipeline_override + .init + .map(|init| self.import_global_expression(init)), + }; + + let span = self + .shader + .as_ref() + .unwrap() + .overrides + .get_span(*h_override); + let new_h = self + .pipeline_overrides + .fetch_or_append(new_override, self.map_span(span)); + self.pipeline_override_map.insert(*h_override, new_h); + new_h + }) + } + // remap a block fn import_block( &mut self, @@ -363,7 +409,40 @@ impl<'a> DerivedModule<'a> { naga::RayQueryFunction::Terminate => naga::RayQueryFunction::Terminate, }, }, - + Statement::SubgroupBallot { result, predicate } => Statement::SubgroupBallot { + result: map_expr!(result), + predicate: map_expr_opt!(predicate), + }, + Statement::SubgroupGather { + mut mode, + argument, + result, + } => { + match mode { + GatherMode::BroadcastFirst => (), + GatherMode::Broadcast(ref mut h_src) + | GatherMode::Shuffle(ref mut h_src) + | GatherMode::ShuffleDown(ref mut h_src) + | GatherMode::ShuffleUp(ref mut h_src) + | GatherMode::ShuffleXor(ref mut h_src) => *h_src = map_expr!(h_src), + }; + Statement::SubgroupGather { + mode, + argument: map_expr!(argument), + result: map_expr!(result), + } + } + Statement::SubgroupCollectiveOperation { + op, + collective_op, + argument, + result, + } => Statement::SubgroupCollectiveOperation { + op: *op, + collective_op: *collective_op, + argument: map_expr!(argument), + result: map_expr!(result), + }, // else just copy Statement::Break | Statement::Continue @@ -462,7 +541,7 @@ impl<'a> DerivedModule<'a> { gather: *gather, coordinate: map_expr!(coordinate), array_index: map_expr_opt!(array_index), - offset: offset.map(|c| self.import_const_expression(c)), + offset: offset.map(|c| self.import_global_expression(c)), level: match level { SampleLevel::Auto | SampleLevel::Zero => *level, SampleLevel::Exact(expr) => SampleLevel::Exact(map_expr!(expr)), @@ -592,6 +671,14 @@ impl<'a> DerivedModule<'a> { committed: *committed, } } + Expression::Override(h_override) => { + is_external = true; + Expression::Override(self.import_pipeline_override(h_override)) + } + Expression::SubgroupBallotResult => expr.clone(), + Expression::SubgroupOperationResult { ty } => Expression::SubgroupOperationResult { + ty: self.import_type(ty), + }, }; if !non_emitting_only || is_external { @@ -748,12 +835,13 @@ impl<'a> From> for naga::Module { types: derived.types, constants: derived.constants, global_variables: derived.globals, - const_expressions: Rc::try_unwrap(derived.const_expressions) + global_expressions: Rc::try_unwrap(derived.global_expressions) .unwrap() .into_inner(), functions: derived.functions, special_types: Default::default(), entry_points: Default::default(), + overrides: derived.pipeline_overrides, } } } diff --git a/src/prune/mod.rs b/src/prune/mod.rs index 9b0cc4d..94c5bd8 100644 --- a/src/prune/mod.rs +++ b/src/prune/mod.rs @@ -248,6 +248,9 @@ impl FunctionReq { committed: *committed, } } + Expression::Override(_) => expr.clone(), + Expression::SubgroupBallotResult => expr.clone(), + Expression::SubgroupOperationResult { .. } => expr.clone(), } } @@ -1372,6 +1375,15 @@ impl<'a> Pruner<'a> { } => { self.add_expression(function, func_req, context, *query, &PartReq::All); } + Expression::Override(_) => { + // we don't prune overrides, so nothing to do + } + Expression::SubgroupBallotResult => { + // nothing, handled by the statement + } + Expression::SubgroupOperationResult { .. } => { + // nothing, handled by the statement + } } func_req.exprs_required.insert(h_expr, part.clone()); @@ -1676,6 +1688,48 @@ impl<'a> Pruner<'a> { } RayQuery(required) } + Statement::SubgroupBallot { result, predicate } => { + let var_ref = Self::resolve_var(function, *result, Vec::default()); + let required = self.store_required(context, &var_ref).is_some(); + if required { + if let Some(predicate) = predicate { + self.add_expression(function, func_req, context, *predicate, &PartReq::All); + } + } + RayQuery(required) + } + Statement::SubgroupGather { + mode, + argument, + result, + } => { + let var_ref = Self::resolve_var(function, *result, Vec::default()); + let required = self.store_required(context, &var_ref).is_some(); + if required { + match mode { + naga::GatherMode::BroadcastFirst => (), + naga::GatherMode::Broadcast(h_src) + | naga::GatherMode::Shuffle(h_src) + | naga::GatherMode::ShuffleDown(h_src) + | naga::GatherMode::ShuffleUp(h_src) + | naga::GatherMode::ShuffleXor(h_src) => { + self.add_expression(function, func_req, context, *h_src, &PartReq::All) + } + } + self.add_expression(function, func_req, context, *argument, &PartReq::All); + } + RayQuery(required) + } + Statement::SubgroupCollectiveOperation { + argument, result, .. + } => { + let var_ref = Self::resolve_var(function, *result, Vec::default()); + let required = self.store_required(context, &var_ref).is_some(); + if required { + self.add_expression(function, func_req, context, *argument, &PartReq::All); + } + RayQuery(required) + } } } @@ -1802,9 +1856,9 @@ impl<'a> Pruner<'a> { let mut derived = DerivedModule::default(); derived.set_shader_source(self.module, 0); - // just copy all the constants for now, so we can copy const handles as well - for (h_cexpr, _) in self.module.const_expressions.iter() { - derived.import_const_expression(h_cexpr); + // just copy all the (pipeline + normal) constants for now, so we can copy const handles as well + for (h_cexpr, _) in self.module.global_expressions.iter() { + derived.import_global_expression(h_cexpr); } for (h_f, f) in self.module.functions.iter() { diff --git a/src/redirect.rs b/src/redirect.rs index eca25bd..98caddd 100644 --- a/src/redirect.rs +++ b/src/redirect.rs @@ -68,7 +68,10 @@ impl Redirector { | Statement::Store { .. } | Statement::ImageStore { .. } | Statement::Atomic { .. } - | Statement::RayQuery { .. } => (), + | Statement::RayQuery { .. } + | Statement::SubgroupBallot { .. } + | Statement::SubgroupGather { .. } + | Statement::SubgroupCollectiveOperation { .. } => (), } } }