Skip to content

Commit

Permalink
fix(vm): allow only nvidia and amd gpus (#2362)
Browse files Browse the repository at this point in the history
  • Loading branch information
kmd-fl authored Sep 10, 2024
1 parent 0e3ee8a commit 96cd1f4
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
14 changes: 13 additions & 1 deletion crates/gpu-utils/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ pub enum PciError {
UnsupportedProperty,
}

const AMD_VENDOR_ID: u16 = 0x1002;
const NVIDIA_VENDOR_ID: u16 = 0x10de;

pub fn get_gpu_pci() -> Result<HashSet<PciLocation>, PciError> {
let info = PciInfo::enumerate_pci()?;
// List of GPU devices
Expand All @@ -29,12 +32,14 @@ pub fn get_gpu_pci() -> Result<HashSet<PciLocation>, PciError> {
let device = device?;
let device_class = process_property_result(device.device_class())?;
let device_location = process_property_result(device.location())?;
if device_class == DisplayController {
if device_class == DisplayController && is_vendor_allowed(device.vendor_id()) {
gpu_devices.insert(device_location);
}
pci_devices.insert(device_location, device_class);
}

tracing::info!(target: "gpu-utils", "Found GPU devices: {:?}", gpu_devices);

let result = match get_iommu_groups() {
Ok(iommu_groups) => {
// Find all devices that are in the same IOMMU group as the GPU devices
Expand All @@ -54,9 +59,16 @@ pub fn get_gpu_pci() -> Result<HashSet<PciLocation>, PciError> {
gpu_devices
}
};

tracing::info!(target: "gpu-utils", "Importing PCI devices: {:?}", result);

Ok(result)
}

fn is_vendor_allowed(vendor_id: u16) -> bool {
vendor_id == AMD_VENDOR_ID || vendor_id == NVIDIA_VENDOR_ID
}

// AFAIK the bridge devices are the only non-endpoint devices
// May require to update this function if there are other non-endpoint devices
fn is_endpoint_device(device_class: &PciDeviceClass) -> bool {
Expand Down
1 change: 1 addition & 0 deletions crates/vm-utils/src/vm_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ pub fn create_domain(uri: &str, params: &CreateVMDomainParams) -> Result<(), VmE
tracing::info!(target: "vm-utils","Domain with name {} doesn't exists. Creating", params.name);
// There's certainly better places to do this, but RN it doesn't really matter
let gpu_pci_locations = if params.allow_gpu {
tracing::info!(target: "gpu-utils", "Collecting info about GPU devices...");
gpu_utils::get_gpu_pci()?.into_iter().collect::<Vec<_>>()
} else {
vec![]
Expand Down

0 comments on commit 96cd1f4

Please sign in to comment.