Skip to content

Commit

Permalink
os: Implement newtype for HRESULT with must_use and hexadecimal format (
Browse files Browse the repository at this point in the history
#29)

Unify the use of `com_rs::HResult` into a newtyped `HRESULT` that
allows us to place a `#[must_use]` annotation and in turn prevent
missing errors from the API, which could lead to worse Undefined
Behaviour.

In addition we can now use hexadecimal formatting for the `Display` and
`Debug` traits making them much easier to read in `unwrap()` and similar
scenarios.
  • Loading branch information
MarijnS95 authored Jan 25, 2022
1 parent db22c06 commit ae6a5d1
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 43 deletions.
38 changes: 38 additions & 0 deletions src/os.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,41 @@ mod os_defs {
}

pub use os_defs::*;

#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[repr(transparent)]
#[must_use]
pub struct HRESULT(pub os_defs::HRESULT);
impl HRESULT {
pub fn is_err(&self) -> bool {
self.0 < 0
}
}

impl From<i32> for HRESULT {
fn from(v: i32) -> Self {
Self(v)
}
}

impl std::fmt::Debug for HRESULT {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
<Self as std::fmt::Display>::fmt(self, f)
}
}

impl std::fmt::Display for HRESULT {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("{:#x}", self))
}
}

impl std::fmt::LowerHex for HRESULT {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let prefix = if f.alternate() { "0x" } else { "" };
let bare_hex = format!("{:x}", self.0.abs());
// https://stackoverflow.com/a/44712309
f.pad_integral(self.0 >= 0, prefix, &bare_hex)
// <i32 as std::fmt::LowerHex>::fmt(&self.0, f)
}
}
10 changes: 7 additions & 3 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ impl DxcIncludeHandler for DefaultIncludeHandler {

#[derive(Error, Debug)]
pub enum HassleError {
#[error("Win32 error: {0:X}")]
#[error("Win32 error: {0:x}")]
Win32Error(HRESULT),
#[error("{0}")]
CompileError(String),
Expand Down Expand Up @@ -117,7 +117,9 @@ pub fn compile_hlsl(
.get_error_buffer()
.map_err(HassleError::Win32Error)?;
Err(HassleError::CompileError(
library.get_blob_as_string(&error_blob),
library
.get_blob_as_string(&error_blob)
.map_err(HassleError::Win32Error)?,
))
}
Ok(result) => {
Expand Down Expand Up @@ -152,7 +154,9 @@ pub fn validate_dxil(data: &[u8]) -> Result<Vec<u8>, HassleError> {
.get_error_buffer()
.map_err(HassleError::Win32Error)?;
Err(HassleError::ValidationError(
library.get_blob_as_string(&error_blob),
library
.get_blob_as_string(&error_blob)
.map_err(HassleError::Win32Error)?,
))
}
}
Expand Down
78 changes: 38 additions & 40 deletions src/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use std::pin::Pin;
macro_rules! check_hr {
($hr:expr, $v: expr) => {{
let hr = $hr;
if hr == 0 {
if !hr.is_err() {
Ok($v)
} else {
Err(hr)
Expand All @@ -28,7 +28,7 @@ macro_rules! check_hr {
macro_rules! check_hr_wrapped {
($hr:expr, $v: expr) => {{
let hr = $hr;
if hr == 0 {
if !hr.is_err() {
Ok($v)
} else {
Err(HassleError::Win32Error(hr))
Expand Down Expand Up @@ -143,15 +143,14 @@ struct DxcIncludeHandlerWrapperVtbl {
*const com_rs::IUnknown,
&com_rs::IID,
*mut *mut core::ffi::c_void,
) -> com_rs::HResult,
) -> HRESULT,
add_ref: extern "system" fn(*const com_rs::IUnknown) -> HRESULT,
release: extern "system" fn(*const com_rs::IUnknown) -> HRESULT,
#[cfg(not(windows))]
complete_object_destructor: extern "system" fn(*const com_rs::IUnknown) -> HRESULT,
#[cfg(not(windows))]
deleting_destructor: extern "system" fn(*const com_rs::IUnknown) -> HRESULT,
load_source:
extern "system" fn(*mut com_rs::IUnknown, LPCWSTR, *mut *mut IDxcBlob) -> com_rs::HResult,
load_source: extern "system" fn(*mut com_rs::IUnknown, LPCWSTR, *mut *mut IDxcBlob) -> HRESULT,
}

#[repr(C)]
Expand All @@ -167,19 +166,19 @@ impl<'a, 'i> DxcIncludeHandlerWrapper<'a, 'i> {
_me: *const com_rs::IUnknown,
_rrid: &com_rs::IID,
_ppv_obj: *mut *mut core::ffi::c_void,
) -> com_rs::HResult {
0 // dummy impl
) -> HRESULT {
HRESULT(0) // dummy impl
}

extern "system" fn dummy(_me: *const com_rs::IUnknown) -> HRESULT {
0 // dummy impl
HRESULT(0) // dummy impl
}

extern "system" fn load_source(
me: *mut com_rs::IUnknown,
filename: LPCWSTR,
include_source: *mut *mut IDxcBlob,
) -> com_rs::HResult {
) -> HRESULT {
let me = me as *mut DxcIncludeHandlerWrapper;

let filename = crate::utils::from_wide(filename as *mut _);
Expand All @@ -205,6 +204,7 @@ impl<'a, 'i> DxcIncludeHandlerWrapper<'a, 'i> {
} else {
-2_147_024_894 // ERROR_FILE_NOT_FOUND / 0x80070002
}
.into()
}
}

Expand Down Expand Up @@ -316,11 +316,9 @@ impl DxcCompiler {
};

let mut compile_error = 0u32;
unsafe {
result.get_status(&mut compile_error);
}
let status_hr = unsafe { result.get_status(&mut compile_error) };

if result_hr == 0 && compile_error == 0 {
if !result_hr.is_err() && !status_hr.is_err() && compile_error == 0 {
Ok(DxcOperationResult::new(result))
} else {
Err((DxcOperationResult::new(result), result_hr))
Expand Down Expand Up @@ -371,11 +369,9 @@ impl DxcCompiler {
};

let mut compile_error = 0u32;
unsafe {
result.get_status(&mut compile_error);
}
let status_hr = unsafe { result.get_status(&mut compile_error) };

if result_hr == 0 && compile_error == 0 {
if !result_hr.is_err() && !status_hr.is_err() && compile_error == 0 {
Ok((
DxcOperationResult::new(result),
from_wide(debug_filename),
Expand Down Expand Up @@ -421,11 +417,9 @@ impl DxcCompiler {
};

let mut compile_error = 0u32;
unsafe {
result.get_status(&mut compile_error);
}
let status_hr = unsafe { result.get_status(&mut compile_error) };

if result_hr == 0 && compile_error == 0 {
if !result_hr.is_err() && !status_hr.is_err() && compile_error == 0 {
Ok(DxcOperationResult::new(result))
} else {
Err((DxcOperationResult::new(result), result_hr))
Expand Down Expand Up @@ -489,22 +483,25 @@ impl DxcLibrary {
)
}

pub fn get_blob_as_string(&self, blob: &DxcBlobEncoding) -> String {
pub fn get_blob_as_string(&self, blob: &DxcBlobEncoding) -> Result<String, HRESULT> {
let mut blob_utf8: ComPtr<IDxcBlobEncoding> = ComPtr::new();

unsafe {
self.inner
.get_blob_as_utf8(blob.inner.as_ptr(), blob_utf8.as_mut_ptr())
};

let slice = unsafe {
std::slice::from_raw_parts(
blob_utf8.get_buffer_pointer() as *const u8,
blob_utf8.get_buffer_size(),
)
};

String::from_utf8(slice.to_vec()).unwrap()
check_hr!(
unsafe {
self.inner
.get_blob_as_utf8(blob.inner.as_ptr(), blob_utf8.as_mut_ptr())
},
{
let slice = unsafe {
std::slice::from_raw_parts(
blob_utf8.get_buffer_pointer() as *const u8,
blob_utf8.get_buffer_size(),
)
};

String::from_utf8(slice.to_vec()).unwrap()
}
)
}
}

Expand Down Expand Up @@ -596,12 +593,13 @@ impl DxcValidator {
pub fn version(&self) -> Result<DxcValidatorVersion, HRESULT> {
let mut version: ComPtr<IDxcVersionInfo> = ComPtr::new();

let result_hr = unsafe {
let result_hr: HRESULT = unsafe {
self.inner
.query_interface(&IID_IDxcVersionInfo, version.as_mut_ptr())
};
}
.into();

if result_hr != 0 {
if result_hr.is_err() {
return Err(result_hr);
}

Expand All @@ -625,9 +623,9 @@ impl DxcValidator {
};

let mut validate_status = 0u32;
unsafe { result.get_status(&mut validate_status) };
let status_hr = unsafe { result.get_status(&mut validate_status) };

if result_hr == 0 && validate_status == 0 {
if !result_hr.is_err() && !status_hr.is_err() && validate_status == 0 {
Ok(blob)
} else {
Err((DxcOperationResult::new(result), result_hr))
Expand Down

0 comments on commit ae6a5d1

Please sign in to comment.