diff --git a/src/back/hlsl/conv.rs b/src/back/hlsl/conv.rs index 5eb24962f6..54c3f19a1e 100644 --- a/src/back/hlsl/conv.rs +++ b/src/back/hlsl/conv.rs @@ -40,11 +40,7 @@ impl crate::TypeInner { } } - pub(super) fn size_hlsl( - &self, - types: &crate::UniqueArena, - constants: &crate::Arena, - ) -> u32 { + pub(super) fn size_hlsl(&self, gctx: crate::GlobalCtx) -> u32 { match *self { Self::Matrix { columns, @@ -57,27 +53,24 @@ impl crate::TypeInner { } Self::Array { base, size, stride } => { let count = match size { - crate::ArraySize::Constant(handle) => { - constants[handle].to_array_length().unwrap_or(1) - } + crate::ArraySize::Constant(handle) => gctx.to_array_length(handle).unwrap_or(1), // A dynamically-sized array has to have at least one element crate::ArraySize::Dynamic => 1, }; - let last_el_size = types[base].inner.size_hlsl(types, constants); + let last_el_size = gctx.types[base].inner.size_hlsl(gctx.reborrow()); ((count - 1) * stride) + last_el_size } - _ => self.size(constants), + _ => self.size(gctx.reborrow()), } } /// Used to generate the name of the wrapped type constructor pub(super) fn hlsl_type_id<'a>( base: crate::Handle, - types: &crate::UniqueArena, - constants: &crate::Arena, + gctx: crate::GlobalCtx, names: &'a crate::FastHashMap, ) -> Result, Error> { - Ok(match types[base].inner { + Ok(match gctx.types[base].inner { crate::TypeInner::Scalar { kind, width } => Cow::Borrowed(kind.to_hlsl_str(width)?), crate::TypeInner::Vector { size, kind, width } => Cow::Owned(format!( "{}{}", @@ -100,8 +93,8 @@ impl crate::TypeInner { .. } => Cow::Owned(format!( "array{}_{}_", - constants[size].to_array_length().unwrap(), - Self::hlsl_type_id(base, types, constants, names)? + gctx.to_array_length(size).unwrap(), + Self::hlsl_type_id(base, gctx, names)? )), crate::TypeInner::Struct { .. } => { Cow::Borrowed(&names[&crate::proc::NameKey::Type(base)]) diff --git a/src/back/hlsl/help.rs b/src/back/hlsl/help.rs index bb18bf3ab4..e3db7df659 100644 --- a/src/back/hlsl/help.rs +++ b/src/back/hlsl/help.rs @@ -347,12 +347,7 @@ impl<'a, W: Write> super::Writer<'a, W> { module: &crate::Module, constructor: WrappedConstructor, ) -> BackendResult { - let name = crate::TypeInner::hlsl_type_id( - constructor.ty, - &module.types, - &module.constants, - &self.names, - )?; + let name = crate::TypeInner::hlsl_type_id(constructor.ty, module.to_ctx(), &self.names)?; write!(self.out, "Construct{name}")?; Ok(()) } @@ -411,7 +406,7 @@ impl<'a, W: Write> super::Writer<'a, W> { size: crate::ArraySize::Constant(size), .. } => { - let count = module.constants[size].to_array_length().unwrap(); + let count = module.to_ctx().to_array_length(size).unwrap(); for i in 0..count as usize { write_arg(i, base)?; } @@ -486,7 +481,7 @@ impl<'a, W: Write> super::Writer<'a, W> { write!(self.out, " {RETURN_VARIABLE_NAME}")?; self.write_array_size(module, base, crate::ArraySize::Constant(size))?; write!(self.out, " = {{ ")?; - let count = module.constants[size].to_array_length().unwrap(); + let count = module.to_ctx().to_array_length(size).unwrap(); for i in 0..count { if i != 0 { write!(self.out, ", ")?; diff --git a/src/back/hlsl/storage.rs b/src/back/hlsl/storage.rs index 21396c953a..cc8a0717fc 100644 --- a/src/back/hlsl/storage.rs +++ b/src/back/hlsl/storage.rs @@ -148,8 +148,8 @@ impl super::Writer<'_, W> { .. } => { write!(self.out, "{{")?; - let count = module.constants[const_handle].to_array_length().unwrap(); - let stride = module.types[base].inner.size(&module.constants); + let count = module.to_ctx().to_array_length(const_handle).unwrap(); + let stride = module.types[base].inner.size(module.to_ctx()); let iter = (0..count).map(|i| (TypeResolution::Handle(base), stride * i)); self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?; write!(self.out, "}}")?; @@ -311,8 +311,8 @@ impl super::Writer<'_, W> { self.write_store_value(module, &value, func_ctx)?; writeln!(self.out, ";")?; // then iterate the stores - let count = module.constants[const_handle].to_array_length().unwrap(); - let stride = module.types[base].inner.size(&module.constants); + let count = module.to_ctx().to_array_length(const_handle).unwrap(); + let stride = module.types[base].inner.size(module.to_ctx()); for i in 0..count { self.temp_access_chain.push(SubAccess::Offset(i * stride)); let sv = StoreValue::TempIndex { diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index 1f0491c58a..ef003af987 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -822,7 +822,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { // Panics if `ArraySize::Constant` has a constant that isn't an sint or uint match size { crate::ArraySize::Constant(const_handle) => { - let size = module.constants[const_handle].to_array_length().unwrap(); + let size = module.to_ctx().to_array_length(const_handle).unwrap(); write!(self.out, "{size}")?; } crate::ArraySize::Dynamic => {} @@ -870,7 +870,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } } let ty_inner = &module.types[member.ty].inner; - last_offset = member.offset + ty_inner.size_hlsl(&module.types, &module.constants); + last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx()); // The indentation is only for readability write!(self.out, "{}", back::INDENT)?; diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index bac178729f..1f968459c9 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -70,7 +70,7 @@ const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e"; struct TypeContext<'a> { handle: Handle, - module: &'a crate::Module, + gctx: crate::GlobalCtx<'a>, names: &'a FastHashMap, access: crate::StorageAccess, binding: Option<&'a super::ResolvedBinding>, @@ -79,7 +79,7 @@ struct TypeContext<'a> { impl<'a> Display for TypeContext<'a> { fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> { - let ty = &self.module.types[self.handle]; + let ty = &self.gctx.types[self.handle]; if ty.needs_alias() && !self.first_time { let name = &self.names[&NameKey::Type(self.handle)]; return write!(out, "{name}"); @@ -208,13 +208,8 @@ impl<'a> Display for TypeContext<'a> { { write!(out, "{NAMESPACE}::array<{base_tyname}, {override_size}>") } else if let crate::ArraySize::Constant(size) = size { - let constant_ctx = ConstantContext { - handle: size, - arena: &self.module.constants, - names: self.names, - first_time: false, - }; - write!(out, "{NAMESPACE}::array<{base_tyname}, {constant_ctx}>") + let size = self.gctx.to_array_length(size).unwrap(); + write!(out, "{NAMESPACE}::array<{base_tyname}, {size}>") } else { unreachable!("metal requires all arrays be constant sized"); } @@ -258,7 +253,7 @@ impl<'a> TypedGlobalVariable<'a> { }; let ty_name = TypeContext { handle: var.ty, - module: self.module, + gctx: self.module.to_ctx(), names: self.names, access: storage_access, binding: self.binding, @@ -294,21 +289,15 @@ impl<'a> TypedGlobalVariable<'a> { } struct ConstantContext<'a> { - handle: Handle, - arena: &'a crate::Arena, + handle: Handle, + gctx: crate::GlobalCtx<'a>, names: &'a FastHashMap, first_time: bool, } impl<'a> Display for ConstantContext<'a> { fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> { - let con = &self.arena[self.handle]; - if con.needs_alias() && !self.first_time { - let name = &self.names[&NameKey::Constant(self.handle)]; - return write!(out, "{name}"); - } - - match con.inner { + match self.gctx.const_expressions[self.handle] { crate::ConstantInner::Scalar { value, width: _ } => match value { crate::ScalarValue::Sint(value) => { write!(out, "{value}") @@ -386,7 +375,7 @@ fn should_pack_struct_member( } let ty_inner = &module.types[member.ty].inner; - let last_offset = member.offset + ty_inner.size(&module.constants); + let last_offset = member.offset + ty_inner.size(module.to_ctx()); let next_offset = match members.get(index + 1) { Some(next) => next.offset, None => span, @@ -490,16 +479,6 @@ impl crate::Type { } } -impl crate::Constant { - // Returns `true` if we need to emit an alias for this constant. - const fn needs_alias(&self) -> bool { - match self.inner { - crate::ConstantInner::Scalar { .. } => self.name.is_some(), - crate::ConstantInner::Composite { .. } => true, - } - } -} - enum FunctionOrigin { Handle(Handle), EntryPoint(proc::EntryPointIndex), @@ -1136,7 +1115,7 @@ impl Writer { crate::TypeInner::Array { base, stride, .. } => ( context.module.types[base] .inner - .size(&context.module.constants), + .size(context.module.to_ctx()), stride, ), _ => return Err(Error::Validation), @@ -1307,15 +1286,6 @@ impl Writer { self.put_access_chain(expr_handle, policy, context)?; } } - crate::Expression::Constant(handle) => { - let coco = ConstantContext { - handle, - arena: &context.module.constants, - names: &self.names, - first_time: false, - }; - write!(self.out, "{coco}")?; - } crate::Expression::Splat { size, value } => { let scalar_kind = match *context.resolve_type(value) { crate::TypeInner::Scalar { kind, .. } => kind, @@ -1337,6 +1307,17 @@ impl Writer { write!(self.out, "{}", back::COMPONENTS[sc as usize])?; } } + crate::Expression::Literal(_) => todo!(), + crate::Expression::Constant(handle) => { + let coco = ConstantContext { + handle, + arena: &context.module.constants, + names: &self.names, + first_time: false, + }; + write!(self.out, "{coco}")?; + } + crate::Expression::New(_) => todo!(), crate::Expression::Compose { ty, ref components } => { self.put_compose(ty, components, context)?; } @@ -2253,8 +2234,10 @@ impl Writer { .. } = context.module.types[member.ty].inner { - let size = context.module.constants[const_handle] - .to_array_length() + let size = context + .module + .to_ctx() + .to_array_length(const_handle) .unwrap(); write!(self.out, "{comma} {{")?; for j in 0..size { @@ -2359,7 +2342,7 @@ impl Writer { TypeResolution::Handle(ty_handle) => { let ty_name = TypeContext { handle: ty_handle, - module: context.module, + gctx: context.module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), binding: None, @@ -2812,8 +2795,11 @@ impl Writer { // we can't assign fixed-size arrays if let Some(const_handle) = array_size { - let size = context.expression.module.constants[const_handle] - .to_array_length() + let size = context + .expression + .module + .to_ctx() + .to_array_length(const_handle) .unwrap(); write!(self.out, "{level}for(int _i=0; _i<{size}; ++_i) ")?; self.put_access_chain(pointer, policy, &context.expression)?; @@ -2892,9 +2878,8 @@ impl Writer { } }; - self.write_scalar_constants(module)?; self.write_type_defs(module)?; - self.write_composite_constants(module)?; + self.write_global_constants(module)?; self.write_functions(module, info, options, pipeline_options) } @@ -2947,7 +2932,7 @@ impl Writer { } => { let base_name = TypeContext { handle: base, - module, + gctx: module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), binding: None, @@ -2958,7 +2943,7 @@ impl Writer { crate::ArraySize::Constant(const_handle) => { let coco = ConstantContext { handle: const_handle, - arena: &module.constants, + gctx: module.to_ctx(), names: &self.names, first_time: false, }; @@ -2992,7 +2977,7 @@ impl Writer { writeln!(self.out, "{}char _pad{}[{}];", back::INDENT, index, pad)?; } let ty_inner = &module.types[member.ty].inner; - last_offset = member.offset + ty_inner.size(&module.constants); + last_offset = member.offset + ty_inner.size(module.to_ctx()); let member_name = &self.names[&NameKey::StructMember(handle, index as u32)]; @@ -3011,7 +2996,7 @@ impl Writer { None => { let base_name = TypeContext { handle: member.ty, - module, + gctx: module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), binding: None, @@ -3042,7 +3027,7 @@ impl Writer { _ => { let ty_name = TypeContext { handle, - module, + gctx: module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), binding: None, @@ -3055,60 +3040,32 @@ impl Writer { Ok(()) } - fn write_scalar_constants(&mut self, module: &crate::Module) -> BackendResult { + fn write_global_constants(&mut self, module: &crate::Module) -> BackendResult { for (handle, constant) in module.constants.iter() { - match constant.inner { - crate::ConstantInner::Scalar { - width: _, - ref value, - } if constant.name.is_some() => { - debug_assert!(constant.needs_alias()); - write!(self.out, "constexpr constant ")?; - match *value { - crate::ScalarValue::Sint(_) => { - write!(self.out, "int")?; - } - crate::ScalarValue::Uint(_) => { - write!(self.out, "unsigned")?; - } - crate::ScalarValue::Float(_) => { - write!(self.out, "float")?; - } - crate::ScalarValue::Bool(_) => { - write!(self.out, "bool")?; - } - } + let ty_name = TypeContext { + handle: constant.ty, + gctx: module.to_ctx(), + names: &self.names, + access: crate::StorageAccess::empty(), + binding: None, + first_time: false, + }; + let name = &self.names[&NameKey::Constant(handle)]; + write!(self.out, "constexpr constant {ty_name} {name} = ")?; + + match module.const_expressions[constant.init.unwrap()] { + crate::Expression::Literal(lit) => { let name = &self.names[&NameKey::Constant(handle)]; let coco = ConstantContext { - handle, - arena: &module.constants, + handle: constant.init.unwrap(), + gctx: module.to_ctx(), names: &self.names, first_time: true, }; - writeln!(self.out, " {name} = {coco};")?; + writeln!(self.out, "{coco};")?; } - _ => {} - } - } - Ok(()) - } - - fn write_composite_constants(&mut self, module: &crate::Module) -> BackendResult { - for (handle, constant) in module.constants.iter() { - match constant.inner { - crate::ConstantInner::Scalar { .. } => {} - crate::ConstantInner::Composite { ty, ref components } => { - debug_assert!(constant.needs_alias()); - let name = &self.names[&NameKey::Constant(handle)]; - let ty_name = TypeContext { - handle: ty, - module, - names: &self.names, - access: crate::StorageAccess::empty(), - binding: None, - first_time: false, - }; - write!(self.out, "constant {ty_name} {name} = {{",)?; + crate::Expression::Compose { ty, components } => { + write!(self.out, "{{",)?; for (i, &sub_handle) in components.iter().enumerate() { // insert padding initialization, if needed if self.struct_member_pads.contains(&(ty, i as u32)) { @@ -3117,7 +3074,7 @@ impl Writer { let separator = if i != 0 { ", " } else { "" }; let coco = ConstantContext { handle: sub_handle, - arena: &module.constants, + gctx: module.to_ctx(), names: &self.names, first_time: false, }; @@ -3125,8 +3082,15 @@ impl Writer { } writeln!(self.out, "}};")?; } + _ => unreachable!(), } } + + if !module.constants.is_empty() { + // Add extra newline for readability + writeln!(self.out)?; + } + Ok(()) } @@ -3242,7 +3206,7 @@ impl Writer { Some(ref result) => { let ty_name = TypeContext { handle: result.ty, - module, + gctx: module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), binding: None, @@ -3260,7 +3224,7 @@ impl Writer { let name = &self.names[&NameKey::FunctionArgument(fun_handle, index as u32)]; let param_type_name = TypeContext { handle: arg.ty, - module, + gctx: module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), binding: None, @@ -3309,7 +3273,7 @@ impl Writer { for (local_handle, local) in fun.local_variables.iter() { let ty_name = TypeContext { handle: local.ty, - module, + gctx: module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), binding: None, @@ -3321,7 +3285,7 @@ impl Writer { Some(value) => { let coco = ConstantContext { handle: value, - arena: &module.constants, + gctx: module.to_ctx(), names: &self.names, first_time: false, }; @@ -3506,7 +3470,7 @@ impl Writer { let name = &self.names[name_key]; let ty_name = TypeContext { handle: ty, - module, + gctx: module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), binding: None, @@ -3550,7 +3514,7 @@ impl Writer { for (name, ty, binding) in result_members { let ty_name = TypeContext { handle: ty, - module, + gctx: module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), binding: None, @@ -3569,7 +3533,7 @@ impl Writer { crate::TypeInner::Array { size: crate::ArraySize::Constant(handle), .. - } => module.constants[handle].to_array_length(), + } => module.to_ctx().to_array_length(handle), _ => None, }; let resolved = options.resolve_local_binding(binding, out_mode)?; @@ -3645,7 +3609,7 @@ impl Writer { let ty_name = TypeContext { handle: ty, - module, + gctx: module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), binding: None, @@ -3882,7 +3846,7 @@ impl Writer { let name = &self.names[&NameKey::EntryPointLocal(ep_index as _, local_handle)]; let ty_name = TypeContext { handle: local.ty, - module, + gctx: module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), binding: None, @@ -3893,7 +3857,7 @@ impl Writer { Some(value) => { let coco = ConstantContext { handle: value, - arena: &module.constants, + gctx: module.to_ctx(), names: &self.names, first_time: false, }; @@ -4167,7 +4131,7 @@ fn test_stack_size() { let constant = module.constants.append( crate::Constant { name: None, - specialization: None, + specialization: crate::Specialization::None, inner: crate::ConstantInner::Scalar { value: crate::ScalarValue::Float(1.0), width: 4, diff --git a/src/back/wgsl/writer.rs b/src/back/wgsl/writer.rs index 7bfb0e964d..b9af6b78fa 100644 --- a/src/back/wgsl/writer.rs +++ b/src/back/wgsl/writer.rs @@ -118,10 +118,15 @@ impl Writer { // Write all constants for (handle, constant) in module.constants.iter() { if constant.name.is_some() { - self.write_global_constant(module, &constant.inner, handle)?; + self.write_global_constant(module, handle)?; } } + if !module.constants.is_empty() { + // Add extra newline for readability + writeln!(self.out)?; + } + // Write all globals for (ty, global) in module.global_variables.iter() { self.write_global(module, global, ty)?; @@ -179,25 +184,6 @@ impl Writer { Ok(()) } - /// Helper method used to write [`ScalarValue`](crate::ScalarValue) - /// - /// # Notes - /// Adds no trailing or leading whitespace - fn write_scalar_value(&mut self, value: crate::ScalarValue) -> BackendResult { - use crate::ScalarValue as Sv; - - match value { - Sv::Sint(value) => write!(self.out, "{value}")?, - Sv::Uint(value) => write!(self.out, "{value}u")?, - // Floats are written using `Debug` instead of `Display` because it always appends the - // decimal part even it's zero - Sv::Float(value) => write!(self.out, "{value:?}")?, - Sv::Bool(value) => write!(self.out, "{value}")?, - } - - Ok(()) - } - /// Helper method used to write struct name /// /// # Notes @@ -305,7 +291,7 @@ impl Writer { // Write the constant // `write_constant` adds no trailing or leading space/newline - self.write_constant(module, init)?; + self.write_expr(module, init, None)?; } // Finish the local with `;` and add a newline (only for readability) @@ -515,7 +501,7 @@ impl Writer { crate::ArraySize::Constant(handle) => { self.write_type(module, base)?; write!(self.out, ",")?; - self.write_constant(module, handle)?; + self.write_expr(module, handle, None)?; } crate::ArraySize::Dynamic => { self.write_type(module, base)?; @@ -530,7 +516,7 @@ impl Writer { crate::ArraySize::Constant(handle) => { self.write_type(module, base)?; write!(self.out, ",")?; - self.write_constant(module, handle)?; + self.write_expr(module, handle, None)?; } crate::ArraySize::Dynamic => { self.write_type(module, base)?; @@ -644,7 +630,7 @@ impl Writer { Some(self.namer.call(name)) } else if info.ref_count == 0 { write!(self.out, "{level}_ = ")?; - self.write_expr(module, handle, func_ctx)?; + self.write_expr(module, handle, Some(func_ctx))?; writeln!(self.out, ";")?; continue; } else { @@ -667,7 +653,7 @@ impl Writer { if let Some(name) = expr_name { write!(self.out, "{level}")?; self.start_named_expr(module, handle, func_ctx, &name)?; - self.write_expr(module, handle, func_ctx)?; + self.write_expr(module, handle, Some(func_ctx))?; self.named_expressions.insert(handle, name); writeln!(self.out, ";")?; } @@ -681,7 +667,7 @@ impl Writer { } => { write!(self.out, "{level}")?; write!(self.out, "if ")?; - self.write_expr(module, condition, func_ctx)?; + self.write_expr(module, condition, Some(func_ctx))?; writeln!(self.out, " {{")?; let l2 = level.next(); @@ -709,7 +695,7 @@ impl Writer { if let Some(return_value) = value { // The leading space is important write!(self.out, " ")?; - self.write_expr(module, return_value, func_ctx)?; + self.write_expr(module, return_value, Some(func_ctx))?; } writeln!(self.out, ";")?; } @@ -730,9 +716,9 @@ impl Writer { }; if is_atomic { write!(self.out, "atomicStore(")?; - self.write_expr(module, pointer, func_ctx)?; + self.write_expr(module, pointer, Some(func_ctx))?; write!(self.out, ", ")?; - self.write_expr(module, value, func_ctx)?; + self.write_expr(module, value, Some(func_ctx))?; write!(self.out, ")")?; } else { self.write_expr_with_indirection( @@ -742,7 +728,7 @@ impl Writer { Indirection::Reference, )?; write!(self.out, " = ")?; - self.write_expr(module, value, func_ctx)?; + self.write_expr(module, value, Some(func_ctx))?; } writeln!(self.out, ";")? } @@ -760,7 +746,7 @@ impl Writer { let func_name = &self.names[&NameKey::Function(function)]; write!(self.out, "{func_name}(")?; for (index, &argument) in arguments.iter().enumerate() { - self.write_expr(module, argument, func_ctx)?; + self.write_expr(module, argument, Some(func_ctx))?; // Only write a comma if isn't the last element if index != arguments.len().saturating_sub(1) { // The leading space is for readability only @@ -782,13 +768,13 @@ impl Writer { let fun_str = fun.to_wgsl(); write!(self.out, "atomic{fun_str}(")?; - self.write_expr(module, pointer, func_ctx)?; + self.write_expr(module, pointer, Some(func_ctx))?; if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun { write!(self.out, ", ")?; - self.write_expr(module, cmp, func_ctx)?; + self.write_expr(module, cmp, Some(func_ctx))?; } write!(self.out, ", ")?; - self.write_expr(module, value, func_ctx)?; + self.write_expr(module, value, Some(func_ctx))?; writeln!(self.out, ");")? } Statement::ImageStore { @@ -799,15 +785,15 @@ impl Writer { } => { write!(self.out, "{level}")?; write!(self.out, "textureStore(")?; - self.write_expr(module, image, func_ctx)?; + self.write_expr(module, image, Some(func_ctx))?; write!(self.out, ", ")?; - self.write_expr(module, coordinate, func_ctx)?; + self.write_expr(module, coordinate, Some(func_ctx))?; if let Some(array_index_expr) = array_index { write!(self.out, ", ")?; - self.write_expr(module, array_index_expr, func_ctx)?; + self.write_expr(module, array_index_expr, Some(func_ctx))?; } write!(self.out, ", ")?; - self.write_expr(module, value, func_ctx)?; + self.write_expr(module, value, Some(func_ctx))?; writeln!(self.out, ");")?; } // TODO: copy-paste from glsl-out @@ -827,7 +813,7 @@ impl Writer { // Start the switch write!(self.out, "{level}")?; write!(self.out, "switch ")?; - self.write_expr(module, selector, func_ctx)?; + self.write_expr(module, selector, Some(func_ctx))?; writeln!(self.out, " {{")?; let l2 = level.next(); @@ -912,7 +898,7 @@ impl Writer { if let Some(condition) = break_if { // The trailing space is important write!(self.out, "{}break if ", l2.next())?; - self.write_expr(module, condition, func_ctx)?; + self.write_expr(module, condition, Some(func_ctx))?; // Close the `break if` statement writeln!(self.out, ";")?; } @@ -1035,9 +1021,13 @@ impl Writer { &mut self, module: &Module, expr: Handle, - func_ctx: &back::FunctionCtx<'_>, + func_ctx: Option<&back::FunctionCtx<'_>>, ) -> BackendResult { - self.write_expr_with_indirection(module, expr, func_ctx, Indirection::Ordinary) + if let Some(func_ctx) = func_ctx { + self.write_expr_with_indirection(module, expr, func_ctx, Indirection::Ordinary) + } else { + self.write_expr_plain_form(module, expr, None, Indirection::Ordinary) + } } /// Write `expr` as a WGSL expression with the requested indirection. @@ -1065,15 +1055,15 @@ impl Writer { match (requested, plain) { (Indirection::Ordinary, Indirection::Reference) => { write!(self.out, "(&")?; - self.write_expr_plain_form(module, expr, func_ctx, plain)?; + self.write_expr_plain_form(module, expr, Some(func_ctx), plain)?; write!(self.out, ")")?; } (Indirection::Reference, Indirection::Ordinary) => { write!(self.out, "(*")?; - self.write_expr_plain_form(module, expr, func_ctx, plain)?; + self.write_expr_plain_form(module, expr, Some(func_ctx), plain)?; write!(self.out, ")")?; } - (_, _) => self.write_expr_plain_form(module, expr, func_ctx, plain)?, + (_, _) => self.write_expr_plain_form(module, expr, Some(func_ctx), plain)?, } Ok(()) @@ -1090,7 +1080,7 @@ impl Writer { &mut self, module: &Module, expr: Handle, - func_ctx: &back::FunctionCtx<'_>, + func_ctx: Option<&back::FunctionCtx<'_>>, indirection: Indirection, ) -> BackendResult { use crate::Expression; @@ -1100,7 +1090,11 @@ impl Writer { return Ok(()); } - let expression = &func_ctx.expressions[expr]; + let expression = if let Some(func_ctx) = func_ctx { + &func_ctx.expressions[expr] + } else { + &module.const_expressions[expr] + }; // Write the plain WGSL form of a Naga expression. // @@ -1111,7 +1105,31 @@ impl Writer { // `postfix_expression` forms for member/component access and // subscripting. match *expression { - Expression::Constant(constant) => self.write_constant(module, constant)?, + Expression::Literal(literal) => { + match literal { + crate::Literal::I32(value) => write!(self.out, "{}", value)?, + crate::Literal::U32(value) => write!(self.out, "{}u", value)?, + // Floats are written using `Debug` instead of `Display` because it always appends the + // decimal part even it's zero + crate::Literal::F32(value) => write!(self.out, "{:?}", value)?, + crate::Literal::Bool(value) => write!(self.out, "{}", value)?, + crate::Literal::F64(_) => { + return Err(Error::Unimplemented("f64 literal".to_string())); + } + } + } + Expression::Constant(handle) => { + let constant = &module.constants[handle]; + if constant.name.is_some() { + write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?; + } else { + self.write_expr(module, constant.init.unwrap(), func_ctx)?; + } + } + Expression::New(ty) => { + self.write_type(module, ty)?; + write!(self.out, "()")?; + } Expression::Compose { ty, ref components } => { self.write_type(module, ty)?; write!(self.out, "(")?; @@ -1126,7 +1144,7 @@ impl Writer { write!(self.out, ")")? } Expression::FunctionArgument(pos) => { - let name_key = func_ctx.argument_key(pos); + let name_key = func_ctx.expect("non-global context").argument_key(pos); let name = &self.names[&name_key]; write!(self.out, "{name}")?; } @@ -1138,16 +1156,26 @@ impl Writer { write!(self.out, ")")?; } Expression::Access { base, index } => { - self.write_expr_with_indirection(module, base, func_ctx, indirection)?; + self.write_expr_with_indirection( + module, + base, + func_ctx.expect("non-global context"), + indirection, + )?; write!(self.out, "[")?; self.write_expr(module, index, func_ctx)?; write!(self.out, "]")? } Expression::AccessIndex { base, index } => { - let base_ty_res = &func_ctx.info[base].ty; + let base_ty_res = &func_ctx.expect("non-global context").info[base].ty; let mut resolved = base_ty_res.inner_with(&module.types); - self.write_expr_with_indirection(module, base, func_ctx, indirection)?; + self.write_expr_with_indirection( + module, + base, + func_ctx.expect("non-global context"), + indirection, + )?; let base_ty_handle = match *resolved { TypeInner::Pointer { base, space: _ } => { @@ -1246,7 +1274,7 @@ impl Writer { if let Some(offset) = offset { write!(self.out, ", ")?; - self.write_constant(module, offset)?; + self.write_expr(module, offset, None)?; } write!(self.out, ")")?; @@ -1267,15 +1295,20 @@ impl Writer { }; write!(self.out, "textureGather{suffix_cmp}(")?; - match *func_ctx.info[image].ty.inner_with(&module.types) { - TypeInner::Image { - class: crate::ImageClass::Depth { multi: _ }, - .. - } => {} - _ => { - write!(self.out, "{}, ", component as u8)?; - } - } + // match *func_ctx.expect("non-global context").info[image] + // .ty + // .inner_with(&module.types) + // { + // TypeInner::Image { + // class: crate::ImageClass::Depth { multi: _ }, + // .. + // } => {} + // _ => { + // write!(self.out, "{}, ", component as u8)?; + // } + // } + self.write_expr(module, component, func_ctx)?; + write!(self.out, ", ")?; self.write_expr(module, image, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, sampler, func_ctx)?; @@ -1294,7 +1327,7 @@ impl Writer { if let Some(offset) = offset { write!(self.out, ", ")?; - self.write_constant(module, offset)?; + self.write_expr(module, offset, None)?; } write!(self.out, ")")?; @@ -1347,7 +1380,9 @@ impl Writer { kind, convert, } => { - let inner = func_ctx.info[expr].ty.inner_with(&module.types); + let inner = func_ctx.expect("non-global context").info[expr] + .ty + .inner_with(&module.types); match *inner { TypeInner::Matrix { columns, @@ -1392,7 +1427,9 @@ impl Writer { write!(self.out, ")")?; } Expression::Splat { size, value } => { - let inner = func_ctx.info[value].ty.inner_with(&module.types); + let inner = func_ctx.expect("non-global context").info[value] + .ty + .inner_with(&module.types); let (scalar_kind, scalar_width) = match *inner { crate::TypeInner::Scalar { kind, width } => (kind, width), _ => { @@ -1409,7 +1446,10 @@ impl Writer { write!(self.out, ")")?; } Expression::Load { pointer } => { - let is_atomic = match *func_ctx.info[pointer].ty.inner_with(&module.types) { + let is_atomic = match *func_ctx.expect("non-global context").info[pointer] + .ty + .inner_with(&module.types) + { crate::TypeInner::Pointer { base, .. } => match module.types[base].inner { crate::TypeInner::Atomic { .. } => true, _ => false, @@ -1425,14 +1465,16 @@ impl Writer { self.write_expr_with_indirection( module, pointer, - func_ctx, + func_ctx.expect("non-global context"), Indirection::Reference, )?; } } - Expression::LocalVariable(handle) => { - write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])? - } + Expression::LocalVariable(handle) => write!( + self.out, + "{}", + self.names[&func_ctx.expect("non-global context").name_key(handle)] + )?, Expression::ArrayLength(expr) => { write!(self.out, "arrayLength(")?; self.write_expr(module, expr, func_ctx)?; @@ -1561,7 +1603,10 @@ impl Writer { let unary = match op { crate::UnaryOperator::Negate => "-", crate::UnaryOperator::Not => { - match *func_ctx.info[expr].ty.inner_with(&module.types) { + match *func_ctx.expect("non-global context").info[expr] + .ty + .inner_with(&module.types) + { TypeInner::Scalar { kind: crate::ScalarKind::Bool, .. @@ -1663,7 +1708,7 @@ impl Writer { // Write initializer if let Some(init) = global.init { write!(self.out, " = ")?; - self.write_constant(module, init)?; + self.write_expr(module, init, None)?; } // End with semicolon @@ -1672,47 +1717,6 @@ impl Writer { Ok(()) } - /// Helper method used to write constants - /// - /// # Notes - /// Doesn't add any newlines or leading/trailing spaces - fn write_constant( - &mut self, - module: &Module, - handle: Handle, - ) -> BackendResult { - let constant = &module.constants[handle]; - match constant.inner { - crate::ConstantInner::Scalar { - width: _, - ref value, - } => { - if constant.name.is_some() { - write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?; - } else { - self.write_scalar_value(*value)?; - } - } - crate::ConstantInner::Composite { ty, ref components } => { - self.write_type(module, ty)?; - write!(self.out, "(")?; - - // Write the comma separated constants - for (index, constant) in components.iter().enumerate() { - self.write_constant(module, *constant)?; - // Only write a comma if isn't the last element - if index != components.len().saturating_sub(1) { - // The leading space is for readability only - write!(self.out, ", ")?; - } - } - write!(self.out, ")")? - } - } - - Ok(()) - } - /// Helper method used to write global constants /// /// # Notes @@ -1720,61 +1724,17 @@ impl Writer { fn write_global_constant( &mut self, module: &Module, - inner: &crate::ConstantInner, handle: Handle, ) -> BackendResult { - match *inner { - crate::ConstantInner::Scalar { - width: _, - ref value, - } => { - let name = &self.names[&NameKey::Constant(handle)]; - // First write only constant name - write!(self.out, "const {name}: ")?; - // Next write constant type and value - match *value { - crate::ScalarValue::Sint(value) => { - write!(self.out, "i32 = {value}")?; - } - crate::ScalarValue::Uint(value) => { - write!(self.out, "u32 = {value}u")?; - } - crate::ScalarValue::Float(value) => { - // Floats are written using `Debug` instead of `Display` because it always appends the - // decimal part even it's zero - write!(self.out, "f32 = {value:?}")?; - } - crate::ScalarValue::Bool(value) => { - write!(self.out, "bool = {value}")?; - } - }; - // End with semicolon - writeln!(self.out, ";")?; - } - crate::ConstantInner::Composite { ty, ref components } => { - let name = &self.names[&NameKey::Constant(handle)]; - // First write only constant name - write!(self.out, "const {name}: ")?; - // Next write constant type - self.write_type(module, ty)?; - - write!(self.out, " = ")?; - self.write_type(module, ty)?; + let name = &self.names[&NameKey::Constant(handle)]; + // First write only constant name + write!(self.out, "const {name}: ")?; + self.write_type(module, module.constants[handle].ty)?; + write!(self.out, " = ")?; + let init = module.constants[handle].init.unwrap(); + self.write_expr(module, init, None)?; + writeln!(self.out, ";")?; - write!(self.out, "(")?; - for (index, constant) in components.iter().enumerate() { - self.write_constant(module, *constant)?; - // Only write a comma if isn't the last element - if index != components.len().saturating_sub(1) { - // The leading space is for readability only - write!(self.out, ", ")?; - } - } - write!(self.out, ");")?; - } - } - // End with extra newline for readability - writeln!(self.out)?; Ok(()) } diff --git a/src/front/glsl/constants.rs b/src/front/glsl/constants.rs index d9a6fc7cd7..8ca17c5750 100644 --- a/src/front/glsl/constants.rs +++ b/src/front/glsl/constants.rs @@ -546,7 +546,7 @@ impl<'a> ConstantSolver<'a> { self.constants.fetch_or_append( Constant { name: None, - specialization: None, + specialization: crate::Specialization::None, inner, }, span, @@ -650,7 +650,7 @@ mod tests { let h = constants.append( Constant { name: None, - specialization: None, + specialization: crate::Specialization::None, inner: ConstantInner::Scalar { width: 4, value: ScalarValue::Sint(4), @@ -662,7 +662,7 @@ mod tests { let h1 = constants.append( Constant { name: None, - specialization: None, + specialization: crate::Specialization::None, inner: ConstantInner::Scalar { width: 4, value: ScalarValue::Sint(8), @@ -674,7 +674,7 @@ mod tests { let vec_h = constants.append( Constant { name: None, - specialization: None, + specialization: crate::Specialization::None, inner: ConstantInner::Composite { ty: vec_ty, components: vec![h, h1], @@ -773,7 +773,7 @@ mod tests { let h = constants.append( Constant { name: None, - specialization: None, + specialization: crate::Specialization::None, inner: ConstantInner::Scalar { width: 4, value: ScalarValue::Sint(4), @@ -847,7 +847,7 @@ mod tests { let h = constants.append( Constant { name: None, - specialization: None, + specialization: crate::Specialization::None, inner: ConstantInner::Scalar { width: 4, value: ScalarValue::Float(i as f64), @@ -863,7 +863,7 @@ mod tests { let h = constants.append( Constant { name: None, - specialization: None, + specialization: crate::Specialization::None, inner: ConstantInner::Scalar { width: 4, value: ScalarValue::Float(i as f64), @@ -878,7 +878,7 @@ mod tests { let vec1 = constants.append( Constant { name: None, - specialization: None, + specialization: crate::Specialization::None, inner: ConstantInner::Composite { ty: vec_ty, components: vec1_components, @@ -890,7 +890,7 @@ mod tests { let vec2 = constants.append( Constant { name: None, - specialization: None, + specialization: crate::Specialization::None, inner: ConstantInner::Composite { ty: vec_ty, components: vec2_components, @@ -902,7 +902,7 @@ mod tests { let h = constants.append( Constant { name: None, - specialization: None, + specialization: crate::Specialization::None, inner: ConstantInner::Composite { ty: matrix_ty, components: vec![vec1, vec2], diff --git a/src/front/glsl/context.rs b/src/front/glsl/context.rs index 6df4850efa..500c148668 100644 --- a/src/front/glsl/context.rs +++ b/src/front/glsl/context.rs @@ -1303,7 +1303,7 @@ impl Context { let constant_1 = frontend.module.constants.append( Constant { name: None, - specialization: None, + specialization: crate::Specialization::None, inner, }, Default::default(), diff --git a/src/front/glsl/functions.rs b/src/front/glsl/functions.rs index 10c964b5e0..bd85ff7f75 100644 --- a/src/front/glsl/functions.rs +++ b/src/front/glsl/functions.rs @@ -40,7 +40,7 @@ impl Frontend { self.module.constants.fetch_or_append( Constant { name: None, - specialization: None, + specialization: crate::Specialization::None, inner: ConstantInner::Scalar { width: 4, value }, }, meta, @@ -285,7 +285,7 @@ impl Frontend { let zero_constant = self.module.constants.fetch_or_append( Constant { name: None, - specialization: None, + specialization: crate::Specialization::None, inner: ConstantInner::Scalar { width, value: ScalarValue::Float(0.0), @@ -326,7 +326,7 @@ impl Frontend { let zero_constant = self.module.constants.fetch_or_append( Constant { name: None, - specialization: None, + specialization: crate::Specialization::None, inner: ConstantInner::Scalar { width, value: ScalarValue::Float(0.0), @@ -338,7 +338,7 @@ impl Frontend { let one_constant = self.module.constants.fetch_or_append( Constant { name: None, - specialization: None, + specialization: crate::Specialization::None, inner: ConstantInner::Scalar { width, value: ScalarValue::Float(1.0), @@ -409,7 +409,7 @@ impl Frontend { let vec_constant = self.module.constants.fetch_or_append( Constant { name: None, - specialization: None, + specialization: crate::Specialization::None, inner: ConstantInner::Composite { ty: vector_ty, components: (0..rows as u32) @@ -1501,7 +1501,7 @@ impl Frontend { offset: span, }); - span += self.module.types[ty].inner.size(&self.module.constants); + span += self.module.types[ty].inner.size(self.module.to_ctx()); let len = expressions.len(); let load = expressions.append(Expression::Load { pointer }, Default::default()); diff --git a/src/front/glsl/parser/expressions.rs b/src/front/glsl/parser/expressions.rs index f09e58b6f6..1f3d94fe61 100644 --- a/src/front/glsl/parser/expressions.rs +++ b/src/front/glsl/parser/expressions.rs @@ -62,7 +62,7 @@ impl<'source> ParsingContext<'source> { let handle = frontend.module.constants.fetch_or_append( Constant { name: None, - specialization: None, + specialization: crate::Specialization::None, inner: ConstantInner::Scalar { width, value }, }, token.meta, @@ -139,7 +139,7 @@ impl<'source> ParsingContext<'source> { let constant = frontend.module.constants.fetch_or_append( Constant { name: None, - specialization: None, + specialization: crate::Specialization::None, inner: ConstantInner::Scalar { width: 4, value: ScalarValue::Uint(args.len() as u64), diff --git a/src/front/glsl/parser/types.rs b/src/front/glsl/parser/types.rs index 08a70669a0..ede4397927 100644 --- a/src/front/glsl/parser/types.rs +++ b/src/front/glsl/parser/types.rs @@ -40,7 +40,7 @@ impl<'source> ParsingContext<'source> { let constant = frontend.module.constants.fetch_or_append( crate::Constant { name: None, - specialization: None, + specialization: crate::Specialization::None, inner: crate::ConstantInner::Scalar { width: 4, value: crate::ScalarValue::Uint(value as u64), @@ -55,7 +55,7 @@ impl<'source> ParsingContext<'source> { frontend .layouter - .update(&frontend.module.types, &frontend.module.constants) + .update(frontend.module.to_ctx()) .unwrap(); let stride = frontend.layouter[*ty].to_stride(); *ty = frontend.module.types.insert( diff --git a/src/front/glsl/parser_tests.rs b/src/front/glsl/parser_tests.rs index ae529a6fbe..3cbf927cd6 100644 --- a/src/front/glsl/parser_tests.rs +++ b/src/front/glsl/parser_tests.rs @@ -532,7 +532,7 @@ fn constants() { constants.next().unwrap().1, &Constant { name: Some("a".to_owned()), - specialization: None, + specialization: crate::Specialization::None, inner: ConstantInner::Scalar { width: 4, value: ScalarValue::Float(1.0) @@ -543,7 +543,7 @@ fn constants() { constants.next().unwrap().1, &Constant { name: Some("b".to_owned()), - specialization: None, + specialization: crate::Specialization::None, inner: ConstantInner::Scalar { width: 4, value: ScalarValue::Float(1.0) diff --git a/src/front/spv/mod.rs b/src/front/spv/mod.rs index fde52ed270..e60b52e77e 100644 --- a/src/front/spv/mod.rs +++ b/src/front/spv/mod.rs @@ -4301,9 +4301,7 @@ impl> Frontend { let decor = self.future_decor.remove(&id).unwrap_or_default(); let base = self.lookup_type.lookup(type_id)?.handle; - self.layouter - .update(&module.types, &module.constants) - .unwrap(); + self.layouter.update(module.to_ctx()).unwrap(); // HACK if the underlying type is an image or a sampler, let's assume // that we're dealing with a binding-array @@ -4384,9 +4382,7 @@ impl> Frontend { let decor = self.future_decor.remove(&id).unwrap_or_default(); let base = self.lookup_type.lookup(type_id)?.handle; - self.layouter - .update(&module.types, &module.constants) - .unwrap(); + self.layouter.update(module.to_ctx()).unwrap(); // HACK same case as in `parse_type_array()` let inner = if let crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } = @@ -4437,9 +4433,7 @@ impl> Frontend { .as_ref() .map_or(false, |decor| decor.storage_buffer); - self.layouter - .update(&module.types, &module.constants) - .unwrap(); + self.layouter.update(module.to_ctx()).unwrap(); let mut members = Vec::::with_capacity(inst.wc as usize - 2); let mut member_lookups = Vec::with_capacity(members.capacity()); @@ -4671,11 +4665,11 @@ impl> Frontend { let low = self.next()?; match width { 4 => crate::Literal::U32(low), - 8 => { - inst.expect(5)?; - let high = self.next()?; - crate::Literal::U64((u64::from(high) << 32) | u64::from(low)) - } + // 8 => { + // inst.expect(5)?; + // let high = self.next()?; + // crate::Literal::U64((u64::from(high) << 32) | u64::from(low)) + // } _ => return Err(Error::InvalidTypeWidth(width as u32)), } } @@ -4686,13 +4680,13 @@ impl> Frontend { let low = self.next()?; match width { 4 => crate::Literal::I32(low as i32), - 8 => { - inst.expect(5)?; - let high = self.next()?; - crate::Literal::I64( - (i64::from(high as i32) << 32) | ((i64::from(low as i32) << 32) >> 32), - ) - } + // 8 => { + // inst.expect(5)?; + // let high = self.next()?; + // crate::Literal::I64( + // (i64::from(high as i32) << 32) | ((i64::from(low as i32) << 32) >> 32), + // ) + // } _ => return Err(Error::InvalidTypeWidth(width as u32)), } } @@ -4778,7 +4772,7 @@ impl> Frontend { handle: module.constants.append( crate::Constant { name: self.future_decor.remove(&id).and_then(|dec| dec.name), - specialization: None, + specialization: crate::Specialization::None, ty: None, init: Some(expr), }, @@ -4811,7 +4805,7 @@ impl> Frontend { let handle = module.constants.append( crate::Constant { name: self.future_decor.remove(&id).and_then(|dec| dec.name), - specialization: None, //TODO + specialization: crate::Specialization::None, //TODO ty: None, init: Some(expr), }, @@ -4845,7 +4839,7 @@ impl> Frontend { handle: module.constants.append( crate::Constant { name: self.future_decor.remove(&id).and_then(|dec| dec.name), - specialization: None, //TODO + specialization: crate::Specialization::None, //TODO ty: None, init: Some(expr), }, diff --git a/src/front/spv/null.rs b/src/front/spv/null.rs index b3e2314e81..9d9224419f 100644 --- a/src/front/spv/null.rs +++ b/src/front/spv/null.rs @@ -1,117 +1,6 @@ use super::Error; use crate::arena::{Arena, Handle, UniqueArena}; -// const fn make_scalar_inner(kind: crate::ScalarKind, width: crate::Bytes) -> crate::ConstantInner { -// crate::ConstantInner::Scalar { -// width, -// value: match kind { -// crate::ScalarKind::Uint => crate::ScalarValue::Uint(0), -// crate::ScalarKind::Sint => crate::ScalarValue::Sint(0), -// crate::ScalarKind::Float => crate::ScalarValue::Float(0.0), -// crate::ScalarKind::Bool => crate::ScalarValue::Bool(false), -// }, -// } -// } - -// pub fn generate_null_constant( -// ty: Handle, -// type_arena: &UniqueArena, -// constant_arena: &mut Arena, -// span: crate::Span, -// ) -> Result { -// let inner = match type_arena[ty].inner { -// crate::TypeInner::Scalar { kind, width } => make_scalar_inner(kind, width), -// crate::TypeInner::Vector { size, kind, width } => { -// let mut components = Vec::with_capacity(size as usize); -// for _ in 0..size as usize { -// components.push(constant_arena.fetch_or_append( -// crate::Constant { -// name: None, -// specialization: None, -// inner: make_scalar_inner(kind, width), -// }, -// span, -// )); -// } -// crate::ConstantInner::Composite { ty, components } -// } -// crate::TypeInner::Matrix { -// columns, -// rows, -// width, -// } => { -// // If we successfully declared a matrix type, we have declared a vector type for it too. -// let vector_ty = type_arena -// .get(&crate::Type { -// name: None, -// inner: crate::TypeInner::Vector { -// kind: crate::ScalarKind::Float, -// size: rows, -// width, -// }, -// }) -// .unwrap(); -// let vector_inner = generate_null_constant(vector_ty, type_arena, constant_arena, span)?; -// let vector_handle = constant_arena.fetch_or_append( -// crate::Constant { -// name: None, -// specialization: None, -// inner: vector_inner, -// }, -// span, -// ); -// crate::ConstantInner::Composite { -// ty, -// components: vec![vector_handle; columns as usize], -// } -// } -// crate::TypeInner::Struct { ref members, .. } => { -// let mut components = Vec::with_capacity(members.len()); -// // copy out the types to avoid borrowing `members` -// let member_tys = members.iter().map(|member| member.ty).collect::>(); -// for member_ty in member_tys { -// let inner = generate_null_constant(member_ty, type_arena, constant_arena, span)?; -// components.push(constant_arena.fetch_or_append( -// crate::Constant { -// name: None, -// specialization: None, -// inner, -// }, -// span, -// )); -// } -// crate::ConstantInner::Composite { ty, components } -// } -// crate::TypeInner::Array { -// base, -// size: crate::ArraySize::Constant(handle), -// .. -// } => { -// let size = constant_arena[handle] -// .to_array_length() -// .ok_or(Error::InvalidArraySize(handle))?; -// let inner = generate_null_constant(base, type_arena, constant_arena, span)?; -// let value = constant_arena.fetch_or_append( -// crate::Constant { -// name: None, -// specialization: None, -// inner, -// }, -// span, -// ); -// crate::ConstantInner::Composite { -// ty, -// components: vec![value; size as usize], -// } -// } -// ref other => { -// log::warn!("null constant type {:?}", other); -// return Err(Error::UnsupportedType(ty)); -// } -// }; -// Ok(inner) -// } - /// Create a default value for an output built-in. pub fn generate_default_built_in( built_in: Option, diff --git a/src/front/wgsl/lower/construction.rs b/src/front/wgsl/lower/construction.rs index d8b1ba2fa0..cb573a1e64 100644 --- a/src/front/wgsl/lower/construction.rs +++ b/src/front/wgsl/lower/construction.rs @@ -2,7 +2,7 @@ use crate::front::wgsl::parse::ast; use crate::{Handle, Span}; use crate::front::wgsl::error::Error; -use crate::front::wgsl::lower::{ExpressionContext, Lowerer, OutputContext}; +use crate::front::wgsl::lower::{ExpressionContext, Lowerer}; use crate::proc::TypeResolution; enum ConcreteConstructorHandle { @@ -143,7 +143,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { components: &[Handle>], mut ctx: ExpressionContext<'source, '_, '_>, ) -> Result, Error<'source>> { - let constructor_h = self.constructor(constructor, ctx.as_output())?; + let constructor_h = self.constructor(constructor, ctx.reborrow())?; let components_h = match *components { [] => ComponentsHandle::None, @@ -151,7 +151,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let span = ctx.ast_expressions.get_span(component); let component = self.expression(component, ctx.reborrow())?; ctx.grow_types(component)?; - let ty = &ctx.typifier[component]; + let ty = &ctx.typifier()[component]; ComponentsHandle::One { component, @@ -177,7 +177,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .collect(); ctx.grow_types(component)?; - let ty = &ctx.typifier[component]; + let ty = &ctx.typifier()[component]; ComponentsHandle::Many { components, @@ -193,14 +193,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ); let expr = match (components, constructor) { // Empty constructor - (Components::None, dst_ty) => { - let ty = match dst_ty { - ConcreteConstructor::Type(ty, _) => ty, - _ => return Err(Error::TypeNotInferrable(ty_span)), - }; - - return Ok(ctx.interrupt_emitter(crate::Expression::New(ty), span)); - } + (Components::None, dst_ty) => match dst_ty { + ConcreteConstructor::Type(ty, _) => crate::Expression::New(ty), + _ => return Err(Error::TypeNotInferrable(ty_span)), + }, // Scalar constructor & conversion (scalar -> scalar) ( @@ -400,7 +396,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let components = components .chunks(rows as usize) .map(|vec_components| { - ctx.naga_expressions.append( + ctx.expressions().append( crate::Expression::Compose { ty: vec_ty, components: Vec::from(vec_components), @@ -456,25 +452,16 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let base = ctx.register_type(components[0])?; - let size = crate::Constant { - name: None, - specialization: None, - ty: None, - init: Some(ctx.module.const_expressions.append( - crate::Expression::Literal(crate::Literal::U32(components.len() as _)), - span, - )), - }; + let size_expr = ctx.module.const_expressions.append( + crate::Expression::Literal(crate::Literal::U32(components.len() as _)), + span, + ); let inner = crate::TypeInner::Array { base, - size: crate::ArraySize::Constant( - ctx.module.constants.fetch_or_append(size, Span::UNDEFINED), - ), + size: crate::ArraySize::Constant(size_expr), stride: { - self.layouter - .update(&ctx.module.types, &ctx.module.constants) - .unwrap(); + self.layouter.update(ctx.module.to_ctx()).unwrap(); self.layouter[base].to_stride() }, }; @@ -535,55 +522,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { _ => return Err(Error::TypeNotConstructible(ty_span)), }; - let expr = ctx.naga_expressions.append(expr, span); + let expr = ctx.expressions().append(expr, span); Ok(expr) } - // /// Build a Naga IR [`ConstantInner`] given a WGSL construction expression. - // /// - // /// Given `constructor`, representing the head of a WGSL [`type constructor - // /// expression`], and a slice of [`ast::Expression`] handles representing - // /// the constructor's arguments, build a Naga [`ConstantInner`] value - // /// representing the given value. - // /// - // /// If `constructor` is for a composite type, this may entail adding new - // /// [`Type`]s and [`Constant`]s to [`ctx.module`], if it doesn't already - // /// have what we need. - // /// - // /// If the arguments cannot be evaluated at compile time, return an error. - // /// - // /// [`ConstantInner`]: crate::ConstantInner - // /// [`type constructor expression`]: https://gpuweb.github.io/gpuweb/wgsl/#type-constructor-expr - // /// [`Function::expressions`]: ast::Function::expressions - // /// [`TranslationUnit::global_expressions`]: ast::TranslationUnit::global_expressions - // /// [`Type`]: crate::Type - // /// [`Constant`]: crate::Constant - // /// [`ctx.module`]: OutputContext::module - // pub fn const_construct( - // &mut self, - // span: Span, - // constructor: &ast::ConstructorType<'source>, - // components: &[Handle>], - // mut ctx: OutputContext<'source, '_, '_>, - // ) -> Result> { - // // TODO: Support zero values, splatting and inference. - - // let constructor = self.constructor(constructor, ctx.reborrow())?; - - // let c = match constructor { - // ConcreteConstructorHandle::Type(ty) => { - // let components = components - // .iter() - // .map(|&expr| self.constant(expr, ctx.reborrow())) - // .collect::>()?; - - // crate::ConstantInner::Composite { ty, components } - // } - // _ => return Err(Error::ConstExprUnsupported(span)), - // }; - // Ok(c) - // } - /// Build a Naga IR [`Type`] for `constructor` if there is enough /// information to do so. /// @@ -601,13 +543,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { /// array's length. /// /// [`Type`]: crate::Type - /// [`ctx.module`]: OutputContext::module + /// [`ctx.module`]: GlobalContext::module /// [`Array`]: crate::TypeInner::Array /// [`Constant`]: crate::Constant fn constructor<'out>( &mut self, constructor: &ast::ConstructorType<'source>, - mut ctx: OutputContext<'source, '_, 'out>, + mut ctx: ExpressionContext<'source, '_, 'out>, ) -> Result> { let c = match *constructor { ast::ConstructorType::Scalar { width, kind } => { @@ -638,17 +580,15 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } ast::ConstructorType::PartialArray => ConcreteConstructorHandle::PartialArray, ast::ConstructorType::Array { base, size } => { - let base = self.resolve_ast_type(base, ctx.reborrow())?; + let base = self.resolve_ast_type(base, ctx.as_global())?; let size = match size { ast::ArraySize::Constant(expr) => { - crate::ArraySize::Constant(self.constant(expr, ctx.reborrow())?) + crate::ArraySize::Constant(self.expression(expr, ctx.reborrow())?) } ast::ArraySize::Dynamic => crate::ArraySize::Dynamic, }; - self.layouter - .update(&ctx.module.types, &ctx.module.constants) - .unwrap(); + self.layouter.update(ctx.module.to_ctx()).unwrap(); let ty = ctx.ensure_type_exists(crate::TypeInner::Array { base, size, diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index 4434c5c3be..b405d0e1cd 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -10,7 +10,7 @@ use indexmap::IndexMap; mod construction; /// State for constructing a `crate::Module`. -pub struct OutputContext<'source, 'temp, 'out> { +pub struct GlobalContext<'source, 'temp, 'out> { /// The `TranslationUnit`'s expressions arena. ast_expressions: &'temp Arena>, @@ -24,15 +24,29 @@ pub struct OutputContext<'source, 'temp, 'out> { /// The module we're constructing. module: &'out mut crate::Module, + + const_typifier: &'temp mut Typifier, } -impl<'source> OutputContext<'source, '_, '_> { - fn reborrow(&mut self) -> OutputContext<'source, '_, '_> { - OutputContext { +impl<'source> GlobalContext<'source, '_, '_> { + fn reborrow(&mut self) -> GlobalContext<'source, '_, '_> { + GlobalContext { + ast_expressions: self.ast_expressions, + globals: self.globals, + types: self.types, + module: self.module, + const_typifier: self.const_typifier, + } + } + + fn as_const(&mut self) -> ExpressionContext<'source, '_, '_> { + ExpressionContext { ast_expressions: self.ast_expressions, globals: self.globals, types: self.types, module: self.module, + const_typifier: self.const_typifier, + expr_type: ExpressionContextType::Constant, } } @@ -72,6 +86,7 @@ pub struct StatementContext<'source, 'temp, 'out> { /// [`FunctionArgument`]: crate::Expression::FunctionArgument local_table: &'temp mut FastHashMap, TypedExpression>, + const_typifier: &'temp mut Typifier, typifier: &'temp mut Typifier, variables: &'out mut Arena, naga_expressions: &'out mut Arena, @@ -89,6 +104,7 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> { globals: self.globals, types: self.types, ast_expressions: self.ast_expressions, + const_typifier: self.const_typifier, typifier: self.typifier, variables: self.variables, naga_expressions: self.naga_expressions, @@ -107,26 +123,30 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> { 'temp: 't, { ExpressionContext { - local_table: self.local_table, globals: self.globals, types: self.types, ast_expressions: self.ast_expressions, - typifier: self.typifier, - naga_expressions: self.naga_expressions, + const_typifier: self.const_typifier, module: self.module, - local_vars: self.variables, - arguments: self.arguments, - block, - emitter, + expr_type: ExpressionContextType::Runtime(RuntimeExpressionContext { + local_table: self.local_table, + naga_expressions: self.naga_expressions, + local_vars: self.variables, + arguments: self.arguments, + typifier: self.typifier, + block, + emitter, + }), } } - fn as_output(&mut self) -> OutputContext<'a, '_, '_> { - OutputContext { + fn as_global(&mut self) -> GlobalContext<'a, '_, '_> { + GlobalContext { ast_expressions: self.ast_expressions, globals: self.globals, types: self.types, module: self.module, + const_typifier: self.const_typifier, } } @@ -144,12 +164,41 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> { } } +pub struct RuntimeExpressionContext<'temp, 'out> { + local_table: &'temp mut FastHashMap, TypedExpression>, + + naga_expressions: &'out mut Arena, + local_vars: &'out Arena, + arguments: &'out [crate::FunctionArgument], + block: &'temp mut crate::Block, + emitter: &'temp mut Emitter, + typifier: &'temp mut Typifier, +} + +impl RuntimeExpressionContext<'_, '_> { + fn reborrow(&mut self) -> RuntimeExpressionContext<'_, '_> { + RuntimeExpressionContext { + local_table: self.local_table, + naga_expressions: self.naga_expressions, + local_vars: self.local_vars, + arguments: self.arguments, + block: self.block, + emitter: self.emitter, + typifier: self.typifier, + } + } +} + +pub enum ExpressionContextType<'temp, 'out> { + Runtime(RuntimeExpressionContext<'temp, 'out>), + Constant, +} + /// State for lowering an `ast::Expression` to Naga IR. /// /// Not to be confused with `parser::ExpressionContext`. pub struct ExpressionContext<'source, 'temp, 'out> { // WGSL AST values. - local_table: &'temp mut FastHashMap, TypedExpression>, ast_expressions: &'temp Arena>, types: &'temp Arena>, @@ -158,38 +207,51 @@ pub struct ExpressionContext<'source, 'temp, 'out> { /// `Handle`s we have built for them, owned by `Lowerer::lower`. globals: &'temp mut FastHashMap<&'source str, LoweredGlobalDecl>, - typifier: &'temp mut Typifier, - naga_expressions: &'out mut Arena, - local_vars: &'out Arena, - arguments: &'out [crate::FunctionArgument], + const_typifier: &'temp mut Typifier, module: &'out mut crate::Module, - block: &'temp mut crate::Block, - emitter: &'temp mut Emitter, + expr_type: ExpressionContextType<'temp, 'out>, } -impl<'a> ExpressionContext<'a, '_, '_> { - fn reborrow(&mut self) -> ExpressionContext<'a, '_, '_> { +impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { + fn reborrow(&mut self) -> ExpressionContext<'source, 'temp, 'out> { ExpressionContext { - local_table: self.local_table, globals: self.globals, types: self.types, ast_expressions: self.ast_expressions, - typifier: self.typifier, - naga_expressions: self.naga_expressions, + const_typifier: self.const_typifier, module: self.module, - local_vars: self.local_vars, - arguments: self.arguments, - block: self.block, - emitter: self.emitter, + expr_type: self.expr_type, } } - fn as_output(&mut self) -> OutputContext<'a, '_, '_> { - OutputContext { + fn as_global(&mut self) -> GlobalContext<'source, '_, '_> { + GlobalContext { ast_expressions: self.ast_expressions, globals: self.globals, types: self.types, module: self.module, + const_typifier: self.const_typifier, + } + } + + fn expressions(&mut self) -> &'out mut Arena { + match self.expr_type { + ExpressionContextType::Runtime(ctx) => ctx.naga_expressions, + ExpressionContextType::Constant => &mut self.module.const_expressions, + } + } + + fn typifier(&mut self) -> &'temp mut Typifier { + match self.expr_type { + ExpressionContextType::Runtime(ctx) => ctx.typifier, + ExpressionContextType::Constant => &mut self.const_typifier, + } + } + + fn runtime_expression_ctx(&mut self) -> RuntimeExpressionContext<'temp, 'out> { + match self.expr_type { + ExpressionContextType::Runtime(ctx) => ctx.reborrow(), + ExpressionContextType::Constant => panic!(), } } @@ -208,9 +270,11 @@ impl<'a> ExpressionContext<'a, '_, '_> { fn register_type( &mut self, handle: Handle, - ) -> Result, Error<'a>> { + ) -> Result, Error<'source>> { self.grow_types(handle)?; - Ok(self.typifier.register_type(handle, &mut self.module.types)) + Ok(self + .typifier() + .register_type(handle, &mut self.module.types)) } /// Resolve the types of all expressions up through `handle`. @@ -233,30 +297,39 @@ impl<'a> ExpressionContext<'a, '_, '_> { /// [`self.typifier`]: ExpressionContext::typifier /// [`self.resolved_inner(handle)`]: ExpressionContext::resolved_inner /// [`Typifier`]: Typifier - fn grow_types(&mut self, handle: Handle) -> Result<&mut Self, Error<'a>> { + fn grow_types( + &mut self, + handle: Handle, + ) -> Result<&mut Self, Error<'source>> { let resolve_ctx = ResolveContext { constants: &self.module.constants, types: &self.module.types, global_vars: &self.module.global_variables, - local_vars: self.local_vars, + local_vars: match self.expr_type { + ExpressionContextType::Runtime(ctx) => ctx.local_vars, + ExpressionContextType::Constant => &Arena::new(), + }, functions: &self.module.functions, - arguments: self.arguments, + arguments: match self.expr_type { + ExpressionContextType::Runtime(ctx) => ctx.arguments, + ExpressionContextType::Constant => &[], + }, }; - self.typifier - .grow(handle, self.naga_expressions, &resolve_ctx) + self.typifier() + .grow(handle, self.expressions(), &resolve_ctx) .map_err(Error::InvalidResolve)?; Ok(self) } fn resolved_inner(&self, handle: Handle) -> &crate::TypeInner { - self.typifier[handle].inner_with(&self.module.types) + self.typifier()[handle].inner_with(&self.module.types) } fn image_data( &mut self, image: Handle, span: Span, - ) -> Result<(crate::ImageClass, bool), Error<'a>> { + ) -> Result<(crate::ImageClass, bool), Error<'source>> { self.grow_types(image)?; match *self.resolved_inner(image) { crate::TypeInner::Image { class, arrayed, .. } => Ok((class, arrayed)), @@ -266,10 +339,10 @@ impl<'a> ExpressionContext<'a, '_, '_> { fn prepare_args<'b>( &mut self, - args: &'b [Handle>], + args: &'b [Handle>], min_args: u32, span: Span, - ) -> ArgumentContext<'b, 'a> { + ) -> ArgumentContext<'b, 'source> { ArgumentContext { args: args.iter(), min_args, @@ -285,7 +358,7 @@ impl<'a> ExpressionContext<'a, '_, '_> { op: crate::BinaryOperator, left: &mut Handle, right: &mut Handle, - ) -> Result<(), Error<'a>> { + ) -> Result<(), Error<'source>> { if op != crate::BinaryOperator::Multiply { self.grow_types(*left)?.grow_types(*right)?; @@ -296,18 +369,20 @@ impl<'a> ExpressionContext<'a, '_, '_> { match (left_size, self.resolved_inner(*right)) { (Some(size), &crate::TypeInner::Scalar { .. }) => { - *right = self.naga_expressions.append( + let expressions = self.expressions(); + *right = expressions.append( crate::Expression::Splat { size, value: *right, }, - self.naga_expressions.get_span(*right), + expressions.get_span(*right), ); } (None, &crate::TypeInner::Vector { size, .. }) => { - *left = self.naga_expressions.append( + let expressions = self.expressions(); + *left = expressions.append( crate::Expression::Splat { size, value: *left }, - self.naga_expressions.get_span(*left), + expressions.get_span(*left), ); } _ => {} @@ -326,10 +401,20 @@ impl<'a> ExpressionContext<'a, '_, '_> { expression: crate::Expression, span: Span, ) -> Handle { - self.block - .extend(self.emitter.finish(self.naga_expressions)); - let result = self.naga_expressions.append(expression, span); - self.emitter.start(self.naga_expressions); + let expressions = self.expressions(); + match self.expr_type { + ExpressionContextType::Runtime(rctx) => { + rctx.block.extend(rctx.emitter.finish(expressions)); + } + _ => {} + } + let result = expressions.append(expression, span); + match self.expr_type { + ExpressionContextType::Runtime(rctx) => { + rctx.emitter.start(expressions); + } + _ => {} + } result } @@ -342,90 +427,14 @@ impl<'a> ExpressionContext<'a, '_, '_> { let load = crate::Expression::Load { pointer: expr.handle, }; - let span = self.naga_expressions.get_span(expr.handle); - self.naga_expressions.append(load, span) + let expressions = self.expressions(); + let span = expressions.get_span(expr.handle); + expressions.append(load, span) } else { expr.handle } } - // /// Creates a zero value constant of type `ty` - // /// - // /// Returns `None` if the given `ty` is not a constructible type - // fn create_zero_value_constant( - // &mut self, - // ty: Handle, - // ) -> Option> { - // let inner = match self.module.types[ty].inner { - // crate::TypeInner::Scalar { kind, width } => { - // let value = match kind { - // crate::ScalarKind::Sint => crate::ScalarValue::Sint(0), - // crate::ScalarKind::Uint => crate::ScalarValue::Uint(0), - // crate::ScalarKind::Float => crate::ScalarValue::Float(0.), - // crate::ScalarKind::Bool => crate::ScalarValue::Bool(false), - // }; - // crate::ConstantInner::Scalar { width, value } - // } - // crate::TypeInner::Vector { size, kind, width } => { - // let scalar_ty = self.ensure_type_exists(crate::TypeInner::Scalar { width, kind }); - // let component = self.create_zero_value_constant(scalar_ty)?; - // crate::ConstantInner::Composite { - // ty, - // components: (0..size as u8).map(|_| component).collect(), - // } - // } - // crate::TypeInner::Matrix { - // columns, - // rows, - // width, - // } => { - // let vec_ty = self.ensure_type_exists(crate::TypeInner::Vector { - // width, - // kind: crate::ScalarKind::Float, - // size: rows, - // }); - // let component = self.create_zero_value_constant(vec_ty)?; - // crate::ConstantInner::Composite { - // ty, - // components: (0..columns as u8).map(|_| component).collect(), - // } - // } - // crate::TypeInner::Array { - // base, - // size: crate::ArraySize::Constant(size), - // .. - // } => { - // let size = self.module.constants[size].to_array_length()?; - // let component = self.create_zero_value_constant(base)?; - // crate::ConstantInner::Composite { - // ty, - // components: (0..size).map(|_| component).collect(), - // } - // } - // crate::TypeInner::Struct { ref members, .. } => { - // let members = members.clone(); - // crate::ConstantInner::Composite { - // ty, - // components: members - // .iter() - // .map(|member| self.create_zero_value_constant(member.ty)) - // .collect::>()?, - // } - // } - // _ => return None, - // }; - - // let constant = self.module.constants.fetch_or_append( - // crate::Constant { - // name: None, - // specialization: None, - // inner, - // }, - // Span::UNDEFINED, - // ); - // Some(constant) - // } - fn format_typeinner(&self, inner: &crate::TypeInner) -> String { inner.to_wgsl(self.module.to_ctx()) } @@ -446,7 +455,7 @@ impl<'a> ExpressionContext<'a, '_, '_> { } fn ensure_type_exists(&mut self, inner: crate::TypeInner) -> Handle { - self.as_output().ensure_type_exists(inner) + self.as_global().ensure_type_exists(inner) } } @@ -637,11 +646,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ) -> Result> { let mut module = crate::Module::default(); - let mut ctx = OutputContext { + let mut ctx = GlobalContext { ast_expressions: &tu.expressions, globals: &mut FastHashMap::default(), types: &tu.types, module: &mut module, + const_typifier: &mut Typifier::new(), }; for decl in self.index.visit_ordered() { @@ -658,7 +668,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let init = v .init - .map(|init| self.expression(init, ctx.reborrow())) + .map(|init| self.expression(init, ctx.as_const())) .transpose()?; let handle = ctx.module.global_variables.append( @@ -676,14 +686,15 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .insert(v.name.name, LoweredGlobalDecl::Var(handle)); } ast::GlobalDeclKind::Const(ref c) => { - let init = self.expression(c.init, ctx.reborrow())?; - // let inferred_type = ctx.resolve_type(c.init)?; + let ectx = ctx.as_const(); + let init = self.expression(c.init, ectx)?; + let ty = ectx.register_type(init)?; let handle = ctx.module.constants.append( crate::Constant { name: Some(c.name.name.to_string()), - specialization: None, - ty: None, + specialization: crate::Specialization::None, + ty, init: Some(init), }, span, @@ -738,7 +749,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { &mut self, f: &ast::Function<'source>, span: Span, - mut ctx: OutputContext<'source, '_, '_>, + mut ctx: GlobalContext<'source, '_, '_>, ) -> Result> { let mut local_table = FastHashMap::default(); let mut local_variables = Arena::new(); @@ -783,6 +794,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { local_table: &mut local_table, globals: ctx.globals, ast_expressions: ctx.ast_expressions, + const_typifier: ctx.const_typifier, typifier: &mut typifier, variables: &mut local_variables, naga_expressions: &mut expressions, @@ -855,7 +867,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let value = self.expression(l.init, ctx.as_expression(block, &mut emitter))?; let explicit_ty = - l.ty.map(|ty| self.resolve_ast_type(ty, ctx.as_output())) + l.ty.map(|ty| self.resolve_ast_type(ty, ctx.as_global())) .transpose()?; if let Some(ty) = explicit_ty { @@ -897,7 +909,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }; let explicit_ty = - v.ty.map(|ty| self.resolve_ast_type(ty, ctx.as_output())) + v.ty.map(|ty| self.resolve_ast_type(ty, ctx.as_global())) .transpose()?; let ty = match (explicit_ty, initializer) { @@ -1091,7 +1103,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let mut ctx = ctx.as_expression(block, &mut emitter); let mut left = ctx.apply_load_rule(expr); ctx.binary_op_splat(op, &mut left, &mut value)?; - ctx.naga_expressions.append( + ctx.expressions().append( crate::Expression::Binary { op, left, @@ -1144,15 +1156,17 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { _ => return Err(Error::BadIncrDecrReferenceType(value_span)), }; - let left = ectx.naga_expressions.append( + let rctx = ectx.runtime_expression_ctx(); + let left = rctx.naga_expressions.append( crate::Expression::Load { pointer: reference.handle, }, value_span, ); - let right = - ectx.interrupt_emitter(crate::Expression::Literal(literal), Span::UNDEFINED); - let value = ectx + let right = rctx + .naga_expressions + .append(crate::Expression::Literal(literal), Span::UNDEFINED); + let value = rctx .naga_expressions .append(crate::Expression::Binary { op, left, right }, stmt.span); @@ -1205,11 +1219,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } ast::Literal::Bool(b) => crate::Literal::Bool(b), }; - let handle = ctx.interrupt_emitter(crate::Expression::Literal(literal), span); + let handle = ctx + .expressions() + .append(crate::Expression::Literal(literal), span); return Ok(TypedExpression::non_reference(handle)); } ast::Expression::Ident(ast::IdentExpr::Local(local)) => { - return Ok(ctx.local_table[&local]) + let rctx = ctx.runtime_expression_ctx(); + return Ok(rctx.local_table[&local]); } ast::Expression::Ident(ast::IdentExpr::Unresolved(name)) => { return if let Some(global) = ctx.globals.get(name) { @@ -1427,13 +1444,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } ast::Expression::Bitcast { expr, to, ty_span } => { let expr = self.expression(expr, ctx.reborrow())?; - let to_resolved = self.resolve_ast_type(to, ctx.as_output())?; + let to_resolved = self.resolve_ast_type(to, ctx.as_global())?; let kind = match ctx.module.types[to_resolved].inner { crate::TypeInner::Scalar { kind, .. } => kind, crate::TypeInner::Vector { kind, .. } => kind, _ => { - let ty = &ctx.typifier[expr]; + let ty = &ctx.typifier()[expr]; return Err(Error::BadTypeCast { from_type: ctx.format_type_resolution(ty), span: ty_span, @@ -1453,7 +1470,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } }; - let handle = ctx.naga_expressions.append(expr, span); + let handle = ctx.expressions().append(expr, span); Ok(TypedExpression { handle, is_reference, @@ -1506,13 +1523,15 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .map(|&arg| self.expression(arg, ctx.reborrow())) .collect::, _>>()?; - ctx.block.extend(ctx.emitter.finish(ctx.naga_expressions)); + let rctx = ctx.runtime_expression_ctx(); + rctx.block + .extend(rctx.emitter.finish(rctx.naga_expressions)); let result = ctx.module.functions[function].result.is_some().then(|| { - ctx.naga_expressions + rctx.naga_expressions .append(crate::Expression::CallResult(function), span) }); - ctx.emitter.start(ctx.naga_expressions); - ctx.block.push( + rctx.emitter.start(rctx.naga_expressions); + rctx.block.push( crate::Statement::Call { function, arguments, @@ -1606,9 +1625,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let value = self.expression(args.next()?, ctx.reborrow())?; args.finish()?; - ctx.block.extend(ctx.emitter.finish(ctx.naga_expressions)); - ctx.emitter.start(ctx.naga_expressions); - ctx.block + let rctx = ctx.runtime_expression_ctx(); + rctx.block + .extend(rctx.emitter.finish(rctx.naga_expressions)); + rctx.emitter.start(rctx.naga_expressions); + rctx.block .push(crate::Statement::Store { pointer, value }, span); return Ok(None); } @@ -1742,8 +1763,9 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { _ => return Err(Error::InvalidAtomicOperandType(value_span)), }; + let rctx = ctx.runtime_expression_ctx(); let result = ctx.interrupt_emitter(expression, span); - ctx.block.push( + rctx.block.push( crate::Statement::Atomic { pointer, fun: crate::AtomicFunction::Exchange { @@ -1759,14 +1781,16 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { "storageBarrier" => { ctx.prepare_args(arguments, 0, span).finish()?; - ctx.block + let rctx = ctx.runtime_expression_ctx(); + rctx.block .push(crate::Statement::Barrier(crate::Barrier::STORAGE), span); return Ok(None); } "workgroupBarrier" => { ctx.prepare_args(arguments, 0, span).finish()?; - ctx.block + let rctx = ctx.runtime_expression_ctx(); + rctx.block .push(crate::Statement::Barrier(crate::Barrier::WORK_GROUP), span); return Ok(None); } @@ -1788,15 +1812,17 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { args.finish()?; - ctx.block.extend(ctx.emitter.finish(ctx.naga_expressions)); - ctx.emitter.start(ctx.naga_expressions); + let rctx = ctx.runtime_expression_ctx(); + rctx.block + .extend(rctx.emitter.finish(rctx.naga_expressions)); + rctx.emitter.start(rctx.naga_expressions); let stmt = crate::Statement::ImageStore { image, coordinate, array_index, value, }; - ctx.block.push(stmt, span); + rctx.block.push(stmt, span); return Ok(None); } "textureLoad" => { @@ -1882,7 +1908,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } }; - let expr = ctx.naga_expressions.append(expr, span); + let expr = ctx.expressions().append(expr, span); Ok(Some(expr)) } } @@ -1929,6 +1955,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { args.finish()?; + let rctx = ctx.runtime_expression_ctx(); let result = ctx.interrupt_emitter( crate::Expression::AtomicResult { ty, @@ -1936,7 +1963,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }, span, ); - ctx.block.push( + rctx.block.push( crate::Statement::Atomic { pointer, fun, @@ -2047,7 +2074,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let offset = args .next() - .map(|arg| self.expression(arg, ctx.as_output())) + .map(|arg| self.expression(arg, ctx.reborrow())) .ok() .transpose()?; @@ -2065,45 +2092,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }) } - // fn gather_component( - // &mut self, - // expr: Handle>, - // mut ctx: ExpressionContext<'source, '_, '_>, - // ) -> Result, Error<'source>> { - // let span = ctx.ast_expressions.get_span(expr); - - // let constant = match self.constant_inner(expr, ctx.as_output()).ok() { - // Some(ConstantOrInner::Constant(c)) => ctx.module.constants[c].inner.clone(), - // Some(ConstantOrInner::Inner(inner)) => inner, - // None => return Ok(None), - // }; - - // let int = match constant { - // crate::ConstantInner::Scalar { - // value: crate::ScalarValue::Sint(i), - // .. - // } if i >= 0 => i as u64, - // crate::ConstantInner::Scalar { - // value: crate::ScalarValue::Uint(i), - // .. - // } => i, - // _ => { - // return Err(Error::InvalidGatherComponent(span)); - // } - // }; - - // crate::SwizzleComponent::XYZW - // .get(int as usize) - // .copied() - // .map(Some) - // .ok_or(Error::InvalidGatherComponent(span)) - // } - fn r#struct( &mut self, s: &ast::Struct<'source>, span: Span, - mut ctx: OutputContext<'source, '_, '_>, + mut ctx: GlobalContext<'source, '_, '_>, ) -> Result, Error<'source>> { let mut offset = 0; let mut struct_alignment = Alignment::ONE; @@ -2112,9 +2105,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { for member in s.members.iter() { let ty = self.resolve_ast_type(member.ty, ctx.reborrow())?; - self.layouter - .update(&ctx.module.types, &ctx.module.constants) - .unwrap(); + self.layouter.update(ctx.module.to_ctx()).unwrap(); let member_min_size = self.layouter[ty].size; let member_min_alignment = self.layouter[ty].alignment; @@ -2178,7 +2169,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { fn resolve_ast_type( &mut self, handle: Handle>, - mut ctx: OutputContext<'source, '_, '_>, + mut ctx: GlobalContext<'source, '_, '_>, ) -> Result, Error<'source>> { let inner = match ctx.types[handle] { ast::Type::Scalar { kind, width } => crate::TypeInner::Scalar { kind, width }, @@ -2201,15 +2192,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } ast::Type::Array { base, size } => { let base = self.resolve_ast_type(base, ctx.reborrow())?; - self.layouter - .update(&ctx.module.types, &ctx.module.constants) - .unwrap(); + self.layouter.update(ctx.module.to_ctx()).unwrap(); crate::TypeInner::Array { base, size: match size { ast::ArraySize::Constant(constant) => { - let constant = self.constant(constant, ctx.reborrow())?; + let constant = self.expression(constant, ctx.as_const())?; crate::ArraySize::Constant(constant) } ast::ArraySize::Dynamic => crate::ArraySize::Dynamic, @@ -2234,7 +2223,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { base, size: match size { ast::ArraySize::Constant(constant) => { - let constant = self.constant(constant, ctx.reborrow())?; + let constant = self.expression(constant, ctx.as_const())?; crate::ArraySize::Constant(constant) } ast::ArraySize::Dynamic => crate::ArraySize::Dynamic, @@ -2253,105 +2242,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { Ok(ctx.ensure_type_exists(inner)) } - // /// Find or construct a Naga [`Constant`] whose value is `expr`. - // /// - // /// The `ctx` indicates the Naga [`Module`] to which we should add - // /// new `Constant`s or [`Type`]s as needed. - // /// - // /// [`Module`]: crate::Module - // /// [`Constant`]: crate::Constant - // /// [`Type`]: crate::Type - // fn constant( - // &mut self, - // expr: Handle>, - // mut ctx: OutputContext<'source, '_, '_>, - // ) -> Result, Error<'source>> { - // let inner = match self.constant_inner(expr, ctx.reborrow())? { - // ConstantOrInner::Constant(c) => return Ok(c), - // ConstantOrInner::Inner(inner) => inner, - // }; - - // let c = ctx.module.constants.fetch_or_append( - // crate::Constant { - // name: None, - // specialization: None, - // inner, - // }, - // Span::UNDEFINED, - // ); - // Ok(c) - // } - - // fn constant_inner( - // &mut self, - // expr: Handle>, - // mut ctx: OutputContext<'source, '_, '_>, - // ) -> Result> { - // let span = ctx.ast_expressions.get_span(expr); - // let inner = match ctx.ast_expressions[expr] { - // ast::Expression::Literal(literal) => match literal { - // ast::Literal::Number(Number::F32(f)) => crate::ConstantInner::Scalar { - // width: 4, - // value: crate::ScalarValue::Float(f as _), - // }, - // ast::Literal::Number(Number::I32(i)) => crate::ConstantInner::Scalar { - // width: 4, - // value: crate::ScalarValue::Sint(i as _), - // }, - // ast::Literal::Number(Number::U32(u)) => crate::ConstantInner::Scalar { - // width: 4, - // value: crate::ScalarValue::Uint(u as _), - // }, - // ast::Literal::Number(_) => { - // unreachable!("got abstract numeric type when not expected"); - // } - // ast::Literal::Bool(b) => crate::ConstantInner::Scalar { - // width: 1, - // value: crate::ScalarValue::Bool(b), - // }, - // }, - // ast::Expression::Ident(ast::IdentExpr::Local(_)) => { - // return Err(Error::Unexpected(span, ExpectedToken::Constant)) - // } - // ast::Expression::Ident(ast::IdentExpr::Unresolved(name)) => { - // return if let Some(global) = ctx.globals.get(name) { - // match *global { - // LoweredGlobalDecl::Const(handle) => Ok(ConstantOrInner::Constant(handle)), - // _ => Err(Error::Unexpected(span, ExpectedToken::Constant)), - // } - // } else { - // Err(Error::UnknownIdent(span, name)) - // } - // } - // ast::Expression::Construct { - // ref ty, - // ref components, - // .. - // } => self.const_construct(span, ty, components, ctx.reborrow())?, - // ast::Expression::Call { - // ref function, - // ref arguments, - // } => match ctx.globals.get(function.name) { - // Some(&LoweredGlobalDecl::Type(ty)) => self.const_construct( - // span, - // &ast::ConstructorType::Type(ty), - // arguments, - // ctx.reborrow(), - // )?, - // Some(_) => return Err(Error::ConstExprUnsupported(span)), - // None => return Err(Error::UnknownIdent(function.span, function.name)), - // }, - // _ => return Err(Error::ConstExprUnsupported(span)), - // }; - - // Ok(ConstantOrInner::Inner(inner)) - // } - fn interpolate_default( &mut self, binding: &Option, ty: Handle, - ctx: OutputContext<'source, '_, '_>, + ctx: GlobalContext<'source, '_, '_>, ) -> Option { let mut binding = binding.clone(); if let Some(ref mut binding) = binding { diff --git a/src/front/wgsl/mod.rs b/src/front/wgsl/mod.rs index 77806b564c..1b3a57b783 100644 --- a/src/front/wgsl/mod.rs +++ b/src/front/wgsl/mod.rs @@ -133,26 +133,16 @@ impl crate::TypeInner { let member_type = &global_ctx.types[base]; let base = member_type.name.as_deref().unwrap_or("unknown"); match size { - crate::ArraySize::Constant(size) => { - let constant = &global_ctx.constants[size]; - let size = constant - .name - .clone() - .unwrap_or_else(|| match constant.init { - Some(handle) => { - let expr = global_ctx.const_expressions[handle]; - match expr { - crate::Expression::Literal(crate::Literal::U32(size)) => { - size.to_string() - } - crate::Expression::Literal(crate::Literal::I32(size)) => { - size.to_string() - } - _ => "?".to_string(), - } - } - _ => "?".to_string(), - }); + crate::ArraySize::Constant(size_expr) => { + let size = match global_ctx.const_expressions[size_expr] { + crate::Expression::Literal(crate::Literal::U32(size)) => { + size.to_string() + } + crate::Expression::Literal(crate::Literal::I32(size)) => { + size.to_string() + } + _ => "?".to_string(), + }; format!("array<{base}, {size}>") } crate::ArraySize::Dynamic => format!("array<{base}>"), @@ -208,11 +198,16 @@ impl crate::TypeInner { let member_type = &global_ctx.types[base]; let base = member_type.name.as_deref().unwrap_or("unknown"); match size { - crate::ArraySize::Constant(size) => { - let size = global_ctx.constants[size] - .name - .as_deref() - .unwrap_or("unknown"); + crate::ArraySize::Constant(size_expr) => { + let size = match global_ctx.const_expressions[size_expr] { + crate::Expression::Literal(crate::Literal::U32(size)) => { + size.to_string() + } + crate::Expression::Literal(crate::Literal::I32(size)) => { + size.to_string() + } + _ => "?".to_string(), + }; format!("binding_array<{base}, {size}>") } crate::ArraySize::Dynamic => format!("binding_array<{base}>"), @@ -228,16 +223,8 @@ mod type_inner_tests { let mut types = crate::UniqueArena::new(); let mut constants = crate::Arena::new(); let mut const_expressions = crate::Arena::new(); - let c = constants.append( - crate::Constant { - name: Some("C".to_string()), - specialization: None, - ty: None, - init: Some(const_expressions.append( - crate::Expression::Literal(crate::Literal::U32(32)), - Default::default(), - )), - }, + let size_expr = const_expressions.append( + crate::Expression::Literal(crate::Literal::U32(32)), Default::default(), ); @@ -270,7 +257,7 @@ mod type_inner_tests { let array = crate::TypeInner::Array { base: mytype1, stride: 4, - size: crate::ArraySize::Constant(c), + size: crate::ArraySize::Constant(size_expr), }; assert_eq!(array.to_wgsl(global_ctx.reborrow()), "array"); @@ -324,7 +311,7 @@ mod type_inner_tests { let array = crate::TypeInner::BindingArray { base: mytype1, - size: crate::ArraySize::Constant(c), + size: crate::ArraySize::Constant(size_expr), }; assert_eq!( array.to_wgsl(global_ctx.reborrow()), diff --git a/src/lib.rs b/src/lib.rs index 2db6aea626..d92af30a00 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -401,7 +401,9 @@ pub enum ScalarKind { #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum ArraySize { /// The array size is constant. - Constant(Handle), + /// + /// Expression handle lives in const_expressions + Constant(Handle), /// The array size can change at runtime. Dynamic, } @@ -765,24 +767,62 @@ pub enum TypeInner { BindingArray { base: Handle, size: ArraySize }, } -#[derive(Debug, Clone, Copy, PartialOrd, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum Literal { F64(f64), - U64(u64), - I64(i64), F32(f32), U32(u32), I32(i32), Bool(bool), } +// impl PartialEq for Literal { +// fn eq(&self, other: &Self) -> bool { +// match (*self, *other) { +// (Self::F64(a), Self::F64(b)) => a.to_bits() == b.to_bits(), +// (Self::F32(a), Self::F32(b)) => a.to_bits() == b.to_bits(), +// (Self::U32(a), Self::U32(b)) => a == b, +// (Self::I32(a), Self::I32(b)) => a == b, +// (Self::Bool(a), Self::Bool(b)) => a == b, +// _ => false, +// } +// } +// } +// impl Eq for Literal {} +// impl std::hash::Hash for Literal { +// fn hash(&self, hasher: &mut H) { +// match *self { +// Self::F64(v) => { +// hasher.write_u8(0); +// v.to_bits().hash(hasher); +// } +// Self::F32(v) => { +// hasher.write_u8(1); +// v.to_bits().hash(hasher); +// } +// Self::U32(v) => { +// hasher.write_u8(2); +// v.hash(hasher); +// } +// Self::I32(v) => { +// hasher.write_u8(3); +// v.hash(hasher); +// } +// Self::Bool(v) => { +// hasher.write_u8(4); +// v.hash(hasher); +// } +// } +// } +// } + impl Literal { pub const fn width(&self) -> Bytes { match *self { - Self::F64(_) | Self::U64(_) | Self::I64(_) => 8, + Self::F64(_) => 8, Self::F32(_) | Self::U32(_) | Self::I32(_) => 4, Self::Bool(_) => 1, } @@ -790,8 +830,8 @@ impl Literal { pub const fn scalar_kind(&self) -> ScalarKind { match *self { Self::F64(_) | Self::F32(_) => ScalarKind::Float, - Self::U64(_) | Self::U32(_) => ScalarKind::Uint, - Self::I64(_) | Self::I32(_) => ScalarKind::Sint, + Self::U32(_) => ScalarKind::Uint, + Self::I32(_) => ScalarKind::Sint, Self::Bool(_) => ScalarKind::Bool, } } @@ -803,6 +843,17 @@ impl Literal { } } +#[derive(Debug, PartialEq)] +#[cfg_attr(feature = "clone", derive(Clone))] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +enum Specialization { + None, + ByName, + ByNameOrId(u32), +} + /// Constant value. #[derive(Debug, PartialEq)] #[cfg_attr(feature = "clone", derive(Clone))] @@ -811,8 +862,8 @@ impl Literal { #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct Constant { pub name: Option, - pub specialization: Option>, - pub ty: Option>, + pub specialization: Specialization, + pub ty: Handle, /// Expression handle lives in const_expressions pub init: Option>, } @@ -1457,7 +1508,6 @@ pub enum Expression { pub use block::Block; /// The value of the switch case. -// Clone is used only for error reporting and is not intended for end users #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] diff --git a/src/proc/index.rs b/src/proc/index.rs index 3fea79ec01..20df20a439 100644 --- a/src/proc/index.rs +++ b/src/proc/index.rs @@ -339,8 +339,12 @@ impl GuardedIndex { /// [`Constant`]: crate::Expression::Constant fn try_resolve_to_constant(&mut self, function: &crate::Function, module: &crate::Module) { if let GuardedIndex::Expression(expr) = *self { + // TODO: check if this is right if let crate::Expression::Constant(handle) = function.expressions[expr] { - if let Some(value) = module.constants[handle].to_array_length() { + if let Some(value) = module + .to_ctx() + .to_array_length(module.constants[handle].init.unwrap()) + { *self = GuardedIndex::Known(value); } } @@ -353,7 +357,7 @@ pub enum IndexableLengthError { #[error("Type is not indexable, and has no length (validation error)")] TypeNotIndexable, #[error("Array length constant {0:?} is invalid")] - InvalidArrayLength(Handle), + InvalidArrayLength(Handle), } impl crate::TypeInner { @@ -420,16 +424,30 @@ impl crate::ArraySize { ) -> Result { Ok(match self { Self::Constant(k) => { - let constant = &module.constants[k]; - if constant.specialization.is_some() { - // Specializable constants are not supported as array lengths. - // See valid::TypeError::UnsupportedSpecializedArrayLength. - return Err(IndexableLengthError::InvalidArrayLength(k)); + let const_expr = &module.const_expressions[k]; + match *const_expr { + crate::Expression::Constant(c) + if matches!( + module.constants[c], + crate::Constant { + specialization: crate::Specialization::ByName + | crate::Specialization::ByNameOrId(_), + .. + } + ) => + { + // Specializable constants are not supported as array lengths. + // See valid::TypeError::UnsupportedSpecializedArrayLength. + return Err(IndexableLengthError::InvalidArrayLength(k)); + } + _ => { + let length = module + .to_ctx() + .to_array_length(k) + .ok_or(IndexableLengthError::InvalidArrayLength(k))?; + IndexableLength::Known(length) + } } - let length = constant - .to_array_length() - .ok_or(IndexableLengthError::InvalidArrayLength(k))?; - IndexableLength::Known(length) } Self::Dynamic => IndexableLength::Dynamic, }) diff --git a/src/proc/layouter.rs b/src/proc/layouter.rs index db07f261a4..924fea6a19 100644 --- a/src/proc/layouter.rs +++ b/src/proc/layouter.rs @@ -1,4 +1,4 @@ -use crate::arena::{Arena, Handle, UniqueArena}; +use crate::arena::Handle; use std::{fmt::Display, num::NonZeroU32, ops}; /// A newtype struct where its only valid values are powers of 2 @@ -165,15 +165,11 @@ impl Layouter { /// constant arenas, and then assume that layouts are available for all /// types. #[allow(clippy::or_fun_call)] - pub fn update( - &mut self, - types: &UniqueArena, - constants: &Arena, - ) -> Result<(), LayoutError> { + pub fn update(&mut self, gctx: crate::GlobalCtx) -> Result<(), LayoutError> { use crate::TypeInner as Ti; - for (ty_handle, ty) in types.iter().skip(self.layouts.len()) { - let size = ty.inner.size(constants); + for (ty_handle, ty) in gctx.types.iter().skip(self.layouts.len()) { + let size = ty.inner.size(gctx.reborrow()); let layout = match ty.inner { Ti::Scalar { width, .. } | Ti::Atomic { width, .. } => { let alignment = Alignment::new(width as u32) diff --git a/src/proc/mod.rs b/src/proc/mod.rs index 308f810163..55adda830f 100644 --- a/src/proc/mod.rs +++ b/src/proc/mod.rs @@ -8,8 +8,6 @@ mod namer; mod terminator; mod typifier; -use std::cmp::PartialEq; - pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError}; pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout}; pub use namer::{EntryPointIndex, NameKey, Namer}; @@ -93,7 +91,7 @@ impl super::TypeInner { } /// Get the size of this type. - pub fn size(&self, constants: &super::Arena) -> u32 { + pub fn size(&self, gctx: crate::GlobalCtx) -> u32 { match *self { Self::Scalar { kind: _, width } | Self::Atomic { kind: _, width } => width as u32, Self::Vector { @@ -114,9 +112,7 @@ impl super::TypeInner { stride, } => { let count = match size { - super::ArraySize::Constant(handle) => { - constants[handle].to_array_length().unwrap_or(1) - } + super::ArraySize::Constant(handle) => gctx.to_array_length(handle).unwrap_or(1), // A dynamically-sized array has to have at least one element super::ArraySize::Dynamic => 1, }; @@ -320,7 +316,7 @@ impl crate::Expression { pub fn is_dynamic_index(&self, module: &crate::Module) -> bool { if let Self::Constant(handle) = *self { let constant = &module.constants[handle]; - constant.specialization.is_some() + !matches!(constant.specialization, crate::Specialization::None) } else { true } @@ -363,7 +359,7 @@ impl crate::SampleLevel { } } -impl crate::Constant { +impl crate::GlobalCtx<'_> { /// Interpret this constant as an array length, and return it as a `u32`. /// /// Ignore any specialization available for this constant; return its @@ -376,19 +372,17 @@ impl crate::Constant { /// ends, it may indicate ill-formed input (for example, a SPIR-V /// `OpArrayType` referring to an inappropriate `OpConstant`). So we return /// `Option` and let the caller sort things out. - pub(crate) fn to_array_length(&self) -> Option { - match self.inner { - crate::ConstantInner::Scalar { value, width: _ } => match value { - crate::ScalarValue::Uint(value) => value.try_into().ok(), - // Accept a signed integer size to avoid - // requiring an explicit uint - // literal. Type inference should make - // this unnecessary. - crate::ScalarValue::Sint(value) => value.try_into().ok(), + pub(crate) fn to_array_length(&self, handle: crate::Handle) -> Option { + fn get(gctx: crate::GlobalCtx, handle: crate::Handle) -> Option { + match gctx.const_expressions[handle] { + crate::Expression::Literal(crate::Literal::U32(value)) => Some(value), + crate::Expression::Literal(crate::Literal::I32(value)) => value.try_into().ok(), _ => None, - }, - // caught by type validation - crate::ConstantInner::Composite { .. } => None, + } + } + match self.const_expressions[handle] { + crate::Expression::Constant(c) => get(self.reborrow(), self.constants[c].init.unwrap()), + _ => get(self.reborrow(), handle), } } } @@ -402,30 +396,6 @@ impl crate::Binding { } } -//TODO: should we use an existing crate for hashable floats? -impl PartialEq for crate::ScalarValue { - fn eq(&self, other: &Self) -> bool { - match (*self, *other) { - (Self::Uint(a), Self::Uint(b)) => a == b, - (Self::Sint(a), Self::Sint(b)) => a == b, - (Self::Float(a), Self::Float(b)) => a.to_bits() == b.to_bits(), - (Self::Bool(a), Self::Bool(b)) => a == b, - _ => false, - } - } -} -impl Eq for crate::ScalarValue {} -impl std::hash::Hash for crate::ScalarValue { - fn hash(&self, hasher: &mut H) { - match *self { - Self::Sint(v) => v.hash(hasher), - Self::Uint(v) => v.hash(hasher), - Self::Float(v) => v.to_bits().hash(hasher), - Self::Bool(v) => v.hash(hasher), - } - } -} - impl super::SwizzleComponent { pub const XYZW: [Self; 4] = [Self::X, Self::Y, Self::Z, Self::W]; @@ -465,14 +435,14 @@ impl super::ImageClass { #[test] fn test_matrix_size() { - let constants = crate::Arena::new(); + let module = crate::Module::default(); assert_eq!( crate::TypeInner::Matrix { columns: crate::VectorSize::Tri, rows: crate::VectorSize::Tri, width: 4 } - .size(&constants), + .size(module.to_ctx()), 48, ); } diff --git a/src/proc/namer.rs b/src/proc/namer.rs index 053126b8ac..abb1ad8ff4 100644 --- a/src/proc/namer.rs +++ b/src/proc/namer.rs @@ -206,43 +206,7 @@ impl Namer { use std::fmt::Write; // Try to be more descriptive about the constant values temp.clear(); - match constant.inner { - crate::ConstantInner::Scalar { - width: _, - value: crate::ScalarValue::Sint(v), - } => write!(temp, "const_{v}i"), - crate::ConstantInner::Scalar { - width: _, - value: crate::ScalarValue::Uint(v), - } => write!(temp, "const_{v}u"), - crate::ConstantInner::Scalar { - width: _, - value: crate::ScalarValue::Float(v), - } => { - let abs = v.abs(); - write!( - temp, - "const_{}{}", - if v < 0.0 { "n" } else { "" }, - abs.trunc(), - ) - .unwrap(); - let fract = abs.fract(); - if fract == 0.0 { - write!(temp, "f") - } else { - write!(temp, "_{:02}f", (fract * 100.0) as i8) - } - } - crate::ConstantInner::Scalar { - width: _, - value: crate::ScalarValue::Bool(v), - } => write!(temp, "const_{v}"), - crate::ConstantInner::Composite { ty, components: _ } => { - write!(temp, "const_{}", output[&NameKey::Type(ty)]) - } - } - .unwrap(); + write!(temp, "const_{}", output[&NameKey::Type(constant.ty)]).unwrap(); &temp } }; diff --git a/src/proc/typifier.rs b/src/proc/typifier.rs index 3b9fa1d50b..7aab8050fd 100644 --- a/src/proc/typifier.rs +++ b/src/proc/typifier.rs @@ -148,18 +148,6 @@ impl Clone for TypeResolution { } } -impl crate::ConstantInner { - pub const fn resolve_type(&self) -> TypeResolution { - match *self { - Self::Scalar { width, ref value } => TypeResolution::Value(crate::TypeInner::Scalar { - kind: value.scalar_kind(), - width, - }), - Self::Composite { ty, components: _ } => TypeResolution::Handle(ty), - } - } -} - #[derive(Clone, Debug, Error, PartialEq)] pub enum ResolveError { #[error("Index {index} is out of bounds for expression {expr:?}")] @@ -399,15 +387,6 @@ impl<'a> ResolveContext<'a> { } } } - crate::Expression::Constant(h) => match self.constants[h].inner { - crate::ConstantInner::Scalar { width, ref value } => { - TypeResolution::Value(Ti::Scalar { - kind: value.scalar_kind(), - width, - }) - } - crate::ConstantInner::Composite { ty, components: _ } => TypeResolution::Handle(ty), - }, crate::Expression::Splat { size, value } => match *past(value)?.inner_with(types) { Ti::Scalar { kind, width } => { TypeResolution::Value(Ti::Vector { size, kind, width }) @@ -432,6 +411,9 @@ impl<'a> ResolveContext<'a> { return Err(ResolveError::InvalidVector(vector)); } }, + crate::Expression::Literal(lit) => TypeResolution::Value(lit.ty_inner()), + crate::Expression::Constant(h) => TypeResolution::Handle(self.constants[h].ty), + crate::Expression::New(ty) => TypeResolution::Handle(ty), crate::Expression::Compose { ty, .. } => TypeResolution::Handle(ty), crate::Expression::FunctionArgument(index) => { let arg = self diff --git a/src/valid/analyzer.rs b/src/valid/analyzer.rs index 30fb32d076..c66b82ea37 100644 --- a/src/valid/analyzer.rs +++ b/src/valid/analyzer.rs @@ -492,7 +492,6 @@ impl FunctionInfo { requirements: UniformityRequirements::empty(), }, // always uniform - E::Constant(_) => Uniformity::new(), E::Splat { size: _, value } => Uniformity { non_uniform_result: self.add_ref(value), requirements: UniformityRequirements::empty(), @@ -501,6 +500,7 @@ impl FunctionInfo { non_uniform_result: self.add_ref(vector), requirements: UniformityRequirements::empty(), }, + E::Literal(_) | E::Constant(_) | E::New(_) => Uniformity::new(), E::Compose { ref components, .. } => { let non_uniform_result = components .iter() @@ -963,17 +963,6 @@ fn uniform_control_flow() { use crate::{Expression as E, Statement as S}; let mut constant_arena = Arena::new(); - let constant = constant_arena.append( - crate::Constant { - name: None, - specialization: None, - inner: crate::ConstantInner::Scalar { - width: 4, - value: crate::ScalarValue::Uint(0), - }, - }, - Default::default(), - ); let mut type_arena = crate::UniqueArena::new(); let ty = type_arena.insert( crate::Type { @@ -1010,7 +999,7 @@ fn uniform_control_flow() { let mut expressions = Arena::new(); // checks the uniform control flow - let constant_expr = expressions.append(E::Constant(constant), Default::default()); + let constant_expr = expressions.append(E::Literal(crate::Literal::U32(0)), Default::default()); // checks the non-uniform control flow let derivative_expr = expressions.append( E::Derivative { diff --git a/src/valid/compose.rs b/src/valid/compose.rs index e77d538255..6faa2d877b 100644 --- a/src/valid/compose.rs +++ b/src/valid/compose.rs @@ -1,8 +1,5 @@ #[cfg(feature = "validate")] -use crate::{ - arena::{Arena, UniqueArena}, - proc::TypeResolution, -}; +use crate::proc::TypeResolution; use crate::arena::Handle; @@ -20,18 +17,17 @@ pub enum ComposeError { #[cfg(feature = "validate")] pub fn validate_compose( self_ty_handle: Handle, - constant_arena: &Arena, - type_arena: &UniqueArena, + gctx: crate::GlobalCtx, component_resolutions: impl ExactSizeIterator, ) -> Result<(), ComposeError> { use crate::TypeInner as Ti; - match type_arena[self_ty_handle].inner { + match gctx.types[self_ty_handle].inner { // vectors are composed from scalars or other vectors Ti::Vector { size, kind, width } => { let mut total = 0; for (index, comp_res) in component_resolutions.enumerate() { - total += match *comp_res.inner_with(type_arena) { + total += match *comp_res.inner_with(gctx.types) { Ti::Scalar { kind: comp_kind, width: comp_width, @@ -74,7 +70,7 @@ pub fn validate_compose( }); } for (index, comp_res) in component_resolutions.enumerate() { - if comp_res.inner_with(type_arena) != &inner { + if comp_res.inner_with(gctx.types) != &inner { log::error!("Matrix component[{}] type {:?}", index, comp_res); return Err(ComposeError::ComponentType { index: index as u32, @@ -87,7 +83,7 @@ pub fn validate_compose( size: crate::ArraySize::Constant(handle), stride: _, } => { - let count = constant_arena[handle].to_array_length().unwrap(); + let count = gctx.to_array_length(handle).unwrap(); if count as usize != component_resolutions.len() { return Err(ComposeError::ComponentCount { expected: count, @@ -95,11 +91,11 @@ pub fn validate_compose( }); } for (index, comp_res) in component_resolutions.enumerate() { - let base_inner = &type_arena[base].inner; - let comp_res_inner = comp_res.inner_with(type_arena); + let base_inner = &gctx.types[base].inner; + let comp_res_inner = comp_res.inner_with(gctx.types); // We don't support arrays of pointers, but it seems best not to // embed that assumption here, so use `TypeInner::equivalent`. - if !base_inner.equivalent(comp_res_inner, type_arena) { + if !base_inner.equivalent(comp_res_inner, gctx.types) { log::error!("Array component[{}] type {:?}", index, comp_res); return Err(ComposeError::ComponentType { index: index as u32, @@ -116,11 +112,11 @@ pub fn validate_compose( } for (index, (member, comp_res)) in members.iter().zip(component_resolutions).enumerate() { - let member_inner = &type_arena[member.ty].inner; - let comp_res_inner = comp_res.inner_with(type_arena); + let member_inner = &gctx.types[member.ty].inner; + let comp_res_inner = comp_res.inner_with(gctx.types); // We don't support pointers in structs, but it seems best not to embed // that assumption here, so use `TypeInner::equivalent`. - if !comp_res_inner.equivalent(member_inner, type_arena) { + if !comp_res_inner.equivalent(member_inner, gctx.types) { log::error!("Struct component[{}] type {:?}", index, comp_res); return Err(ComposeError::ComponentType { index: index as u32, diff --git a/src/valid/expression.rs b/src/valid/expression.rs index 01d6910eba..1212f274c1 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -25,8 +25,8 @@ pub enum ExpressionError { InvalidBaseType(Handle), #[error("Accessing with index {0:?} can't be done")] InvalidIndexType(Handle), - #[error("Accessing index {1:?} is out of {0:?} bounds")] - IndexOutOfBounds(Handle, crate::ScalarValue), + #[error("Accessing index {1} is out of {0:?} bounds")] + IndexOutOfBounds(Handle, u32), #[error("The expression {0:?} may only be indexed by a constant")] IndexMustBeConstant(Handle), #[error("Function argument {0:?} doesn't exist")] @@ -199,32 +199,32 @@ impl super::Validator { // If we know both the length and the index, we can do the // bounds check now. - if let crate::proc::IndexableLength::Known(known_length) = - base_type.indexable_length(module)? - { - if let E::Constant(k) = function.expressions[index] { - if let crate::Constant { - // We must treat specializable constants as unknown. - specialization: None, - // Non-scalar indices should have been caught above. - inner: crate::ConstantInner::Scalar { value, .. }, - .. - } = module.constants[k] - { - match value { - crate::ScalarValue::Uint(u) if u >= known_length as u64 => { - return Err(ExpressionError::IndexOutOfBounds(base, value)); - } - crate::ScalarValue::Sint(s) - if s < 0 || s >= known_length as i64 => - { - return Err(ExpressionError::IndexOutOfBounds(base, value)); - } - _ => (), - } - } - } - } + // if let crate::proc::IndexableLength::Known(known_length) = + // base_type.indexable_length(module)? + // { + // if let E::Constant(k) = function.expressions[index] { + // if let crate::Constant { + // // We must treat specializable constants as unknown. + // specialization: crate::Specialization::None, + // // Non-scalar indices should have been caught above. + // inner: crate::ConstantInner::Scalar { value, .. }, + // .. + // } = module.constants[k] + // { + // match value { + // crate::ScalarValue::Uint(u) if u >= known_length as u64 => { + // return Err(ExpressionError::IndexOutOfBounds(base, value)); + // } + // crate::ScalarValue::Sint(s) + // if s < 0 || s >= known_length as i64 => + // { + // return Err(ExpressionError::IndexOutOfBounds(base, value)); + // } + // _ => (), + // } + // } + // } + // } ShaderStages::all() } @@ -242,9 +242,9 @@ impl super::Validator { } => size as u32, Ti::Matrix { columns, .. } => columns as u32, Ti::Array { - size: crate::ArraySize::Constant(handle), + size: crate::ArraySize::Constant(const_expr), .. - } => module.constants[handle].to_array_length().unwrap(), + } => module.to_ctx().to_array_length(const_expr).unwrap(), Ti::Array { .. } | Ti::BindingArray { .. } => u32::MAX, // can't statically know, but need run-time checks Ti::Pointer { base, .. } if top_level => { resolve_index_limit(module, top, &module.types[base].inner, false)? @@ -260,14 +260,10 @@ impl super::Validator { let limit = resolve_index_limit(module, base, &resolver[base], true)?; if index >= limit { - return Err(ExpressionError::IndexOutOfBounds( - base, - crate::ScalarValue::Uint(limit as _), - )); + return Err(ExpressionError::IndexOutOfBounds(base, limit)); } ShaderStages::all() } - E::Constant(_handle) => ShaderStages::all(), E::Splat { size: _, value } => match resolver[value] { Ti::Scalar { .. } => ShaderStages::all(), ref other => { @@ -294,11 +290,11 @@ impl super::Validator { } ShaderStages::all() } + E::Literal(_) | E::Constant(_) | E::New(_) => ShaderStages::all(), E::Compose { ref components, ty } => { validate_compose( ty, - &module.constants, - &module.types, + module.to_ctx(), components.iter().map(|&handle| info[handle].ty.clone()), )?; ShaderStages::all() @@ -408,28 +404,29 @@ impl super::Validator { } // check constant offset - if let Some(const_handle) = offset { - let good = match module.constants[const_handle].inner { - crate::ConstantInner::Scalar { - width: _, - value: crate::ScalarValue::Sint(_), - } => num_components == 1, - crate::ConstantInner::Scalar { .. } => false, - crate::ConstantInner::Composite { ty, .. } => { - match module.types[ty].inner { - Ti::Vector { - size, - kind: Sk::Sint, - .. - } => size as u32 == num_components, - _ => false, - } - } - }; - if !good { - return Err(ExpressionError::InvalidSampleOffset(dim, const_handle)); - } - } + // TODO + // if let Some(const_handle) = offset { + // let good = match module.constants[const_handle].inner { + // crate::ConstantInner::Scalar { + // width: _, + // value: crate::ScalarValue::Sint(_), + // } => num_components == 1, + // crate::ConstantInner::Scalar { .. } => false, + // crate::ConstantInner::Composite { ty, .. } => { + // match module.types[ty].inner { + // Ti::Vector { + // size, + // kind: Sk::Sint, + // .. + // } => size as u32 == num_components, + // _ => false, + // } + // } + // }; + // if !good { + // return Err(ExpressionError::InvalidSampleOffset(dim, const_handle)); + // } + // } // check depth reference type if let Some(expr) = depth_ref { @@ -456,9 +453,10 @@ impl super::Validator { crate::ImageClass::Depth { .. } => crate::SwizzleComponent::X, _ => crate::SwizzleComponent::W, }; - if component > max_component { - return Err(ExpressionError::InvalidGatherComponent(component)); - } + // TODO + // if component > max_component { + // return Err(ExpressionError::InvalidGatherComponent(component)); + // } match level { crate::SampleLevel::Zero => {} _ => return Err(ExpressionError::InvalidGatherLevel), diff --git a/src/valid/function.rs b/src/valid/function.rs index 464496f6d6..fee22f47bb 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -845,24 +845,24 @@ impl super::Validator { return Err(LocalVariableError::InvalidType(var.ty)); } - if let Some(const_handle) = var.init { - match constants[const_handle].inner { - crate::ConstantInner::Scalar { width, ref value } => { - let ty_inner = crate::TypeInner::Scalar { - width, - kind: value.scalar_kind(), - }; - if types[var.ty].inner != ty_inner { - return Err(LocalVariableError::InitializerType); - } - } - crate::ConstantInner::Composite { ty, components: _ } => { - if ty != var.ty { - return Err(LocalVariableError::InitializerType); - } - } - } - } + // if let Some(const_handle) = var.init { + // match constants[const_handle].inner { + // crate::ConstantInner::Scalar { width, ref value } => { + // let ty_inner = crate::TypeInner::Scalar { + // width, + // kind: value.scalar_kind(), + // }; + // if types[var.ty].inner != ty_inner { + // return Err(LocalVariableError::InitializerType); + // } + // } + // crate::ConstantInner::Composite { ty, components: _ } => { + // if ty != var.ty { + // return Err(LocalVariableError::InitializerType); + // } + // } + // } + // } Ok(()) } diff --git a/src/valid/handles.rs b/src/valid/handles.rs index be87e54d48..09c9af80f8 100644 --- a/src/valid/handles.rs +++ b/src/valid/handles.rs @@ -9,7 +9,7 @@ use crate::{ use crate::{Arena, UniqueArena}; #[cfg(feature = "validate")] -use super::{TypeError, ValidationError}; +use super::ValidationError; #[cfg(feature = "validate")] use std::{convert::TryInto, hash::Hash, num::NonZeroU32}; @@ -39,6 +39,7 @@ impl super::Validator { ref functions, ref global_variables, ref types, + ref const_expressions, } = module; // NOTE: Types being first is important. All other forms of validation depend on this. @@ -48,21 +49,17 @@ impl super::Validator { ref inner, } = ty; - let validate_array_size = |size| { + let validate_array_size = |size| -> Result, ValidationError> { match size { - crate::ArraySize::Constant(constant) => { - let &crate::Constant { - name: _, - specialization: _, - ref inner, - } = constants.try_get(constant)?; - if !matches!(inner, &crate::ConstantInner::Scalar { .. }) { - return Err(ValidationError::Type { - handle: this_handle, - name: name.clone().unwrap_or_default(), - source: TypeError::InvalidArraySizeConstant(constant), - }); - } + crate::ArraySize::Constant(const_expr) => { + let _ = const_expressions.try_get(const_expr)?; + // if !matches!(inner, &crate::ConstantInner::Scalar { .. }) { + // return Err(ValidationError::Type { + // handle: this_handle, + // name: name.clone().unwrap_or_default(), + // source: TypeError::InvalidArraySizeConstant(const_expr), + // }); + // } } crate::ArraySize::Dynamic => (), }; @@ -99,24 +96,22 @@ impl super::Validator { } let validate_type = |handle| Self::validate_type_handle(handle, types); + let validate_const_expr = + |handle| Self::validate_constant_expression_handle(handle, const_expressions); for (this_handle, constant) in constants.iter() { let &crate::Constant { name: _, specialization: _, - ref inner, + ty, + init, } = constant; - match *inner { - crate::ConstantInner::Scalar { .. } => (), - crate::ConstantInner::Composite { ty, ref components } => { - validate_type(ty)?; - this_handle.check_dep_iter(components.iter().copied())?; - } + validate_type(ty)?; + if let Some(init_expr) = init { + validate_const_expr(init_expr)?; } } - let validate_constant = |handle| Self::validate_constant_handle(handle, constants); - for (_handle, global_variable) in global_variables.iter() { let &crate::GlobalVariable { name: _, @@ -127,7 +122,7 @@ impl super::Validator { } = global_variable; validate_type(ty)?; if let Some(init_expr) = init { - validate_constant(init_expr)?; + validate_const_expr(init_expr)?; } } @@ -159,7 +154,7 @@ impl super::Validator { let &crate::LocalVariable { name: _, ty, init } = local_variable; validate_type(ty)?; if let Some(init_constant) = init { - validate_constant(init_constant)?; + validate_const_expr(init_constant)?; } } @@ -171,6 +166,7 @@ impl super::Validator { Self::validate_expression_handles( handle_and_expr, constants, + const_expressions, types, local_variables, global_variables, @@ -209,6 +205,13 @@ impl super::Validator { handle.check_valid_for(constants).map(|_| ()) } + fn validate_constant_expression_handle( + handle: Handle, + const_expressions: &Arena, + ) -> Result<(), InvalidHandleError> { + handle.check_valid_for(const_expressions).map(|_| ()) + } + fn validate_expression_handle( handle: Handle, expressions: &Arena, @@ -226,6 +229,7 @@ impl super::Validator { fn validate_expression_handles( (handle, expression): (Handle, &crate::Expression), constants: &Arena, + const_expressions: &Arena, types: &UniqueArena, local_variables: &Arena, global_variables: &Arena, @@ -234,6 +238,8 @@ impl super::Validator { current_function: Option>, ) -> Result<(), InvalidHandleError> { let validate_constant = |handle| Self::validate_constant_handle(handle, constants); + let validate_const_expr = + |handle| Self::validate_constant_expression_handle(handle, const_expressions); let validate_type = |handle| Self::validate_type_handle(handle, types); match *expression { @@ -243,15 +249,19 @@ impl super::Validator { crate::Expression::AccessIndex { base, .. } => { handle.check_dep(base)?; } - crate::Expression::Constant(constant) => { - validate_constant(constant)?; - } crate::Expression::Splat { value, .. } => { handle.check_dep(value)?; } crate::Expression::Swizzle { vector, .. } => { handle.check_dep(vector)?; } + crate::Expression::Literal(_) => {} + crate::Expression::Constant(constant) => { + validate_constant(constant)?; + } + crate::Expression::New(ty) => { + validate_type(ty)?; + } crate::Expression::Compose { ty, ref components } => { validate_type(ty)?; handle.check_dep_iter(components.iter().copied())?; @@ -269,15 +279,18 @@ impl super::Validator { crate::Expression::ImageSample { image, sampler, - gather: _, + gather, coordinate, array_index, offset, level, depth_ref, } => { + if let Some(gather) = gather { + validate_const_expr(gather)?; + } if let Some(offset) = offset { - validate_constant(offset)?; + validate_const_expr(offset)?; } handle diff --git a/src/valid/mod.rs b/src/valid/mod.rs index eb92e8892d..b497929cf6 100644 --- a/src/valid/mod.rs +++ b/src/valid/mod.rs @@ -300,35 +300,9 @@ impl Validator { types: &UniqueArena, ) -> Result<(), ConstantError> { let con = &constants[handle]; - match con.inner { - crate::ConstantInner::Scalar { width, ref value } => { - if !self.check_width(value.scalar_kind(), width) { - return Err(ConstantError::InvalidType); - } - } - crate::ConstantInner::Composite { ty, ref components } => { - match types[ty].inner { - crate::TypeInner::Array { - size: crate::ArraySize::Constant(size_handle), - .. - } if handle <= size_handle => { - return Err(ConstantError::UnresolvedSize(size_handle)); - } - _ => {} - } - if let Some(&comp) = components.iter().find(|&&comp| handle <= comp) { - return Err(ConstantError::UnresolvedComponent(comp)); - } - compose::validate_compose( - ty, - constants, - types, - components - .iter() - .map(|&component| constants[component].inner.resolve_type()), - )?; - } - } + let type_info = &self.types[con.ty.index()]; + // TODO: if !type_info.flags.contains(TypeFlags::) { return Err() } + Ok(()) } @@ -343,12 +317,10 @@ impl Validator { #[cfg(feature = "validate")] Self::validate_module_handles(module).map_err(|e| e.with_span())?; - self.layouter - .update(&module.types, &module.constants) - .map_err(|e| { - let handle = e.ty; - ValidationError::from(e).with_span_handle(handle, &module.types) - })?; + self.layouter.update(module.to_ctx()).map_err(|e| { + let handle = e.ty; + ValidationError::from(e).with_span_handle(handle, &module.types) + })?; #[cfg(feature = "validate")] if self.flags.contains(ValidationFlags::CONSTANTS) { @@ -373,7 +345,7 @@ impl Validator { for (handle, ty) in module.types.iter() { let ty_info = self - .validate_type(handle, &module.types, &module.constants) + .validate_type(handle, module.to_ctx()) .map_err(|source| { ValidationError::Type { handle, diff --git a/src/valid/type.rs b/src/valid/type.rs index 4fcc1a1c58..8c9994cf26 100644 --- a/src/valid/type.rs +++ b/src/valid/type.rs @@ -1,8 +1,5 @@ use super::Capabilities; -use crate::{ - arena::{Arena, Handle, UniqueArena}, - proc::Alignment, -}; +use crate::{arena::Handle, proc::Alignment}; bitflags::bitflags! { /// Flags associated with [`Type`]s by [`Validator`]. @@ -108,11 +105,11 @@ pub enum TypeError { #[error("Base type {0:?} for the array is invalid")] InvalidArrayBaseType(Handle), #[error("The constant {0:?} can not be used for an array size")] - InvalidArraySizeConstant(Handle), + InvalidArraySizeConstant(Handle), #[error("The constant {0:?} is specialized, and cannot be used as an array size")] UnsupportedSpecializedArrayLength(Handle), #[error("Array type {0:?} must have a length of one or more")] - NonPositiveArrayLength(Handle), + NonPositiveArrayLength(Handle), #[error("Array stride {stride} does not match the expected {expected}")] InvalidArrayStride { stride: u32, expected: u32 }, #[error("Field '{0}' can't be dynamically-sized, has type {1:?}")] @@ -222,11 +219,10 @@ impl super::Validator { pub(super) fn validate_type( &self, handle: Handle, - types: &UniqueArena, - constants: &Arena, + gctx: crate::GlobalCtx, ) -> Result { use crate::TypeInner as Ti; - Ok(match types[handle].inner { + Ok(match gctx.types[handle].inner { Ti::Scalar { kind, width } => { if !self.check_width(kind, width) { return Err(TypeError::InvalidWidth(kind, width)); @@ -409,13 +405,18 @@ impl super::Validator { }; let type_info_mask = match size { - crate::ArraySize::Constant(const_handle) => { - let constant = &constants[const_handle]; - let length_is_positive = match *constant { - crate::Constant { - specialization: Some(_), - .. - } => { + crate::ArraySize::Constant(const_expr) => { + match gctx.const_expressions[const_expr] { + crate::Expression::Constant(const_handle) + if matches!( + gctx.constants[const_handle], + crate::Constant { + specialization: crate::Specialization::ByName + | crate::Specialization::ByNameOrId(_), + .. + } + ) => + { // Many of our back ends don't seem to support // specializable array lengths. If you want to try to make // this work, be sure to address all uses of @@ -425,34 +426,16 @@ impl super::Validator { const_handle, )); } - crate::Constant { - inner: - crate::ConstantInner::Scalar { - width: _, - value: crate::ScalarValue::Uint(length), - }, - .. - } => length > 0, - // Accept a signed integer size to avoid - // requiring an explicit uint - // literal. Type inference should make - // this unnecessary. - crate::Constant { - inner: - crate::ConstantInner::Scalar { - width: _, - value: crate::ScalarValue::Sint(length), - }, - .. - } => length > 0, - _ => { - log::warn!("Array size {:?}", constant); - return Err(TypeError::InvalidArraySizeConstant(const_handle)); - } - }; + _ => {} + } - if !length_is_positive { - return Err(TypeError::NonPositiveArrayLength(const_handle)); + if let Some(len) = gctx.to_array_length(const_expr) { + if len == 0 { + return Err(TypeError::NonPositiveArrayLength(const_expr)); + } + } else { + log::warn!("Array size {:?}", const_expr); + return Err(TypeError::InvalidArraySizeConstant(const_expr)); } TypeFlags::DATA @@ -529,7 +512,7 @@ impl super::Validator { } } - let base_size = types[member.ty].inner.size(constants); + let base_size = gctx.types[member.ty].inner.size(gctx.reborrow()); min_offset = member.offset + base_size; if min_offset > span { return Err(TypeError::MemberOutOfBounds { @@ -573,14 +556,14 @@ impl super::Validator { } }; - prev_struct_data = match types[member.ty].inner { + prev_struct_data = match gctx.types[member.ty].inner { crate::TypeInner::Struct { span, .. } => Some((span, member.offset)), _ => None, }; // The last field may be an unsized array. if !base_info.flags.contains(TypeFlags::SIZED) { - let is_array = match types[member.ty].inner { + let is_array = match gctx.types[member.ty].inner { crate::TypeInner::Array { .. } => true, _ => false, };