Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Profiler annotations & tutorial #582

Merged
merged 7 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "Enzyme/MLIR/Dialect/Ops.h"
#include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h"
#include "Enzyme/MLIR/Passes/Passes.h"
#include "mlir/CAPI/Support.h"
#include "mlir/Conversion/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
Expand All @@ -35,7 +36,6 @@
#include "src/enzyme_ad/jax/Implementations/XLADerivatives.h"
#include "src/enzyme_ad/jax/Passes/Passes.h"
#include "llvm/Support/TargetSelect.h"
#include "mlir/CAPI/Support.h"

#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
#include "stablehlo/dialect/ChloOps.h"
Expand All @@ -54,8 +54,9 @@
#include "xla/pjrt/status_casters.h"

#include "tsl/profiler/lib/profiler_session.h"
#include "xla/tsl/profiler/rpc/profiler_server.h"
#include "tsl/profiler/lib/traceme.h"
#include "xla/tsl/profiler/rpc/client/capture_profile.h"
#include "xla/tsl/profiler/rpc/profiler_server.h"

#include "xla/python/ifrt/hlo/hlo_program.h"
#include "llvm/ExecutionEngine/ExecutionEngine.h"
Expand Down Expand Up @@ -240,12 +241,21 @@ extern "C" void ProfilerSessionDelete(tsl::ProfilerSession *session) {
delete session;
}

extern "C" void* ProfilerServerStart(int32_t port) {
extern "C" int64_t ProfilerActivityStart(const char *name, int level) {
return tsl::profiler::TraceMe::ActivityStart(name, level);
}

extern "C" void ProfilerActivityEnd(int64_t id) {
tsl::profiler::TraceMe::ActivityEnd(id);
}

extern "C" tsl::profiler::ProfilerServer *ProfilerServerStart(int32_t port) {
auto server = new tsl::profiler::ProfilerServer();
server->StartProfilerServer(port);
return server;
}
extern "C" void* ProfilerServerStop(tsl::profiler::ProfilerServer* server) {

extern "C" void ProfilerServerStop(tsl::profiler::ProfilerServer *server) {
delete server;
}

Expand Down Expand Up @@ -448,14 +458,12 @@ static void noop() {}
#ifdef REACTANT_CUDA
#include "third_party/gpus/cuda/include/cuda.h"
extern "C" int32_t ReactantCudaDriverGetVersion() {
int32_t data;
ReactantHandleCuResult(cuDriverGetVersion(&data));
return data;
int32_t data;
ReactantHandleCuResult(cuDriverGetVersion(&data));
return data;
}
#else
extern "C" int32_t ReactantCudaDriverGetVersion() {
return 0;
}
extern "C" int32_t ReactantCudaDriverGetVersion() { return 0; }
#endif

extern "C" void *UnsafeBufferPointer(PjRtBuffer *buffer) {
Expand Down Expand Up @@ -752,8 +760,8 @@ static mlir::LogicalResult updateSymbolAndAllUses(mlir::SymbolOpInterface op,
return success();
}

extern "C" void ReactantFuncSetArgAttr(MlirOperation op, intptr_t pos, MlirStringRef name,
MlirAttribute attr) {
extern "C" void ReactantFuncSetArgAttr(MlirOperation op, intptr_t pos,
MlirStringRef name, MlirAttribute attr) {
llvm::cast<mlir::FunctionOpInterface>(unwrap(op))
.setArgAttr(pos, unwrap(name), unwrap(attr));
}
Expand Down
3 changes: 3 additions & 0 deletions deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,8 @@ cc_library(
"-Wl,-exported_symbol,_ProfilerSessionDelete",
"-Wl,-exported_symbol,_ProfilerServerStart",
"-Wl,-exported_symbol,_ProfilerServerStop",
"-Wl,-exported_symbol,_ProfilerActivityStart",
"-Wl,-exported_symbol,_ProfilerActivityEnd",
"-Wl,-exported_symbol,_ReactantFuncSetArgAttr",
"-Wl,-exported_symbol,_ReactantCudaDriverGetVersion"
]}),
Expand Down Expand Up @@ -522,6 +524,7 @@ cc_library(
"@tsl//tsl/profiler/lib:profiler_session_impl",
"@tsl//tsl/profiler/lib:profiler_factory_impl",
"@tsl//tsl/profiler/lib:profiler_controller",
"@tsl//tsl/profiler/lib:traceme",
"@xla//xla/tsl/profiler/rpc:profiler_server_impl",
"@xla//xla/tsl/profiler/rpc/client:capture_profile",
"@xla//xla/tsl/profiler/rpc/client:profiler_client",
Expand Down
3 changes: 2 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ examples = [
pages = [
"Reactant.jl" => "index.md",
"Introduction" => ["Getting Started" => "introduction/index.md"],
"Tutorials" => ["Overview" => "tutorials/index.md"],
"Tutorials" =>
["Overview" => "tutorials/index.md", "Profiling" => "tutorials/profiling.md"],
"API Reference" => [
"Reactant API" => "api/api.md",
"Ops" => "api/ops.md",
Expand Down
9 changes: 8 additions & 1 deletion docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,13 @@ export default defineConfig({
{ text: "Home", link: "/" },
{ text: "Getting Started", link: "/introduction" },
{ text: "Benchmarks", link: "https://enzymead.github.io/Reactant.jl/benchmarks/" },
{ text: "Tutorials", link: "/tutorials/" },
{
text: "Tutorials",
items: [
{text: "Overview", link: "/tutorials/"},
{text: "Profiling", link: "/tutorials/profiling"},
],
},
{
text: "API",
items: [
Expand Down Expand Up @@ -105,6 +111,7 @@ export default defineConfig({
collapsed: false,
items: [
{ text: "Overview", link: "/tutorials/" },
{ text: "Profiling", link: "/tutorials/profiling" },
],
},
"/api/": {
Expand Down
21 changes: 5 additions & 16 deletions docs/src/api/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,13 @@ Reactant.@jit
@code_hlo
```

```@raw html
<br>
```

## Profile XLA

```@docs
Reactant.Profiler.with_profiler
```

# Internal Functionality

!!! danger "Private"

These functions are not part of the public API and are subject to change at any time.
Reactant can hook into XLA's profiler to generate compilation and execution traces.
See the [profiling tutorial](@ref profiling) for more details.

```@docs
Reactant.Compiler.codegen_unflatten!
Reactant.Compiler.codegen_flatten!
Reactant.Compiler.codegen_xla_call
Reactant.Profiler.with_profiler
Reactant.Profiler.annotate
Reactant.Profiler.@annotate
```
7 changes: 6 additions & 1 deletion docs/src/api/internal.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,13 @@ CollapsedDocStrings = true

# Internal API

These functions are not part of the public API and are subject to change at any time.
!!! danger "Private"

These functions are not part of the public API and are subject to change at any time.

```@docs
Reactant.REDUB_ARGUMENTS_NAME
Reactant.Compiler.codegen_unflatten!
Reactant.Compiler.codegen_flatten!
Reactant.Compiler.codegen_xla_call
```
Binary file added docs/src/tutorials/images/perfetto.png
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these screenshots by us or are they taken from somewhere?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

by us, those are from the example

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/src/tutorials/images/tensorboard.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 3 additions & 1 deletion docs/src/tutorials/index.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Tutorials

We are currently working on adding tutorials to Reactant!! Please check back soon!
- [Profiling](@ref profiling).

We are currently working on adding more tutorials to Reactant!! Please check back soon!
84 changes: 84 additions & 0 deletions docs/src/tutorials/profiling.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# [Profiling](@id profiling)

## Capturing traces

When running Reactant, it is possible to capture traces using the [XLA profiler](https://jax.readthedocs.io/en/latest/profiling.html).
These traces can provide information about where the XLA specific parts of program spend time during compilation or execution.

Let's setup a simple function which we can then profile

```@example profiling
using Reactant

x = Reactant.to_rarray(randn(Float32, 100, 2))
W = Reactant.to_rarray(randn(Float32, 10, 100))
b = Reactant.to_rarray(randn(Float32, 10))

linear(x, W, b) = (W * x) .+ b
```

The profiler can be accessed using the [`Reactant.with_profiler`](@ref Reactant.Profiler.with_profiler) function.

```@example profiling
Reactant.with_profiler("./") do
mylinear = Reactant.@compile linear(x, W, b)
mylinear(x, W, b)
end
```

Running this function should create a folder called `plugins` in the folder provided to `Reactant.with_profiler` which will
contain the trace files. The traces can then be visualized in different ways.

!!! note
For more insights about the current state of Reactant, it is possible to fetch device information about allocations using the [`Reactant.XLA.allocatorstats`](@ref) function.

## Perfetto UI

![The perfetto interface](images/perfetto.png)

The first and easiest way to visualize a captured trace is to use the online [`perfetto.dev`](https://ui.perfetto.dev/) tool.
[`Reactant.with_profiler`](@ref Reactant.Profiler.with_profiler) has a keyword parameter called `create_perfetto_link` which will create a usable perfetto URL for the generated trace.
The function will block execution until the URL has been clicked and the trace is visualized. The URL only works once.

```julia
Reactant.with_profiler("./"; create_perfetto_link=true) do
mylinear = Reactant.@compile linear(x, W, b)
mylinear(x, W, b)
end
```

!!! note
It is recommended to use the Chrome browser to open the perfetto URL.

## Tensorboard

![The tensorboard interface](images/tensorboard.png)

Another option to visualize the generated trace files is to use the [tensorboard profiler plugin](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras).
The tensorboard viewer can offer more details than the timeline view such as visualization for compute graphs.

First install tensorboard and its profiler plugin:

```bash
pip install tensorboard tensorboard-plugin-profile
```

And then run the following in the folder where the `plugins` folder was generated:

```bash
tensorboard --logdir ./
```

## Adding Custom Annotations

By default, the traces contain only information captured from within XLA.
The [`Reactant.Profiler.annotate`](@ref) function can be used to annotate traces for Julia code evaluated *during tracing*.

```julia
Reactant.Profiler.annotate("my_annotation") do
# Do things...
end
```

The added annotations will be captured in the traces and can be seen in the different viewers along with the default XLA annotations.
When the profiler is not activated, then the custom annotations have no effect and can therefore always be activated.
51 changes: 48 additions & 3 deletions src/Profiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ function with_profiler(
trace_host=true,
create_perfetto_link=false,
)
# TODO: we should be able to inject traces from Julia to fill in the blank spots in the trace.

device_tracer_level = UInt32(trace_device ? 1 : 0)
host_tracer_level = UInt32(trace_host ? 2 : 0)
profiler = @ccall Reactant.MLIR.API.mlir_c.CreateProfilerSession(
Expand Down Expand Up @@ -55,7 +53,53 @@ function with_profiler(
return results
end

export with_profiler
# https://github.com/google/tsl/blob/ffeadbc9111309a845ab07df3ff41d59cb005afb/tsl/profiler/lib/traceme.h#L49-L53
const TRACE_ME_LEVEL_CRITICAL = Cint(1)
const TRACE_ME_LEVEL_INFO = Cint(2)
const TRACE_ME_LEVEL_VERBOSE = Cint(3)

"""
annotate(f, name, [level=TRACE_ME_LEVEL_CRITICAL])

Generate an annotation in the current trace.
"""
function annotate(f, name, level=TRACE_ME_LEVEL_CRITICAL)
id = @ccall Reactant.MLIR.API.mlir_c.ProfilerActivityStart(
name::Cstring, level::Cint
)::Int64
try
f()
finally
@ccall Reactant.MLIR.API.mlir_c.ProfilerActivityEnd(id::Int64)::Cvoid
end
end

"""
@annotate [name] function foo(a, b, c)
...
end

The created function will generate an annotation in the captured XLA profiles.
"""
macro annotate(name, func_def=nothing)
noname = isnothing(func_def)
func_def = something(func_def, name)

if !Meta.isexpr(func_def, :function)
error("not a function definition: $func_def")
end

name = noname ? string(func_def.args[1].args[1]) : name
code = func_def.args[2]

code = quote
annotate(() -> $(esc(code)), $(esc(name)))
end

return Expr(:function, esc(func_def.args[1]), code)
end

export with_profiler, annotate, @annotate

function serve_to_perfetto(path_to_trace_file)
port_hint = 9001
Expand Down Expand Up @@ -141,4 +185,5 @@ mutable struct ProfileServer
return finalizer(free_profiler, new(exec))
end
end

end # module Profiler
6 changes: 3 additions & 3 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ end
include("mlir/MLIR.jl")
include("XLA.jl")
include("Interpreter.jl")
include("Profiler.jl")

const with_profiler = Profiler.with_profiler

include("utils.jl")

Expand Down Expand Up @@ -247,9 +250,6 @@ const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue}
include("ControlFlow.jl")
include("Tracing.jl")
include("Compiler.jl")
include("Profiler.jl")

const with_profiler = Profiler.with_profiler

include("Overlay.jl")

Expand Down
26 changes: 26 additions & 0 deletions src/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,24 @@ struct JLAllocatorStats
peak_pool_bytes::Int64
end

"""
AllocatorStats()

Contains the following fields:
- `num_allocs`
- `bytes_in_use`
- `peak_bytes_in_use`
- `largest_alloc_size`
- `bytes_limit`
- `bytes_reserved`
- `peak_bytes_reserved`
- `bytes_reservable_limit`
- `largest_free_block_bytes`
- `pool_bytes`
- `peak_pool_bytes`

It should be constructed using the [`allocatorstats`](@ref) function.
"""
struct AllocatorStats
num_allocs::Int64
bytes_in_use::Int64
Expand All @@ -260,6 +278,14 @@ struct AllocatorStats
peak_pool_bytes::Union{Nothing,Int64}
end

"""
allocatorstats([device])

Return an [`AllocatorStats`](@ref) instance with information about the device specific allocator.

!!! warning
This method is currently not implemented for the CPU device.
"""
function allocatorstats(
device::Device=ClientGetDevice(default_backend[], default_device_idx[])
)
Expand Down
Loading