From 154ab2b83450facd7106d4b32759a9bdfe3562b6 Mon Sep 17 00:00:00 2001 From: Joel Smith Date: Wed, 29 Jan 2025 17:16:02 +0000 Subject: [PATCH] Add flag to change caching attribute when pinning * This only works on X86. * This is for experimental purposes only. --- ioctl.h | 2 ++ memory.c | 22 ++++++++++++++++++++-- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/ioctl.h b/ioctl.h index d56e1f6..33bc8a8 100644 --- a/ioctl.h +++ b/ioctl.h @@ -144,6 +144,8 @@ struct tenstorrent_reset_device { // tenstorrent_pin_pages_in.flags #define TENSTORRENT_PIN_PAGES_CONTIGUOUS 1 // app attests that the pages are physically contiguous +#define TENSTORRENT_PIN_PAGES_WC 2 +#define TENSTORRENT_PIN_PAGES_UC 4 struct tenstorrent_pin_pages_in { __u32 output_size_bytes; diff --git a/memory.c b/memory.c index a9d7c22..d9110ad 100644 --- a/memory.c +++ b/memory.c @@ -13,6 +13,10 @@ #include #include +#if CONFIG_X86 +#include +#endif + #include "chardev_private.h" #include "device.h" #include "memory.h" @@ -397,8 +401,8 @@ long ioctl_pin_pages(struct chardev_private *priv, if (!is_pin_pages_size_safe(in.size)) return -EINVAL; - if (in.flags != 0 && in.flags != TENSTORRENT_PIN_PAGES_CONTIGUOUS) - return -EINVAL; + // if (in.flags != 0 && in.flags != TENSTORRENT_PIN_PAGES_CONTIGUOUS) + // return -EINVAL; pinning = kmalloc(sizeof(*pinning), GFP_KERNEL); if (!pinning) @@ -490,6 +494,20 @@ long ioctl_pin_pages(struct chardev_private *priv, pinning->virtual_address = in.virtual_address; list_add(&pinning->list, &priv->pinnings); + +#if CONFIG_X86 + set_pages_array_wb(pages, nr_pages); + if (in.flags & TENSTORRENT_PIN_PAGES_WC) { + int r = set_pages_array_wc(pages, nr_pages); + if (r != 0) + pr_warn("set_pages_array_wc failed: %d (%lu pages)\n", r, nr_pages); + } else if (in.flags & TENSTORRENT_PIN_PAGES_UC) { + int r = set_pages_array_uc(pages, nr_pages); + if (r != 0) + pr_warn("set_pages_array_uc failed: %d (%lu pages)\n", r, nr_pages); + } +#endif + mutex_unlock(&priv->mutex); if (clear_user(&arg->out, in.output_size_bytes) != 0)