Skip to content

Commit

Permalink
fix: allocation device name lifetime
Browse files Browse the repository at this point in the history
so that was a lie
  • Loading branch information
decahedron1 committed Jan 13, 2025
1 parent 9bada32 commit 3ca14c2
Showing 1 changed file with 18 additions and 19 deletions.
37 changes: 18 additions & 19 deletions src/memory.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Types for managing memory & device allocations.
use std::{
ffi::{CString, c_char, c_int, c_void},
ffi::{c_char, c_int, c_void},
mem,
ptr::NonNull,
sync::Arc
Expand Down Expand Up @@ -248,22 +248,22 @@ impl Drop for AllocatedBlock<'_> {
pub struct AllocationDevice(&'static str);

impl AllocationDevice {
pub const CPU: AllocationDevice = AllocationDevice("Cpu");
pub const CUDA: AllocationDevice = AllocationDevice("Cuda");
pub const CUDA_PINNED: AllocationDevice = AllocationDevice("CudaPinned");
pub const CANN: AllocationDevice = AllocationDevice("Cann");
pub const CANN_PINNED: AllocationDevice = AllocationDevice("CannPinned");
pub const DIRECTML: AllocationDevice = AllocationDevice("DML");
pub const DIRECTML_CPU: AllocationDevice = AllocationDevice("DML CPU");
pub const HIP: AllocationDevice = AllocationDevice("Hip");
pub const HIP_PINNED: AllocationDevice = AllocationDevice("HipPinned");
pub const OPENVINO_CPU: AllocationDevice = AllocationDevice("OpenVINO_CPU");
pub const OPENVINO_GPU: AllocationDevice = AllocationDevice("OpenVINO_GPU");
pub const XNNPACK: AllocationDevice = AllocationDevice("XnnpackExecutionProvider");
pub const TVM: AllocationDevice = AllocationDevice("TVM");
pub const CPU: AllocationDevice = AllocationDevice("Cpu\0");
pub const CUDA: AllocationDevice = AllocationDevice("Cuda\0");
pub const CUDA_PINNED: AllocationDevice = AllocationDevice("CudaPinned\0");
pub const CANN: AllocationDevice = AllocationDevice("Cann\0");
pub const CANN_PINNED: AllocationDevice = AllocationDevice("CannPinned\0");
pub const DIRECTML: AllocationDevice = AllocationDevice("DML\0");
pub const DIRECTML_CPU: AllocationDevice = AllocationDevice("DML CPU\0");
pub const HIP: AllocationDevice = AllocationDevice("Hip\0");
pub const HIP_PINNED: AllocationDevice = AllocationDevice("HipPinned\0");
pub const OPENVINO_CPU: AllocationDevice = AllocationDevice("OpenVINO_CPU\0");
pub const OPENVINO_GPU: AllocationDevice = AllocationDevice("OpenVINO_GPU\0");
pub const XNNPACK: AllocationDevice = AllocationDevice("XnnpackExecutionProvider\0");
pub const TVM: AllocationDevice = AllocationDevice("TVM\0");

pub fn as_str(&self) -> &'static str {
self.0
&self.0[..self.0.len() - 1]
}
}

Expand Down Expand Up @@ -390,9 +390,8 @@ impl MemoryInfo {
/// ```
pub fn new(allocation_device: AllocationDevice, device_id: c_int, allocator_type: AllocatorType, memory_type: MemoryType) -> Result<Self> {
let mut memory_info_ptr: *mut ort_sys::OrtMemoryInfo = std::ptr::null_mut();
let allocator_name = CString::new(allocation_device.as_str()).unwrap_or_else(|_| unreachable!());
ortsys![
unsafe CreateMemoryInfo(allocator_name.as_ptr(), allocator_type.into(), device_id, memory_type.into(), &mut memory_info_ptr)?;
unsafe CreateMemoryInfo(allocation_device.as_str().as_ptr().cast(), allocator_type.into(), device_id, memory_type.into(), &mut memory_info_ptr)?;
nonNull(memory_info_ptr)
];
Ok(Self {
Expand Down Expand Up @@ -470,7 +469,7 @@ impl MemoryInfo {
ortsys![unsafe MemoryInfoGetName(self.ptr.as_ptr(), &mut name_ptr)];

// SAFETY: `name_ptr` can never be null - `CreateMemoryInfo` internally checks against builtin device names, erroring
// if a non-builtin device is passed, and ONNX Runtime will never supply a pointer to the C++ constructor
// if a non-builtin device is passed

let mut len = 0;
while unsafe { *name_ptr.add(len) } != 0x00 {
Expand All @@ -479,7 +478,7 @@ impl MemoryInfo {

// SAFETY: ONNX Runtime internally only ever defines allocation device names as ASCII. can't wait for this to blow up
// one day regardless
let name = unsafe { std::str::from_utf8_unchecked(std::slice::from_raw_parts(name_ptr.cast::<u8>(), len)) };
let name = unsafe { std::str::from_utf8_unchecked(std::slice::from_raw_parts(name_ptr.cast::<u8>(), len + 1)) };
AllocationDevice(name)
}

Expand Down

0 comments on commit 3ca14c2

Please sign in to comment.