Torch-MLIR Eager Mode + GPU + IREE #8878
Replies: 2 comments 4 replies
-
I think it should "just" be a matter of keeping the DeviceArray on device and only transferring it to host in case of a fallback. |
Beta Was this translation helpful? Give feedback.
-
Correct Sean! (some thoughts here mostly as an overview/explainer and a way to start poking at this usage in more detail, but the TLDR is what sean says :) To use IREE effectively you'll want to treat what inputs/outputs IREE is passed as opaque handles that can be converted (at a cost) to a real host-visible buffer you can manipulate. To be efficient you need to avoid that cost unless you absolutely require it, and even then defer it until the last possible moment. That means that when IREE returns you a DeviceArray you don't convert it to any other form until the absolute last moment and otherwise just pass that back into IREE without modification - the hope being that in a bulk of the cases you never need to convert at all. There's a few reasons to do this - you hit one in that there's a bunch of extra work happening if importing/exporting between each IREE invocation - but the major reason is that as we enable asynchronous execution the handles are how we pipeline things. You can think of them like promises/futures: when you call an IREE function you're getting the promise that a buffer will be available at some point in the future and you can optionally block and wait for it to be available now. The key design point in a system like that is you want to be chaining/joining/forking promises instead of blocking on each one, especially if you don't actually need the result immediately. This is important in full model execution but much more so in op-by-op as the overheads are significantly higher: by going op-by-op you are preventing the compiler from doing any internal pipelining and so the handles need to be... handled... correctly to enable external pipelining. I'm not a python person (and like types), so I think of it like: // IREE-compiled functions:
iree_handle_t async_model_func_a();
iree_handle_t async_model_func_b(iree_handle_t a, iree_handle_t b);
iree_handle_t async_model_func_c(iree_handle_t a);
// Some external code using void*:
void sync_external_func(void* ptr);
// IREE handles remain opaque while moving between IREE stuff:
iree_handle_t t0 = async_model_func_a();
iree_handle_t t1 = async_model_func_a();
iree_handle_t t2 = async_model_func_b(t0, t1);
// Until they need to be used externally; only the handle that needs to be used is exported:
// (this forces a synchronization point, but all the work above can run in parallel/pipelined until here)
void* t3 = iree_export_handle(t2);
sync_external_func(t3);
// And then to use them in IREE we have to import them again:
// (this begins a new synchronization scope, and subsequent work can now use the handles freely)
iree_handle_t t4 = iree_import_handle(t3);
iree_handle_t t5 = async_model_func_c(t4, t0); // note that we never exported t0 In this example the equivalent handle in our python wrapper is DeviceArray - today the impl is missing the async logic but its API is such that it can be added with no user-visible changes in behavior. This means that as long as something works with DeviceArray round-tripping from IREE output to IREE input when we enable async the program will instantly benefit. How you make this interop with existing python stuff is going to vary - you could do things like DeviceArray itself does by implementing the required interfaces with a proxy object and swapping out/shimming the implementation as needed (e.g. https://github.com/google/iree/blob/844e208682a0875016b792cc55239be2efbc1845/bindings/python/iree/runtime/array_interop.py#L77-L82, though note this is prototype code and how its mapping memory isn't great), but there may be other approaches. Another thing that factors in here would be batching: if you can get a list of the DeviceArrays you need to import/export then you can do them as a batch to make things much faster. That's an optimization but if building a system around on-demand transfers having it be able to support a "gather me all the inputs I need to run" design will make it easier to add in the future. Think:
Unfortunately at a glance As for providing output buffers you may not need that here but it's something to watch for. The behavior of transient buffers returned by IREE would give you optimal performance and memory consumption in the (hopefully) common case of IREE->IREE, whereas passing in an output buffer would only help you on transition points of IREE->external and even then may not help you because it ties the compilers' hands. Either way, getting things working with the opaque handles and on-demand transfers will be needed in either approach and is a good place to start - after that specifying output buffers is a possible optimization in some specific cases vs. load-bearing to things working reasonably. I think you'd be able to build an efficient eager executor without it and the cases where it'd benefit are probably not ones where performance matters much anyway as you never want to be running your CPU and GPU in lockstep or op-by-op - it can be ok at model boundaries but op boundaries are too fine grained. HTH and happy to dig into any of the specifics - someone like @stellaraccident would be better to talk about python tricks with though and may have some good pointers based on her research into this area while building DeviceArray. |
Beta Was this translation helpful? Give feedback.
-
I've been working on eager mode for Torch-MLIR. This mode (currently) works by compiling PyTorch models op by op. The current "beta" version only supports CPU tensors and only on
RefBackendLinalgOnTensorsBackend
. In order for this eager mode to be compelling to users, we hope to eventually support GPU (hence this preliminary discussion/exploration).A naive implementation (of extending to GPU) involves merely copying host->device and device->host; on the PyTorch side
and on the IREE side
While straightforward and workable, it's probably suboptimal, particularly when considering that PyTorch eager mode doesn't incur device->host copies unless the tensor is inspected.
The cental challenge (as far as I can see so far) is how to maintain buffers on device in between IREE compilations/invocations, while keeping those buffers accessible to the PyTorch runtime. The reason for this is this eager mode should operate in "fail safe" fashion, i.e., in case of any malfunction, exception, miscompilation, we dispatch to conventional PyTorch eager.
In discussing this on IREE discord, @benvanik informed me that in fact this is in some way supported in that IREE allows for providing output buffers (and that the IREE runtime/compiler interface is HAL buffers in general). In addition he pointed out that some host<->device interop is already implemented in the
DeviceArray
abstraction. With that in mind, I wonder if implementing cuda-array-interface for HAL buffers is the right way to approach this (since supposedly PyTorch also supports this interface).All that being said, I'm interested to hear anyone's ideas/thoughts/concerns about any aspect of the feature.
Beta Was this translation helpful? Give feedback.
All reactions