From bfc7f90265e6feddbd54a2cdd4cbdd561e906674 Mon Sep 17 00:00:00 2001 From: Marijn Suijten Date: Tue, 25 Jan 2022 18:14:21 +0100 Subject: [PATCH] Consistently return `HassleError` and remove unneeded check macros Our public API becomes a lot cleaner when returning `HassleError` everywhere, helped by the fact that `HRESULT`s from the low-level COM API is a newtype with support for functions to convert to `Result<>` easily. These can be short-circuit-returned in a more Rust'y way through the questionmark operator instead of wrapping everything in large macro calls. This change is based on a similar approach in the Ash crate: https://github.com/MaikKlein/ash/pull/339/files --- examples/intellisense-tu.rs | 14 +- src/intellisense/wrapper.rs | 471 +++++++++++++++--------------------- src/lib.rs | 6 +- src/utils.rs | 58 +++-- src/wrapper.rs | 225 +++++++---------- 5 files changed, 321 insertions(+), 453 deletions(-) diff --git a/examples/intellisense-tu.rs b/examples/intellisense-tu.rs index 8df503e..3ba37b8 100644 --- a/examples/intellisense-tu.rs +++ b/examples/intellisense-tu.rs @@ -43,17 +43,11 @@ fn main() { let child_cursors = cursor.get_all_children().unwrap(); + assert_eq!(child_cursors[0].get_display_name().unwrap(), "g_input"); + assert_eq!(child_cursors[1].get_display_name().unwrap(), "g_output"); assert_eq!( - child_cursors[0].get_display_name(), - Ok("g_input".to_owned()) - ); - assert_eq!( - child_cursors[1].get_display_name(), - Ok("g_output".to_owned()) - ); - assert_eq!( - child_cursors[2].get_display_name(), - Ok("copyCs(uint3)".to_owned()) + child_cursors[2].get_display_name().unwrap(), + "copyCs(uint3)" ); for child_cursor in child_cursors { diff --git a/src/intellisense/wrapper.rs b/src/intellisense/wrapper.rs index f240118..f04a6b6 100644 --- a/src/intellisense/wrapper.rs +++ b/src/intellisense/wrapper.rs @@ -1,6 +1,6 @@ use crate::intellisense::ffi::*; -use crate::os::{CoTaskMemFree, BSTR, HRESULT, LPSTR}; -use crate::utils::HassleError; +use crate::os::{CoTaskMemFree, BSTR, LPSTR}; +use crate::utils::Result; use crate::wrapper::Dxc; use com_rs::ComPtr; use std::ffi::CString; @@ -15,46 +15,33 @@ impl DxcIntellisense { Self { inner } } - pub fn get_default_editing_tu_options(&self) -> Result { + pub fn get_default_editing_tu_options(&self) -> Result { let mut options: DxcTranslationUnitFlags = DxcTranslationUnitFlags::NONE; - unsafe { - check_hr!( - self.inner.get_default_editing_tu_options(&mut options), - options - ) - } + unsafe { self.inner.get_default_editing_tu_options(&mut options) } + .result_with_success(options) } - pub fn create_index(&self) -> Result { + pub fn create_index(&self) -> Result { let mut index: ComPtr = ComPtr::new(); - unsafe { - check_hr!( - self.inner.create_index(index.as_mut_ptr()), - DxcIndex::new(index) - ) - } + unsafe { self.inner.create_index(index.as_mut_ptr()) }.result()?; + Ok(DxcIndex::new(index)) } - pub fn create_unsaved_file( - &self, - file_name: &str, - contents: &str, - ) -> Result { + pub fn create_unsaved_file(&self, file_name: &str, contents: &str) -> Result { let c_file_name = CString::new(file_name).expect("Failed to convert `file_name`"); let c_contents = CString::new(contents).expect("Failed to convert `contents`"); let mut file: ComPtr = ComPtr::new(); unsafe { - check_hr!( - self.inner.create_unsaved_file( - c_file_name.as_ptr(), - c_contents.as_ptr(), - contents.len() as u32, - file.as_mut_ptr() - ), - DxcUnsavedFile::new(file) + self.inner.create_unsaved_file( + c_file_name.as_ptr(), + c_contents.as_ptr(), + contents.len() as u32, + file.as_mut_ptr(), ) } + .result()?; + Ok(DxcUnsavedFile::new(file)) } } @@ -76,7 +63,7 @@ impl DxcIndex { args: &[&str], unsaved_files: &[&DxcUnsavedFile], options: DxcTranslationUnitFlags, - ) -> Result { + ) -> Result { let c_source_filename = CString::new(source_filename).expect("Failed to convert `source_filename`"); @@ -86,30 +73,30 @@ impl DxcIndex { uf.push(unsaved_file.inner.as_ptr()); } - unsafe { - let mut c_args: Vec = vec![]; - let mut cliargs = vec![]; + let mut c_args: Vec = vec![]; + let mut cliargs = vec![]; - for arg in args.iter() { - let c_arg = CString::new(*arg).expect("Failed to convert `arg`"); - cliargs.push(c_arg.as_ptr() as *const u8); - c_args.push(c_arg); - } + for arg in args.iter() { + let c_arg = CString::new(*arg).expect("Failed to convert `arg`"); + cliargs.push(c_arg.as_ptr() as *const u8); + c_args.push(c_arg); + } + + let mut tu: ComPtr = ComPtr::new(); - let mut tu: ComPtr = ComPtr::new(); - check_hr!( - self.inner.parse_translation_unit( - c_source_filename.as_ptr() as *const u8, - cliargs.as_ptr(), - cliargs.len() as i32, - uf.as_ptr(), - uf.len() as u32, - options, - tu.as_mut_ptr() - ), - DxcTranslationUnit::new(tu) + unsafe { + self.inner.parse_translation_unit( + c_source_filename.as_ptr() as *const u8, + cliargs.as_ptr(), + cliargs.len() as i32, + uf.as_ptr(), + uf.len() as u32, + options, + tu.as_mut_ptr(), ) } + .result()?; + Ok(DxcTranslationUnit::new(tu)) } } @@ -119,9 +106,9 @@ pub struct DxcUnsavedFile { } impl DxcUnsavedFile { - pub fn get_length(&self) -> Result { + pub fn get_length(&self) -> Result { let mut length: u32 = 0; - unsafe { check_hr!(self.inner.get_length(&mut length), length) } + unsafe { self.inner.get_length(&mut length) }.result_with_success(length) } fn new(inner: ComPtr) -> Self { @@ -139,24 +126,16 @@ impl DxcTranslationUnit { DxcTranslationUnit { inner } } - pub fn get_file(&self, name: &[u8]) -> Result { + pub fn get_file(&self, name: &[u8]) -> Result { let mut file: ComPtr = ComPtr::new(); - unsafe { - check_hr!( - self.inner.get_file(name.as_ptr(), file.as_mut_ptr()), - DxcFile::new(file) - ) - } + unsafe { self.inner.get_file(name.as_ptr(), file.as_mut_ptr()) }.result()?; + Ok(DxcFile::new(file)) } - pub fn get_cursor(&self) -> Result { + pub fn get_cursor(&self) -> Result { let mut cursor: ComPtr = ComPtr::new(); - unsafe { - check_hr!( - self.inner.get_cursor(cursor.as_mut_ptr()), - DxcCursor::new(cursor) - ) - } + unsafe { self.inner.get_cursor(cursor.as_mut_ptr()) }.result()?; + Ok(DxcCursor::new(cursor)) } } @@ -170,32 +149,30 @@ impl DxcCursor { DxcCursor { inner } } - pub fn get_children(&self, skip: u32, max_count: u32) -> Result, HRESULT> { + pub fn get_children(&self, skip: u32, max_count: u32) -> Result> { + let mut result: *mut *mut IDxcCursor = std::ptr::null_mut(); + let mut result_length: u32 = 0; + unsafe { - let mut result: *mut *mut IDxcCursor = std::ptr::null_mut(); - let mut result_length: u32 = 0; - - check_hr!( - self.inner - .get_children(skip, max_count, &mut result_length, &mut result), - { - // get_children allocates a buffer to pass the result in. - let child_cursors = std::slice::from_raw_parts(result, result_length as usize) - .iter() - .map(|&ptr| { - let mut childcursor = ComPtr::::new(); - *childcursor.as_mut_ptr() = ptr; - DxcCursor::new(childcursor) - }) - .collect::>(); - CoTaskMemFree(result as *mut _); - child_cursors - } - ) + self.inner + .get_children(skip, max_count, &mut result_length, &mut result) } - } - - pub fn get_all_children(&self) -> Result, HRESULT> { + .result()?; + + // get_children allocates a buffer to pass the result in. + let child_cursors = unsafe { std::slice::from_raw_parts(result, result_length as usize) } + .iter() + .map(|&ptr| { + let mut childcursor = ComPtr::::new(); + *childcursor.as_mut_ptr() = ptr; + DxcCursor::new(childcursor) + }) + .collect::>(); + unsafe { CoTaskMemFree(result as *mut _) }; + Ok(child_cursors) + } + + pub fn get_all_children(&self) -> Result> { const MAX_CHILDREN_PER_CHUNK: u32 = 10; let mut children = vec![]; @@ -209,139 +186,91 @@ impl DxcCursor { } } - pub fn get_extent(&self) -> Result { - unsafe { - let mut range: ComPtr = ComPtr::new(); - check_hr!( - self.inner.get_extent(range.as_mut_ptr()), - DxcSourceRange::new(range) - ) - } + pub fn get_extent(&self) -> Result { + let mut range: ComPtr = ComPtr::new(); + unsafe { self.inner.get_extent(range.as_mut_ptr()) }.result()?; + Ok(DxcSourceRange::new(range)) } - pub fn get_location(&self) -> Result { - unsafe { - let mut location: ComPtr = ComPtr::new(); - check_hr!( - self.inner.get_location(location.as_mut_ptr()), - DxcSourceLocation::new(location) - ) - } + pub fn get_location(&self) -> Result { + let mut location: ComPtr = ComPtr::new(); + unsafe { self.inner.get_location(location.as_mut_ptr()) }.result()?; + Ok(DxcSourceLocation::new(location)) } - pub fn get_display_name(&self) -> Result { - unsafe { - let mut name: BSTR = std::ptr::null_mut(); - check_hr!( - self.inner.get_display_name(&mut name), - crate::utils::from_bstr(name) - ) - } + pub fn get_display_name(&self) -> Result { + let mut name: BSTR = std::ptr::null_mut(); + unsafe { self.inner.get_display_name(&mut name) }.result()?; + Ok(crate::utils::from_bstr(name)) } - pub fn get_formatted_name(&self, formatting: DxcCursorFormatting) -> Result { - unsafe { - let mut name: BSTR = std::ptr::null_mut(); - check_hr!( - self.inner.get_formatted_name(formatting, &mut name), - crate::utils::from_bstr(name) - ) - } + pub fn get_formatted_name(&self, formatting: DxcCursorFormatting) -> Result { + let mut name: BSTR = std::ptr::null_mut(); + unsafe { self.inner.get_formatted_name(formatting, &mut name) }.result()?; + Ok(crate::utils::from_bstr(name)) } - pub fn get_qualified_name(&self, include_template_args: bool) -> Result { + pub fn get_qualified_name(&self, include_template_args: bool) -> Result { + let mut name: BSTR = std::ptr::null_mut(); unsafe { - let mut name: BSTR = std::ptr::null_mut(); - check_hr!( - self.inner - .get_qualified_name(include_template_args, &mut name), - crate::utils::from_bstr(name) - ) + self.inner + .get_qualified_name(include_template_args, &mut name) } + .result()?; + Ok(crate::utils::from_bstr(name)) } - pub fn get_kind(&self) -> Result { - unsafe { - let mut cursor_kind: DxcCursorKind = DxcCursorKind::UNEXPOSED_DECL; - check_hr!(self.inner.get_kind(&mut cursor_kind), cursor_kind) - } + pub fn get_kind(&self) -> Result { + let mut cursor_kind: DxcCursorKind = DxcCursorKind::UNEXPOSED_DECL; + unsafe { self.inner.get_kind(&mut cursor_kind) }.result_with_success(cursor_kind) } - pub fn get_kind_flags(&self) -> Result { - unsafe { - let mut cursor_kind_flags: DxcCursorKindFlags = DxcCursorKindFlags::NONE; - check_hr!( - self.inner.get_kind_flags(&mut cursor_kind_flags), - cursor_kind_flags - ) - } + pub fn get_kind_flags(&self) -> Result { + let mut cursor_kind_flags: DxcCursorKindFlags = DxcCursorKindFlags::NONE; + unsafe { self.inner.get_kind_flags(&mut cursor_kind_flags) } + .result_with_success(cursor_kind_flags) } - pub fn get_semantic_parent(&self) -> Result { - unsafe { - let mut inner = ComPtr::::new(); - check_hr!( - self.inner.get_semantic_parent(inner.as_mut_ptr()), - DxcCursor::new(inner) - ) - } + pub fn get_semantic_parent(&self) -> Result { + let mut inner = ComPtr::::new(); + unsafe { self.inner.get_semantic_parent(inner.as_mut_ptr()) }.result()?; + Ok(DxcCursor::new(inner)) } - pub fn get_lexical_parent(&self) -> Result { - unsafe { - let mut inner = ComPtr::::new(); - check_hr!( - self.inner.get_lexical_parent(inner.as_mut_ptr()), - DxcCursor::new(inner) - ) - } + pub fn get_lexical_parent(&self) -> Result { + let mut inner = ComPtr::::new(); + unsafe { self.inner.get_lexical_parent(inner.as_mut_ptr()) }.result()?; + Ok(DxcCursor::new(inner)) } - pub fn get_cursor_type(&self) -> Result { - unsafe { - let mut inner = ComPtr::::new(); - check_hr!( - self.inner.get_cursor_type(inner.as_mut_ptr()), - DxcType::new(inner) - ) - } + pub fn get_cursor_type(&self) -> Result { + let mut inner = ComPtr::::new(); + unsafe { self.inner.get_cursor_type(inner.as_mut_ptr()) }.result()?; + Ok(DxcType::new(inner)) } - pub fn get_num_arguments(&self) -> Result { - unsafe { - let mut result: i32 = 0; - check_hr!(self.inner.get_num_arguments(&mut result), result) - } + pub fn get_num_arguments(&self) -> Result { + let mut result: i32 = 0; + + unsafe { self.inner.get_num_arguments(&mut result) }.result_with_success(result) } - pub fn get_argument_at(&self, index: i32) -> Result { - unsafe { - let mut inner = ComPtr::::new(); - check_hr!( - self.inner.get_argument_at(index, inner.as_mut_ptr()), - DxcCursor::new(inner) - ) - } + pub fn get_argument_at(&self, index: i32) -> Result { + let mut inner = ComPtr::::new(); + unsafe { self.inner.get_argument_at(index, inner.as_mut_ptr()) }.result()?; + Ok(DxcCursor::new(inner)) } - pub fn get_referenced_cursor(&self) -> Result { - unsafe { - let mut inner = ComPtr::::new(); - check_hr!( - self.inner.get_referenced_cursor(inner.as_mut_ptr()), - DxcCursor::new(inner) - ) - } + pub fn get_referenced_cursor(&self) -> Result { + let mut inner = ComPtr::::new(); + unsafe { self.inner.get_referenced_cursor(inner.as_mut_ptr()) }.result()?; + Ok(DxcCursor::new(inner)) } - pub fn get_definition_cursor(&self) -> Result { - unsafe { - let mut inner = ComPtr::::new(); - check_hr!( - self.inner.get_definition_cursor(inner.as_mut_ptr()), - DxcCursor::new(inner) - ) - } + pub fn get_definition_cursor(&self) -> Result { + let mut inner = ComPtr::::new(); + unsafe { self.inner.get_definition_cursor(inner.as_mut_ptr()) }.result()?; + Ok(DxcCursor::new(inner)) } pub fn find_references_in_file( @@ -349,82 +278,67 @@ impl DxcCursor { file: &DxcFile, skip: u32, top: u32, - ) -> Result, HRESULT> { + ) -> Result> { + let mut result: *mut *mut IDxcCursor = std::ptr::null_mut(); + let mut result_length: u32 = 0; + unsafe { - let mut result: *mut *mut IDxcCursor = std::ptr::null_mut(); - let mut result_length: u32 = 0; - - check_hr!( - self.inner.find_references_in_file( - file.inner.as_ptr(), - skip, - top, - &mut result_length, - &mut result - ), - { - // find_references_in_file allocates a buffer to pass the result in. - let child_cursors = std::slice::from_raw_parts(result, result_length as usize) - .iter() - .map(|&ptr| { - let mut childcursor = ComPtr::::new(); - *childcursor.as_mut_ptr() = ptr; - DxcCursor::new(childcursor) - }) - .collect::>(); - CoTaskMemFree(result as *mut _); - child_cursors - } + self.inner.find_references_in_file( + file.inner.as_ptr(), + skip, + top, + &mut result_length, + &mut result, ) } + .result()?; + + // find_references_in_file allocates a buffer to pass the result in. + let child_cursors = unsafe { std::slice::from_raw_parts(result, result_length as usize) } + .iter() + .map(|&ptr| { + let mut childcursor = ComPtr::::new(); + *childcursor.as_mut_ptr() = ptr; + DxcCursor::new(childcursor) + }) + .collect::>(); + unsafe { CoTaskMemFree(result as *mut _) }; + Ok(child_cursors) } - pub fn get_spelling(&self) -> Result { - unsafe { - let mut spelling: LPSTR = std::ptr::null_mut(); - check_hr!( - self.inner.get_spelling(&mut spelling), - crate::utils::from_lpstr(spelling) - ) - } + pub fn get_spelling(&self) -> Result { + let mut spelling: LPSTR = std::ptr::null_mut(); + unsafe { self.inner.get_spelling(&mut spelling) }.result()?; + Ok(crate::utils::from_lpstr(spelling)) } - pub fn is_equal_to(&self, other: &DxcCursor) -> Result { - unsafe { - let mut result: bool = false; - check_hr!( - self.inner.is_equal_to(other.inner.as_ptr(), &mut result), - result - ) - } + pub fn is_equal_to(&self, other: &DxcCursor) -> Result { + let mut result: bool = false; + unsafe { self.inner.is_equal_to(other.inner.as_ptr(), &mut result) } + .result_with_success(result) } - pub fn is_null(&mut self) -> Result { - unsafe { - let mut result: bool = false; - check_hr!(IDxcCursor::is_null(&self.inner, &mut result), result) - } + pub fn is_null(&mut self) -> Result { + let mut result: bool = false; + unsafe { IDxcCursor::is_null(&self.inner, &mut result) }.result_with_success(result) } - pub fn is_definition(&self) -> Result { - unsafe { - let mut result: bool = false; - check_hr!(self.inner.is_definition(&mut result), result) - } + pub fn is_definition(&self) -> Result { + let mut result: bool = false; + unsafe { self.inner.is_definition(&mut result) }.result_with_success(result) } - pub fn get_snapped_child(&self, location: &DxcSourceLocation) -> Result { + pub fn get_snapped_child(&self, location: &DxcSourceLocation) -> Result { + let mut inner = ComPtr::::new(); unsafe { - let mut inner = ComPtr::::new(); - check_hr!( - self.inner - .get_snapped_child(location.inner.as_ptr(), inner.as_mut_ptr()), - DxcCursor::new(inner) - ) + self.inner + .get_snapped_child(location.inner.as_ptr(), inner.as_mut_ptr()) } + .result()?; + Ok(DxcCursor::new(inner)) } - pub fn get_source<'a>(&self, source: &'a str) -> Result<&'a str, HRESULT> { + pub fn get_source<'a>(&self, source: &'a str) -> Result<&'a str> { let range = self.get_extent()?; let DxcSourceOffsets { @@ -448,14 +362,10 @@ impl DxcType { DxcType { inner } } - pub fn get_spelling(&self) -> Result { - unsafe { - let mut spelling: LPSTR = std::ptr::null_mut(); - check_hr!( - self.inner.get_spelling(&mut spelling), - crate::utils::from_lpstr(spelling) - ) - } + pub fn get_spelling(&self) -> Result { + let mut spelling: LPSTR = std::ptr::null_mut(); + unsafe { self.inner.get_spelling(&mut spelling) } + .result_with_success(crate::utils::from_lpstr(spelling)) } } @@ -482,18 +392,15 @@ pub struct DxcSourceRange { } impl DxcSourceRange { - pub fn get_offsets(&self) -> Result { - unsafe { - let mut start_offset: u32 = 0; - let mut end_offset: u32 = 0; - check_hr!( - self.inner.get_offsets(&mut start_offset, &mut end_offset), - DxcSourceOffsets { - start_offset, - end_offset - } - ) - } + pub fn get_offsets(&self) -> Result { + let mut start_offset: u32 = 0; + let mut end_offset: u32 = 0; + unsafe { self.inner.get_offsets(&mut start_offset, &mut end_offset) }.result_with_success( + DxcSourceOffsets { + start_offset, + end_offset, + }, + ) } } @@ -515,15 +422,15 @@ impl DxcFile { } impl Dxc { - pub fn create_intellisense(&self) -> Result { + pub fn create_intellisense(&self) -> Result { let mut intellisense: ComPtr = ComPtr::new(); - check_hr_wrapped!( - self.get_dxc_create_instance()?( - &CLSID_DxcIntelliSense, - &IID_IDxcIntelliSense, - intellisense.as_mut_ptr(), - ), - DxcIntellisense::new(intellisense) + + self.get_dxc_create_instance()?( + &CLSID_DxcIntelliSense, + &IID_IDxcIntelliSense, + intellisense.as_mut_ptr(), ) + .result()?; + Ok(DxcIntellisense::new(intellisense)) } } diff --git a/src/lib.rs b/src/lib.rs index dc34cac..b2d65e8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,14 +43,10 @@ pub mod fake_sign; pub mod ffi; pub mod os; pub mod utils; - -#[macro_use] pub mod wrapper; pub mod intellisense; pub use crate::ffi::*; -pub use crate::utils::compile_hlsl; -pub use crate::utils::fake_sign_dxil_in_place; -pub use crate::utils::validate_dxil; +pub use crate::utils::{compile_hlsl, fake_sign_dxil_in_place, validate_dxil, HassleError, Result}; pub use crate::wrapper::*; diff --git a/src/utils.rs b/src/utils.rs index c160644..7252d53 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -76,6 +76,34 @@ pub enum HassleError { WindowsOnly(String), } +pub type Result = std::result::Result; + +impl HRESULT { + /// Turns an [`HRESULT`] from the COM [`crate::ffi`] API declaration + /// into a [`Result`] containing [`HassleError`]. + pub fn result(self) -> Result<()> { + self.result_with_success(()) + } + + /// Turns an [`HRESULT`] from the COM [`crate::ffi`] API declaration + /// into a [`Result`] containing [`HassleError`], with the desired value. + /// + /// Note that `v` is passed by value and is not a closure that is executed + /// lazily. Use the short-circuiting `?` operator for such cases: + /// ```no_run + /// let mut blob: ComPtr = ComPtr::new(); + /// unsafe { self.inner.get_result(blob.as_mut_ptr()) }.result()?; + /// Ok(DxcBlob::new(blob)) + /// ``` + pub fn result_with_success(self, v: T) -> Result { + if self.is_err() { + Err(HassleError::Win32Error(self)) + } else { + Ok(v) + } + } +} + /// Helper function to directly compile a HLSL shader to an intermediate language, /// this function expects `dxcompiler.dll` to be available in the current /// executable environment. @@ -90,15 +118,13 @@ pub fn compile_hlsl( target_profile: &str, args: &[&str], defines: &[(&str, Option<&str>)], -) -> Result, HassleError> { +) -> Result> { let dxc = Dxc::new(None)?; let compiler = dxc.create_compiler()?; let library = dxc.create_library()?; - let blob = library - .create_blob_with_encoding_from_str(shader_text) - .map_err(HassleError::Win32Error)?; + let blob = library.create_blob_with_encoding_from_str(shader_text)?; let result = compiler.compile( &blob, @@ -112,18 +138,13 @@ pub fn compile_hlsl( match result { Err(result) => { - let error_blob = result - .0 - .get_error_buffer() - .map_err(HassleError::Win32Error)?; + let error_blob = result.0.get_error_buffer()?; Err(HassleError::CompileError( - library - .get_blob_as_string(&error_blob) - .map_err(HassleError::Win32Error)?, + library.get_blob_as_string(&error_blob)?, )) } Ok(result) => { - let result_blob = result.get_result().map_err(HassleError::Win32Error)?; + let result_blob = result.get_result()?; Ok(result_blob.to_vec()) } @@ -142,21 +163,14 @@ pub fn validate_dxil(data: &[u8]) -> Result, HassleError> { let validator = dxil.create_validator()?; let library = dxc.create_library()?; - let blob_encoding = library - .create_blob_with_encoding(data) - .map_err(HassleError::Win32Error)?; + let blob_encoding = library.create_blob_with_encoding(data)?; match validator.validate(blob_encoding.into()) { Ok(blob) => Ok(blob.to_vec()), Err(result) => { - let error_blob = result - .0 - .get_error_buffer() - .map_err(HassleError::Win32Error)?; + let error_blob = result.0.get_error_buffer()?; Err(HassleError::ValidationError( - library - .get_blob_as_string(&error_blob) - .map_err(HassleError::Win32Error)?, + library.get_blob_as_string(&error_blob)?, )) } } diff --git a/src/wrapper.rs b/src/wrapper.rs index 5a83e1b..e033878 100644 --- a/src/wrapper.rs +++ b/src/wrapper.rs @@ -6,36 +6,13 @@ use crate::ffi::*; use crate::os::{HRESULT, LPCWSTR, LPWSTR, WCHAR}; -use crate::utils::{from_wide, to_wide, HassleError}; +use crate::utils::{from_wide, to_wide, HassleError, Result}; use com_rs::ComPtr; use libloading::{Library, Symbol}; use std::ffi::c_void; use std::path::{Path, PathBuf}; use std::pin::Pin; -#[macro_export] -macro_rules! check_hr { - ($hr:expr, $v: expr) => {{ - let hr = $hr; - if !hr.is_err() { - Ok($v) - } else { - Err(hr) - } - }}; -} - -macro_rules! check_hr_wrapped { - ($hr:expr, $v: expr) => {{ - let hr = $hr; - if !hr.is_err() { - Ok($v) - } else { - Err(HassleError::Win32Error(hr)) - } - }}; -} - #[derive(Debug)] pub struct DxcBlob { inner: ComPtr, @@ -111,25 +88,22 @@ impl DxcOperationResult { Self { inner } } - pub fn get_status(&self) -> Result { + pub fn get_status(&self) -> Result { let mut status: u32 = 0; - check_hr!(unsafe { self.inner.get_status(&mut status) }, status) + unsafe { self.inner.get_status(&mut status) }.result_with_success(status) } - pub fn get_result(&self) -> Result { + pub fn get_result(&self) -> Result { let mut blob: ComPtr = ComPtr::new(); - check_hr!( - unsafe { self.inner.get_result(blob.as_mut_ptr()) }, - DxcBlob::new(blob) - ) + unsafe { self.inner.get_result(blob.as_mut_ptr()) }.result()?; + Ok(DxcBlob::new(blob)) } - pub fn get_error_buffer(&self) -> Result { + pub fn get_error_buffer(&self) -> Result { let mut blob: ComPtr = ComPtr::new(); - check_hr!( - unsafe { self.inner.get_error_buffer(blob.as_mut_ptr()) }, - DxcBlobEncoding::new(blob) - ) + + unsafe { self.inner.get_error_buffer(blob.as_mut_ptr()) }.result()?; + Ok(DxcBlobEncoding::new(blob)) } } @@ -426,15 +400,14 @@ impl DxcCompiler { } } - pub fn disassemble(&self, blob: &DxcBlob) -> Result { + pub fn disassemble(&self, blob: &DxcBlob) -> Result { let mut result_blob: ComPtr = ComPtr::new(); - check_hr!( - unsafe { - self.inner - .disassemble(blob.inner.as_ptr(), result_blob.as_mut_ptr()) - }, - DxcBlobEncoding::new(result_blob) - ) + unsafe { + self.inner + .disassemble(blob.inner.as_ptr(), result_blob.as_mut_ptr()) + } + .result()?; + Ok(DxcBlobEncoding::new(result_blob)) } } @@ -448,60 +421,54 @@ impl DxcLibrary { Self { inner } } - pub fn create_blob_with_encoding(&self, data: &[u8]) -> Result { + pub fn create_blob_with_encoding(&self, data: &[u8]) -> Result { let mut blob: ComPtr = ComPtr::new(); - check_hr!( - unsafe { - self.inner.create_blob_with_encoding_from_pinned( - data.as_ptr() as *const c_void, - data.len() as u32, - 0, // Binary; no code page - blob.as_mut_ptr(), - ) - }, - DxcBlobEncoding::new(blob) - ) + + unsafe { + self.inner.create_blob_with_encoding_from_pinned( + data.as_ptr() as *const c_void, + data.len() as u32, + 0, // Binary; no code page + blob.as_mut_ptr(), + ) + } + .result()?; + Ok(DxcBlobEncoding::new(blob)) } - pub fn create_blob_with_encoding_from_str( - &self, - text: &str, - ) -> Result { + pub fn create_blob_with_encoding_from_str(&self, text: &str) -> Result { let mut blob: ComPtr = ComPtr::new(); const CP_UTF8: u32 = 65001; // UTF-8 translation - check_hr!( - unsafe { - self.inner.create_blob_with_encoding_from_pinned( - text.as_ptr() as *const c_void, - text.len() as u32, - CP_UTF8, - blob.as_mut_ptr(), - ) - }, - DxcBlobEncoding::new(blob) - ) + unsafe { + self.inner.create_blob_with_encoding_from_pinned( + text.as_ptr() as *const c_void, + text.len() as u32, + CP_UTF8, + blob.as_mut_ptr(), + ) + } + .result()?; + Ok(DxcBlobEncoding::new(blob)) } - pub fn get_blob_as_string(&self, blob: &DxcBlobEncoding) -> Result { + pub fn get_blob_as_string(&self, blob: &DxcBlobEncoding) -> Result { let mut blob_utf8: ComPtr = ComPtr::new(); - check_hr!( - unsafe { - self.inner - .get_blob_as_utf8(blob.inner.as_ptr(), blob_utf8.as_mut_ptr()) - }, - { - let slice = unsafe { - std::slice::from_raw_parts( - blob_utf8.get_buffer_pointer() as *const u8, - blob_utf8.get_buffer_size(), - ) - }; - - String::from_utf8(slice.to_vec()).unwrap() - } - ) + unsafe { + self.inner + .get_blob_as_utf8(blob.inner.as_ptr(), blob_utf8.as_mut_ptr()) + } + .result()?; + + let slice = unsafe { + std::slice::from_raw_parts( + blob_utf8.get_buffer_pointer() as *const u8, + blob_utf8.get_buffer_size(), + ) + }; + + Ok(String::from_utf8(slice.to_vec()).unwrap()) } } @@ -528,7 +495,7 @@ fn dxcompiler_lib_name() -> &'static Path { impl Dxc { /// `dxc_path` can point to a library directly or the directory containing the library, /// in which case the appended filename depends on the platform. - pub fn new(lib_path: Option) -> Result { + pub fn new(lib_path: Option) -> Result { let lib_path = if let Some(lib_path) = lib_path { if lib_path.is_file() { lib_path @@ -547,34 +514,28 @@ impl Dxc { Ok(Self { dxc_lib }) } - pub(crate) fn get_dxc_create_instance( - &self, - ) -> Result, HassleError> { + pub(crate) fn get_dxc_create_instance(&self) -> Result> { Ok(unsafe { self.dxc_lib.get(b"DxcCreateInstance\0")? }) } - pub fn create_compiler(&self) -> Result { + pub fn create_compiler(&self) -> Result { let mut compiler: ComPtr = ComPtr::new(); - check_hr_wrapped!( - self.get_dxc_create_instance()?( - &CLSID_DxcCompiler, - &IID_IDxcCompiler2, - compiler.as_mut_ptr(), - ), - DxcCompiler::new(compiler, self.create_library()?) + + self.get_dxc_create_instance()?( + &CLSID_DxcCompiler, + &IID_IDxcCompiler2, + compiler.as_mut_ptr(), ) + .result()?; + Ok(DxcCompiler::new(compiler, self.create_library()?)) } - pub fn create_library(&self) -> Result { + pub fn create_library(&self) -> Result { let mut library: ComPtr = ComPtr::new(); - check_hr_wrapped!( - self.get_dxc_create_instance()?( - &CLSID_DxcLibrary, - &IID_IDxcLibrary, - library.as_mut_ptr(), - ), - DxcLibrary::new(library) - ) + + self.get_dxc_create_instance()?(&CLSID_DxcLibrary, &IID_IDxcLibrary, library.as_mut_ptr()) + .result()?; + Ok(DxcLibrary::new(library)) } } @@ -590,29 +551,22 @@ impl DxcValidator { Self { inner } } - pub fn version(&self) -> Result { + pub fn version(&self) -> Result { let mut version: ComPtr = ComPtr::new(); - let result_hr: HRESULT = unsafe { + HRESULT::from(unsafe { self.inner .query_interface(&IID_IDxcVersionInfo, version.as_mut_ptr()) - } - .into(); - - if result_hr.is_err() { - return Err(result_hr); - } + }) + .result()?; let mut major = 0; let mut minor = 0; - check_hr! { - unsafe { version.get_version(&mut major, &mut minor) }, - (major, minor) - } + unsafe { version.get_version(&mut major, &mut minor) }.result_with_success((major, minor)) } - pub fn validate(&self, blob: DxcBlob) -> Result { + pub fn validate(&self, blob: DxcBlob) -> Result { let mut result: ComPtr = ComPtr::new(); let result_hr = unsafe { self.inner.validate( @@ -628,7 +582,10 @@ impl DxcValidator { if !result_hr.is_err() && !status_hr.is_err() && validate_status == 0 { Ok(blob) } else { - Err((DxcOperationResult::new(result), result_hr)) + Err(( + DxcOperationResult::new(result), + HassleError::Win32Error(result_hr), + )) } } } @@ -640,7 +597,7 @@ pub struct Dxil { impl Dxil { #[cfg(not(windows))] - pub fn new(_: Option) -> Result { + pub fn new(_: Option) -> Result { Err(HassleError::WindowsOnly( "DXIL Signing is only supported on Windows".to_string(), )) @@ -649,7 +606,7 @@ impl Dxil { /// `dxil_path` can point to a library directly or the directory containing the library, /// in which case `dxil.dll` is appended. #[cfg(windows)] - pub fn new(lib_path: Option) -> Result { + pub fn new(lib_path: Option) -> Result { let lib_path = if let Some(lib_path) = lib_path { if lib_path.is_file() { lib_path @@ -669,19 +626,19 @@ impl Dxil { Ok(Self { dxil_lib }) } - fn get_dxc_create_instance(&self) -> Result, HassleError> { + fn get_dxc_create_instance(&self) -> Result> { Ok(unsafe { self.dxil_lib.get(b"DxcCreateInstance\0")? }) } - pub fn create_validator(&self) -> Result { + pub fn create_validator(&self) -> Result { let mut validator: ComPtr = ComPtr::new(); - check_hr_wrapped!( - self.get_dxc_create_instance()?( - &CLSID_DxcValidator, - &IID_IDxcValidator, - validator.as_mut_ptr(), - ), - DxcValidator::new(validator) + + self.get_dxc_create_instance()?( + &CLSID_DxcValidator, + &IID_IDxcValidator, + validator.as_mut_ptr(), ) + .result()?; + Ok(DxcValidator::new(validator)) } }