Skip to content

Commit

Permalink
mesh shader ext: add support for mesh shaders, based on @BeastLe9enD
Browse files Browse the repository at this point in the history
…work
  • Loading branch information
Firestar99 committed Feb 1, 2024
1 parent ede757b commit de312eb
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 4 deletions.
3 changes: 2 additions & 1 deletion crates/rustc_codegen_spirv/src/linker/simple_passes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ pub fn outgoing_edges(block: &Block) -> impl Iterator<Item = Word> + '_ {
| Op::Kill
| Op::Unreachable
| Op::IgnoreIntersectionKHR
| Op::TerminateRayKHR => (0..0).step_by(1),
| Op::TerminateRayKHR
| Op::EmitMeshTasksEXT => (0..0).step_by(1),
_ => panic!("Invalid block terminator: {terminator:?}"),
};
operand_indices.map(move |i| terminator.operands[i].unwrap_id_ref())
Expand Down
3 changes: 2 additions & 1 deletion crates/rustc_codegen_spirv/src/spirv_type_constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,8 @@ pub fn instruction_signatures(op: Op) -> Option<&'static [InstSig<'static>]> {
}
// SPV_EXT_mesh_shader
Op::EmitMeshTasksEXT | Op::SetMeshOutputsEXT => {
reserved!(SPV_EXT_mesh_shader)
// NOTE(eddyb) we actually use these despite not being in the standard yet.
// reserved!(SPV_EXT_mesh_shader)
}
// SPV_NV_ray_tracing_motion_blur
Op::TraceMotionNV | Op::TraceRayMotionNV => reserved!(SPV_NV_ray_tracing_motion_blur),
Expand Down
25 changes: 23 additions & 2 deletions crates/rustc_codegen_spirv/src/symbols.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,13 @@ const BUILTINS: &[(&str, BuiltIn)] = {
("bary_coord_no_persp_nv", BuiltIn::BaryCoordNoPerspNV),
("bary_coord", BaryCoordKHR),
("bary_coord_no_persp", BaryCoordNoPerspKHR),
("primitive_point_indices_ext", PrimitivePointIndicesEXT),
("primitive_line_indices_ext", PrimitiveLineIndicesEXT),
(
"primitive_triangle_indices_ext",
PrimitiveTriangleIndicesEXT,
),
("cull_primitive_ext", CullPrimitiveEXT),
("frag_size_ext", FragSizeEXT),
("frag_invocation_count_ext", FragInvocationCountEXT),
("launch_id", BuiltIn::LaunchIdKHR),
Expand Down Expand Up @@ -171,6 +178,7 @@ const STORAGE_CLASSES: &[(&str, StorageClass)] = {
("incoming_ray_payload", StorageClass::IncomingRayPayloadKHR),
("shader_record_buffer", StorageClass::ShaderRecordBufferKHR),
("physical_storage_buffer", PhysicalStorageBuffer),
("task_payload_workgroup_ext", TaskPayloadWorkgroupEXT),
]
};

Expand All @@ -185,6 +193,8 @@ const EXECUTION_MODELS: &[(&str, ExecutionModel)] = {
("compute", GLCompute),
("task_nv", TaskNV),
("mesh_nv", MeshNV),
("task_ext", TaskEXT),
("mesh_ext", MeshEXT),
("ray_generation", ExecutionModel::RayGenerationKHR),
("intersection", ExecutionModel::IntersectionKHR),
("any_hit", ExecutionModel::AnyHitKHR),
Expand Down Expand Up @@ -265,6 +275,17 @@ const EXECUTION_MODES: &[(&str, ExecutionMode, ExecutionModeExtraDim)] = {
("output_primitives_nv", OutputPrimitivesNV, Value),
("derivative_group_quads_nv", DerivativeGroupQuadsNV, None),
("output_triangles_nv", OutputTrianglesNV, None),
("output_lines_ext", ExecutionMode::OutputLinesEXT, None),
(
"output_triangles_ext",
ExecutionMode::OutputTrianglesEXT,
None,
),
(
"output_primitives_ext",
ExecutionMode::OutputPrimitivesEXT,
Value,
),
(
"pixel_interlock_ordered_ext",
PixelInterlockOrderedEXT,
Expand Down Expand Up @@ -717,7 +738,7 @@ fn parse_entry_attrs(
.execution_modes
.push((origin_mode, ExecutionModeExtra::new([])));
}
GLCompute | MeshNV | TaskNV => {
GLCompute | MeshNV | TaskNV | TaskEXT | MeshEXT => {
if let Some(local_size) = local_size {
entry
.execution_modes
Expand All @@ -726,7 +747,7 @@ fn parse_entry_attrs(
return Err((
arg.span(),
String::from(
"The `threads` argument must be specified when using `#[spirv(compute)]`, `#[spirv(mesh_nv)]` or `#[spirv(task_nv)]`",
"The `threads` argument must be specified when using `#[spirv(compute)]`, `#[spirv(mesh_nv)]`, `#[spirv(task_nv)]`, `#[spirv(task_ext)]` or `#[spirv(mesh_ext)]`",
),
));
}
Expand Down
2 changes: 2 additions & 0 deletions crates/spirv-std/src/arch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ mod atomics;
mod barrier;
mod demote_to_helper_invocation_ext;
mod derivative;
mod mesh_shading;
mod primitive;
mod ray_tracing;

pub use atomics::*;
pub use barrier::*;
pub use demote_to_helper_invocation_ext::*;
pub use derivative::*;
pub use mesh_shading::*;
pub use primitive::*;
pub use ray_tracing::*;

Expand Down
66 changes: 66 additions & 0 deletions crates/spirv-std/src/arch/mesh_shading.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#[cfg(target_arch = "spirv")]
use core::arch::asm;

/// Sets the actual output size of the primitives and vertices that the mesh shader
/// workgroup will emit upon completion.
///
/// 'Vertex Count' must be a 32-bit unsigned integer value.
/// It defines the array size of per-vertex outputs.
///
/// 'Primitive Count' must a 32-bit unsigned integer value.
/// It defines the array size of per-primitive outputs.
///
/// The arguments are taken from the first invocation in each workgroup.
/// Any invocation must execute this instruction no more than once and under
/// uniform control flow.
/// There must not be any control flow path to an output write that is not preceded
/// by this instruction.
///
/// This instruction is only valid in the *MeshEXT* Execution Model.
#[spirv_std_macros::gpu_only]
#[doc(alias = "OpSetMeshOutputsEXT")]
#[inline]
pub unsafe fn set_mesh_outputs_ext(vertex_count: u32, primitive_count: u32) {
asm! {
"OpSetMeshOutputsEXT {vertex_count} {primitive_count}",
vertex_count = in(reg) vertex_count,
primitive_count = in(reg) primitive_count,
}
}

/// Defines the grid size of subsequent mesh shader workgroups to generate
/// upon completion of the task shader workgroup.
///
/// 'Group Count X Y Z' must each be a 32-bit unsigned integer value.
/// They configure the number of local workgroups in each respective dimensions
/// for the launch of child mesh tasks. See Vulkan API specification for more detail.
///
/// 'Payload' is an optional pointer to the payload structure to pass to the generated mesh shader invocations.
/// 'Payload' must be the result of an *OpVariable* with a storage class of *TaskPayloadWorkgroupEXT*.
///
/// The arguments are taken from the first invocation in each workgroup.
/// Any invocation must execute this instruction exactly once and under uniform
/// control flow.
/// This instruction also serves as an *OpControlBarrier* instruction, and also
/// performs and adheres to the description and semantics of an *OpControlBarrier*
/// instruction with the 'Execution' and 'Memory' operands set to *Workgroup* and
/// the 'Semantics' operand set to a combination of *WorkgroupMemory* and
/// *AcquireRelease*.
/// Ceases all further processing: Only instructions executed before
/// *OpEmitMeshTasksEXT* have observable side effects.
///
/// This instruction must be the last instruction in a block.
///
/// This instruction is only valid in the *TaskEXT* Execution Model.
#[spirv_std_macros::gpu_only]
#[doc(alias = "OpEmitMeshTasksEXT")]
#[inline]
pub unsafe fn emit_mesh_tasks_ext(group_count_x: u32, group_count_y: u32, group_count_z: u32) -> ! {
asm! {
"OpEmitMeshTasksEXT {group_count_x} {group_count_y} {group_count_z}",
group_count_x = in(reg) group_count_x,
group_count_y = in(reg) group_count_y,
group_count_z = in(reg) group_count_z,
options(noreturn),
}
}

0 comments on commit de312eb

Please sign in to comment.