Replies: 3 comments
-
THanks Ben for the write up. Since I am looking at this part right now, I am going to be leaving comments as I read through this. So some comments might be outdated as I read through it.
I would actually go one step further and say that the (Posting these comments now. Still reading through the rest of the post). |
Beta Was this translation helpful? Give feedback.
-
That's a bit TL;DR for me right now :) Can you make the codegen meeting tomorrow so we go over the key points? |
Beta Was this translation helpful? Give feedback.
-
So my understanding so far is this. This is more about what we have right now. Ill come around to what we might want to support. Take the example
When you tile this you get
The tile loop can be distributed in a block-cyclic fashion. Ignoring the inner loop
The missing piece here is the tile size. I think it is reasonable assumption to make that the
The generated code does not make any assumptions about number of workgroups used by the runtime. I have a WIP change locally that sets the value of The remaining question is what are the heuristics the runtime can use to decide the number of workgroups. Maybe have a way to compute "min number of workgroups", which is what I think the |
Beta Was this translation helpful? Give feedback.
-
(following up from chat / @MaheshRavishankar @ThomasRaoux)
I wanted to brainstorm a bit more about workgroup counts and how/when we know them/what we know/etc. This is mostly a brain dump on this so that we can chat about it and hopefully point people here in the future or use it to produce some real docs :)
As background, today (in the new linalg-on-tensors world) at the flow dialect level we have IR like this:
What's important here is that outside of the region the workload (any 1+D set of values) represents the domain of some function
compute_workgroup_count(workload)
that produces the workgroup count passed to the host-side API (vkCmdDispatch, cudaLaunchKernel, etc). This implies that the definition of (and thus actual workgroup count produced by) thatcompute_workgroup_count
function is not knowable as we don't yet know which (of possibly many) target backends/device classes/etc we are targeting. In this example that workload is chosen as the shape of the output, but again that does not indicate what the API will end up getting as it has not yet run throughcompute_workgroup_count
- it's essentially just tracking for the flow dialect to know which SSA values will be required later on in lowering and that they must be kept live even through canonicalization/cse/etc.Within the region the workgroup size has not been decided and may vary for each HAL target backend. How that workgroup size factors into
compute_workgroup_count
is - because we cannot define that function yet - also unknowable. It could not be related at all or could be some variant of the classicworkload / workgroup_size
, etc. We do, however, have symbols to represent the workgroup count and size as well as all of the input shapes to the region and can also pass in arbitrary values to the region too. This means that so long ascompute_workgroup_count
only depends on that available set of symbols and provided values (derived from shapes, etc) we can also callcompute_workgroup_count
within the region itself.Tiles are something that also doesn't matter here outside of the region: in fact, ideally the tile count processed within a workgroup should not be fixed yet as I strongly suspect that will be something that will be tuned - possibly even at runtime (via specialization constants/etc). Regardless of the number of tiles or the tile size the workload does not change (it's again only the inputs to
compute_workgroup_count
, nothing more), nor does the value offlow.dispatch.workgroup.count
or any of the other symbols. They are completely independent.The goal here is to be remain fully parametric while in the flow dialect - even if immediately after producing the dispatch op folding/canonicalization runs and starts to make things more static. For example,
shapex.ranked_dim %ret0_shape[1]
on both the host (outside of the region) and device (inside of the region) knows the value isconstant 4
. Doing so does not change the workload or workgroup count, though, for the same reason as above: the SSA values passed tocompute_workgroup_count
may just have some of their values be constants.I think it's important we all agree on this representation as otherwise the rest falls apart (not implying it's not already laying on the floor in pieces, but at least this part should be ok :). The goal again here is that in the flow level we are backend- and device-agnostic and just representing dataflow (of which the SSA values used as the workload are data, too). I think this should be fine as while still in the flow level the intent is that the inside of the region produced here is not touched anymore as there's nothing we can really do in there beyond standard safe stuff like CSE/DCE/etc - any assumption we try to make about the values and how they are used is breaking the fundamental assertion that we cannot make assumptions about the target.
After leaving flow and going into the HAL level there are two major phases: executable translation and dispatch recording.
First we translate the input (the region contents above) into one or more target backends and within those target backends into one or more possible variants. Example backends here are LLVM IR, SPIR-V, NVVM, WASM - of which each may have radically different representations that require different levels of specificity on certain parameters from above - targeting execution on a CPU, for example, will likely have much higher tile counts/sizes per workgroup than when targeting execution on a GPU as the overhead of processing each workgroup is higher (it goes to the OS thread scheduler and not the hardware workgroup scheduler, may have vector and scalar codepaths, etc). If we picked the tile counts/sizes already in the flow level we wouldn't be able to vary that. Even for a particular backend - like Vulkan + SPIR-V - we may have many different strategies that completely change tiling and workgroup parameters such as whether cooperative matrix is available (and used in the particular executable), whether subgroup operations are available (and there's like an infinite number of combinations of whether voting/ballot/shuffle/etc are supported that will be dependent on not just the device but the executable being translated), whether specific image formats are supported for optimal tiling (for when we want to sample images), etc.
Second - once we have completed translation and know all of the variants of each executable we have - we need to record the dispatch operations. This involves taking the original
flow.dispatch
op with its workload and symbolic reference to the executable and turning it into one or morehal.command_buffer.dispatch
ops each with a (possibly different!) workgroup count. This is where the yet undefinedcompute_workgroup_count
function comes in: there's not just one but one per unique translated executable (meaning one per target backend and per variant within that). It may happen in practice that a lot of them (particularly for a specific target backend) are the same but that is not a guarantee. As thehal.command_buffer.dispatch
is recordedcompute_workgroup_count
is (effectively) fed the translated executable and the workload SSA values and produces the standard 3D workgroup count used by the dispatch API.We should expect that a large majority of tuning and specialization will happen only once at the HAL level and with specialization based on the contents and target device capabilities and that a single compiled IREE output module may have support for many of these backends and variants. What we don't want is to cause that to multiply the binary size for each new variant added. To prevent that from happening we need to be able to elide duplicate translation results (what
hal.interface
and several other details are designed to aid with) as well as duplicate dispatch ops (whathal.device.switch
is designed to aid with). The deduplication can be effective because we are still symbolic here and haven't yet specialized (much). Over time I suspect we'll tune that balance.To summarize so far: everything possible to be parametric when in the flow dialect level should be parametric and the only assumptions we can make are those involving safe operations (propagating constants, etc). Once translating/recording in the HAL level those parameters can be specialized.
This does raise a problem then that I think is what is being hit here (and has come up before): if translation happens before recording, what if you need to know something about the dispatch parameters while translating (such as which are static/etc)? Spoiler: that's where the
flow.dispatch
workload comes in :)The assertion is this: the code that creates the
flow.dispatch.workgroups
op defines the workload and the region that represents a single workgroup invocation including all of the arguments captured and provided to each invocation (in addition to the ones available for things like shapes of inputs/outputs). Valid IR is, for example:Note that the SSA values for the workload were passed in to the region and are now (at the flow level here) symbolically available. If they are constants, RematerializeDispatchConstants (should, when updated) inline them:
I think the above may be what's missing from the IR examples today that helps explain how this all connects and how I can say things like "you can always get the workload in the region" - because you can literally pass the workload to the region :)
Returning to the example at the very top, it may help to think of this how image processing is normally done in compute/graphics (as it's super common to handle padding/differing workgroup counts from image sizes - especially when rescaling):
Note that the workload is just derived from the output width/height and that's also "available" inside of the region because it can be computed in the same way that it was computed on the outside from the same source values - no need to capture it and pass it in as an argument. You still could capture but that'd be wasteful (capturing variables is not free!). Also note that this still hasn't started talking about workgroup sizes - or tile sizes/counts/etc - as this is still in the flow dialect: this just demonstrates that there's no loss of information and that we have (in the previous
rand()
example) ways to pass it in when it's too hard to recover or ways to recover it (as in this example the same values are available on both sides).If comparing this to things like GLSL/VK, you'd see code like this for recording:
And the shader would be:
That's what we are mimicing from the host side (workload + fn) and device side (all the symbolic parameters and shapes and such).
Now to connect it up (working backwards) in IREE: when recording the HAL dispatch for a particular translated executable we (effectively) pass all this information to the
compute_workgroup_count
function on the host side. And it has all of the information it could need: the workload SSA values from theflow.dispatch
op and all attributes of the translated executable produced during translation (workgroup size, tile size, tile count, etc etc). Computing the workgroup count to pass to the API then is just some math. This is what TargetBackend::calculateDispatchWorkgroupCount is meant to do:entryPointOp
is the translated entry point, of which you can add whatever attributes you want to like tile size/count/phase of the moon/etc,workload
is the SSA values provided toflow.dispatch
. The important thing here is that because this only runs on an executable translated by the same backend custom attributes are fine - they just have to be self-consistent - as if the SPIRV target backend adds an attribute it's safe for the SPIRV target backend recordDispatch function to assume that attribute is present.Note that the
TargetBackend::calculateDispatchWorkgroupCount
does not try to inline an (MLIR) function or anything else but instead construct IR: that's because the workgroup count computed should be a (mathematical) function of the values we have here and be able to be synthesized from that set of values - there's no need to generate it in one place and then reuse it in another as if there was it would mean that wasn't the case and all of this falls apart!But then what if you need to know the actual final workgroup count statically earlier during translation (what #4528 is doing)? Run that same
compute_workgroup_count
-generating function inside of the executable! The workload is available (as above it can be either passed in or recovered), the tile size/count/etc are available (they're determined during the translation that's happening), and you can produce the same exact values that the recording does.And that's the whole (simple) story: both host (recording dispatches) and device (translating executables) have access to the same information and can produce the same values so if that's something that's needed by the way the linalg lowering system is designed it's just code that needs to be written. There will be cases when you want that information and cases when you don't - but there's nothing else that can be provided by
flow.dispatch
that does not start to make assumptions about that information in a way we want to leave flexible for the backends.Concretely this means that instead of something like the current
getNumWorkgroupsFn
which then adds that function to the executable to later be plucked out there should be acomputeWorkgroupCount(..., OpBuilder &builder) -> ValueRange
function that could be called to insert that logic anywhere - either during translation to insert into the device region or during recording (TargetBackend::calculateWorkgroupDispatchCount
) to insert into the host dispatch logic.HTH, and would love to get some more IR examples together so we can see what missing pieces there may be (mostly on the order of attributes emitted or the
calculateWorkgroupDispatchCount
implementation that uses them, etc).Beta Was this translation helpful? Give feedback.
All reactions