diff --git a/drivers/iommu/iommufd/device.c b/drivers/iommu/iommufd/device.c index 3214a4c17c6b..bcbcd6d94062 100644 --- a/drivers/iommu/iommufd/device.c +++ b/drivers/iommu/iommufd/device.c @@ -327,8 +327,9 @@ static int iommufd_group_setup_msi(struct iommufd_group *igroup, return 0; } -static int iommufd_hwpt_paging_attach(struct iommufd_hwpt_paging *hwpt_paging, - struct iommufd_device *idev) +static int +iommufd_device_attach_reserved_iova(struct iommufd_device *idev, + struct iommufd_hwpt_paging *hwpt_paging) { int rc; @@ -354,6 +355,7 @@ static int iommufd_hwpt_paging_attach(struct iommufd_hwpt_paging *hwpt_paging, int iommufd_hw_pagetable_attach(struct iommufd_hw_pagetable *hwpt, struct iommufd_device *idev) { + struct iommufd_hwpt_paging *hwpt_paging = find_hwpt_paging(hwpt); int rc; mutex_lock(&idev->igroup->lock); @@ -363,8 +365,8 @@ int iommufd_hw_pagetable_attach(struct iommufd_hw_pagetable *hwpt, goto err_unlock; } - if (hwpt_is_paging(hwpt)) { - rc = iommufd_hwpt_paging_attach(to_hwpt_paging(hwpt), idev); + if (hwpt_paging) { + rc = iommufd_device_attach_reserved_iova(idev, hwpt_paging); if (rc) goto err_unlock; } @@ -387,9 +389,8 @@ int iommufd_hw_pagetable_attach(struct iommufd_hw_pagetable *hwpt, mutex_unlock(&idev->igroup->lock); return 0; err_unresv: - if (hwpt_is_paging(hwpt)) - iopt_remove_reserved_iova(&to_hwpt_paging(hwpt)->ioas->iopt, - idev->dev); + if (hwpt_paging) + iopt_remove_reserved_iova(&hwpt_paging->ioas->iopt, idev->dev); err_unlock: mutex_unlock(&idev->igroup->lock); return rc; @@ -399,6 +400,7 @@ struct iommufd_hw_pagetable * iommufd_hw_pagetable_detach(struct iommufd_device *idev) { struct iommufd_hw_pagetable *hwpt = idev->igroup->hwpt; + struct iommufd_hwpt_paging *hwpt_paging = find_hwpt_paging(hwpt); mutex_lock(&idev->igroup->lock); list_del(&idev->group_item); @@ -406,9 +408,8 @@ iommufd_hw_pagetable_detach(struct iommufd_device *idev) iommufd_hwpt_detach_device(hwpt, idev); idev->igroup->hwpt = NULL; } - if (hwpt_is_paging(hwpt)) - iopt_remove_reserved_iova(&to_hwpt_paging(hwpt)->ioas->iopt, - idev->dev); + if (hwpt_paging) + iopt_remove_reserved_iova(&hwpt_paging->ioas->iopt, idev->dev); mutex_unlock(&idev->igroup->lock); /* Caller must destroy hwpt */ @@ -440,17 +441,17 @@ iommufd_group_remove_reserved_iova(struct iommufd_group *igroup, } static int -iommufd_group_do_replace_paging(struct iommufd_group *igroup, - struct iommufd_hwpt_paging *hwpt_paging) +iommufd_group_do_replace_reserved_iova(struct iommufd_group *igroup, + struct iommufd_hwpt_paging *hwpt_paging) { - struct iommufd_hw_pagetable *old_hwpt = igroup->hwpt; + struct iommufd_hwpt_paging *old_hwpt_paging; struct iommufd_device *cur; int rc; lockdep_assert_held(&igroup->lock); - if (!hwpt_is_paging(old_hwpt) || - hwpt_paging->ioas != to_hwpt_paging(old_hwpt)->ioas) { + old_hwpt_paging = find_hwpt_paging(igroup->hwpt); + if (!old_hwpt_paging || hwpt_paging->ioas != old_hwpt_paging->ioas) { list_for_each_entry(cur, &igroup->device_list, group_item) { rc = iopt_table_enforce_dev_resv_regions( &hwpt_paging->ioas->iopt, cur->dev, NULL); @@ -473,6 +474,8 @@ static struct iommufd_hw_pagetable * iommufd_device_do_replace(struct iommufd_device *idev, struct iommufd_hw_pagetable *hwpt) { + struct iommufd_hwpt_paging *hwpt_paging = find_hwpt_paging(hwpt); + struct iommufd_hwpt_paging *old_hwpt_paging; struct iommufd_group *igroup = idev->igroup; struct iommufd_hw_pagetable *old_hwpt; unsigned int num_devices; @@ -491,9 +494,8 @@ iommufd_device_do_replace(struct iommufd_device *idev, } old_hwpt = igroup->hwpt; - if (hwpt_is_paging(hwpt)) { - rc = iommufd_group_do_replace_paging(igroup, - to_hwpt_paging(hwpt)); + if (hwpt_paging) { + rc = iommufd_group_do_replace_reserved_iova(igroup, hwpt_paging); if (rc) goto err_unlock; } @@ -502,11 +504,10 @@ iommufd_device_do_replace(struct iommufd_device *idev, if (rc) goto err_unresv; - if (hwpt_is_paging(old_hwpt) && - (!hwpt_is_paging(hwpt) || - to_hwpt_paging(hwpt)->ioas != to_hwpt_paging(old_hwpt)->ioas)) - iommufd_group_remove_reserved_iova(igroup, - to_hwpt_paging(old_hwpt)); + old_hwpt_paging = find_hwpt_paging(old_hwpt); + if (old_hwpt_paging && + (!hwpt_paging || hwpt_paging->ioas != old_hwpt_paging->ioas)) + iommufd_group_remove_reserved_iova(igroup, old_hwpt_paging); igroup->hwpt = hwpt; @@ -524,9 +525,8 @@ iommufd_device_do_replace(struct iommufd_device *idev, /* Caller must destroy old_hwpt */ return old_hwpt; err_unresv: - if (hwpt_is_paging(hwpt)) - iommufd_group_remove_reserved_iova(igroup, - to_hwpt_paging(hwpt)); + if (hwpt_paging) + iommufd_group_remove_reserved_iova(igroup, hwpt_paging); err_unlock: mutex_unlock(&idev->igroup->lock); return ERR_PTR(rc); diff --git a/drivers/iommu/iommufd/iommufd_private.h b/drivers/iommu/iommufd/iommufd_private.h index 6eed84674919..40d6a6badc28 100644 --- a/drivers/iommu/iommufd/iommufd_private.h +++ b/drivers/iommu/iommufd/iommufd_private.h @@ -304,6 +304,25 @@ to_hwpt_paging(struct iommufd_hw_pagetable *hwpt) return container_of(hwpt, struct iommufd_hwpt_paging, common); } +static inline struct iommufd_hwpt_nested * +to_hwpt_nested(struct iommufd_hw_pagetable *hwpt) +{ + return container_of(hwpt, struct iommufd_hwpt_nested, common); +} + +static inline struct iommufd_hwpt_paging * +find_hwpt_paging(struct iommufd_hw_pagetable *hwpt) +{ + switch (hwpt->obj.type) { + case IOMMUFD_OBJ_HWPT_PAGING: + return to_hwpt_paging(hwpt); + case IOMMUFD_OBJ_HWPT_NESTED: + return to_hwpt_nested(hwpt)->parent; + default: + return NULL; + } +} + static inline struct iommufd_hwpt_paging * iommufd_get_hwpt_paging(struct iommufd_ucmd *ucmd, u32 id) {