From 130f9e5d35c5803e9428d421ab5125d6d5cd0a9b Mon Sep 17 00:00:00 2001 From: ComfyFluffy <24245520+ComfyFluffy@users.noreply.github.com> Date: Sat, 14 Dec 2024 17:40:31 +0800 Subject: [PATCH] Ray Tracing Pipeline (KHR) (#2564) * ray tracing pipeline * fix invalid pointers * sync * triangle-raytracing * working example * refactor * refactor SBT * example for AutoCommandBufferBuilder * bind & trace_rays validation * trace_rays validation * doc for triangle-raytracing-auto * move mod.rs * fmt & clippy * fix clippy ci * add unit tests for sbt builder * fmt * undo sbt copy refactor * fix clippy * rename example & cleanup * rmiss -> miss * refactor & add doc * fmt * update dep & remove abbreviates * refactor: - add doc - add check for stages/groups - pub modification - `pub use` StridedDeviceAddressRegionKHR in lib.rs * implement StridedDeviceAddressRegion * rename examples --- Cargo.lock | 23 + examples/ray-tracing-auto/Cargo.toml | 19 + examples/ray-tracing-auto/main.rs | 405 ++++++++ examples/ray-tracing-auto/raytrace.miss | 6 + examples/ray-tracing-auto/raytrace.rchit | 10 + examples/ray-tracing-auto/raytrace.rgen | 43 + examples/ray-tracing-auto/scene.rs | 481 +++++++++ examples/ray-tracing/Cargo.toml | 20 + examples/ray-tracing/main.rs | 416 ++++++++ examples/ray-tracing/raytrace.miss | 6 + examples/ray-tracing/raytrace.rchit | 10 + examples/ray-tracing/raytrace.rgen | 43 + examples/ray-tracing/scene.rs | 499 +++++++++ .../src/command_buffer/commands/bind_push.rs | 29 +- .../src/command_buffer/commands/pipeline.rs | 46 +- vulkano-taskgraph/src/resource.rs | 104 +- vulkano/src/buffer/usage.rs | 4 +- vulkano/src/command_buffer/auto/builder.rs | 2 + .../src/command_buffer/commands/bind_push.rs | 116 ++- .../src/command_buffer/commands/pipeline.rs | 162 +++ vulkano/src/command_buffer/mod.rs | 1 + vulkano/src/device/mod.rs | 113 +++ vulkano/src/lib.rs | 19 + vulkano/src/pipeline/compute.rs | 2 +- vulkano/src/pipeline/graphics/mod.rs | 2 +- vulkano/src/pipeline/mod.rs | 5 +- vulkano/src/pipeline/ray_tracing.rs | 957 ++++++++++++++++++ 27 files changed, 3477 insertions(+), 66 deletions(-) create mode 100644 examples/ray-tracing-auto/Cargo.toml create mode 100644 examples/ray-tracing-auto/main.rs create mode 100644 examples/ray-tracing-auto/raytrace.miss create mode 100644 examples/ray-tracing-auto/raytrace.rchit create mode 100644 examples/ray-tracing-auto/raytrace.rgen create mode 100644 examples/ray-tracing-auto/scene.rs create mode 100644 examples/ray-tracing/Cargo.toml create mode 100644 examples/ray-tracing/main.rs create mode 100644 examples/ray-tracing/raytrace.miss create mode 100644 examples/ray-tracing/raytrace.rchit create mode 100644 examples/ray-tracing/raytrace.rgen create mode 100644 examples/ray-tracing/scene.rs create mode 100644 vulkano/src/pipeline/ray_tracing.rs diff --git a/Cargo.lock b/Cargo.lock index f2aa9703c6..64630b654e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1320,6 +1320,29 @@ dependencies = [ "objc2-quartz-core", ] +[[package]] +name = "ray-tracing" +version = "0.0.0" +dependencies = [ + "ash", + "glam", + "vulkano", + "vulkano-shaders", + "vulkano-taskgraph", + "winit", +] + +[[package]] +name = "ray-tracing-auto" +version = "0.0.0" +dependencies = [ + "ash", + "glam", + "vulkano", + "vulkano-shaders", + "winit", +] + [[package]] name = "redox_syscall" version = "0.4.1" diff --git a/examples/ray-tracing-auto/Cargo.toml b/examples/ray-tracing-auto/Cargo.toml new file mode 100644 index 0000000000..292c75270c --- /dev/null +++ b/examples/ray-tracing-auto/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "ray-tracing-auto" +version = "0.0.0" +edition = "2021" +publish = false + +[[bin]] +name = "ray-tracing-auto" +path = "main.rs" +test = false +bench = false +doc = false + +[dependencies] +vulkano = { workspace = true, default-features = true } +vulkano-shaders = { workspace = true } +winit = { workspace = true, default-features = true } +ash = { workspace = true } +glam = { workspace = true } diff --git a/examples/ray-tracing-auto/main.rs b/examples/ray-tracing-auto/main.rs new file mode 100644 index 0000000000..d39d912a35 --- /dev/null +++ b/examples/ray-tracing-auto/main.rs @@ -0,0 +1,405 @@ +use scene::Scene; +use std::{error::Error, sync::Arc}; +use vulkano::{ + command_buffer::{ + allocator::StandardCommandBufferAllocator, AutoCommandBufferBuilder, CommandBufferUsage, + }, + descriptor_set::{ + allocator::StandardDescriptorSetAllocator, + layout::{ + DescriptorSetLayout, DescriptorSetLayoutBinding, DescriptorSetLayoutCreateInfo, + DescriptorType, + }, + }, + device::{ + physical::PhysicalDeviceType, Device, DeviceCreateInfo, DeviceExtensions, DeviceFeatures, + Queue, QueueCreateInfo, QueueFlags, + }, + image::{ImageFormatInfo, ImageUsage}, + instance::{Instance, InstanceCreateFlags, InstanceCreateInfo, InstanceExtensions}, + memory::allocator::StandardMemoryAllocator, + pipeline::{layout::PipelineLayoutCreateInfo, PipelineLayout}, + shader::ShaderStages, + swapchain::{ + acquire_next_image, Surface, Swapchain, SwapchainCreateInfo, SwapchainPresentInfo, + }, + sync::{self, GpuFuture}, + Version, VulkanLibrary, +}; +use winit::{ + application::ApplicationHandler, + event::WindowEvent, + event_loop::{ActiveEventLoop, EventLoop}, + window::{Window, WindowId}, +}; + +mod scene; + +fn main() -> Result<(), impl Error> { + let event_loop = EventLoop::new().unwrap(); + let mut app = App::new(&event_loop); + + event_loop.run_app(&mut app) +} + +struct App { + instance: Arc, + device: Arc, + queue: Arc, + rcx: Option, + command_buffer_allocator: Arc, +} + +pub struct RenderContext { + window: Arc, + swapchain: Arc, + recreate_swapchain: bool, + scene: Scene, + previous_frame_end: Option>, +} + +impl App { + fn new(event_loop: &EventLoop<()>) -> Self { + let library = VulkanLibrary::new().unwrap(); + let required_extensions = Surface::required_extensions(event_loop).unwrap(); + let instance = Instance::new( + library, + InstanceCreateInfo { + flags: InstanceCreateFlags::ENUMERATE_PORTABILITY, + enabled_extensions: InstanceExtensions { + ext_swapchain_colorspace: true, + ..required_extensions + }, + ..Default::default() + }, + ) + .unwrap(); + + let device_extensions = DeviceExtensions { + khr_swapchain: true, + khr_ray_tracing_pipeline: true, + khr_ray_tracing_maintenance1: true, + khr_synchronization2: true, + khr_deferred_host_operations: true, + khr_acceleration_structure: true, + ..DeviceExtensions::empty() + }; + let (physical_device, queue_family_index) = instance + .enumerate_physical_devices() + .unwrap() + .filter(|p| p.api_version() >= Version::V1_3) + .filter(|p| p.supported_extensions().contains(&device_extensions)) + .filter_map(|p| { + p.queue_family_properties() + .iter() + .enumerate() + .position(|(i, q)| { + q.queue_flags + .contains(QueueFlags::GRAPHICS | QueueFlags::COMPUTE) + && p.presentation_support(i as u32, event_loop).unwrap() + }) + .map(|i| (p, i as u32)) + }) + .min_by_key(|(p, _)| match p.properties().device_type { + PhysicalDeviceType::DiscreteGpu => 0, + PhysicalDeviceType::IntegratedGpu => 1, + PhysicalDeviceType::VirtualGpu => 2, + PhysicalDeviceType::Cpu => 3, + PhysicalDeviceType::Other => 4, + _ => 5, + }) + .unwrap(); + + let (device, mut queues) = Device::new( + physical_device, + DeviceCreateInfo { + enabled_extensions: device_extensions, + queue_create_infos: vec![QueueCreateInfo { + queue_family_index, + ..Default::default() + }], + enabled_features: DeviceFeatures { + acceleration_structure: true, + ray_tracing_pipeline: true, + buffer_device_address: true, + synchronization2: true, + ..Default::default() + }, + ..Default::default() + }, + ) + .unwrap(); + + let queue = queues.next().unwrap(); + + let command_buffer_allocator = Arc::new(StandardCommandBufferAllocator::new( + device.clone(), + Default::default(), + )); + + App { + instance, + device, + queue, + rcx: None, + command_buffer_allocator, + } + } +} + +impl ApplicationHandler for App { + fn resumed(&mut self, event_loop: &ActiveEventLoop) { + let window = Arc::new( + event_loop + .create_window(Window::default_attributes()) + .unwrap(), + ); + let surface = Surface::from_window(self.instance.clone(), window.clone()).unwrap(); + + let physical_device = self.device.physical_device(); + let supported_surface_formats = physical_device + .surface_formats(&surface, Default::default()) + .unwrap(); + + // For each supported format, check if it is supported for storage images + let supported_storage_formats = supported_surface_formats + .into_iter() + .filter(|(format, _)| { + physical_device + .image_format_properties(ImageFormatInfo { + format: *format, + usage: ImageUsage::STORAGE, + ..Default::default() + }) + .unwrap() + .is_some() + }) + .collect::>(); + + println!( + "Using device: {} (type: {:?})", + physical_device.properties().device_name, + physical_device.properties().device_type, + ); + + let (swapchain, images) = { + let surface_capabilities = self + .device + .physical_device() + .surface_capabilities(&surface, Default::default()) + .unwrap(); + + let (swapchain_format, swapchain_color_space) = supported_storage_formats + .first() + .map(|(format, color_space)| (*format, *color_space)) + .unwrap(); + Swapchain::new( + self.device.clone(), + surface.clone(), + SwapchainCreateInfo { + min_image_count: surface_capabilities.min_image_count.max(2), + image_format: swapchain_format, + image_color_space: swapchain_color_space, + image_extent: window.inner_size().into(), + // To simplify the example, we will directly write to the swapchain images + // from the ray tracing shader. This requires the images to support storage + // usage. + image_usage: ImageUsage::STORAGE, + composite_alpha: surface_capabilities + .supported_composite_alpha + .into_iter() + .next() + .unwrap(), + ..Default::default() + }, + ) + .unwrap() + }; + + let pipeline_layout = PipelineLayout::new( + self.device.clone(), + PipelineLayoutCreateInfo { + set_layouts: vec![ + DescriptorSetLayout::new( + self.device.clone(), + DescriptorSetLayoutCreateInfo { + bindings: [ + ( + 0, + DescriptorSetLayoutBinding { + stages: ShaderStages::RAYGEN, + ..DescriptorSetLayoutBinding::descriptor_type( + DescriptorType::AccelerationStructure, + ) + }, + ), + ( + 1, + DescriptorSetLayoutBinding { + stages: ShaderStages::RAYGEN, + ..DescriptorSetLayoutBinding::descriptor_type( + DescriptorType::UniformBuffer, + ) + }, + ), + ] + .into_iter() + .collect(), + ..Default::default() + }, + ) + .unwrap(), + DescriptorSetLayout::new( + self.device.clone(), + DescriptorSetLayoutCreateInfo { + bindings: [( + 0, + DescriptorSetLayoutBinding { + stages: ShaderStages::RAYGEN, + ..DescriptorSetLayoutBinding::descriptor_type( + DescriptorType::StorageImage, + ) + }, + )] + .into_iter() + .collect(), + ..Default::default() + }, + ) + .unwrap(), + ], + ..Default::default() + }, + ) + .unwrap(); + + let descriptor_set_allocator = Arc::new(StandardDescriptorSetAllocator::new( + self.device.clone(), + Default::default(), + )); + + let memory_allocator = Arc::new(StandardMemoryAllocator::new_default(self.device.clone())); + + let scene = Scene::new( + self, + &images, + pipeline_layout, + descriptor_set_allocator.clone(), + memory_allocator.clone(), + self.command_buffer_allocator.clone(), + ); + self.rcx = Some(RenderContext { + window, + swapchain, + recreate_swapchain: false, + previous_frame_end: None, + scene, + }); + } + + fn window_event( + &mut self, + event_loop: &ActiveEventLoop, + _window_id: WindowId, + event: WindowEvent, + ) { + let rcx = self.rcx.as_mut().unwrap(); + + match event { + WindowEvent::CloseRequested => { + event_loop.exit(); + } + WindowEvent::Resized(_) => { + rcx.recreate_swapchain = true; + } + WindowEvent::RedrawRequested => { + let window_size = rcx.window.inner_size(); + + if window_size.width == 0 || window_size.height == 0 { + return; + } + + // Cleanup previous frame + if let Some(previous_frame_end) = rcx.previous_frame_end.as_mut() { + previous_frame_end.cleanup_finished(); + } + + // Recreate swapchain if needed + if rcx.recreate_swapchain { + let (new_swapchain, new_images) = + match rcx.swapchain.recreate(SwapchainCreateInfo { + image_extent: window_size.into(), + ..rcx.swapchain.create_info() + }) { + Ok(r) => r, + Err(e) => panic!("Failed to recreate swapchain: {e:?}"), + }; + + rcx.swapchain = new_swapchain; + rcx.scene.handle_resize(&new_images); + rcx.recreate_swapchain = false; + } + + // Acquire next image + let (image_index, suboptimal, acquire_future) = + match acquire_next_image(rcx.swapchain.clone(), None) { + Ok(r) => r, + Err(e) => { + eprintln!("Failed to acquire next image: {e:?}"); + rcx.recreate_swapchain = true; + return; + } + }; + + if suboptimal { + rcx.recreate_swapchain = true; + } + + let mut builder = AutoCommandBufferBuilder::primary( + self.command_buffer_allocator.clone(), + self.queue.queue_family_index(), + CommandBufferUsage::OneTimeSubmit, + ) + .unwrap(); + + rcx.scene.record_commands(image_index, &mut builder); + + let command_buffer = builder.build().unwrap(); + + let future = rcx + .previous_frame_end + .take() + .unwrap_or_else(|| { + Box::new(sync::now(self.device.clone())) as Box + }) + .join(acquire_future) + .then_execute(self.queue.clone(), command_buffer) + .unwrap() + .then_swapchain_present( + self.queue.clone(), + SwapchainPresentInfo::swapchain_image_index( + rcx.swapchain.clone(), + image_index, + ), + ) + .then_signal_fence_and_flush(); + + match future { + Ok(future) => { + rcx.previous_frame_end = Some(Box::new(future) as Box); + } + Err(e) => { + println!("Failed to flush future: {e:?}"); + rcx.previous_frame_end = Some(Box::new(sync::now(self.device.clone()))); + } + } + } + _ => {} + } + } + + fn about_to_wait(&mut self, _event_loop: &ActiveEventLoop) { + let rcx = self.rcx.as_mut().unwrap(); + rcx.window.request_redraw(); + } +} diff --git a/examples/ray-tracing-auto/raytrace.miss b/examples/ray-tracing-auto/raytrace.miss new file mode 100644 index 0000000000..1c584d5420 --- /dev/null +++ b/examples/ray-tracing-auto/raytrace.miss @@ -0,0 +1,6 @@ +#version 460 +#extension GL_EXT_ray_tracing : require + +layout(location = 0) rayPayloadInEXT vec3 hitValue; + +void main() { hitValue = vec3(0.0, 0.0, 0.2); } diff --git a/examples/ray-tracing-auto/raytrace.rchit b/examples/ray-tracing-auto/raytrace.rchit new file mode 100644 index 0000000000..52c407b96a --- /dev/null +++ b/examples/ray-tracing-auto/raytrace.rchit @@ -0,0 +1,10 @@ +#version 460 +#extension GL_EXT_ray_tracing : require + +layout(location = 0) rayPayloadInEXT vec3 hitValue; +hitAttributeEXT vec2 attribs; + +void main() { + vec3 barycentrics = vec3(1.0 - attribs.x - attribs.y, attribs.x, attribs.y); + hitValue = barycentrics; +} diff --git a/examples/ray-tracing-auto/raytrace.rgen b/examples/ray-tracing-auto/raytrace.rgen new file mode 100644 index 0000000000..8a9416e201 --- /dev/null +++ b/examples/ray-tracing-auto/raytrace.rgen @@ -0,0 +1,43 @@ +#version 460 +#extension GL_EXT_ray_tracing : require + +struct Camera { + mat4 viewProj; // Camera view * projection + mat4 viewInverse; // Camera inverse view matrix + mat4 projInverse; // Camera inverse projection matrix +}; + +layout(location = 0) rayPayloadEXT vec3 hitValue; + +layout(set = 0, binding = 0) uniform accelerationStructureEXT topLevelAS; +layout(set = 0, binding = 1) uniform _Camera { Camera camera; }; +layout(set = 1, binding = 0, rgba32f) uniform image2D image; + +void main() { + const vec2 pixelCenter = vec2(gl_LaunchIDEXT.xy) + vec2(0.5); + const vec2 inUV = pixelCenter / vec2(gl_LaunchSizeEXT.xy); + vec2 d = inUV * 2.0 - 1.0; + + vec4 origin = camera.viewInverse * vec4(0, 0, 0, 1); + vec4 target = camera.projInverse * vec4(d.x, d.y, 1, 1); + vec4 direction = camera.viewInverse * vec4(normalize(target.xyz), 0); + + uint rayFlags = gl_RayFlagsOpaqueEXT; + float tMin = 0.001; + float tMax = 10000.0; + + traceRayEXT(topLevelAS, // acceleration structure + rayFlags, // rayFlags + 0xFF, // cullMask + 0, // sbtRecordOffset + 0, // sbtRecordStride + 0, // missIndex + origin.xyz, // ray origin + tMin, // ray min range + direction.xyz, // ray direction + tMax, // ray max range + 0 // payload (location = 0) + ); + + imageStore(image, ivec2(gl_LaunchIDEXT.xy), vec4(hitValue, 1.0)); +} diff --git a/examples/ray-tracing-auto/scene.rs b/examples/ray-tracing-auto/scene.rs new file mode 100644 index 0000000000..019b61e631 --- /dev/null +++ b/examples/ray-tracing-auto/scene.rs @@ -0,0 +1,481 @@ +use crate::App; +use glam::{Mat4, Vec3}; +use std::{iter, mem::size_of, sync::Arc}; +use vulkano::{ + acceleration_structure::{ + AccelerationStructure, AccelerationStructureBuildGeometryInfo, + AccelerationStructureBuildRangeInfo, AccelerationStructureBuildType, + AccelerationStructureCreateInfo, AccelerationStructureGeometries, + AccelerationStructureGeometryInstancesData, AccelerationStructureGeometryInstancesDataType, + AccelerationStructureGeometryTrianglesData, AccelerationStructureInstance, + AccelerationStructureType, BuildAccelerationStructureFlags, BuildAccelerationStructureMode, + }, + buffer::{Buffer, BufferContents, BufferCreateInfo, BufferUsage, Subbuffer}, + command_buffer::{ + allocator::CommandBufferAllocator, AutoCommandBufferBuilder, CommandBufferUsage, + PrimaryAutoCommandBuffer, PrimaryCommandBufferAbstract, + }, + descriptor_set::{ + allocator::StandardDescriptorSetAllocator, DescriptorSet, WriteDescriptorSet, + }, + device::{Device, Queue}, + format::Format, + image::{view::ImageView, Image}, + memory::allocator::{AllocationCreateInfo, MemoryAllocator, MemoryTypeFilter}, + pipeline::{ + graphics::vertex_input::Vertex, + ray_tracing::{ + RayTracingPipeline, RayTracingPipelineCreateInfo, RayTracingShaderGroupCreateInfo, + ShaderBindingTable, + }, + PipelineBindPoint, PipelineLayout, PipelineShaderStageCreateInfo, + }, + sync::GpuFuture, +}; + +mod raygen { + vulkano_shaders::shader! { + ty: "raygen", + path: "raytrace.rgen", + vulkan_version: "1.2" + } +} + +mod closest_hit { + vulkano_shaders::shader! { + ty: "closesthit", + path: "raytrace.rchit", + vulkan_version: "1.2" + } +} + +mod miss { + vulkano_shaders::shader! { + ty: "miss", + path: "raytrace.miss", + vulkan_version: "1.2" + } +} + +#[derive(BufferContents, Vertex)] +#[repr(C)] +struct MyVertex { + #[format(R32G32B32_SFLOAT)] + position: [f32; 3], +} + +pub struct Scene { + descriptor_set_0: Arc, + swapchain_image_sets: Vec<(Arc, Arc)>, + pipeline_layout: Arc, + descriptor_set_allocator: Arc, + shader_binding_table: ShaderBindingTable, + pipeline: Arc, + // The bottom-level acceleration structure is required to be kept alive + // as we reference it in the top-level acceleration structure. + _blas: Arc, + _tlas: Arc, +} + +impl Scene { + pub fn new( + app: &App, + images: &[Arc], + pipeline_layout: Arc, + descriptor_set_allocator: Arc, + memory_allocator: Arc, + command_buffer_allocator: Arc, + ) -> Self { + let pipeline = { + let raygen = raygen::load(app.device.clone()) + .unwrap() + .entry_point("main") + .unwrap(); + let closest_hit = closest_hit::load(app.device.clone()) + .unwrap() + .entry_point("main") + .unwrap(); + + let miss = miss::load(app.device.clone()) + .unwrap() + .entry_point("main") + .unwrap(); + + // Make a list of the shader stages that the pipeline will have. + let stages = [ + PipelineShaderStageCreateInfo::new(raygen), + PipelineShaderStageCreateInfo::new(miss), + PipelineShaderStageCreateInfo::new(closest_hit), + ]; + + // Define the shader groups that will eventually turn into the shader binding table. + // The numbers are the indices of the stages in the `stages` array. + let groups = [ + RayTracingShaderGroupCreateInfo::General { general_shader: 0 }, + RayTracingShaderGroupCreateInfo::General { general_shader: 1 }, + RayTracingShaderGroupCreateInfo::TrianglesHit { + closest_hit_shader: Some(2), + any_hit_shader: None, + }, + ]; + + RayTracingPipeline::new( + app.device.clone(), + None, + RayTracingPipelineCreateInfo { + stages: stages.into_iter().collect(), + groups: groups.into_iter().collect(), + max_pipeline_ray_recursion_depth: 1, + + ..RayTracingPipelineCreateInfo::layout(pipeline_layout.clone()) + }, + ) + .unwrap() + }; + + let vertices = [ + MyVertex { + position: [-0.5, -0.25, 0.0], + }, + MyVertex { + position: [0.0, 0.5, 0.0], + }, + MyVertex { + position: [0.25, -0.1, 0.0], + }, + ]; + let vertex_buffer = Buffer::from_iter( + memory_allocator.clone(), + BufferCreateInfo { + usage: BufferUsage::VERTEX_BUFFER + | BufferUsage::SHADER_DEVICE_ADDRESS + | BufferUsage::ACCELERATION_STRUCTURE_BUILD_INPUT_READ_ONLY, + ..Default::default() + }, + AllocationCreateInfo { + memory_type_filter: MemoryTypeFilter::PREFER_DEVICE + | MemoryTypeFilter::HOST_SEQUENTIAL_WRITE, + ..Default::default() + }, + vertices, + ) + .unwrap(); + + // Build the bottom-level acceleration structure and then the top-level acceleration + // structure. Acceleration structures are used to accelerate ray tracing. + // The bottom-level acceleration structure contains the geometry data. + // The top-level acceleration structure contains the instances of the bottom-level + // acceleration structures. In our shader, we will trace rays against the top-level + // acceleration structure. + let blas = unsafe { + build_acceleration_structure_triangles( + vertex_buffer, + memory_allocator.clone(), + command_buffer_allocator.clone(), + app.device.clone(), + app.queue.clone(), + ) + }; + + let tlas = unsafe { + build_top_level_acceleration_structure( + blas.clone(), + memory_allocator.clone(), + command_buffer_allocator.clone(), + app.device.clone(), + app.queue.clone(), + ) + }; + + let proj = Mat4::perspective_rh_gl(std::f32::consts::FRAC_PI_2, 4.0 / 3.0, 0.01, 100.0); + let view = Mat4::look_at_rh( + Vec3::new(0.0, 0.0, 1.0), + Vec3::new(0.0, 0.0, 0.0), + Vec3::new(0.0, -1.0, 0.0), + ); + + let uniform_buffer = Buffer::from_data( + memory_allocator.clone(), + BufferCreateInfo { + usage: BufferUsage::UNIFORM_BUFFER, + ..Default::default() + }, + AllocationCreateInfo { + memory_type_filter: MemoryTypeFilter::PREFER_DEVICE + | MemoryTypeFilter::HOST_SEQUENTIAL_WRITE, + ..Default::default() + }, + raygen::Camera { + viewInverse: view.inverse().to_cols_array_2d(), + projInverse: proj.inverse().to_cols_array_2d(), + viewProj: (proj * view).to_cols_array_2d(), + }, + ) + .unwrap(); + + let descriptor_set_0 = DescriptorSet::new( + descriptor_set_allocator.clone(), + pipeline_layout.set_layouts()[0].clone(), + [ + WriteDescriptorSet::acceleration_structure(0, tlas.clone()), + WriteDescriptorSet::buffer(1, uniform_buffer.clone()), + ], + [], + ) + .unwrap(); + + let swapchain_image_sets = + window_size_dependent_setup(images, &pipeline_layout, &descriptor_set_allocator); + + let shader_binding_table = + ShaderBindingTable::new(memory_allocator.clone(), &pipeline).unwrap(); + + Scene { + descriptor_set_0, + swapchain_image_sets, + descriptor_set_allocator, + pipeline_layout, + shader_binding_table, + pipeline, + _blas: blas, + _tlas: tlas, + } + } + + pub fn handle_resize(&mut self, images: &[Arc]) { + self.swapchain_image_sets = window_size_dependent_setup( + images, + &self.pipeline_layout, + &self.descriptor_set_allocator, + ); + } + + pub fn record_commands( + &self, + image_index: u32, + builder: &mut AutoCommandBufferBuilder, + ) { + builder + .bind_descriptor_sets( + PipelineBindPoint::RayTracing, + self.pipeline_layout.clone(), + 0, + vec![ + self.descriptor_set_0.clone(), + self.swapchain_image_sets[image_index as usize].1.clone(), + ], + ) + .unwrap(); + + builder + .bind_pipeline_ray_tracing(self.pipeline.clone()) + .unwrap(); + + let extent = self.swapchain_image_sets[0].0.image().extent(); + + unsafe { + builder + .trace_rays( + self.shader_binding_table.addresses().clone(), + extent[0], + extent[1], + 1, + ) + .unwrap(); + } + } +} + +/// This function is called once during initialization, then again whenever the window is resized. +fn window_size_dependent_setup( + images: &[Arc], + pipeline_layout: &Arc, + descriptor_set_allocator: &Arc, +) -> Vec<(Arc, Arc)> { + let swapchain_image_sets = images + .iter() + .map(|image| { + let image_view = ImageView::new_default(image.clone()).unwrap(); + let descriptor_set = DescriptorSet::new( + descriptor_set_allocator.clone(), + pipeline_layout.set_layouts()[1].clone(), + [WriteDescriptorSet::image_view(0, image_view.clone())], + [], + ) + .unwrap(); + (image_view, descriptor_set) + }) + .collect(); + + swapchain_image_sets +} + +/// A helper function to build a acceleration structure and wait for its completion. +/// # SAFETY +/// - If you are referencing a bottom-level acceleration structure in a top-level acceleration +/// structure, you must ensure that the bottom-level acceleration structure is kept alive. +unsafe fn build_acceleration_structure_common( + geometries: AccelerationStructureGeometries, + primitive_count: u32, + ty: AccelerationStructureType, + memory_allocator: Arc, + command_buffer_allocator: Arc, + device: Arc, + queue: Arc, +) -> Arc { + let mut as_build_geometry_info = AccelerationStructureBuildGeometryInfo { + mode: BuildAccelerationStructureMode::Build, + flags: BuildAccelerationStructureFlags::PREFER_FAST_TRACE, + ..AccelerationStructureBuildGeometryInfo::new(geometries) + }; + + let as_build_sizes_info = device + .acceleration_structure_build_sizes( + AccelerationStructureBuildType::Device, + &as_build_geometry_info, + &[primitive_count], + ) + .unwrap(); + + // We build a new scratch buffer for each acceleration structure for simplicity. + // You may want to reuse scratch buffers if you need to build many acceleration structures. + let scratch_buffer = Buffer::new_slice::( + memory_allocator.clone(), + BufferCreateInfo { + usage: BufferUsage::SHADER_DEVICE_ADDRESS | BufferUsage::STORAGE_BUFFER, + ..Default::default() + }, + AllocationCreateInfo::default(), + as_build_sizes_info.build_scratch_size, + ) + .unwrap(); + + let as_create_info = AccelerationStructureCreateInfo { + ty, + ..AccelerationStructureCreateInfo::new( + Buffer::new_slice::( + memory_allocator, + BufferCreateInfo { + usage: BufferUsage::ACCELERATION_STRUCTURE_STORAGE + | BufferUsage::SHADER_DEVICE_ADDRESS, + ..Default::default() + }, + AllocationCreateInfo::default(), + as_build_sizes_info.acceleration_structure_size, + ) + .unwrap(), + ) + }; + + let acceleration = unsafe { AccelerationStructure::new(device, as_create_info).unwrap() }; + + as_build_geometry_info.dst_acceleration_structure = Some(acceleration.clone()); + as_build_geometry_info.scratch_data = Some(scratch_buffer); + + let as_build_range_info = AccelerationStructureBuildRangeInfo { + primitive_count, + ..Default::default() + }; + + // For simplicity, we build a single command buffer + // that builds the acceleration structure, then waits + // for its execution to complete. + let mut builder = AutoCommandBufferBuilder::primary( + command_buffer_allocator, + queue.queue_family_index(), + CommandBufferUsage::OneTimeSubmit, + ) + .unwrap(); + + builder + .build_acceleration_structure( + as_build_geometry_info, + iter::once(as_build_range_info).collect(), + ) + .unwrap(); + + builder + .build() + .unwrap() + .execute(queue) + .unwrap() + .then_signal_fence_and_flush() + .unwrap() + .wait(None) + .unwrap(); + + acceleration +} + +unsafe fn build_acceleration_structure_triangles( + vertex_buffer: Subbuffer<[MyVertex]>, + memory_allocator: Arc, + command_buffer_allocator: Arc, + device: Arc, + queue: Arc, +) -> Arc { + let primitive_count = (vertex_buffer.len() / 3) as u32; + let as_geometry_triangles_data = AccelerationStructureGeometryTrianglesData { + max_vertex: vertex_buffer.len() as _, + vertex_data: Some(vertex_buffer.into_bytes()), + vertex_stride: size_of::() as _, + ..AccelerationStructureGeometryTrianglesData::new(Format::R32G32B32_SFLOAT) + }; + + let geometries = AccelerationStructureGeometries::Triangles(vec![as_geometry_triangles_data]); + + build_acceleration_structure_common( + geometries, + primitive_count, + AccelerationStructureType::BottomLevel, + memory_allocator, + command_buffer_allocator, + device, + queue, + ) +} + +unsafe fn build_top_level_acceleration_structure( + acceleration_structure: Arc, + allocator: Arc, + command_buffer_allocator: Arc, + device: Arc, + queue: Arc, +) -> Arc { + let as_instance = AccelerationStructureInstance { + acceleration_structure_reference: acceleration_structure.device_address().into(), + ..Default::default() + }; + + let instance_buffer = Buffer::from_iter( + allocator.clone(), + BufferCreateInfo { + usage: BufferUsage::SHADER_DEVICE_ADDRESS + | BufferUsage::ACCELERATION_STRUCTURE_BUILD_INPUT_READ_ONLY, + ..Default::default() + }, + AllocationCreateInfo { + memory_type_filter: MemoryTypeFilter::PREFER_DEVICE + | MemoryTypeFilter::HOST_SEQUENTIAL_WRITE, + ..Default::default() + }, + [as_instance], + ) + .unwrap(); + + let as_geometry_instances_data = AccelerationStructureGeometryInstancesData::new( + AccelerationStructureGeometryInstancesDataType::Values(Some(instance_buffer)), + ); + + let geometries = AccelerationStructureGeometries::Instances(as_geometry_instances_data); + + build_acceleration_structure_common( + geometries, + 1, + AccelerationStructureType::TopLevel, + allocator, + command_buffer_allocator, + device, + queue, + ) +} diff --git a/examples/ray-tracing/Cargo.toml b/examples/ray-tracing/Cargo.toml new file mode 100644 index 0000000000..005d984195 --- /dev/null +++ b/examples/ray-tracing/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "ray-tracing" +version = "0.0.0" +edition = "2021" +publish = false + +[[bin]] +name = "ray-tracing" +path = "main.rs" +test = false +bench = false +doc = false + +[dependencies] +vulkano = { workspace = true, default-features = true } +vulkano-shaders = { workspace = true } +vulkano-taskgraph = { workspace = true } +winit = { workspace = true, default-features = true } +ash = { workspace = true } +glam = { workspace = true } diff --git a/examples/ray-tracing/main.rs b/examples/ray-tracing/main.rs new file mode 100644 index 0000000000..18f26d2def --- /dev/null +++ b/examples/ray-tracing/main.rs @@ -0,0 +1,416 @@ +// TODO: document + +use scene::SceneTask; +use std::{error::Error, sync::Arc}; +use vulkano::{ + command_buffer::allocator::StandardCommandBufferAllocator, + descriptor_set::{ + allocator::StandardDescriptorSetAllocator, + layout::{ + DescriptorSetLayout, DescriptorSetLayoutBinding, DescriptorSetLayoutCreateInfo, + DescriptorType, + }, + }, + device::{ + physical::PhysicalDeviceType, Device, DeviceCreateInfo, DeviceExtensions, DeviceFeatures, + Queue, QueueCreateInfo, QueueFlags, + }, + image::{ImageFormatInfo, ImageUsage}, + instance::{Instance, InstanceCreateFlags, InstanceCreateInfo, InstanceExtensions}, + memory::allocator::StandardMemoryAllocator, + pipeline::{layout::PipelineLayoutCreateInfo, PipelineLayout}, + shader::ShaderStages, + swapchain::{Surface, Swapchain, SwapchainCreateInfo}, + Validated, Version, VulkanError, VulkanLibrary, +}; +use vulkano_taskgraph::{ + graph::{CompileInfo, ExecutableTaskGraph, ExecuteError, NodeId, TaskGraph}, + resource::{AccessType, Flight, ImageLayoutType, Resources}, + resource_map, Id, QueueFamilyType, +}; +use winit::{ + application::ApplicationHandler, + event::WindowEvent, + event_loop::{ActiveEventLoop, EventLoop}, + window::{Window, WindowId}, +}; + +mod scene; + +const MAX_FRAMES_IN_FLIGHT: u32 = 2; + +fn main() -> Result<(), impl Error> { + let event_loop = EventLoop::new().unwrap(); + let mut app = App::new(&event_loop); + + event_loop.run_app(&mut app) +} + +struct App { + instance: Arc, + device: Arc, + queue: Arc, + resources: Arc, + flight_id: Id, + rcx: Option, +} + +pub struct RenderContext { + window: Arc, + swapchain_id: Id, + recreate_swapchain: bool, + task_graph: ExecutableTaskGraph, + scene_node_id: NodeId, + virtual_swapchain_id: Id, +} + +impl App { + fn new(event_loop: &EventLoop<()>) -> Self { + let library = VulkanLibrary::new().unwrap(); + let required_extensions = Surface::required_extensions(event_loop).unwrap(); + let instance = Instance::new( + library, + InstanceCreateInfo { + flags: InstanceCreateFlags::ENUMERATE_PORTABILITY, + enabled_extensions: InstanceExtensions { + ext_debug_utils: true, + ext_swapchain_colorspace: true, + ..required_extensions + }, + ..Default::default() + }, + ) + .unwrap(); + + let device_extensions = DeviceExtensions { + khr_swapchain: true, + khr_ray_tracing_pipeline: true, + khr_ray_tracing_maintenance1: true, + khr_synchronization2: true, + khr_deferred_host_operations: true, + khr_acceleration_structure: true, + ..DeviceExtensions::empty() + }; + let (physical_device, queue_family_index) = instance + .enumerate_physical_devices() + .unwrap() + .filter(|p| p.api_version() >= Version::V1_3) + .filter(|p| p.supported_extensions().contains(&device_extensions)) + .filter_map(|p| { + p.queue_family_properties() + .iter() + .enumerate() + .position(|(i, q)| { + q.queue_flags + .contains(QueueFlags::GRAPHICS | QueueFlags::COMPUTE) + && p.presentation_support(i as u32, event_loop).unwrap() + }) + .map(|i| (p, i as u32)) + }) + .min_by_key(|(p, _)| match p.properties().device_type { + PhysicalDeviceType::DiscreteGpu => 0, + PhysicalDeviceType::IntegratedGpu => 1, + PhysicalDeviceType::VirtualGpu => 2, + PhysicalDeviceType::Cpu => 3, + PhysicalDeviceType::Other => 4, + _ => 5, + }) + .unwrap(); + + let (device, mut queues) = Device::new( + physical_device, + DeviceCreateInfo { + enabled_extensions: device_extensions, + queue_create_infos: vec![QueueCreateInfo { + queue_family_index, + ..Default::default() + }], + enabled_features: DeviceFeatures { + acceleration_structure: true, + ray_tracing_pipeline: true, + buffer_device_address: true, + synchronization2: true, + ..Default::default() + }, + ..Default::default() + }, + ) + .unwrap(); + + let queue = queues.next().unwrap(); + + let resources = Resources::new(&device, &Default::default()); + + let flight_id = resources.create_flight(MAX_FRAMES_IN_FLIGHT).unwrap(); + + App { + instance, + device, + queue, + resources, + flight_id, + rcx: None, + } + } +} + +impl ApplicationHandler for App { + fn resumed(&mut self, event_loop: &ActiveEventLoop) { + let window = Arc::new( + event_loop + .create_window(Window::default_attributes()) + .unwrap(), + ); + let surface = Surface::from_window(self.instance.clone(), window.clone()).unwrap(); + + let physical_device = self.device.physical_device(); + let supported_surface_formats = physical_device + .surface_formats(&surface, Default::default()) + .unwrap(); + + // For each supported format, check if it is supported for storage images + let supported_storage_formats = supported_surface_formats + .into_iter() + .filter(|(format, _)| { + physical_device + .image_format_properties(ImageFormatInfo { + format: *format, + usage: ImageUsage::STORAGE, + ..Default::default() + }) + .unwrap() + .is_some() + }) + .collect::>(); + + println!( + "Using device: {} (type: {:?})", + physical_device.properties().device_name, + physical_device.properties().device_type, + ); + + let swapchain_id = { + let surface_capabilities = self + .device + .physical_device() + .surface_capabilities(&surface, Default::default()) + .unwrap(); + + let (swapchain_format, swapchain_color_space) = supported_storage_formats + .first() + .map(|(format, color_space)| (*format, *color_space)) + .unwrap(); + + self.resources + .create_swapchain( + self.flight_id, + surface, + SwapchainCreateInfo { + min_image_count: surface_capabilities.min_image_count.max(3), + image_format: swapchain_format, + image_extent: window.inner_size().into(), + image_usage: ImageUsage::STORAGE | ImageUsage::COLOR_ATTACHMENT, + image_color_space: swapchain_color_space, + composite_alpha: surface_capabilities + .supported_composite_alpha + .into_iter() + .next() + .unwrap(), + ..Default::default() + }, + ) + .unwrap() + }; + + let pipeline_layout = PipelineLayout::new( + self.device.clone(), + PipelineLayoutCreateInfo { + set_layouts: vec![ + DescriptorSetLayout::new( + self.device.clone(), + DescriptorSetLayoutCreateInfo { + bindings: [ + ( + 0, + DescriptorSetLayoutBinding { + stages: ShaderStages::RAYGEN, + ..DescriptorSetLayoutBinding::descriptor_type( + DescriptorType::AccelerationStructure, + ) + }, + ), + ( + 1, + DescriptorSetLayoutBinding { + stages: ShaderStages::RAYGEN, + ..DescriptorSetLayoutBinding::descriptor_type( + DescriptorType::UniformBuffer, + ) + }, + ), + ] + .into_iter() + .collect(), + ..Default::default() + }, + ) + .unwrap(), + DescriptorSetLayout::new( + self.device.clone(), + DescriptorSetLayoutCreateInfo { + bindings: [( + 0, + DescriptorSetLayoutBinding { + stages: ShaderStages::RAYGEN, + ..DescriptorSetLayoutBinding::descriptor_type( + DescriptorType::StorageImage, + ) + }, + )] + .into_iter() + .collect(), + ..Default::default() + }, + ) + .unwrap(), + ], + push_constant_ranges: vec![], + ..Default::default() + }, + ) + .unwrap(); + + let descriptor_set_allocator = Arc::new(StandardDescriptorSetAllocator::new( + self.device.clone(), + Default::default(), + )); + + let memory_allocator = Arc::new(StandardMemoryAllocator::new_default(self.device.clone())); + + let command_buffer_allocator = Arc::new(StandardCommandBufferAllocator::new( + self.device.clone(), + Default::default(), + )); + + let mut task_graph = TaskGraph::new(&self.resources, 3, 2); + + let virtual_swapchain_id = task_graph.add_swapchain(&SwapchainCreateInfo::default()); + + let scene_node_id = task_graph + .create_task_node( + "Scene", + QueueFamilyType::Graphics, + SceneTask::new( + self, + pipeline_layout.clone(), + swapchain_id, + virtual_swapchain_id, + descriptor_set_allocator, + memory_allocator, + command_buffer_allocator, + ), + ) + .image_access( + virtual_swapchain_id.current_image_id(), + AccessType::RayTracingShaderStorageWrite, + ImageLayoutType::General, + ) + .build(); + + let task_graph = unsafe { + task_graph.compile(&CompileInfo { + queues: &[&self.queue], + present_queue: Some(&self.queue), + flight_id: self.flight_id, + ..Default::default() + }) + } + .unwrap(); + + self.rcx = Some(RenderContext { + window, + swapchain_id, + virtual_swapchain_id, + recreate_swapchain: false, + task_graph, + scene_node_id, + }); + } + + fn window_event( + &mut self, + event_loop: &ActiveEventLoop, + _window_id: WindowId, + event: WindowEvent, + ) { + let rcx = self.rcx.as_mut().unwrap(); + + match event { + WindowEvent::CloseRequested => { + event_loop.exit(); + } + WindowEvent::Resized(_) => { + rcx.recreate_swapchain = true; + } + WindowEvent::RedrawRequested => { + let window_size = rcx.window.inner_size(); + + if window_size.width == 0 || window_size.height == 0 { + return; + } + + let flight = self.resources.flight(self.flight_id).unwrap(); + + if rcx.recreate_swapchain { + rcx.swapchain_id = self + .resources + .recreate_swapchain(rcx.swapchain_id, |create_info| SwapchainCreateInfo { + image_extent: window_size.into(), + ..create_info + }) + .expect("failed to recreate swapchain"); + + rcx.task_graph + .task_node_mut(rcx.scene_node_id) + .unwrap() + .task_mut() + .downcast_mut::() + .unwrap() + .handle_resize(&self.resources, rcx.swapchain_id); + + rcx.recreate_swapchain = false; + } + + flight.wait(None).unwrap(); + + let resource_map = resource_map!( + &rcx.task_graph, + rcx.virtual_swapchain_id => rcx.swapchain_id, + ) + .unwrap(); + + match unsafe { + rcx.task_graph + .execute(resource_map, rcx, || rcx.window.pre_present_notify()) + } { + Ok(()) => {} + Err(ExecuteError::Swapchain { + error: Validated::Error(VulkanError::OutOfDate), + .. + }) => { + rcx.recreate_swapchain = true; + } + Err(e) => { + panic!("failed to execute next frame: {e:?}"); + } + } + } + _ => {} + } + } + + fn about_to_wait(&mut self, _event_loop: &ActiveEventLoop) { + let rcx = self.rcx.as_mut().unwrap(); + rcx.window.request_redraw(); + } +} diff --git a/examples/ray-tracing/raytrace.miss b/examples/ray-tracing/raytrace.miss new file mode 100644 index 0000000000..1c584d5420 --- /dev/null +++ b/examples/ray-tracing/raytrace.miss @@ -0,0 +1,6 @@ +#version 460 +#extension GL_EXT_ray_tracing : require + +layout(location = 0) rayPayloadInEXT vec3 hitValue; + +void main() { hitValue = vec3(0.0, 0.0, 0.2); } diff --git a/examples/ray-tracing/raytrace.rchit b/examples/ray-tracing/raytrace.rchit new file mode 100644 index 0000000000..52c407b96a --- /dev/null +++ b/examples/ray-tracing/raytrace.rchit @@ -0,0 +1,10 @@ +#version 460 +#extension GL_EXT_ray_tracing : require + +layout(location = 0) rayPayloadInEXT vec3 hitValue; +hitAttributeEXT vec2 attribs; + +void main() { + vec3 barycentrics = vec3(1.0 - attribs.x - attribs.y, attribs.x, attribs.y); + hitValue = barycentrics; +} diff --git a/examples/ray-tracing/raytrace.rgen b/examples/ray-tracing/raytrace.rgen new file mode 100644 index 0000000000..8a9416e201 --- /dev/null +++ b/examples/ray-tracing/raytrace.rgen @@ -0,0 +1,43 @@ +#version 460 +#extension GL_EXT_ray_tracing : require + +struct Camera { + mat4 viewProj; // Camera view * projection + mat4 viewInverse; // Camera inverse view matrix + mat4 projInverse; // Camera inverse projection matrix +}; + +layout(location = 0) rayPayloadEXT vec3 hitValue; + +layout(set = 0, binding = 0) uniform accelerationStructureEXT topLevelAS; +layout(set = 0, binding = 1) uniform _Camera { Camera camera; }; +layout(set = 1, binding = 0, rgba32f) uniform image2D image; + +void main() { + const vec2 pixelCenter = vec2(gl_LaunchIDEXT.xy) + vec2(0.5); + const vec2 inUV = pixelCenter / vec2(gl_LaunchSizeEXT.xy); + vec2 d = inUV * 2.0 - 1.0; + + vec4 origin = camera.viewInverse * vec4(0, 0, 0, 1); + vec4 target = camera.projInverse * vec4(d.x, d.y, 1, 1); + vec4 direction = camera.viewInverse * vec4(normalize(target.xyz), 0); + + uint rayFlags = gl_RayFlagsOpaqueEXT; + float tMin = 0.001; + float tMax = 10000.0; + + traceRayEXT(topLevelAS, // acceleration structure + rayFlags, // rayFlags + 0xFF, // cullMask + 0, // sbtRecordOffset + 0, // sbtRecordStride + 0, // missIndex + origin.xyz, // ray origin + tMin, // ray min range + direction.xyz, // ray direction + tMax, // ray max range + 0 // payload (location = 0) + ); + + imageStore(image, ivec2(gl_LaunchIDEXT.xy), vec4(hitValue, 1.0)); +} diff --git a/examples/ray-tracing/scene.rs b/examples/ray-tracing/scene.rs new file mode 100644 index 0000000000..b01eec3ed0 --- /dev/null +++ b/examples/ray-tracing/scene.rs @@ -0,0 +1,499 @@ +use crate::{App, RenderContext}; +use glam::{Mat4, Vec3}; +use std::{iter, mem::size_of, sync::Arc}; +use vulkano::{ + acceleration_structure::{ + AccelerationStructure, AccelerationStructureBuildGeometryInfo, + AccelerationStructureBuildRangeInfo, AccelerationStructureBuildType, + AccelerationStructureCreateInfo, AccelerationStructureGeometries, + AccelerationStructureGeometryInstancesData, AccelerationStructureGeometryInstancesDataType, + AccelerationStructureGeometryTrianglesData, AccelerationStructureInstance, + AccelerationStructureType, BuildAccelerationStructureFlags, BuildAccelerationStructureMode, + }, + buffer::{Buffer, BufferContents, BufferCreateInfo, BufferUsage, Subbuffer}, + command_buffer::{ + allocator::CommandBufferAllocator, AutoCommandBufferBuilder, CommandBufferUsage, + PrimaryCommandBufferAbstract, + }, + descriptor_set::{ + allocator::StandardDescriptorSetAllocator, sys::RawDescriptorSet, WriteDescriptorSet, + }, + device::{Device, Queue}, + format::Format, + image::view::ImageView, + memory::allocator::{AllocationCreateInfo, MemoryAllocator, MemoryTypeFilter}, + pipeline::{ + graphics::vertex_input::Vertex, + ray_tracing::{ + RayTracingPipeline, RayTracingPipelineCreateInfo, RayTracingShaderGroupCreateInfo, + ShaderBindingTable, + }, + PipelineBindPoint, PipelineLayout, PipelineShaderStageCreateInfo, + }, + swapchain::Swapchain, + sync::GpuFuture, +}; +use vulkano_taskgraph::{ + command_buffer::RecordingCommandBuffer, resource::Resources, Id, Task, TaskContext, TaskResult, +}; + +mod raygen { + vulkano_shaders::shader! { + ty: "raygen", + path: "raytrace.rgen", + vulkan_version: "1.2" + } +} + +mod closest_hit { + vulkano_shaders::shader! { + ty: "closesthit", + path: "raytrace.rchit", + vulkan_version: "1.2" + } +} + +mod miss { + vulkano_shaders::shader! { + ty: "miss", + path: "raytrace.miss", + vulkan_version: "1.2" + } +} + +#[derive(BufferContents, Vertex)] +#[repr(C)] +struct MyVertex { + #[format(R32G32B32_SFLOAT)] + position: [f32; 3], +} + +pub struct SceneTask { + descriptor_set_0: Arc, + swapchain_image_sets: Vec<(Arc, Arc)>, + pipeline_layout: Arc, + descriptor_set_allocator: Arc, + virtual_swapchain_id: Id, + shader_binding_table: ShaderBindingTable, + pipeline: Arc, + blas: Arc, + tlas: Arc, + uniform_buffer: Subbuffer, +} + +impl SceneTask { + pub fn new( + app: &App, + pipeline_layout: Arc, + swapchain_id: Id, + virtual_swapchain_id: Id, + descriptor_set_allocator: Arc, + memory_allocator: Arc, + command_buffer_allocator: Arc, + ) -> Self { + let pipeline = { + let raygen = raygen::load(app.device.clone()) + .unwrap() + .entry_point("main") + .unwrap(); + let closest_hit = closest_hit::load(app.device.clone()) + .unwrap() + .entry_point("main") + .unwrap(); + + let miss = miss::load(app.device.clone()) + .unwrap() + .entry_point("main") + .unwrap(); + + // Make a list of the shader stages that the pipeline will have. + let stages = [ + PipelineShaderStageCreateInfo::new(raygen), + PipelineShaderStageCreateInfo::new(miss), + PipelineShaderStageCreateInfo::new(closest_hit), + ]; + + let groups = [ + RayTracingShaderGroupCreateInfo::General { general_shader: 0 }, + RayTracingShaderGroupCreateInfo::General { general_shader: 1 }, + RayTracingShaderGroupCreateInfo::TrianglesHit { + closest_hit_shader: Some(2), + any_hit_shader: None, + }, + ]; + + RayTracingPipeline::new( + app.device.clone(), + None, + RayTracingPipelineCreateInfo { + stages: stages.into_iter().collect(), + groups: groups.into_iter().collect(), + max_pipeline_ray_recursion_depth: 1, + + ..RayTracingPipelineCreateInfo::layout(pipeline_layout.clone()) + }, + ) + .unwrap() + }; + + let vertices = [ + MyVertex { + position: [-0.5, -0.25, 0.0], + }, + MyVertex { + position: [0.0, 0.5, 0.0], + }, + MyVertex { + position: [0.25, -0.1, 0.0], + }, + ]; + let vertex_buffer = Buffer::from_iter( + memory_allocator.clone(), + BufferCreateInfo { + usage: BufferUsage::VERTEX_BUFFER + | BufferUsage::SHADER_DEVICE_ADDRESS + | BufferUsage::ACCELERATION_STRUCTURE_BUILD_INPUT_READ_ONLY, + ..Default::default() + }, + AllocationCreateInfo { + memory_type_filter: MemoryTypeFilter::PREFER_DEVICE + | MemoryTypeFilter::HOST_SEQUENTIAL_WRITE, + ..Default::default() + }, + vertices, + ) + .unwrap(); + + let blas = unsafe { + build_acceleration_structure_triangles( + vertex_buffer, + memory_allocator.clone(), + command_buffer_allocator.clone(), + app.device.clone(), + app.queue.clone(), + ) + }; + + let tlas = unsafe { + build_top_level_acceleration_structure( + blas.clone(), + memory_allocator.clone(), + command_buffer_allocator.clone(), + app.device.clone(), + app.queue.clone(), + ) + }; + + let proj = Mat4::perspective_rh_gl(std::f32::consts::FRAC_PI_2, 4.0 / 3.0, 0.01, 100.0); + let view = Mat4::look_at_rh( + Vec3::new(0.0, 0.0, 1.0), + Vec3::new(0.0, 0.0, 0.0), + Vec3::new(0.0, -1.0, 0.0), + ); + + let uniform_buffer = Buffer::from_data( + memory_allocator.clone(), + BufferCreateInfo { + usage: BufferUsage::UNIFORM_BUFFER, + ..Default::default() + }, + AllocationCreateInfo { + memory_type_filter: MemoryTypeFilter::PREFER_DEVICE + | MemoryTypeFilter::HOST_SEQUENTIAL_WRITE, + ..Default::default() + }, + raygen::Camera { + viewInverse: view.inverse().to_cols_array_2d(), + projInverse: proj.inverse().to_cols_array_2d(), + viewProj: (proj * view).to_cols_array_2d(), + }, + ) + .unwrap(); + + let descriptor_set_0 = RawDescriptorSet::new( + descriptor_set_allocator.clone(), + &pipeline_layout.set_layouts()[0], + 0, + ) + .unwrap(); + + unsafe { + let writes = &[ + WriteDescriptorSet::acceleration_structure(0, tlas.clone()), + WriteDescriptorSet::buffer(1, uniform_buffer.clone()), + ]; + descriptor_set_0.update(writes, &[]).unwrap(); + } + + let swapchain_image_sets = window_size_dependent_setup( + &app.resources, + swapchain_id, + &pipeline_layout, + &descriptor_set_allocator, + ); + + let shader_binding_table = + ShaderBindingTable::new(memory_allocator.clone(), &pipeline).unwrap(); + + SceneTask { + descriptor_set_0: Arc::new(descriptor_set_0), + swapchain_image_sets, + descriptor_set_allocator, + pipeline_layout, + virtual_swapchain_id, + shader_binding_table, + pipeline, + blas, + tlas, + uniform_buffer, + } + } + + pub fn handle_resize(&mut self, resources: &Resources, swapchain_id: Id) { + self.swapchain_image_sets = window_size_dependent_setup( + resources, + swapchain_id, + &self.pipeline_layout, + &self.descriptor_set_allocator, + ); + } +} + +impl Task for SceneTask { + type World = RenderContext; + + unsafe fn execute( + &self, + cbf: &mut RecordingCommandBuffer<'_>, + tcx: &mut TaskContext<'_>, + _rcx: &Self::World, + ) -> TaskResult { + let swapchain_state = tcx.swapchain(self.virtual_swapchain_id)?; + let image_index = swapchain_state.current_image_index().unwrap(); + + cbf.as_raw().bind_descriptor_sets( + PipelineBindPoint::RayTracing, + &self.pipeline_layout, + 0, + &[ + &self.descriptor_set_0, + &self.swapchain_image_sets[image_index as usize].1, + ], + &[], + )?; + + cbf.bind_pipeline_ray_tracing(&self.pipeline)?; + + let extent = self.swapchain_image_sets[0].0.image().extent(); + + unsafe { + cbf.trace_rays( + self.shader_binding_table.addresses(), + extent[0], + extent[1], + 1, + ) + }?; + + for (image_view, descriptor_set) in self.swapchain_image_sets.iter() { + cbf.destroy_object(descriptor_set.clone()); + cbf.destroy_object(image_view.clone()); + } + cbf.destroy_object(self.blas.clone()); + cbf.destroy_object(self.tlas.clone()); + cbf.destroy_object(self.uniform_buffer.clone().into()); + cbf.destroy_object(self.descriptor_set_0.clone()); + + Ok(()) + } +} + +/// This function is called once during initialization, then again whenever the window is resized. +fn window_size_dependent_setup( + resources: &Resources, + swapchain_id: Id, + pipeline_layout: &Arc, + descriptor_set_allocator: &Arc, +) -> Vec<(Arc, Arc)> { + let swapchain_state = resources.swapchain(swapchain_id).unwrap(); + let images = swapchain_state.images(); + + let swapchain_image_sets = images + .iter() + .map(|image| { + let descriptor_set = RawDescriptorSet::new( + descriptor_set_allocator.clone(), + &pipeline_layout.set_layouts()[1], + 0, + ) + .unwrap(); + let image_view = ImageView::new_default(image.clone()).unwrap(); + let writes = &[WriteDescriptorSet::image_view(0, image_view.clone())]; + unsafe { descriptor_set.update(writes, &[]) }.unwrap(); + (image_view, Arc::new(descriptor_set)) + }) + .collect(); + + swapchain_image_sets +} + +unsafe fn build_acceleration_structure_common( + geometries: AccelerationStructureGeometries, + primitive_count: u32, + ty: AccelerationStructureType, + memory_allocator: Arc, + command_buffer_allocator: Arc, + device: Arc, + queue: Arc, +) -> Arc { + let mut builder = AutoCommandBufferBuilder::primary( + command_buffer_allocator, + queue.queue_family_index(), + CommandBufferUsage::OneTimeSubmit, + ) + .unwrap(); + + let mut as_build_geometry_info = AccelerationStructureBuildGeometryInfo { + mode: BuildAccelerationStructureMode::Build, + flags: BuildAccelerationStructureFlags::PREFER_FAST_TRACE, + ..AccelerationStructureBuildGeometryInfo::new(geometries) + }; + + let as_build_sizes_info = device + .acceleration_structure_build_sizes( + AccelerationStructureBuildType::Device, + &as_build_geometry_info, + &[primitive_count], + ) + .unwrap(); + + let scratch_buffer = Buffer::new_slice::( + memory_allocator.clone(), + BufferCreateInfo { + usage: BufferUsage::SHADER_DEVICE_ADDRESS | BufferUsage::STORAGE_BUFFER, + ..Default::default() + }, + AllocationCreateInfo::default(), + as_build_sizes_info.build_scratch_size, + ) + .unwrap(); + + let as_create_info = AccelerationStructureCreateInfo { + ty, + ..AccelerationStructureCreateInfo::new( + Buffer::new_slice::( + memory_allocator, + BufferCreateInfo { + usage: BufferUsage::ACCELERATION_STRUCTURE_STORAGE + | BufferUsage::SHADER_DEVICE_ADDRESS, + ..Default::default() + }, + AllocationCreateInfo::default(), + as_build_sizes_info.acceleration_structure_size, + ) + .unwrap(), + ) + }; + + let acceleration = unsafe { AccelerationStructure::new(device, as_create_info).unwrap() }; + + as_build_geometry_info.dst_acceleration_structure = Some(acceleration.clone()); + as_build_geometry_info.scratch_data = Some(scratch_buffer); + + let as_build_range_info = AccelerationStructureBuildRangeInfo { + primitive_count, + ..Default::default() + }; + + builder + .build_acceleration_structure( + as_build_geometry_info, + iter::once(as_build_range_info).collect(), + ) + .unwrap(); + + builder + .build() + .unwrap() + .execute(queue) + .unwrap() + .then_signal_fence_and_flush() + .unwrap() + .wait(None) + .unwrap(); + + acceleration +} + +unsafe fn build_acceleration_structure_triangles( + vertex_buffer: Subbuffer<[MyVertex]>, + memory_allocator: Arc, + command_buffer_allocator: Arc, + device: Arc, + queue: Arc, +) -> Arc { + let primitive_count = (vertex_buffer.len() / 3) as u32; + let as_geometry_triangles_data = AccelerationStructureGeometryTrianglesData { + max_vertex: vertex_buffer.len() as _, + vertex_data: Some(vertex_buffer.into_bytes()), + vertex_stride: size_of::() as _, + ..AccelerationStructureGeometryTrianglesData::new(Format::R32G32B32_SFLOAT) + }; + + let geometries = AccelerationStructureGeometries::Triangles(vec![as_geometry_triangles_data]); + + build_acceleration_structure_common( + geometries, + primitive_count, + AccelerationStructureType::BottomLevel, + memory_allocator, + command_buffer_allocator, + device, + queue, + ) +} + +unsafe fn build_top_level_acceleration_structure( + acceleration_structure: Arc, + allocator: Arc, + command_buffer_allocator: Arc, + device: Arc, + queue: Arc, +) -> Arc { + let as_instance = AccelerationStructureInstance { + acceleration_structure_reference: acceleration_structure.device_address().into(), + ..Default::default() + }; + + let instance_buffer = Buffer::from_iter( + allocator.clone(), + BufferCreateInfo { + usage: BufferUsage::SHADER_DEVICE_ADDRESS + | BufferUsage::ACCELERATION_STRUCTURE_BUILD_INPUT_READ_ONLY, + ..Default::default() + }, + AllocationCreateInfo { + memory_type_filter: MemoryTypeFilter::PREFER_DEVICE + | MemoryTypeFilter::HOST_SEQUENTIAL_WRITE, + ..Default::default() + }, + [as_instance], + ) + .unwrap(); + + let as_geometry_instances_data = AccelerationStructureGeometryInstancesData::new( + AccelerationStructureGeometryInstancesDataType::Values(Some(instance_buffer)), + ); + + let geometries = AccelerationStructureGeometries::Instances(as_geometry_instances_data); + + build_acceleration_structure_common( + geometries, + 1, + AccelerationStructureType::TopLevel, + allocator, + command_buffer_allocator, + device, + queue, + ) +} diff --git a/vulkano-taskgraph/src/command_buffer/commands/bind_push.rs b/vulkano-taskgraph/src/command_buffer/commands/bind_push.rs index a15a0a997f..17a98b5029 100644 --- a/vulkano-taskgraph/src/command_buffer/commands/bind_push.rs +++ b/vulkano-taskgraph/src/command_buffer/commands/bind_push.rs @@ -9,7 +9,9 @@ use vulkano::{ self, buffer::{Buffer, BufferContents, IndexType}, device::DeviceOwned, - pipeline::{ComputePipeline, GraphicsPipeline, PipelineLayout}, + pipeline::{ + ray_tracing::RayTracingPipeline, ComputePipeline, GraphicsPipeline, PipelineLayout, + }, DeviceSize, Version, VulkanObject, }; @@ -115,6 +117,31 @@ impl RecordingCommandBuffer<'_> { self } + pub unsafe fn bind_pipeline_ray_tracing( + &mut self, + pipeline: &Arc, + ) -> Result<&mut Self> { + Ok(unsafe { self.bind_pipeline_ray_tracing_unchecked(pipeline) }) + } + + pub unsafe fn bind_pipeline_ray_tracing_unchecked( + &mut self, + pipeline: &Arc, + ) -> &mut Self { + let fns = self.device().fns(); + unsafe { + (fns.v1_0.cmd_bind_pipeline)( + self.handle(), + vk::PipelineBindPoint::RAY_TRACING_KHR, + pipeline.handle(), + ) + }; + + self.death_row.push(pipeline.clone()); + + self + } + /// Binds vertex buffers for future draw calls. pub unsafe fn bind_vertex_buffers( &mut self, diff --git a/vulkano-taskgraph/src/command_buffer/commands/pipeline.rs b/vulkano-taskgraph/src/command_buffer/commands/pipeline.rs index e87d112b6c..8cac98bdaa 100644 --- a/vulkano-taskgraph/src/command_buffer/commands/pipeline.rs +++ b/vulkano-taskgraph/src/command_buffer/commands/pipeline.rs @@ -7,7 +7,10 @@ use vulkano::command_buffer::{ DispatchIndirectCommand, DrawIndexedIndirectCommand, DrawIndirectCommand, DrawMeshTasksIndirectCommand, }; -use vulkano::{buffer::Buffer, device::DeviceOwned, DeviceSize, Version, VulkanObject}; +use vulkano::{ + buffer::Buffer, device::DeviceOwned, pipeline::ray_tracing::ShaderBindingTableAddresses, + DeviceSize, Version, VulkanObject, +}; /// # Commands to execute a bound pipeline /// @@ -658,4 +661,45 @@ impl RecordingCommandBuffer<'_> { self } + + pub unsafe fn trace_rays( + &mut self, + shader_binding_table_addresses: &ShaderBindingTableAddresses, + width: u32, + height: u32, + depth: u32, + ) -> Result<&mut Self> { + Ok(unsafe { + self.trace_rays_unchecked(shader_binding_table_addresses, width, height, depth) + }) + } + + pub unsafe fn trace_rays_unchecked( + &mut self, + shader_binding_table_addresses: &ShaderBindingTableAddresses, + width: u32, + height: u32, + depth: u32, + ) -> &mut Self { + let fns = self.device().fns(); + + let raygen = shader_binding_table_addresses.raygen.to_vk(); + let miss = shader_binding_table_addresses.miss.to_vk(); + let hit = shader_binding_table_addresses.hit.to_vk(); + let callable = shader_binding_table_addresses.callable.to_vk(); + unsafe { + (fns.khr_ray_tracing_pipeline.cmd_trace_rays_khr)( + self.handle(), + &raygen, + &miss, + &hit, + &callable, + width, + height, + depth, + ); + } + + self + } } diff --git a/vulkano-taskgraph/src/resource.rs b/vulkano-taskgraph/src/resource.rs index 3613453bda..8fb33129bb 100644 --- a/vulkano-taskgraph/src/resource.rs +++ b/vulkano-taskgraph/src/resource.rs @@ -1600,69 +1600,61 @@ access_types! { // valid_for: IMAGE, // } - // TODO: - // RayTracingShaderUniformRead { - // stage_mask: RAY_TRACING_SHADER, - // access_mask: UNIFORM_READ, - // image_layout: Undefined, - // valid_for: BUFFER, - // } + RayTracingShaderUniformRead { + stage_mask: RAY_TRACING_SHADER, + access_mask: UNIFORM_READ, + image_layout: Undefined, + valid_for: BUFFER, + } - // TODO: - // RayTracingShaderColorInputAttachmentRead { - // stage_mask: RAY_TRACING_SHADER, - // access_mask: INPUT_ATTACHMENT_READ, - // image_layout: ShaderReadOnlyOptimal, - // valid_for: IMAGE, - // } + RayTracingShaderColorInputAttachmentRead { + stage_mask: RAY_TRACING_SHADER, + access_mask: INPUT_ATTACHMENT_READ, + image_layout: ShaderReadOnlyOptimal, + valid_for: IMAGE, + } - // TODO: - // RayTracingShaderDepthStencilInputAttachmentRead { - // stage_mask: RAY_TRACING_SHADER, - // access_mask: INPUT_ATTACHMENT_READ, - // image_layout: DepthStencilReadOnlyOptimal, - // valid_for: IMAGE, - // } + RayTracingShaderDepthStencilInputAttachmentRead { + stage_mask: RAY_TRACING_SHADER, + access_mask: INPUT_ATTACHMENT_READ, + image_layout: DepthStencilReadOnlyOptimal, + valid_for: IMAGE, + } - // TODO: - // RayTracingShaderSampledRead { - // stage_mask: RAY_TRACING_SHADER, - // access_mask: SHADER_SAMPLED_READ, - // image_layout: ShaderReadOnlyOptimal, - // valid_for: BUFFER | IMAGE, - // } + RayTracingShaderSampledRead { + stage_mask: RAY_TRACING_SHADER, + access_mask: SHADER_SAMPLED_READ, + image_layout: ShaderReadOnlyOptimal, + valid_for: BUFFER | IMAGE, + } - // TODO: - // RayTracingShaderStorageRead { - // stage_mask: RAY_TRACING_SHADER, - // access_mask: SHADER_STORAGE_READ, - // image_layout: General, - // valid_for: BUFFER | IMAGE, - // } + RayTracingShaderStorageRead { + stage_mask: RAY_TRACING_SHADER, + access_mask: SHADER_STORAGE_READ, + image_layout: General, + valid_for: BUFFER | IMAGE, + } - // TODO: - // RayTracingShaderStorageWrite { - // stage_mask: RAY_TRACING_SHADER, - // access_mask: SHADER_STORAGE_WRITE, - // image_layout: General, - // valid_for: BUFFER | IMAGE, - // } + RayTracingShaderStorageWrite { + stage_mask: RAY_TRACING_SHADER, + access_mask: SHADER_STORAGE_WRITE, + image_layout: General, + valid_for: BUFFER | IMAGE, + } - // TODO: - // RayTracingShaderBindingTableRead { - // stage_mask: RAY_TRACING_SHADER, - // access_mask: SHADER_BINDING_TABLE_READ, - // image_layout: Undefined, - // valid_for: BUFFER, - // } + RayTracingShaderBindingTableRead { + stage_mask: RAY_TRACING_SHADER, + access_mask: SHADER_BINDING_TABLE_READ, + image_layout: Undefined, + valid_for: BUFFER, + } - // TODO: - // RayTracingShaderAccelerationStructureRead { - // stage_mask: RAY_TRACING_SHADER, - // access_mask: ACCELERATION_STRUCTURE_READ, - // image_layout: Undefined, - // valid_for: BUFFER, - // } + RayTracingShaderAccelerationStructureRead { + stage_mask: RAY_TRACING_SHADER, + access_mask: ACCELERATION_STRUCTURE_READ, + image_layout: Undefined, + valid_for: BUFFER, + } TaskShaderUniformRead { stage_mask: TASK_SHADER, diff --git a/vulkano/src/buffer/usage.rs b/vulkano/src/buffer/usage.rs index 7b57595561..72f9eb0a19 100644 --- a/vulkano/src/buffer/usage.rs +++ b/vulkano/src/buffer/usage.rs @@ -97,13 +97,13 @@ vulkan_bitflags! { RequiresAllOf([DeviceExtension(khr_acceleration_structure)]), ]), - /* TODO: enable + // TODO: document SHADER_BINDING_TABLE = SHADER_BINDING_TABLE_KHR RequiresOneOf([ RequiresAllOf([DeviceExtension(khr_ray_tracing_pipeline)]), RequiresAllOf([DeviceExtension(nv_ray_tracing)]), - ]),*/ + ]), /* TODO: enable // TODO: document diff --git a/vulkano/src/command_buffer/auto/builder.rs b/vulkano/src/command_buffer/auto/builder.rs index 34a048d8b9..5ef2df87a2 100644 --- a/vulkano/src/command_buffer/auto/builder.rs +++ b/vulkano/src/command_buffer/auto/builder.rs @@ -29,6 +29,7 @@ use crate::{ vertex_input::VertexInputState, viewport::{Scissor, Viewport}, }, + ray_tracing::RayTracingPipeline, ComputePipeline, DynamicState, GraphicsPipeline, PipelineBindPoint, PipelineLayout, }, query::{QueryControlFlags, QueryPool, QueryType}, @@ -1292,6 +1293,7 @@ pub(in crate::command_buffer) struct CommandBufferBuilderState { pub(in crate::command_buffer) index_buffer: Option, pub(in crate::command_buffer) pipeline_compute: Option>, pub(in crate::command_buffer) pipeline_graphics: Option>, + pub(in crate::command_buffer) pipeline_ray_tracing: Option>, pub(in crate::command_buffer) vertex_buffers: HashMap>, pub(in crate::command_buffer) push_constants: RangeSet, pub(in crate::command_buffer) push_constants_pipeline_layout: Option>, diff --git a/vulkano/src/command_buffer/commands/bind_push.rs b/vulkano/src/command_buffer/commands/bind_push.rs index 37d54ac6df..5dfdea2926 100644 --- a/vulkano/src/command_buffer/commands/bind_push.rs +++ b/vulkano/src/command_buffer/commands/bind_push.rs @@ -10,8 +10,8 @@ use crate::{ device::{DeviceOwned, QueueFlags}, memory::is_aligned, pipeline::{ - graphics::vertex_input::VertexBuffersCollection, ComputePipeline, GraphicsPipeline, - PipelineBindPoint, PipelineLayout, + graphics::vertex_input::VertexBuffersCollection, ray_tracing::RayTracingPipeline, + ComputePipeline, GraphicsPipeline, PipelineBindPoint, PipelineLayout, }, DeviceSize, Requires, RequiresAllOf, RequiresOneOf, ValidationError, Version, VulkanObject, }; @@ -378,6 +378,31 @@ impl AutoCommandBufferBuilder { self } + pub fn bind_pipeline_ray_tracing( + &mut self, + pipeline: Arc, + ) -> Result<&mut Self, Box> { + self.inner.validate_bind_pipeline_ray_tracing(&pipeline)?; + Ok(unsafe { self.bind_pipeline_ray_tracing_unchecked(pipeline) }) + } + + #[cfg_attr(not(feature = "document_unchecked"), doc(hidden))] + pub unsafe fn bind_pipeline_ray_tracing_unchecked( + &mut self, + pipeline: Arc, + ) -> &mut Self { + self.builder_state.pipeline_ray_tracing = Some(pipeline.clone()); + self.add_command( + "bind_pipeline_ray_tracing", + Default::default(), + move |out: &mut RecordingCommandBuffer| { + out.bind_pipeline_ray_tracing_unchecked(&pipeline); + }, + ); + + self + } + /// Binds vertex buffers for future draw calls. pub fn bind_vertex_buffers( &mut self, @@ -794,6 +819,25 @@ impl RecordingCommandBuffer { })); } } + PipelineBindPoint::RayTracing => { + if !queue_family_properties + .queue_flags + .intersects(QueueFlags::COMPUTE) + { + return Err(Box::new(ValidationError { + context: "pipeline_bind_point".into(), + problem: "is `PipelineBindPoint::RayTracing`, but \ + the queue family of the command buffer does not support \ + compute operations" + .into(), + vuids: &[ + "VUID-vkCmdBindDescriptorSets-pipelineBindPoint-02391", + "VUID-vkCmdBindDescriptorSets-commandBuffer-cmdpool", + ], + ..Default::default() + })); + } + } } if first_set + descriptor_sets as u32 > pipeline_layout.set_layouts().len() as u32 { @@ -1018,6 +1062,55 @@ impl RecordingCommandBuffer { self } + pub unsafe fn bind_pipeline_ray_tracing( + &mut self, + pipeline: &RayTracingPipeline, + ) -> Result<&mut Self, Box> { + self.validate_bind_pipeline_ray_tracing(pipeline)?; + Ok(self.bind_pipeline_ray_tracing_unchecked(pipeline)) + } + + fn validate_bind_pipeline_ray_tracing( + &self, + pipeline: &RayTracingPipeline, + ) -> Result<(), Box> { + if !self + .queue_family_properties() + .queue_flags + .intersects(QueueFlags::COMPUTE) + { + return Err(Box::new(ValidationError { + problem: "the queue family of the command buffer does not support \ + compute operations" + .into(), + vuids: &["VUID-vkCmdBindPipeline-pipelineBindPoint-02391"], + ..Default::default() + })); + } + + // VUID-vkCmdBindPipeline-commonparent + assert_eq!(self.device(), pipeline.device()); + + // TODO: VUID-vkCmdBindPipeline-pipelineBindPoint-06721 + + Ok(()) + } + + #[cfg_attr(not(feature = "document_unchecked"), doc(hidden))] + pub unsafe fn bind_pipeline_ray_tracing_unchecked( + &mut self, + pipeline: &RayTracingPipeline, + ) -> &mut Self { + let fns = self.device().fns(); + (fns.v1_0.cmd_bind_pipeline)( + self.handle(), + ash::vk::PipelineBindPoint::RAY_TRACING_KHR, + pipeline.handle(), + ); + + self + } + #[inline] pub unsafe fn bind_vertex_buffers( &mut self, @@ -1395,6 +1488,25 @@ impl RecordingCommandBuffer { })); } } + PipelineBindPoint::RayTracing => { + if !queue_family_properties + .queue_flags + .intersects(QueueFlags::COMPUTE) + { + return Err(Box::new(ValidationError { + context: "self".into(), + problem: + "`pipeline_bind_point` is `PipelineBindPoint::RayTracing`, and the \ + queue family does not support compute operations" + .into(), + vuids: &[ + "VUID-vkCmdPushDescriptorSetKHR-pipelineBindPoint-02391", + "VUID-vkCmdPushDescriptorSetKHR-commandBuffer-cmdpool", + ], + ..Default::default() + })); + } + } } // VUID-vkCmdPushDescriptorSetKHR-commonparent diff --git a/vulkano/src/command_buffer/commands/pipeline.rs b/vulkano/src/command_buffer/commands/pipeline.rs index e85bd48635..d88364d988 100644 --- a/vulkano/src/command_buffer/commands/pipeline.rs +++ b/vulkano/src/command_buffer/commands/pipeline.rs @@ -22,6 +22,7 @@ use crate::{ subpass::PipelineSubpassType, vertex_input::{RequiredVertexInputsVUIDs, VertexInputRate}, }, + ray_tracing::ShaderBindingTableAddresses, DynamicState, GraphicsPipeline, Pipeline, PipelineLayout, }, query::QueryType, @@ -1592,6 +1593,39 @@ impl AutoCommandBufferBuilder { self } + pub unsafe fn trace_rays( + &mut self, + shader_binding_table_addresses: ShaderBindingTableAddresses, + width: u32, + height: u32, + depth: u32, + ) -> Result<&mut Self, Box> { + self.inner + .validate_trace_rays(&shader_binding_table_addresses, width, height, depth)?; + + Ok(self.trace_rays_unchecked(shader_binding_table_addresses, width, height, depth)) + } + + #[cfg_attr(not(feature = "document_unchecked"), doc(hidden))] + pub unsafe fn trace_rays_unchecked( + &mut self, + shader_binding_table_addresses: ShaderBindingTableAddresses, + width: u32, + height: u32, + depth: u32, + ) -> &mut Self { + let pipeline = self.builder_state.pipeline_ray_tracing.as_deref().unwrap(); + + let mut used_resources = Vec::new(); + self.add_descriptor_sets_resources(&mut used_resources, pipeline); + + self.add_command("trace_rays", used_resources, move |out| { + out.trace_rays_unchecked(&shader_binding_table_addresses, width, height, depth); + }); + + self + } + fn validate_pipeline_descriptor_sets( &self, vuid_type: VUIDType, @@ -4947,6 +4981,134 @@ impl RecordingCommandBuffer { self } + + pub unsafe fn trace_rays( + &mut self, + shader_binding_table_addresses: &ShaderBindingTableAddresses, + width: u32, + height: u32, + depth: u32, + ) -> Result<&mut Self, Box> { + self.validate_trace_rays(shader_binding_table_addresses, width, height, depth)?; + + Ok(self.trace_rays_unchecked(shader_binding_table_addresses, width, height, depth)) + } + + fn validate_trace_rays( + &self, + _shader_binding_table_addresses: &ShaderBindingTableAddresses, + width: u32, + height: u32, + depth: u32, + ) -> Result<(), Box> { + if !self.device().enabled_features().ray_tracing_pipeline { + return Err(Box::new(ValidationError { + requires_one_of: RequiresOneOf(&[RequiresAllOf(&[Requires::DeviceFeature( + "ray_tracing_pipeline", + )])]), + ..Default::default() + })); + } + + if !self + .queue_family_properties() + .queue_flags + .intersects(QueueFlags::COMPUTE) + { + return Err(Box::new(ValidationError { + problem: "the queue family of the command buffer does not support \ + compute operations" + .into(), + vuids: &["VUID-vkCmdTraceRaysKHR-commandBuffer-cmdpool"], + ..Default::default() + })); + } + + let device_properties = self.device().physical_device().properties(); + + let width = width as u64; + let height = height as u64; + let depth = depth as u64; + + let max_width = device_properties.max_compute_work_group_count[0] as u64 + * device_properties.max_compute_work_group_size[0] as u64; + + if width > max_width { + return Err(Box::new(ValidationError { + context: "width".into(), + problem: "exceeds maxComputeWorkGroupCount[0] * maxComputeWorkGroupSize[0]".into(), + vuids: &["VUID-vkCmdTraceRaysKHR-width-03638"], + ..Default::default() + })); + } + + let max_height = device_properties.max_compute_work_group_count[1] as u64 + * device_properties.max_compute_work_group_size[1] as u64; + + if height > max_height { + return Err(Box::new(ValidationError { + context: "height".into(), + problem: "exceeds maxComputeWorkGroupCount[1] * maxComputeWorkGroupSize[1]".into(), + vuids: &["VUID-vkCmdTraceRaysKHR-height-03639"], + ..Default::default() + })); + } + + let max_depth = device_properties.max_compute_work_group_count[2] as u64 + * device_properties.max_compute_work_group_size[2] as u64; + + if depth > max_depth { + return Err(Box::new(ValidationError { + context: "depth".into(), + problem: "exceeds maxComputeWorkGroupCount[2] * maxComputeWorkGroupSize[2]".into(), + vuids: &["VUID-vkCmdTraceRaysKHR-depth-03640"], + ..Default::default() + })); + } + + let total_invocations = width * height * depth; + let max_invocations = device_properties.max_ray_dispatch_invocation_count.unwrap() as u64; + + if total_invocations > max_invocations { + return Err(Box::new(ValidationError { + context: "width * height * depth".into(), + problem: "exceeds maxRayDispatchInvocationCount".into(), + vuids: &["VUID-vkCmdTraceRaysKHR-width-03641"], + ..Default::default() + })); + } + + Ok(()) + } + + #[cfg_attr(not(feature = "document_unchecked"), doc(hidden))] + pub unsafe fn trace_rays_unchecked( + &mut self, + shader_binding_table_addresses: &ShaderBindingTableAddresses, + width: u32, + height: u32, + depth: u32, + ) -> &mut Self { + let fns = self.device().fns(); + + let raygen = shader_binding_table_addresses.raygen.to_vk(); + let miss = shader_binding_table_addresses.miss.to_vk(); + let hit = shader_binding_table_addresses.hit.to_vk(); + let callable = shader_binding_table_addresses.callable.to_vk(); + + (fns.khr_ray_tracing_pipeline.cmd_trace_rays_khr)( + self.handle(), + &raygen, + &miss, + &hit, + &callable, + width, + height, + depth, + ); + + self + } } #[derive(Clone, Copy)] diff --git a/vulkano/src/command_buffer/mod.rs b/vulkano/src/command_buffer/mod.rs index 2043c6073e..e314d68b73 100644 --- a/vulkano/src/command_buffer/mod.rs +++ b/vulkano/src/command_buffer/mod.rs @@ -1617,6 +1617,7 @@ pub enum ResourceInCommand { SecondaryCommandBuffer { index: u32 }, Source, VertexBuffer { binding: u32 }, + ShaderBindingTableBuffer, } #[doc(hidden)] diff --git a/vulkano/src/device/mod.rs b/vulkano/src/device/mod.rs index 62f8f40006..6d22b516fa 100644 --- a/vulkano/src/device/mod.rs +++ b/vulkano/src/device/mod.rs @@ -114,6 +114,7 @@ use crate::{ instance::{Instance, InstanceOwned, InstanceOwnedDebugWrapper}, macros::{impl_id_counter, vulkan_bitflags}, memory::{ExternalMemoryHandleType, MemoryFdProperties, MemoryRequirements}, + pipeline::ray_tracing::RayTracingPipeline, Requires, RequiresAllOf, RequiresOneOf, Validated, ValidationError, Version, VulkanError, VulkanObject, }; @@ -1304,6 +1305,94 @@ impl Device { Ok(()) } + + pub fn ray_tracing_shader_group_handles( + &self, + ray_tracing_pipeline: &RayTracingPipeline, + first_group: u32, + group_count: u32, + ) -> Result> { + self.validate_ray_tracing_pipeline_properties( + ray_tracing_pipeline, + first_group, + group_count, + )?; + + unsafe { + Ok(self.ray_tracing_shader_group_handles_unchecked( + ray_tracing_pipeline, + first_group, + group_count, + )?) + } + } + + fn validate_ray_tracing_pipeline_properties( + &self, + ray_tracing_pipeline: &RayTracingPipeline, + first_group: u32, + group_count: u32, + ) -> Result<(), Box> { + if !self.enabled_features().ray_tracing_pipeline + || self + .physical_device() + .properties() + .shader_group_handle_size + .is_none() + { + Err(Box::new(ValidationError { + problem: "device property `shader_group_handle_size` is empty".into(), + requires_one_of: RequiresOneOf(&[RequiresAllOf(&[Requires::DeviceFeature( + "ray_tracing_pipeline", + )])]), + ..Default::default() + }))?; + }; + + if (first_group + group_count) as usize > ray_tracing_pipeline.groups().len() { + Err(Box::new(ValidationError { + problem: "the sum of `first_group` and `group_count` must be less than or equal\ + to the number of shader groups in pipeline" + .into(), + vuids: &["VUID-vkGetRayTracingShaderGroupHandlesKHR-firstGroup-02419"], + ..Default::default() + }))? + } + // TODO: VUID-vkGetRayTracingShaderGroupHandlesKHR-pipeline-07828 + + Ok(()) + } + + #[cfg_attr(not(feature = "document_unchecked"), doc(hidden))] + pub unsafe fn ray_tracing_shader_group_handles_unchecked( + &self, + ray_tracing_pipeline: &RayTracingPipeline, + first_group: u32, + group_count: u32, + ) -> Result { + let handle_size = self + .physical_device() + .properties() + .shader_group_handle_size + .unwrap(); + + let mut data = vec![0u8; (handle_size * group_count) as usize]; + let fns = self.fns(); + unsafe { + (fns.khr_ray_tracing_pipeline + .get_ray_tracing_shader_group_handles_khr)( + self.handle, + ray_tracing_pipeline.handle(), + first_group, + group_count, + data.len(), + data.as_mut_ptr().cast(), + ) + .result() + .map_err(VulkanError::from)?; + } + Ok(ShaderGroupHandlesData { data, handle_size }) + } } impl Debug for Device { @@ -2134,6 +2223,30 @@ impl Deref for DeviceOwnedDebugWrapper { } } +/// Holds the data returned by [`Device::ray_tracing_shader_group_handles`]. +#[derive(Clone, Debug)] +pub struct ShaderGroupHandlesData { + data: Vec, + handle_size: u32, +} + +impl ShaderGroupHandlesData { + pub fn data(&self) -> &[u8] { + &self.data + } + + pub fn handle_size(&self) -> u32 { + self.handle_size + } +} + +impl ShaderGroupHandlesData { + /// Returns an iterator over the handles in the data. + pub fn iter(&self) -> impl ExactSizeIterator { + self.data().chunks_exact(self.handle_size as usize) + } +} + #[cfg(test)] mod tests { use crate::device::{ diff --git a/vulkano/src/lib.rs b/vulkano/src/lib.rs index 288e8c1009..0bc52f8114 100644 --- a/vulkano/src/lib.rs +++ b/vulkano/src/lib.rs @@ -179,6 +179,25 @@ pub use ash::vk::DeviceAddress; /// A [`DeviceAddress`] that is known not to equal zero. pub type NonNullDeviceAddress = NonZeroU64; +/// Represents a region of device addresses with a stride. +#[derive(Debug, Copy, Clone, Default)] +pub struct StridedDeviceAddressRegion { + pub device_address: DeviceAddress, + pub stride: DeviceSize, + pub size: DeviceSize, +} + +impl StridedDeviceAddressRegion { + #[doc(hidden)] + pub fn to_vk(&self) -> ash::vk::StridedDeviceAddressRegionKHR { + ash::vk::StridedDeviceAddressRegionKHR { + device_address: self.device_address, + stride: self.stride, + size: self.size, + } + } +} + /// Holds 24 bits in the least significant bits of memory, /// and 8 bytes in the most significant bits of that memory, /// occupying a single [`u32`] in total. diff --git a/vulkano/src/pipeline/compute.rs b/vulkano/src/pipeline/compute.rs index c7cdc2e008..e067e8516a 100644 --- a/vulkano/src/pipeline/compute.rs +++ b/vulkano/src/pipeline/compute.rs @@ -57,7 +57,7 @@ impl ComputePipeline { cache: Option>, create_info: ComputePipelineCreateInfo, ) -> Result, Validated> { - Self::validate_new(&device, cache.as_ref().map(AsRef::as_ref), &create_info)?; + Self::validate_new(&device, cache.as_deref(), &create_info)?; Ok(unsafe { Self::new_unchecked(device, cache, create_info) }?) } diff --git a/vulkano/src/pipeline/graphics/mod.rs b/vulkano/src/pipeline/graphics/mod.rs index 526916cf92..b06d341222 100644 --- a/vulkano/src/pipeline/graphics/mod.rs +++ b/vulkano/src/pipeline/graphics/mod.rs @@ -178,7 +178,7 @@ impl GraphicsPipeline { cache: Option>, create_info: GraphicsPipelineCreateInfo, ) -> Result, Validated> { - Self::validate_new(&device, cache.as_ref().map(AsRef::as_ref), &create_info)?; + Self::validate_new(&device, cache.as_deref(), &create_info)?; Ok(unsafe { Self::new_unchecked(device, cache, create_info) }?) } diff --git a/vulkano/src/pipeline/mod.rs b/vulkano/src/pipeline/mod.rs index 7fe33ce398..76b85ee0a7 100644 --- a/vulkano/src/pipeline/mod.rs +++ b/vulkano/src/pipeline/mod.rs @@ -23,6 +23,7 @@ pub mod cache; pub mod compute; pub mod graphics; pub mod layout; +pub mod ray_tracing; pub(crate) mod shader; /// A trait for operations shared between pipeline types. @@ -60,13 +61,13 @@ vulkan_enum! { // TODO: document Graphics = GRAPHICS, - /* TODO: enable + // TODO: document RayTracing = RAY_TRACING_KHR RequiresOneOf([ RequiresAllOf([DeviceExtension(khr_ray_tracing_pipeline)]), RequiresAllOf([DeviceExtension(nv_ray_tracing)]), - ]),*/ + ]), /* TODO: enable // TODO: document diff --git a/vulkano/src/pipeline/ray_tracing.rs b/vulkano/src/pipeline/ray_tracing.rs new file mode 100644 index 0000000000..abee90419c --- /dev/null +++ b/vulkano/src/pipeline/ray_tracing.rs @@ -0,0 +1,957 @@ +//! Ray tracing pipeline functionality for GPU-accelerated ray tracing. +//! +//! # Overview +//! Ray tracing pipelines enable high-performance ray tracing by defining a set of shader stages +//! that handle ray generation, intersection testing, and shading calculations. The pipeline +//! consists of different shader stages organized into shader groups. +//! +//! # Shader Types +//! +//! ## Ray Generation Shader +//! - Entry point for ray tracing +//! - Generates and traces primary rays +//! - Controls the overall ray tracing process +//! +//! ## Intersection Shaders +//! - **Built-in Triangle Intersection**: Handles standard triangle geometry intersection +//! - **Custom Intersection**: Implements custom geometry intersection testing +//! +//! ## Hit Shaders +//! - **Closest Hit**: Executes when a ray finds its closest intersection +//! - **Any Hit**: Optional shader that runs on every potential intersection +//! +//! ## Miss Shader +//! - Executes when a ray doesn't intersect any geometry +//! - Typically handles environment mapping or background colors +//! +//! ## Callable Shader +//! - Utility shader that can be called from other shader stages +//! - Enables code reuse across different shader stages +//! +//! # Pipeline Organization +//! Shaders are organized into groups: +//! - General groups: Contains ray generation, miss, or callable shaders +//! - Triangle hit groups: Contains closest-hit and optional any-hit shaders +//! - Procedural hit groups: Contains intersection, closest-hit, and optional any-hit shaders +//! +//! The ray tracing pipeline uses a Shader Binding Table (SBT) to organize and access +//! these shader groups during execution. + +use super::{ + cache::PipelineCache, DynamicState, Pipeline, PipelineBindPoint, PipelineCreateFlags, + PipelineLayout, PipelineShaderStageCreateInfo, PipelineShaderStageCreateInfoExtensionsVk, + PipelineShaderStageCreateInfoFields1Vk, PipelineShaderStageCreateInfoFields2Vk, +}; +use crate::{ + buffer::{Buffer, BufferCreateInfo, BufferUsage, Subbuffer}, + device::{Device, DeviceOwned, DeviceOwnedDebugWrapper}, + instance::InstanceOwnedDebugWrapper, + macros::impl_id_counter, + memory::{ + allocator::{align_up, AllocationCreateInfo, MemoryAllocator, MemoryTypeFilter}, + DeviceAlignment, + }, + shader::{spirv::ExecutionModel, DescriptorBindingRequirements}, + StridedDeviceAddressRegion, Validated, ValidationError, VulkanError, VulkanObject, +}; +use foldhash::{HashMap, HashSet}; +use smallvec::SmallVec; +use std::{collections::hash_map::Entry, mem::MaybeUninit, num::NonZeroU64, ptr, sync::Arc}; + +/// Defines how the implementation should perform ray tracing operations. +/// +/// This object uses the `VK_KHR_ray_tracing_pipeline` extension. +#[derive(Debug)] +pub struct RayTracingPipeline { + handle: ash::vk::Pipeline, + device: InstanceOwnedDebugWrapper>, + id: NonZeroU64, + + flags: PipelineCreateFlags, + layout: DeviceOwnedDebugWrapper>, + + descriptor_binding_requirements: HashMap<(u32, u32), DescriptorBindingRequirements>, + num_used_descriptor_sets: u32, + + groups: SmallVec<[RayTracingShaderGroupCreateInfo; 5]>, + stages: SmallVec<[PipelineShaderStageCreateInfo; 5]>, +} + +impl RayTracingPipeline { + /// Creates a new `RayTracingPipeline`. + #[inline] + pub fn new( + device: Arc, + cache: Option>, + create_info: RayTracingPipelineCreateInfo, + ) -> Result, Validated> { + Self::validate_new(&device, cache.as_deref(), &create_info)?; + + unsafe { Ok(Self::new_unchecked(device, cache, create_info)?) } + } + + fn validate_new( + device: &Arc, + cache: Option<&PipelineCache>, + create_info: &RayTracingPipelineCreateInfo, + ) -> Result<(), Validated> { + if let Some(cache) = &cache { + assert_eq!(device, cache.device()); + } + create_info + .validate(device) + .map_err(|err| err.add_context("create_info"))?; + + Ok(()) + } + + #[cfg_attr(not(feature = "document_unchecked"), doc(hidden))] + pub unsafe fn new_unchecked( + device: Arc, + cache: Option>, + create_info: RayTracingPipelineCreateInfo, + ) -> Result, VulkanError> { + let handle = { + let fields3_vk = create_info.to_vk_fields3(); + let fields2_vk = create_info.to_vk_fields2(&fields3_vk); + let mut fields1_extensions_vk = create_info.to_vk_fields1_extensions(); + let fields1_vk = create_info.to_vk_fields1(&fields2_vk, &mut fields1_extensions_vk); + let create_infos_vk = create_info.to_vk(&fields1_vk); + + let fns = device.fns(); + let mut output = MaybeUninit::uninit(); + + (fns.khr_ray_tracing_pipeline + .create_ray_tracing_pipelines_khr)( + device.handle(), + ash::vk::DeferredOperationKHR::null(), // TODO: RayTracing: deferred_operation + cache.map_or(ash::vk::PipelineCache::null(), |c| c.handle()), + 1, + &create_infos_vk, + ptr::null(), + output.as_mut_ptr(), + ) + .result() + .map_err(VulkanError::from)?; + output.assume_init() + }; + + Ok(Self::from_handle(device, handle, create_info)) + } + + /// Creates a new `RayTracingPipeline` from a raw object handle. + /// + /// # Safety + /// + /// - `handle` must be a valid Vulkan object handle created from `device`. + /// - `create_info` must match the info used to create the object. + pub unsafe fn from_handle( + device: Arc, + handle: ash::vk::Pipeline, + create_info: RayTracingPipelineCreateInfo, + ) -> Arc { + let RayTracingPipelineCreateInfo { + flags, + stages, + groups, + layout, + .. + } = create_info; + + let mut descriptor_binding_requirements: HashMap< + (u32, u32), + DescriptorBindingRequirements, + > = HashMap::default(); + for stage in &stages { + for (&loc, reqs) in stage + .entry_point + .info() + .descriptor_binding_requirements + .iter() + { + match descriptor_binding_requirements.entry(loc) { + Entry::Occupied(entry) => { + entry.into_mut().merge(reqs).expect("Could not produce an intersection of the shader descriptor requirements"); + } + Entry::Vacant(entry) => { + entry.insert(reqs.clone()); + } + } + } + } + let num_used_descriptor_sets = descriptor_binding_requirements + .keys() + .map(|loc| loc.0) + .max() + .map(|x| x + 1) + .unwrap_or(0); + Arc::new(Self { + handle, + device: InstanceOwnedDebugWrapper(device), + id: Self::next_id(), + + flags, + layout: DeviceOwnedDebugWrapper(layout), + + descriptor_binding_requirements, + num_used_descriptor_sets, + + groups, + stages, + }) + } + + // Returns the shader groups that the pipeline was created with. + pub fn groups(&self) -> &[RayTracingShaderGroupCreateInfo] { + &self.groups + } + + // Returns the shader stages that the pipeline was created with. + pub fn stages(&self) -> &[PipelineShaderStageCreateInfo] { + &self.stages + } + + /// Returns the `Device` that the pipeline was created with. + pub fn device(&self) -> &Arc { + &self.device + } + + /// Returns the flags that the pipeline was created with. + pub fn flags(&self) -> PipelineCreateFlags { + self.flags + } +} + +impl Pipeline for RayTracingPipeline { + #[inline] + fn bind_point(&self) -> PipelineBindPoint { + PipelineBindPoint::RayTracing + } + + #[inline] + fn layout(&self) -> &Arc { + &self.layout + } + + #[inline] + fn num_used_descriptor_sets(&self) -> u32 { + self.num_used_descriptor_sets + } + + #[inline] + fn descriptor_binding_requirements( + &self, + ) -> &HashMap<(u32, u32), DescriptorBindingRequirements> { + &self.descriptor_binding_requirements + } +} + +impl_id_counter!(RayTracingPipeline); + +unsafe impl VulkanObject for RayTracingPipeline { + type Handle = ash::vk::Pipeline; + + #[inline] + fn handle(&self) -> Self::Handle { + self.handle + } +} + +unsafe impl DeviceOwned for RayTracingPipeline { + #[inline] + fn device(&self) -> &Arc { + self.device() + } +} + +impl Drop for RayTracingPipeline { + #[inline] + fn drop(&mut self) { + unsafe { + let fns = self.device.fns(); + (fns.v1_0.destroy_pipeline)(self.device.handle(), self.handle, ptr::null()); + } + } +} + +/// Parameters to create a new `RayTracingPipeline`. +#[derive(Clone, Debug)] +pub struct RayTracingPipelineCreateInfo { + /// Additional properties of the pipeline. + /// + /// The default value is empty. + pub flags: PipelineCreateFlags, + + /// The ray tracing shader stages to use. + /// + /// The default value is empty, which must be overridden. + pub stages: SmallVec<[PipelineShaderStageCreateInfo; 5]>, + + /// The shader groups to use. They reference the shader stages in `stages`. + /// + /// The default value is empty, which must be overridden. + pub groups: SmallVec<[RayTracingShaderGroupCreateInfo; 5]>, + + /// The maximum recursion depth of the pipeline. + /// + /// The default value is 1. + pub max_pipeline_ray_recursion_depth: u32, + + /// The dynamic state to use. + /// + /// May only contain `DynamicState::RayTracingPipelineStackSize`. + /// + /// The default value is empty. + pub dynamic_state: HashSet, + + /// The pipeline layout to use. + /// + /// There is no default value. + pub layout: Arc, + + /// The pipeline to use as a base when creating this pipeline. + /// + /// If this is `Some`, then `flags` must contain [`PipelineCreateFlags::DERIVATIVE`], + /// and the `flags` of the provided pipeline must contain + /// [`PipelineCreateFlags::ALLOW_DERIVATIVES`]. + /// + /// The default value is `None`. + pub base_pipeline: Option>, + + pub _ne: crate::NonExhaustive, +} + +impl RayTracingPipelineCreateInfo { + pub fn layout(layout: Arc) -> Self { + Self { + flags: PipelineCreateFlags::empty(), + stages: SmallVec::new(), + groups: SmallVec::new(), + max_pipeline_ray_recursion_depth: 1, + dynamic_state: Default::default(), + + layout, + + base_pipeline: None, + _ne: crate::NonExhaustive(()), + } + } + + fn validate(&self, device: &Arc) -> Result<(), Box> { + let &Self { + flags, + ref stages, + ref groups, + ref layout, + ref base_pipeline, + ref dynamic_state, + max_pipeline_ray_recursion_depth, + _ne: _, + } = self; + + flags.validate_device(device).map_err(|err| { + err.add_context("flags") + .set_vuids(&["VUID-VkRayTracingPipelineCreateInfoKHR-flags-parameter"]) + })?; + + if flags.intersects(PipelineCreateFlags::DERIVATIVE) { + let base_pipeline = base_pipeline.as_ref().ok_or_else(|| { + Box::new(ValidationError { + problem: "`flags` contains `PipelineCreateFlags::DERIVATIVE`, but \ + `base_pipeline` is `None`" + .into(), + vuids: &["VUID-VkRayTracingPipelineCreateInfoKHR-flags-07984 +"], + ..Default::default() + }) + })?; + + if !base_pipeline + .flags() + .intersects(PipelineCreateFlags::ALLOW_DERIVATIVES) + { + return Err(Box::new(ValidationError { + context: "base_pipeline.flags()".into(), + problem: "does not contain `PipelineCreateFlags::ALLOW_DERIVATIVES`".into(), + vuids: &["VUID-vkCreateRayTracingPipelinesKHR-flags-03416"], + ..Default::default() + })); + } + } else if base_pipeline.is_some() { + return Err(Box::new(ValidationError { + problem: "`flags` does not contain `PipelineCreateFlags::DERIVATIVE`, but \ + `base_pipeline` is `Some`" + .into(), + ..Default::default() + })); + } + + if stages.is_empty() { + return Err(Box::new(ValidationError { + problem: "`stages` is empty".into(), + vuids: &["VUID-VkRayTracingPipelineCreateInfoKHR-pLibraryInfo-07999"], + ..Default::default() + })); + } + for stage in stages { + stage.validate(device).map_err(|err| { + err.add_context("stages") + .set_vuids(&["VUID-VkRayTracingPipelineCreateInfoKHR-pStages-parameter"]) + })?; + + let entry_point_info = stage.entry_point.info(); + + layout + .ensure_compatible_with_shader( + entry_point_info + .descriptor_binding_requirements + .iter() + .map(|(k, v)| (*k, v)), + entry_point_info.push_constant_requirements.as_ref(), + ) + .map_err(|err| { + Box::new(ValidationError { + context: "stage.entry_point".into(), + vuids: &[ + "VUID-VkRayTracingPipelineCreateInfoKHR-layout-07987", + "VUID-VkRayTracingPipelineCreateInfoKHR-layout-07988", + "VUID-VkRayTracingPipelineCreateInfoKHR-layout-07990", + "VUID-VkRayTracingPipelineCreateInfoKHR-layout-07991", + ], + ..ValidationError::from_error(err) + }) + })?; + } + + if groups.is_empty() { + return Err(Box::new(ValidationError { + problem: "`groups` is empty".into(), + vuids: &["VUID-VkRayTracingPipelineCreateInfoKHR-flags-08700"], + ..Default::default() + })); + } + for group in groups { + group.validate(stages).map_err(|err| { + err.add_context("groups") + .set_vuids(&["VUID-VkRayTracingPipelineCreateInfoKHR-pGroups-parameter"]) + })?; + } + + // TODO: Enable + // if dynamic_state + // .iter() + // .any(|&state| state != DynamicState::RayTracingPipelineStackSize) + // { + // return Err(Box::new(ValidationError { + // problem: + // format!("`dynamic_state` contains a dynamic state other than + // RayTracingPipelineStackSize: {:?}", dynamic_state).into(), vuids: + // &["VUID-VkRayTracingPipelineCreateInfoKHR-pDynamicStates-03602"], + // ..Default::default() + // })); + // } + if !dynamic_state.is_empty() { + todo!("Dynamic state for ray tracing pipelines is not yet supported"); + } + + let max_ray_recursion_depth = device + .physical_device() + .properties() + .max_ray_recursion_depth + .unwrap(); + if max_pipeline_ray_recursion_depth > max_ray_recursion_depth { + return Err(Box::new(ValidationError { + problem: format!( + "`max_pipeline_ray_recursion_depth` is greater than the device's max value of {}", + max_ray_recursion_depth + ).into(), + vuids: &["VUID-VkRayTracingPipelineCreateInfoKHR-maxPipelineRayRecursionDepth-03589"], + ..Default::default() + })); + } + + Ok(()) + } + + pub(crate) fn to_vk<'a>( + &self, + fields1_vk: &'a RayTracingPipelineCreateInfoFields1Vk<'_>, + ) -> ash::vk::RayTracingPipelineCreateInfoKHR<'a> { + let &Self { + flags, + max_pipeline_ray_recursion_depth, + + ref layout, + ref base_pipeline, + .. + } = self; + + let RayTracingPipelineCreateInfoFields1Vk { + stages_vk, + groups_vk, + dynamic_state_vk, + } = fields1_vk; + + let mut val_vk = ash::vk::RayTracingPipelineCreateInfoKHR::default() + .flags(flags.into()) + .stages(stages_vk) + .groups(groups_vk) + .layout(layout.handle()) + .max_pipeline_ray_recursion_depth(max_pipeline_ray_recursion_depth) + .base_pipeline_handle( + base_pipeline + .as_ref() + .map_or(ash::vk::Pipeline::null(), |p| p.handle()), + ) + .base_pipeline_index(-1); + + if let Some(dynamic_state_vk) = dynamic_state_vk { + val_vk = val_vk.dynamic_state(dynamic_state_vk); + } + + val_vk + } + + pub(crate) fn to_vk_fields1<'a>( + &self, + fields2_vk: &'a RayTracingPipelineCreateInfoFields2Vk<'_>, + extensions_vk: &'a mut RayTracingPipelineCreateInfoFields1ExtensionsVk, + ) -> RayTracingPipelineCreateInfoFields1Vk<'a> { + let Self { stages, groups, .. } = self; + let RayTracingPipelineCreateInfoFields2Vk { + stages_fields1_vk, + dynamic_states_vk, + } = fields2_vk; + let RayTracingPipelineCreateInfoFields1ExtensionsVk { + stages_extensions_vk, + } = extensions_vk; + + let stages_vk: SmallVec<[_; 5]> = stages + .iter() + .zip(stages_fields1_vk) + .zip(stages_extensions_vk) + .map(|((stage, fields1), fields1_extensions_vk)| { + stage.to_vk(fields1, fields1_extensions_vk) + }) + .collect(); + + let groups_vk = groups + .iter() + .map(RayTracingShaderGroupCreateInfo::to_vk) + .collect(); + + let dynamic_state_vk = (!dynamic_states_vk.is_empty()).then(|| { + ash::vk::PipelineDynamicStateCreateInfo::default() + .flags(ash::vk::PipelineDynamicStateCreateFlags::empty()) + .dynamic_states(dynamic_states_vk) + }); + + RayTracingPipelineCreateInfoFields1Vk { + stages_vk, + groups_vk, + dynamic_state_vk, + } + } + + pub(crate) fn to_vk_fields1_extensions( + &self, + ) -> RayTracingPipelineCreateInfoFields1ExtensionsVk { + let Self { stages, .. } = self; + + let stages_extensions_vk = stages + .iter() + .map(|stage| stage.to_vk_extensions()) + .collect(); + + RayTracingPipelineCreateInfoFields1ExtensionsVk { + stages_extensions_vk, + } + } + + pub(crate) fn to_vk_fields2<'a>( + &self, + fields3_vk: &'a RayTracingPipelineCreateInfoFields3Vk, + ) -> RayTracingPipelineCreateInfoFields2Vk<'a> { + let Self { + stages, + dynamic_state, + .. + } = self; + + let stages_fields1_vk = stages + .iter() + .zip(fields3_vk.stages_fields2_vk.iter()) + .map(|(stage, fields3)| stage.to_vk_fields1(fields3)) + .collect(); + + let dynamic_states_vk = dynamic_state.iter().copied().map(Into::into).collect(); + + RayTracingPipelineCreateInfoFields2Vk { + stages_fields1_vk, + dynamic_states_vk, + } + } + + pub(crate) fn to_vk_fields3(&self) -> RayTracingPipelineCreateInfoFields3Vk { + let Self { stages, .. } = self; + + let stages_fields2_vk = stages.iter().map(|stage| stage.to_vk_fields2()).collect(); + + RayTracingPipelineCreateInfoFields3Vk { stages_fields2_vk } + } +} + +/// Enum representing different types of Ray Tracing Shader Groups. +/// +/// Contains the index of the shader to use for each type of shader group. +/// The index corresponds to the position of the shader in the `stages` field of the +/// `RayTracingPipelineCreateInfo`. +#[derive(Debug, Clone)] +pub enum RayTracingShaderGroupCreateInfo { + /// General shader group type, typically used for ray generation and miss shaders. + /// + /// Contains a single shader stage that can be: + /// - Ray generation shader + /// - Miss shader + /// - Callable shader + General { + /// Index of the general shader stage + general_shader: u32, + }, + + /// Procedural hit shader group type, used for custom intersection testing. + /// + /// Used when implementing custom intersection shapes or volumes. + /// Requires an intersection shader and can optionally include closest hit + /// and any hit shaders. + ProceduralHit { + /// Optional index of the closest hit shader stage + closest_hit_shader: Option, + /// Optional index of the any hit shader stage + any_hit_shader: Option, + /// Index of the intersection shader stage + intersection_shader: u32, + }, + + /// Triangle hit shader group type, used for built-in triangle intersection. + /// + /// Used for standard triangle geometry intersection testing. + /// Can optionally include closest hit and any hit shaders. + TrianglesHit { + /// Optional index of the closest hit shader stage + closest_hit_shader: Option, + /// Optional index of the any hit shader stage + any_hit_shader: Option, + }, +} + +impl RayTracingShaderGroupCreateInfo { + fn validate( + &self, + stages: &[PipelineShaderStageCreateInfo], + ) -> Result<(), Box> { + let get_shader_type = + |shader: u32| stages[shader as usize].entry_point.info().execution_model; + + match self { + RayTracingShaderGroupCreateInfo::General { general_shader } => { + match get_shader_type(*general_shader) { + ExecutionModel::RayGenerationKHR + | ExecutionModel::MissKHR + | ExecutionModel::CallableKHR => Ok(()), + _ => Err(Box::new(ValidationError { + problem: "general shader in GENERAL group must be a RayGeneration, Miss, or Callable shader".into(), + vuids: &["VUID-VkRayTracingShaderGroupCreateInfoKHR-type-03474"], + ..Default::default() + })), + }?; + } + RayTracingShaderGroupCreateInfo::ProceduralHit { + intersection_shader, + any_hit_shader, + closest_hit_shader, + } => { + if get_shader_type(*intersection_shader) != ExecutionModel::IntersectionKHR { + return Err(Box::new(ValidationError { + problem: "intersection shader in PROCEDURAL_HIT_GROUP must be an Intersection shader".into(), + vuids: &["VUID-VkRayTracingShaderGroupCreateInfoKHR-type-03476"], + ..Default::default() + })); + } + + if let Some(any_hit_shader) = any_hit_shader { + if get_shader_type(*any_hit_shader) != ExecutionModel::AnyHitKHR { + return Err(Box::new(ValidationError { + problem: "any hit shader must be an AnyHit shader".into(), + vuids: &[ + "VUID-VkRayTracingShaderGroupCreateInfoKHR-anyHitShader-03479", + ], + ..Default::default() + })); + } + } + + if let Some(closest_hit_shader) = closest_hit_shader { + if get_shader_type(*closest_hit_shader) != ExecutionModel::ClosestHitKHR { + return Err(Box::new(ValidationError { + problem: "closest hit shader must be a ClosestHit shader".into(), + vuids: &[ + "VUID-VkRayTracingShaderGroupCreateInfoKHR-closestHitShader-03478", + ], + ..Default::default() + })); + } + } + } + RayTracingShaderGroupCreateInfo::TrianglesHit { + any_hit_shader, + closest_hit_shader, + } => { + if let Some(any_hit_shader) = any_hit_shader { + if get_shader_type(*any_hit_shader) != ExecutionModel::AnyHitKHR { + return Err(Box::new(ValidationError { + problem: "any hit shader must be an AnyHit shader".into(), + vuids: &[ + "VUID-VkRayTracingShaderGroupCreateInfoKHR-anyHitShader-03479", + ], + ..Default::default() + })); + } + } + + if let Some(closest_hit_shader) = closest_hit_shader { + if get_shader_type(*closest_hit_shader) != ExecutionModel::ClosestHitKHR { + return Err(Box::new(ValidationError { + problem: "closest hit shader must be a ClosestHit shader".into(), + vuids: &[ + "VUID-VkRayTracingShaderGroupCreateInfoKHR-closestHitShader-03478", + ], + ..Default::default() + })); + } + } + } + } + + Ok(()) + } + + pub(crate) fn to_vk(&self) -> ash::vk::RayTracingShaderGroupCreateInfoKHR<'static> { + match self { + RayTracingShaderGroupCreateInfo::General { general_shader } => { + ash::vk::RayTracingShaderGroupCreateInfoKHR::default() + .ty(ash::vk::RayTracingShaderGroupTypeKHR::GENERAL) + .general_shader(*general_shader) + .closest_hit_shader(ash::vk::SHADER_UNUSED_KHR) + .any_hit_shader(ash::vk::SHADER_UNUSED_KHR) + .intersection_shader(ash::vk::SHADER_UNUSED_KHR) + } + RayTracingShaderGroupCreateInfo::ProceduralHit { + closest_hit_shader, + any_hit_shader, + intersection_shader, + } => ash::vk::RayTracingShaderGroupCreateInfoKHR::default() + .ty(ash::vk::RayTracingShaderGroupTypeKHR::PROCEDURAL_HIT_GROUP) + .general_shader(ash::vk::SHADER_UNUSED_KHR) + .closest_hit_shader(closest_hit_shader.unwrap_or(ash::vk::SHADER_UNUSED_KHR)) + .any_hit_shader(any_hit_shader.unwrap_or(ash::vk::SHADER_UNUSED_KHR)) + .intersection_shader(*intersection_shader), + RayTracingShaderGroupCreateInfo::TrianglesHit { + closest_hit_shader, + any_hit_shader, + } => ash::vk::RayTracingShaderGroupCreateInfoKHR::default() + .ty(ash::vk::RayTracingShaderGroupTypeKHR::TRIANGLES_HIT_GROUP) + .general_shader(ash::vk::SHADER_UNUSED_KHR) + .closest_hit_shader(closest_hit_shader.unwrap_or(ash::vk::SHADER_UNUSED_KHR)) + .any_hit_shader(any_hit_shader.unwrap_or(ash::vk::SHADER_UNUSED_KHR)) + .intersection_shader(ash::vk::SHADER_UNUSED_KHR), + } + } +} + +pub(crate) struct RayTracingPipelineCreateInfoFields1Vk<'a> { + pub(crate) stages_vk: SmallVec<[ash::vk::PipelineShaderStageCreateInfo<'a>; 5]>, + pub(crate) groups_vk: SmallVec<[ash::vk::RayTracingShaderGroupCreateInfoKHR<'static>; 5]>, + pub(crate) dynamic_state_vk: Option>, +} + +pub(crate) struct RayTracingPipelineCreateInfoFields1ExtensionsVk { + pub(crate) stages_extensions_vk: SmallVec<[PipelineShaderStageCreateInfoExtensionsVk; 5]>, +} + +pub(crate) struct RayTracingPipelineCreateInfoFields2Vk<'a> { + pub(crate) stages_fields1_vk: SmallVec<[PipelineShaderStageCreateInfoFields1Vk<'a>; 5]>, + pub(crate) dynamic_states_vk: SmallVec<[ash::vk::DynamicState; 4]>, +} + +pub(crate) struct RayTracingPipelineCreateInfoFields3Vk { + pub(crate) stages_fields2_vk: SmallVec<[PipelineShaderStageCreateInfoFields2Vk; 5]>, +} + +/// An object that holds the strided addresses of the shader groups in a shader binding table. +#[derive(Debug, Clone)] +pub struct ShaderBindingTableAddresses { + /// The address of the ray generation shader group handle. + pub raygen: StridedDeviceAddressRegion, + /// The address of the miss shader group handles. + pub miss: StridedDeviceAddressRegion, + /// The address of the hit shader group handles. + pub hit: StridedDeviceAddressRegion, + /// The address of the callable shader group handles. + pub callable: StridedDeviceAddressRegion, +} + +/// An object that holds the shader binding table buffer and its addresses. +#[derive(Debug, Clone)] +pub struct ShaderBindingTable { + addresses: ShaderBindingTableAddresses, + _buffer: Subbuffer<[u8]>, +} + +impl ShaderBindingTable { + /// Returns the addresses of the shader groups in the shader binding table. + pub fn addresses(&self) -> &ShaderBindingTableAddresses { + &self.addresses + } + + /// Automatically creates a shader binding table from a ray tracing pipeline. + pub fn new( + allocator: Arc, + ray_tracing_pipeline: &RayTracingPipeline, + ) -> Result> { + let mut miss_shader_count: u64 = 0; + let mut hit_shader_count: u64 = 0; + let mut callable_shader_count: u64 = 0; + + for group in ray_tracing_pipeline.groups() { + match group { + RayTracingShaderGroupCreateInfo::General { general_shader } => { + match ray_tracing_pipeline.stages()[*general_shader as usize] + .entry_point + .info() + .execution_model + { + ExecutionModel::RayGenerationKHR => {} + ExecutionModel::MissKHR => miss_shader_count += 1, + ExecutionModel::CallableKHR => callable_shader_count += 1, + _ => { + panic!("Unexpected shader type in general shader group"); + } + } + } + RayTracingShaderGroupCreateInfo::ProceduralHit { .. } + | RayTracingShaderGroupCreateInfo::TrianglesHit { .. } => { + hit_shader_count += 1; + } + } + } + + let handle_data = ray_tracing_pipeline + .device() + .ray_tracing_shader_group_handles( + ray_tracing_pipeline, + 0, + ray_tracing_pipeline.groups().len() as u32, + )?; + + let properties = ray_tracing_pipeline.device().physical_device().properties(); + let handle_size_aligned = align_up( + handle_data.handle_size() as u64, + DeviceAlignment::new(properties.shader_group_handle_alignment.unwrap() as u64).unwrap(), + ); + + let shader_group_base_alignment = + DeviceAlignment::new(properties.shader_group_base_alignment.unwrap() as u64).unwrap(); + + let raygen_stride = align_up(handle_size_aligned, shader_group_base_alignment); + + let mut raygen = StridedDeviceAddressRegion { + stride: raygen_stride, + size: raygen_stride, + device_address: 0, + }; + let mut miss = StridedDeviceAddressRegion { + stride: handle_size_aligned, + size: align_up( + handle_size_aligned * miss_shader_count, + shader_group_base_alignment, + ), + device_address: 0, + }; + let mut hit = StridedDeviceAddressRegion { + stride: handle_size_aligned, + size: align_up( + handle_size_aligned * hit_shader_count, + shader_group_base_alignment, + ), + device_address: 0, + }; + let mut callable = StridedDeviceAddressRegion { + stride: handle_size_aligned, + size: align_up( + handle_size_aligned * callable_shader_count, + shader_group_base_alignment, + ), + device_address: 0, + }; + + let sbt_buffer = Buffer::new_slice::( + allocator, + BufferCreateInfo { + usage: BufferUsage::TRANSFER_SRC + | BufferUsage::SHADER_DEVICE_ADDRESS + | BufferUsage::SHADER_BINDING_TABLE, + ..Default::default() + }, + AllocationCreateInfo { + memory_type_filter: MemoryTypeFilter::HOST_SEQUENTIAL_WRITE + | MemoryTypeFilter::PREFER_DEVICE, + ..Default::default() + }, + raygen.size + miss.size + hit.size + callable.size, + ) + .expect("todo: raytracing: better error type for buffer errors"); + + raygen.device_address = sbt_buffer.buffer().device_address().unwrap().get(); + miss.device_address = raygen.device_address + raygen.size; + hit.device_address = miss.device_address + miss.size; + callable.device_address = hit.device_address + hit.size; + + { + let mut sbt_buffer_write = sbt_buffer.write().unwrap(); + + let mut handle_iter = handle_data.iter(); + + let handle_size = handle_data.handle_size() as usize; + sbt_buffer_write[..handle_size].copy_from_slice(handle_iter.next().unwrap()); + let mut offset = raygen.size as usize; + for _ in 0..miss_shader_count { + sbt_buffer_write[offset..offset + handle_size] + .copy_from_slice(handle_iter.next().unwrap()); + offset += miss.stride as usize; + } + offset = (raygen.size + miss.size) as usize; + for _ in 0..hit_shader_count { + sbt_buffer_write[offset..offset + handle_size] + .copy_from_slice(handle_iter.next().unwrap()); + offset += hit.stride as usize; + } + offset = (raygen.size + miss.size + hit.size) as usize; + for _ in 0..callable_shader_count { + sbt_buffer_write[offset..offset + handle_size] + .copy_from_slice(handle_iter.next().unwrap()); + offset += callable.stride as usize; + } + } + + Ok(Self { + addresses: ShaderBindingTableAddresses { + raygen, + miss, + hit, + callable, + }, + _buffer: sbt_buffer, + }) + } +}