-
Notifications
You must be signed in to change notification settings - Fork 4.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JIT: initial support for reinforcement learning of CSE heuristic #96880
Conversation
Tagging subscribers to this area: @JulieLeeMSFT, @jakobbotsch Issue DetailsInitial support for a reinforcement-learning based CSE heuristic.
|
how feasible it is to have a general infrastructure that is driven by features/parameters, so any optimization can plug into it? I want to do something similar for LSRA. |
Adds special CSE heuristic modes to the JIT to support learning a good CSE heuristic via Policy Gradient, a form of reinforcement learning. The learning must be orchestrated by an external process, but the JIT does all of the actual gradient computations. The orchestration program will be added to jitutils. The overall process also relies on SPMI and the goal is to minimize perf score. Introduce two new CSE heuristic policies: * Replay: simply perform indicated sequence of CSEs * RL: used for the Policy Gradient, with 3 modes: * Stochastic: based on current parameters but allows random variation * Greedy: based on current parameters, deterministic * Update: compute updated parameters per Policy Gradient Also rework the Random policy to be a bit more random, it now alters both the CSEs performed and the order they are performed in. Add the ability to have jit config options that specify sequences of ints or doubles. Add the ability to just dump metric info for a jitted method, and add more details (perhaps considerably more) for CSEs. This is all still simple text format. Also factor out a common check for "non-viable" candidates -- these are CSE candidates that won't actually be CSEs. This leads to some minor diffs as the check is now slightly different for CSEs with zero uses and/or zero weighted uses. Contributes to dotnet#92915.
6ef9785
to
14ba4ea
Compare
@dotnet/jit-contrib FYI Not sure who wants to review this one. Any volunteers? |
Somewhat? The basic structure is common to lots of problems, the tricky bit is figuring out the right state/action model and to either handle this across a jit/host/orchestrator boundary or externalize all the info from the jit so it can be processed entirely by outside code. Let me describe briefly how this all works and maybe we can brainstorm about how to leverage it for your case. The "RL" mode for CSEs has 3 behaviors:
The orchestration process repeatedly cycles through evaluation/exploration + update steps. This process should converge to a set of parameters that (via greedy policy) should obtain the optimal perf score for that method (or scores for sets of methods). In the background the orchestrator also computes "V" and "Q" estimates using the data from each run; this is used to compute increasingly accurate per-step rewards. |
Diff results for #96880Assembly diffsAssembly diffs for linux/arm64 ran on windows/x64Diffs are based on 2,501,661 contexts (1,003,806 MinOpts, 1,497,855 FullOpts). MISSED contexts: base: 3,546 (0.14%), diff: 3,556 (0.14%) Overall (+1,884 bytes)
FullOpts (+1,884 bytes)
Assembly diffs for linux/x64 ran on windows/x64Diffs are based on 2,595,039 contexts (1,052,329 MinOpts, 1,542,710 FullOpts). MISSED contexts: 3,596 (0.14%) Overall (-236 bytes)
FullOpts (-236 bytes)
Assembly diffs for osx/arm64 ran on windows/x64Diffs are based on 2,263,032 contexts (930,876 MinOpts, 1,332,156 FullOpts). MISSED contexts: base: 2,925 (0.13%), diff: 2,933 (0.13%) Overall (+1,512 bytes)
FullOpts (+1,512 bytes)
Assembly diffs for windows/arm64 ran on windows/x64Diffs are based on 2,318,296 contexts (931,543 MinOpts, 1,386,753 FullOpts). MISSED contexts: base: 2,587 (0.11%), diff: 2,598 (0.11%) Overall (-952 bytes)
FullOpts (-952 bytes)
Assembly diffs for windows/x64 ran on windows/x64Diffs are based on 2,492,949 contexts (983,689 MinOpts, 1,509,260 FullOpts). MISSED contexts: base: 3,859 (0.15%), diff: 3,862 (0.15%) Overall (-2,082 bytes)
FullOpts (-2,082 bytes)
Details here Assembly diffs for linux/arm ran on windows/x86Diffs are based on 2,238,212 contexts (827,812 MinOpts, 1,410,400 FullOpts). MISSED contexts: base: 74,052 (3.20%), diff: 74,066 (3.20%) Overall (-2,444 bytes)
FullOpts (-2,444 bytes)
Assembly diffs for windows/x86 ran on windows/x86Diffs are based on 2,299,277 contexts (841,817 MinOpts, 1,457,460 FullOpts). MISSED contexts: base: 2,090 (0.09%), diff: 2,093 (0.09%) Overall (-81 bytes)
FullOpts (-81 bytes)
Details here Throughput diffsThroughput diffs for linux/arm64 ran on windows/x64FullOpts (-0.01% to +0.00%)
Throughput diffs for windows/x64 ran on windows/x64Overall (-0.01% to +0.00%)
FullOpts (-0.01% to +0.00%)
Details here Throughput diffs for linux/arm64 ran on linux/x64FullOpts (-0.01% to -0.00%)
Details here |
@EgorBo can you take a look? |
Sure, need to rewatch your internal talk that I missed first 🙂 |
Diff results for #96880Throughput diffsThroughput diffs for linux/arm64 ran on windows/x64FullOpts (-0.01% to +0.00%)
Throughput diffs for windows/x64 ran on windows/x64Overall (-0.01% to +0.00%)
FullOpts (-0.01% to +0.00%)
Details here |
Diff results for #96880Assembly diffsAssembly diffs for linux/arm64 ran on windows/x64Diffs are based on 2,501,147 contexts (1,003,806 MinOpts, 1,497,341 FullOpts). MISSED contexts: base: 4,060 (0.16%), diff: 4,070 (0.16%) Overall (+3,884 bytes)
FullOpts (+3,884 bytes)
Assembly diffs for linux/x64 ran on windows/x64Diffs are based on 2,595,007 contexts (1,052,329 MinOpts, 1,542,678 FullOpts). MISSED contexts: 3,628 (0.14%) Overall (-365 bytes)
FullOpts (-365 bytes)
Assembly diffs for osx/arm64 ran on windows/x64Diffs are based on 2,262,701 contexts (930,876 MinOpts, 1,331,825 FullOpts). MISSED contexts: base: 3,256 (0.14%), diff: 3,264 (0.14%) Overall (+2,152 bytes)
FullOpts (+2,152 bytes)
Assembly diffs for windows/arm64 ran on windows/x64Diffs are based on 2,318,196 contexts (931,543 MinOpts, 1,386,653 FullOpts). MISSED contexts: base: 2,687 (0.12%), diff: 2,698 (0.12%) Overall (-216 bytes)
FullOpts (-216 bytes)
Assembly diffs for windows/x64 ran on windows/x64Diffs are based on 2,492,909 contexts (983,689 MinOpts, 1,509,220 FullOpts). MISSED contexts: base: 3,899 (0.16%), diff: 3,902 (0.16%) Overall (-2,169 bytes)
FullOpts (-2,169 bytes)
Details here Assembly diffs for linux/arm ran on windows/x86Diffs are based on 2,237,676 contexts (827,812 MinOpts, 1,409,864 FullOpts). MISSED contexts: base: 74,588 (3.23%), diff: 74,602 (3.23%) Overall (-2,188 bytes)
FullOpts (-2,188 bytes)
Assembly diffs for windows/x86 ran on windows/x86Diffs are based on 2,296,274 contexts (841,817 MinOpts, 1,454,457 FullOpts). MISSED contexts: base: 5,093 (0.22%), diff: 5,096 (0.22%) Overall (-79 bytes)
FullOpts (-79 bytes)
Details here Throughput diffsThroughput diffs for linux/arm64 ran on windows/x64FullOpts (-0.01% to +0.00%)
Throughput diffs for windows/x64 ran on windows/x64Overall (-0.01% to +0.00%)
FullOpts (-0.01% to +0.00%)
Details here Throughput diffs for linux/arm64 ran on linux/x64FullOpts (-0.01% to -0.00%)
Details here |
Add a tool that can use ML techniques to explore the JIT's CSE heuristic. Some parts of this are very specific to CSEs, others are general and could be repurposed for use with other heuristics. This is still work in progress. Depends on jit changes in dotnet/runtime#96880
Add a tool that can use ML techniques to explore the JIT's CSE heuristic. Some parts of this are very specific to CSEs, others are general and could be repurposed for use with other heuristics. This is still work in progress. Depends on jit changes in dotnet/runtime#96880
@EgorBo ping |
printf("\n"); | ||
} | ||
|
||
printf("Total bytes of code %d, prolog size %d, PerfScore %.2f, instruction count %d, allocated bytes for " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: looks like "Total bytes of code"
is no longer prefixed with ;
(comments in asm)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will fix in a subsequent change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
jit-analyze is looking for this string (https://github.com/dotnet/jitutils/blob/e30e004fee30f6da62e2ddf856e31e677cec2955/src/jit-analyze/Program.cs#L294), so diffs are semi-broken now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added that back in #97677
// 10. cse costEx is <= MIN_CSE_COST (0/1) | ||
// 11. cse is a constant and live across call (0/1) | ||
// 12. cse is a constant and min cost (0/1) | ||
// 13. cse is a constant and NOT min cost (0/1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just wondering - are you going to take platform's features into account such as number of callee-saved regs (for GPR and floats)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we will need to add something like this -- right now the mechanisms to decide not to do a CSE are too weak.
I have follow-on changes that add some, but I'm not happy with them yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, looking forward to seeing the actual changes! Sorry for the delayed review
Adds special CSE heuristic modes to the JIT to support learning a good CSE
heuristic via Policy Gradient, a form of reinforcement learning. The learning
must be orchestrated by an external process, but the JIT does all of the
actual gradient computations.
The orchestration program will be added to jitutils. The overall process
also relies on SPMI and the goal is to minimize perf score.
Introduce two new CSE heuristic policies:
Also rework the Random policy to be a bit more random, it now alters
both the CSEs performed and the order they are performed in.
Add the ability to have jit config options that specify sequences of ints
or doubles.
Add the ability to just dump metric info for a jitted method, and add
more details (perhaps considerably more) for CSEs. This is all still
simple text format.
Also factor out a common check for "non-viable" candidates -- these are
CSE candidates that won't actually be CSEs. This leads to some minor
diffs as the check is now slightly different for CSEs with zero uses
and/or zero weighted uses.
Contributes to #92915.