Skip to content

Commit

Permalink
Improve lib.rs error handling and return types
Browse files Browse the repository at this point in the history
  • Loading branch information
FlannyH committed May 6, 2024
1 parent 61ea925 commit a0e03e5
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 52 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ members = [

[dependencies]
libloading = "0.8"
thiserror = "1.0.59"
22 changes: 10 additions & 12 deletions examples/convert-dxil.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use std::ffi::CStr;

use saxaboom::{
IRComparisonFunction, IRCompilerFactory, IRFilter, IRMetalLibBinary, IRObject,
IRReflectionVersion, IRRootConstants, IRRootParameter1, IRRootParameter1_u,
IRRootParameterType, IRRootSignature, IRRootSignatureDescriptor1, IRRootSignatureFlags,
IRRootSignatureVersion, IRShaderReflection, IRShaderStage, IRShaderVisibility,
IRStaticBorderColor, IRStaticSamplerDescriptor, IRTextureAddressMode,
IRComparisonFunction, IRCompilerFactory, IRFilter, IRObject, IRReflectionVersion,
IRRootConstants, IRRootParameter1, IRRootParameter1_u, IRRootParameterType, IRRootSignature,
IRRootSignatureDescriptor1, IRRootSignatureFlags, IRRootSignatureVersion, IRShaderStage,
IRShaderVisibility, IRStaticBorderColor, IRStaticSamplerDescriptor, IRTextureAddressMode,
IRVersionedRootSignatureDescriptor, IRVersionedRootSignatureDescriptor_u,
};

Expand Down Expand Up @@ -62,13 +61,13 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {

// Load DXIL
let dxil = include_bytes!("assets/memcpy.cs.dxil");
let obj = IRObject::create_from_dxil(&compiler, dxil)?;
let obj = IRObject::create_from_dxil(&compiler, dxil);

// Convert to Metal
let mut mtl_binary = IRMetalLibBinary::new(&compiler)?;
let mtllib = compiler
.alloc_compile_and_link(CStr::from_bytes_with_nul_unchecked(b"main\0"), &obj)?;
mtllib.get_metal_lib_binary(IRShaderStage::IRShaderStageCompute, &mut mtl_binary);
let mtl_binary =
mtllib.get_metal_lib_binary(&compiler, IRShaderStage::IRShaderStageCompute)?;

// Get Metal bytecode
let metal_bytecode = mtl_binary.get_byte_code();
Expand All @@ -77,12 +76,11 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
dbg!(mtllib.get_metal_ir_shader_stage());

// Get reflection from the shader
let mut mtl_reflection = IRShaderReflection::new(&compiler)?;
mtllib.get_reflection(IRShaderStage::IRShaderStageCompute, &mut mtl_reflection);
let mtl_reflection =
mtllib.get_reflection(&compiler, IRShaderStage::IRShaderStageCompute)?;

let compute_info = mtl_reflection
.get_compute_info(IRReflectionVersion::IRReflectionVersion_1_0)
.unwrap()
.get_compute_info(IRReflectionVersion::IRReflectionVersion_1_0)?
.u_1
.info_1_0;
dbg!(compute_info);
Expand Down
103 changes: 63 additions & 40 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub use bindings::{
IRVersionedRootSignatureDescriptor__bindgen_ty_1 as IRVersionedRootSignatureDescriptor_u,
};
use bindings::{IRError, IRErrorCode};
use thiserror::Error;

pub struct IRShaderReflection {
me: *mut bindings::IRShaderReflection,
Expand All @@ -39,32 +40,33 @@ impl Drop for IRShaderReflection {
}

impl IRShaderReflection {
pub fn new(compiler: &IRCompiler) -> Result<IRShaderReflection, Box<dyn std::error::Error>> {
pub fn new(compiler: &IRCompiler) -> IRShaderReflection {
unsafe {
let me = compiler.funcs.IRShaderReflectionCreate();
Ok(Self {
Self {
me,
funcs: compiler.funcs.clone(),
})
}
}
}

pub fn get_compute_info(
&self,
version: IRReflectionVersion,
) -> Result<IRVersionedCSInfo, String> {
pub fn get_compute_info(&self, version: IRReflectionVersion) -> Option<IRVersionedCSInfo> {
let mut info = MaybeUninit::uninit();
if unsafe {
if dbg!(unsafe {
self.funcs
.IRShaderReflectionCopyComputeInfo(self.me, version, info.as_mut_ptr())
} {
Ok(unsafe { info.assume_init() })
}) {
Some(unsafe { info.assume_init() })
} else {
Err("Failed to get compute shader reflection info".to_string())
None
}
}
}

#[derive(Error, Debug)]
#[error("Failed to get MetalLib bytecode from IRObject")]
pub struct MetalLibNoBytecodeFoundError;

pub struct IRObject {
me: *mut bindings::IRObject,
funcs: Arc<bindings::metal_irconverter>,
Expand All @@ -77,21 +79,18 @@ impl Drop for IRObject {
}

impl IRObject {
pub fn create_from_dxil(
compiler: &IRCompiler,
bytecode: &[u8],
) -> Result<IRObject, Box<dyn std::error::Error>> {
pub fn create_from_dxil(compiler: &IRCompiler, bytecode: &[u8]) -> IRObject {
unsafe {
let me = compiler.funcs.IRObjectCreateFromDXIL(
bytecode.as_ptr(),
bytecode.len(),
bindings::IRBytecodeOwnership::IRBytecodeOwnershipNone,
);

Ok(Self {
Self {
funcs: compiler.funcs.clone(),
me,
})
}
}
}

Expand All @@ -112,23 +111,33 @@ impl IRObject {

pub fn get_metal_lib_binary(
&self,
compiler: &IRCompiler,
shader_stage: IRShaderStage,
dest_lib: &mut IRMetalLibBinary,
) -> bool {
unsafe {
) -> Result<IRMetalLibBinary, MetalLibNoBytecodeFoundError> {
let mtl_lib = IRMetalLibBinary::new(compiler);
if unsafe {
self.funcs
.IRObjectGetMetalLibBinary(self.me, shader_stage, dest_lib.me)
.IRObjectGetMetalLibBinary(self.me, shader_stage, mtl_lib.me)
} {
Ok(mtl_lib)
} else {
Err(MetalLibNoBytecodeFoundError)
}
}

pub fn get_reflection(
&self,
compiler: &IRCompiler,
shader_stage: IRShaderStage,
reflection: &mut IRShaderReflection,
) -> bool {
unsafe {
) -> Option<IRShaderReflection> {
let reflection = IRShaderReflection::new(compiler);
if unsafe {
self.funcs
.IRObjectGetReflection(self.me, shader_stage, reflection.me)
} {
Some(reflection)
} else {
None
}
}
}
Expand All @@ -145,13 +154,13 @@ impl Drop for IRMetalLibBinary {
}

impl IRMetalLibBinary {
pub fn new(compiler: &IRCompiler) -> Result<IRMetalLibBinary, Box<dyn std::error::Error>> {
pub fn new(compiler: &IRCompiler) -> IRMetalLibBinary {
unsafe {
let me = compiler.funcs.IRMetalLibBinaryCreate();
Ok(Self {
Self {
funcs: compiler.funcs.clone(),
me,
})
}
}
}

Expand All @@ -178,13 +187,17 @@ impl Drop for IRRootSignature {
}
}

#[derive(Error, Debug)]
#[error("Failed to create IRRootSignature: {0:?}")]
pub struct RootSignatureCreateError(IRErrorCode);

impl IRRootSignature {
pub fn create_from_descriptor(
compiler: &IRCompiler,
desc: &IRVersionedRootSignatureDescriptor,
) -> Result<IRRootSignature, String> {
) -> Result<IRRootSignature, RootSignatureCreateError> {
unsafe {
let mut error: *mut IRError = std::ptr::null_mut::<IRError>();
let mut error = std::ptr::null_mut();

let me = compiler
.funcs
Expand All @@ -195,9 +208,7 @@ impl IRRootSignature {
// IRErrorCode is #[repr(u32)], so this transmute should be fine
let code: u32 = compiler.funcs.IRErrorGetCode(error);
let code: IRErrorCode = std::mem::transmute(code);
return Err(format!(
"Root Signature creation failed with error code {code:?}"
));
return Err(RootSignatureCreateError(code));
}

Ok(Self {
Expand Down Expand Up @@ -253,6 +264,16 @@ impl IRCompilerFactory {
}
}

#[derive(Error, Debug)]
pub enum CompilerError {
#[error("Failed to compile IRObject: ({0:?})")]
IRObjectCompileError(IRErrorCode),
#[error("Failed to synthesize indirect intersection function")]
SynthesizeIndirectIntersectionFunctionError,
#[error("Failed to synthesize indirect ray dispatch function")]
SynthesizeIndirectRayDispatchError,
}

/// This object is not thread-safe, refer to [the Metal shader converter documentation], the "Multithreading considerations" chapter.
///
/// [the Metal shader converter documentation]: https://developer.apple.com/metal/shader-converter/
Expand Down Expand Up @@ -309,31 +330,31 @@ impl IRCompiler {

pub fn synthesize_indirect_intersection_function(
&mut self,
) -> Result<IRMetalLibBinary, String> {
let target_metallib = IRMetalLibBinary::new(&self).unwrap();
) -> Result<IRMetalLibBinary, CompilerError> {
let target_metallib = IRMetalLibBinary::new(&self);
if unsafe {
self.funcs
.IRMetalLibSynthesizeIndirectIntersectionFunction(self.me, target_metallib.me)
} {
return Ok(target_metallib);
}
{
Err("Failed to synthezize indirect intersection function".to_string())
Err(CompilerError::SynthesizeIndirectIntersectionFunctionError)
}
}

pub fn synthesize_indirect_ray_dispatch_function(
&mut self,
) -> Result<IRMetalLibBinary, String> {
let target_metallib = IRMetalLibBinary::new(&self).unwrap();
) -> Result<IRMetalLibBinary, CompilerError> {
let target_metallib = IRMetalLibBinary::new(&self);
if unsafe {
self.funcs
.IRMetalLibSynthesizeIndirectRayDispatchFunction(self.me, target_metallib.me)
} {
return Ok(target_metallib);
}
{
Err("Failed to synthezize indirect intersection function".to_string())
Err(CompilerError::SynthesizeIndirectRayDispatchError)
}
}

Expand All @@ -348,7 +369,7 @@ impl IRCompiler {
&mut self,
entry_point: &CStr,
input: &IRObject,
) -> Result<IRObject, Box<dyn std::error::Error>> {
) -> Result<IRObject, CompilerError> {
let mut error: *mut IRError = std::ptr::null_mut::<IRError>();

let v = unsafe {
Expand All @@ -366,7 +387,9 @@ impl IRCompiler {
funcs: input.funcs.clone(),
})
} else {
panic!("{:?}", error);
let code: u32 = unsafe { self.funcs.IRErrorGetCode(error) };
let code: IRErrorCode = unsafe { std::mem::transmute(code) };
Err(CompilerError::IRObjectCompileError(code))
}
}
}

0 comments on commit a0e03e5

Please sign in to comment.