Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: Metal timestamps #263

Merged
merged 16 commits into from
Apr 1, 2023
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ name = "window"
[[example]]
name = "headless-render"

[[example]]
name = "counters"

[[example]]
name = "library"

Expand Down
93 changes: 91 additions & 2 deletions examples/circle/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ fn main() {
let device = Device::system_default().expect("no device found");
println!("Your device is: {}", device.name(),);

let counter_sample_buffer = create_counter_sample_buffer(&device);
let counter_sampling_point = MTLCounterSamplingPoint::AtStageBoundary;
assert!(device.supports_counter_sampling(counter_sampling_point));

let binary_archive_path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("examples/circle/binary_archive.metallib");

Expand Down Expand Up @@ -140,7 +144,14 @@ fn main() {

// Obtain a renderPassDescriptor generated from the view's drawable textures.
let render_pass_descriptor = RenderPassDescriptor::new();
prepare_render_pass_descriptor(&render_pass_descriptor, drawable.texture());
handle_render_pass_color_attachment(
&render_pass_descriptor,
drawable.texture(),
);
handle_render_pass_sample_buffer_attachment(
&render_pass_descriptor,
&counter_sample_buffer,
);

// Create a render command encoder.
let encoder =
Expand All @@ -152,11 +163,20 @@ fn main() {
encoder.draw_primitives(MTLPrimitiveType::TriangleStrip, 0, 1080);
encoder.end_encoding();

let timestamps_buffer = resolve_timestamps_into_buffer(
command_buffer,
&counter_sample_buffer,
&device,
);

// Schedule a present once the framebuffer is complete using the current drawable.
command_buffer.present_drawable(&drawable);

// Finalize rendering here & push the command buffer to the GPU.
command_buffer.commit();
command_buffer.wait_until_completed();

print_timestamps(&timestamps_buffer);
}
_ => (),
}
Expand Down Expand Up @@ -210,7 +230,20 @@ fn create_vertex_points_for_circle() -> Vec<AAPLVertex> {
v
}

fn prepare_render_pass_descriptor(descriptor: &RenderPassDescriptorRef, texture: &TextureRef) {
fn handle_render_pass_sample_buffer_attachment(
descriptor: &RenderPassDescriptorRef,
counter_sample_buffer: &CounterSampleBufferRef,
) {
let sample_buffer_attachment_descriptor =
descriptor.sample_buffer_attachments().object_at(0).unwrap();
sample_buffer_attachment_descriptor.set_sample_buffer(&counter_sample_buffer);
sample_buffer_attachment_descriptor.set_start_of_vertex_sample_index(0 as NSUInteger);
sample_buffer_attachment_descriptor.set_end_of_vertex_sample_index(1 as NSUInteger);
sample_buffer_attachment_descriptor.set_start_of_fragment_sample_index(2 as NSUInteger);
sample_buffer_attachment_descriptor.set_end_of_fragment_sample_index(3 as NSUInteger);
}

fn handle_render_pass_color_attachment(descriptor: &RenderPassDescriptorRef, texture: &TextureRef) {
let color_attachment = descriptor.color_attachments().object_at(0).unwrap();

color_attachment.set_texture(Some(texture));
Expand Down Expand Up @@ -248,3 +281,59 @@ fn prepare_pipeline_state(
.new_render_pipeline_state(&pipeline_state_descriptor)
.unwrap()
}

fn resolve_timestamps_into_buffer(
command_buffer: &CommandBufferRef,
counter_sample_buffer: &CounterSampleBufferRef,
device: &Device,
) -> Buffer {
let blit_encoder = command_buffer.new_blit_command_encoder();
let timestamps_buffer = device.new_buffer(
(std::mem::size_of::<u64>() * 4) as u64,
MTLResourceOptions::StorageModeShared,
);
blit_encoder.resolve_counters(
&counter_sample_buffer,
crate::NSRange::new(0_u64, 4_u64),
&timestamps_buffer,
0_u64,
);
blit_encoder.end_encoding();
timestamps_buffer
}

fn print_timestamps(timestamps_buffer: &BufferRef) {
let timestamps =
unsafe { std::slice::from_raw_parts(timestamps_buffer.contents() as *const u64, 4) };
println!("Start of vertex: {}", timestamps[0]);
println!("End of vertex: {}", timestamps[1]);
println!("Vertex elapsed: {}", timestamps[1] - timestamps[0]);
println!("Start of fragment: {}", timestamps[2]);
println!("End of fragment: {}", timestamps[3]);
println!("Fragment elapsed: {}\n", timestamps[3] - timestamps[2]);
}

fn create_counter_sample_buffer(device: &Device) -> CounterSampleBuffer {
let counter_sample_buffer_desc = metal::CounterSampleBufferDescriptor::new();
counter_sample_buffer_desc.set_storage_mode(metal::MTLStorageMode::Shared);
counter_sample_buffer_desc.set_sample_count(4_u64);
counter_sample_buffer_desc.set_counter_set(&fetch_timestamp_counter_set(device));

device
.new_counter_sample_buffer_with_descriptor(&counter_sample_buffer_desc)
.unwrap()
}

fn fetch_timestamp_counter_set(device: &Device) -> metal::CounterSet {
let counter_sets = device.counter_sets();
let mut timestamp_counter = None;
for cs in counter_sets.iter() {
if cs.name() == "timestamp" {
timestamp_counter = Some(cs);
break;
}
}
timestamp_counter
.expect("No timestamp counter found")
.clone()
}
151 changes: 151 additions & 0 deletions examples/counters/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
use metal::*;
fn main() {
let device = Device::system_default().expect("No device found");
FL33TW00D marked this conversation as resolved.
Show resolved Hide resolved

let counter_sample_buffer = create_counter_sample_buffer(&device);

let counter_sampling_point = MTLCounterSamplingPoint::AtStageBoundary;
assert!(device.supports_counter_sampling(counter_sampling_point));

let command_queue = device.new_command_queue();

let data = [1u32; 64 * 64];

let buffer = device.new_buffer_with_data(
unsafe { std::mem::transmute(data.as_ptr()) },
(data.len() * std::mem::size_of::<u32>()) as u64,
MTLResourceOptions::CPUCacheModeDefaultCache,
);

let sum = {
let data = [0u32];
device.new_buffer_with_data(
unsafe { std::mem::transmute(data.as_ptr()) },
(data.len() * std::mem::size_of::<u32>()) as u64,
MTLResourceOptions::CPUCacheModeDefaultCache,
)
};

let command_buffer = command_queue.new_command_buffer();

let compute_pass_descriptor = ComputePassDescriptor::new();
handle_compute_pass_sample_buffer_attachment(&compute_pass_descriptor, &counter_sample_buffer);
let encoder = command_buffer.compute_command_encoder_with_descriptor(&compute_pass_descriptor);
let library_path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("examples/compute/shaders.metallib");

let library = device.new_library_with_file(library_path).unwrap();
let kernel = library.get_function("sum", None).unwrap();

let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&kernel));

let pipeline_state = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();

encoder.set_compute_pipeline_state(&pipeline_state);
encoder.set_buffer(0, Some(&buffer), 0);
encoder.set_buffer(1, Some(&sum), 0);

let width = 16;

let thread_group_count = MTLSize {
width,
height: 1,
depth: 1,
};

let thread_group_size = MTLSize {
width: (data.len() as u64 + width) / width,
height: 1,
depth: 1,
};

encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();

let timestamps_buffer =
resolve_timestamps_into_buffer(&command_buffer, &counter_sample_buffer, &device);

command_buffer.commit();
command_buffer.wait_until_completed();

print_timestamps(&timestamps_buffer);

let ptr = sum.contents() as *mut u32;
println!("Compute shader sum: {}", unsafe { *ptr });

unsafe {
assert_eq!(4096, *ptr);
}
}

fn handle_compute_pass_sample_buffer_attachment(
compute_pass_descriptor: &ComputePassDescriptorRef,
counter_sample_buffer: &CounterSampleBufferRef,
) {
let sample_buffer_attachment_descriptor = compute_pass_descriptor
.sample_buffer_attachments()
.object_at(0)
.unwrap();

sample_buffer_attachment_descriptor.set_sample_buffer(&counter_sample_buffer);
sample_buffer_attachment_descriptor.set_start_of_encoder_sample_index(0);
sample_buffer_attachment_descriptor.set_end_of_encoder_sample_index(1);
}

fn resolve_timestamps_into_buffer(
command_buffer: &CommandBufferRef,
counter_sample_buffer: &CounterSampleBufferRef,
device: &Device,
) -> Buffer {
let blit_encoder = command_buffer.new_blit_command_encoder();
let timestamps_buffer = device.new_buffer(
(std::mem::size_of::<u64>() * 2) as u64,
MTLResourceOptions::StorageModeShared,
);
blit_encoder.resolve_counters(
&counter_sample_buffer,
crate::NSRange::new(0_u64, 2_u64),
&timestamps_buffer,
0_u64,
);
blit_encoder.end_encoding();
timestamps_buffer
}

fn print_timestamps(timestamps_buffer: &BufferRef) {
let timestamps =
unsafe { std::slice::from_raw_parts(timestamps_buffer.contents() as *const u64, 2) };
println!("Start timestamp: {}", timestamps[0]);
println!("End timestamp: {}", timestamps[1]);
println!("Elapsed time: {}", timestamps[1] - timestamps[0]);
}

fn create_counter_sample_buffer(device: &Device) -> CounterSampleBuffer {
let counter_sample_buffer_desc = metal::CounterSampleBufferDescriptor::new();
counter_sample_buffer_desc.set_storage_mode(metal::MTLStorageMode::Shared);
counter_sample_buffer_desc.set_sample_count(2_u64);
counter_sample_buffer_desc.set_counter_set(&fetch_timestamp_counter_set(device));

device
.new_counter_sample_buffer_with_descriptor(&counter_sample_buffer_desc)
.unwrap()
}

fn fetch_timestamp_counter_set(device: &Device) -> metal::CounterSet {
let counter_sets = device.counter_sets();
let mut timestamp_counter = None;
for cs in counter_sets.iter() {
if cs.name() == "timestamp" {
timestamp_counter = Some(cs);
break;
}
}
timestamp_counter
.expect("No timestamp counter found")
.clone()
}
7 changes: 7 additions & 0 deletions src/commandbuffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,13 @@ impl CommandBufferRef {
unsafe { msg_send![self, computeCommandEncoderWithDispatchType: ty] }
}

pub fn compute_command_encoder_with_descriptor(
&self,
descriptor: &ComputePassDescriptorRef,
) -> &ComputeCommandEncoderRef {
unsafe { msg_send![self, computeCommandEncoderWithDescriptor: descriptor] }
}

pub fn encode_signal_event(&self, event: &EventRef, new_value: u64) {
unsafe {
msg_send![self,
Expand Down
Loading