Skip to content

Commit

Permalink
Add Deref implementation for HSTRING (#3291)
Browse files Browse the repository at this point in the history
  • Loading branch information
kennykerr authored Sep 23, 2024
1 parent bca9a76 commit 0f7466c
Show file tree
Hide file tree
Showing 10 changed files with 134 additions and 133 deletions.
32 changes: 11 additions & 21 deletions crates/libs/result/src/bstr.rs
Original file line number Diff line number Diff line change
@@ -1,36 +1,26 @@
use super::*;
use core::ops::Deref;

#[repr(transparent)]
pub struct BasicString(*const u16);

impl BasicString {
pub fn is_empty(&self) -> bool {
self.len() == 0
}
impl Deref for BasicString {
type Target = [u16];

pub fn len(&self) -> usize {
if self.0.is_null() {
fn deref(&self) -> &[u16] {
let len = if self.0.is_null() {
0
} else {
unsafe { SysStringLen(self.0) as usize }
}
}

pub fn as_wide(&self) -> &[u16] {
let len = self.len();
if len != 0 {
unsafe { core::slice::from_raw_parts(self.as_ptr(), len) }
} else {
&[]
}
}
};

pub fn as_ptr(&self) -> *const u16 {
if !self.is_empty() {
self.0
if len > 0 {
unsafe { core::slice::from_raw_parts(self.0, len) }
} else {
// This ensures that if `as_ptr` is called on the slice that the resulting pointer
// will still refer to a null-terminated string.
const EMPTY: [u16; 1] = [0];
EMPTY.as_ptr()
&EMPTY[..0]
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/libs/result/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ mod error_info {
}
}

Some(String::from_utf16_lossy(wide_trim_end(message.as_wide())))
Some(String::from_utf16_lossy(wide_trim_end(&message)))
}

pub(crate) fn as_ptr(&self) -> *mut core::ffi::c_void {
Expand Down
64 changes: 27 additions & 37 deletions crates/libs/strings/src/bstr.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::*;
use core::ops::Deref;

/// A BSTR string ([BSTR](https://learn.microsoft.com/en-us/previous-versions/windows/desktop/automat/string-manipulation-functions))
/// is a length-prefixed wide string.
Expand All @@ -13,35 +14,6 @@ impl BSTR {
Self(core::ptr::null_mut())
}

/// Returns `true` if the string is empty.
pub fn is_empty(&self) -> bool {
self.len() == 0
}

/// Returns the length of the string.
pub fn len(&self) -> usize {
if self.0.is_null() {
0
} else {
unsafe { bindings::SysStringLen(self.0) as usize }
}
}

/// Get the string as 16-bit wide characters (wchars).
pub fn as_wide(&self) -> &[u16] {
unsafe { core::slice::from_raw_parts(self.as_ptr(), self.len()) }
}

/// Returns a raw pointer to the `BSTR` buffer.
pub fn as_ptr(&self) -> *const u16 {
if !self.is_empty() {
self.0
} else {
const EMPTY: [u16; 1] = [0];
EMPTY.as_ptr()
}
}

/// Create a `BSTR` from a slice of 16 bit characters (wchars).
pub fn from_wide(value: &[u16]) -> Self {
if value.is_empty() {
Expand Down Expand Up @@ -75,9 +47,30 @@ impl BSTR {
}
}

impl Deref for BSTR {
type Target = [u16];

fn deref(&self) -> &[u16] {
let len = if self.0.is_null() {
0
} else {
unsafe { bindings::SysStringLen(self.0) as usize }
};

if len > 0 {
unsafe { core::slice::from_raw_parts(self.0, len) }
} else {
// This ensures that if `as_ptr` is called on the slice that the resulting pointer
// will still refer to a null-terminated string.
const EMPTY: [u16; 1] = [0];
&EMPTY[..0]
}
}
}

impl Clone for BSTR {
fn clone(&self) -> Self {
Self::from_wide(self.as_wide())
Self::from_wide(self)
}
}

Expand All @@ -104,7 +97,7 @@ impl<'a> TryFrom<&'a BSTR> for String {
type Error = alloc::string::FromUtf16Error;

fn try_from(value: &BSTR) -> core::result::Result<Self, Self::Error> {
String::from_utf16(value.as_wide())
String::from_utf16(value)
}
}

Expand All @@ -127,7 +120,7 @@ impl core::fmt::Display for BSTR {
core::write!(
f,
"{}",
Decode(|| core::char::decode_utf16(self.as_wide().iter().cloned()))
Decode(|| core::char::decode_utf16(self.iter().cloned()))
)
}
}
Expand All @@ -140,7 +133,7 @@ impl core::fmt::Debug for BSTR {

impl PartialEq for BSTR {
fn eq(&self, other: &Self) -> bool {
self.as_wide() == other.as_wide()
self.deref() == other.deref()
}
}

Expand All @@ -160,10 +153,7 @@ impl PartialEq<BSTR> for String {

impl<T: AsRef<str> + ?Sized> PartialEq<T> for BSTR {
fn eq(&self, other: &T) -> bool {
self.as_wide()
.iter()
.copied()
.eq(other.as_ref().encode_utf16())
self.iter().copied().eq(other.as_ref().encode_utf16())
}
}

Expand Down
65 changes: 25 additions & 40 deletions crates/libs/strings/src/hstring.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::*;
use core::ops::Deref;

/// An ([HSTRING](https://docs.microsoft.com/en-us/windows/win32/winrt/hstring))
/// is a reference-counted and immutable UTF-16 string type.
Expand All @@ -13,50 +14,20 @@ impl HSTRING {
Self(core::ptr::null_mut())
}

/// Returns `true` if the string is empty.
pub fn is_empty(&self) -> bool {
// An empty HSTRING is represented by a null pointer.
self.0.is_null()
}

/// Returns the length of the string. The length is measured in `u16`s (UTF-16 code units), not including the terminating null character.
pub fn len(&self) -> usize {
if let Some(header) = self.as_header() {
header.len as usize
} else {
0
}
}

/// Get the string as 16-bit wide characters (wchars).
pub fn as_wide(&self) -> &[u16] {
unsafe { core::slice::from_raw_parts(self.as_ptr(), self.len()) }
}

/// Returns a raw pointer to the `HSTRING` buffer.
pub fn as_ptr(&self) -> *const u16 {
if let Some(header) = self.as_header() {
header.data
} else {
const EMPTY: [u16; 1] = [0];
EMPTY.as_ptr()
}
}

/// Create a `HSTRING` from a slice of 16 bit characters (wchars).
pub fn from_wide(value: &[u16]) -> Self {
unsafe { Self::from_wide_iter(value.iter().copied(), value.len()) }
}

/// Get the contents of this `HSTRING` as a String lossily.
pub fn to_string_lossy(&self) -> String {
String::from_utf16_lossy(self.as_wide())
String::from_utf16_lossy(self)
}

/// Get the contents of this `HSTRING` as a OsString.
#[cfg(feature = "std")]
pub fn to_os_string(&self) -> std::ffi::OsString {
std::os::windows::ffi::OsStringExt::from_wide(self.as_wide())
std::os::windows::ffi::OsStringExt::from_wide(self)
}

/// # Safety
Expand Down Expand Up @@ -87,6 +58,21 @@ impl HSTRING {
}
}

impl Deref for HSTRING {
type Target = [u16];

fn deref(&self) -> &[u16] {
if let Some(header) = self.as_header() {
unsafe { core::slice::from_raw_parts(header.data, header.len as usize) }
} else {
// This ensures that if `as_ptr` is called on the slice that the resulting pointer
// will still refer to a null-terminated string.
const EMPTY: [u16; 1] = [0];
&EMPTY[..0]
}
}
}

impl Default for HSTRING {
fn default() -> Self {
Self::new()
Expand Down Expand Up @@ -125,7 +111,7 @@ impl core::fmt::Display for HSTRING {
write!(
f,
"{}",
Decode(|| core::char::decode_utf16(self.as_wide().iter().cloned()))
Decode(|| core::char::decode_utf16(self.iter().cloned()))
)
}
}
Expand Down Expand Up @@ -191,13 +177,13 @@ impl Eq for HSTRING {}

impl Ord for HSTRING {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
self.as_wide().cmp(other.as_wide())
self.deref().cmp(other)
}
}

impl core::hash::Hash for HSTRING {
fn hash<H: core::hash::Hasher>(&self, hasher: &mut H) {
self.as_wide().hash(hasher)
self.deref().hash(hasher)
}
}

Expand All @@ -209,7 +195,7 @@ impl PartialOrd for HSTRING {

impl PartialEq for HSTRING {
fn eq(&self, other: &Self) -> bool {
*self.as_wide() == *other.as_wide()
self.deref() == other.deref()
}
}

Expand All @@ -233,7 +219,7 @@ impl PartialEq<&String> for HSTRING {

impl PartialEq<str> for HSTRING {
fn eq(&self, other: &str) -> bool {
self.as_wide().iter().copied().eq(other.encode_utf16())
self.iter().copied().eq(other.encode_utf16())
}
}

Expand Down Expand Up @@ -309,8 +295,7 @@ impl PartialEq<&std::ffi::OsString> for HSTRING {
#[cfg(feature = "std")]
impl PartialEq<std::ffi::OsStr> for HSTRING {
fn eq(&self, other: &std::ffi::OsStr) -> bool {
self.as_wide()
.iter()
self.iter()
.copied()
.eq(std::os::windows::ffi::OsStrExt::encode_wide(other))
}
Expand Down Expand Up @@ -376,7 +361,7 @@ impl<'a> TryFrom<&'a HSTRING> for String {
type Error = alloc::string::FromUtf16Error;

fn try_from(hstring: &HSTRING) -> core::result::Result<Self, Self::Error> {
String::from_utf16(hstring.as_wide())
String::from_utf16(hstring)
}
}

Expand Down
4 changes: 1 addition & 3 deletions crates/tests/misc/literals/tests/win.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ fn test() {
fn into() {
let a = h!("");
assert!(a.is_empty());
assert!(!a.as_ptr().is_null());
assert!(a.as_wide().is_empty());
let b = PCWSTR(a.as_ptr());
// Even though an empty HSTRING is internally represented by a null pointer, the PCWSTR
// will still be a non-null pointer to a null terminated empty string.
Expand Down Expand Up @@ -80,7 +78,7 @@ fn assert_hstring(left: &HSTRING, right: &[u16]) {
unsafe { wcslen(PCWSTR::from_raw(left.as_ptr())) },
right.len() - 1
);
let left = unsafe { std::slice::from_raw_parts(left.as_wide().as_ptr(), right.len()) };
let left = unsafe { std::slice::from_raw_parts(left.as_ptr(), right.len()) };
assert_eq!(left, right);
}

Expand Down
18 changes: 1 addition & 17 deletions crates/tests/misc/string_param/tests/pwstr.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,4 @@
use windows::{core::*, Win32::Foundation::*, Win32::UI::Shell::*};

#[test]
fn error() {
unsafe {
SetLastError(ERROR_BUSY_DRIVE);

let utf8 = "test\0".as_bytes();
let utf16 = HSTRING::from("test\0");
let utf16 = utf16.as_wide();
let len = 5;
assert_eq!(utf8.len(), len);
assert_eq!(utf16.len(), len);

assert_eq!(GetLastError(), ERROR_BUSY_DRIVE);
}
}
use windows::{core::*, Win32::UI::Shell::*};

#[test]
fn convert() {
Expand Down
Loading

0 comments on commit 0f7466c

Please sign in to comment.