Skip to content

Commit

Permalink
Refine JAX integration, example, and docs (#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
jaywonchung authored Jul 19, 2024
1 parent 5526da2 commit d61b5fb
Show file tree
Hide file tree
Showing 16 changed files with 218 additions and 144 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
---
**Project News**

- \[2024/07\] Added AMD GPU, CPU, and DRAM energy measurement support, and preliminary JAX support!
- \[2024/05\] Zeus is now a PyTorch ecosystem project. Read the PyTorch blog post [here](https://pytorch.org/blog/zeus/)!
- \[2024/02\] Zeus was selected as a [2024 Mozilla Technology Fund awardee](https://foundation.mozilla.org/en/blog/open-source-AI-for-environmental-justice/)!
- \[2023/12\] We released Perseus, an energy optimizer for large model training: [Preprint](https://arxiv.org/abs/2312.06902) | [Blog](https://ml.energy/zeus/research_overview/perseus) | [Optimizer](https://ml.energy/zeus/optimize/pipeline_frequency_optimizer)
Expand Down
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ hide:
---
**Project News**

- \[2024/07\] Added AMD GPU, CPU, and DRAM energy measurement support, and preliminary JAX support!
- \[2024/05\] Zeus is now a PyTorch ecosystem project. Read the PyTorch blog post [here](https://pytorch.org/blog/zeus/){.external}!
- \[2024/02\] Zeus was selected as a [2024 Mozilla Technology Fund awardee](https://foundation.mozilla.org/en/blog/open-source-AI-for-environmental-justice/){.external}!
- \[2023/12\] We released Perseus, an energy optimizer for large model training: [Preprint](https://arxiv.org/abs/2312.06902){.external} | [Blog](research_overview/perseus.md) | [Optimizer](optimize/pipeline_frequency_optimizer.md)
Expand Down
20 changes: 20 additions & 0 deletions docs/measure/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,26 @@ if __name__ == "__main__":
In general, energy optimizers measure the energy of the GPU through a [`ZeusMonitor`][zeus.monitor.ZeusMonitor] instance that is passed to their constructor.
Thus, only the GPUs specified by `gpu_indices` will be the target of optimization.

### Synchronizing CPU and GPU computations

Deep learning frameworks typically run actual computation on GPUs in an asynchronous fashion.
That is, the CPU (Python interpreter) asynchronously dispatches computations to run on the GPU and moves on to dispatch the next computation without waiting for the GPU to finish.
This helps GPUs achieve higher utilization with less idle time.

Due to this asynchronous nature of Deep Learning frameworks, we need to be careful when we want to take time and energy measurements of GPU execution.
We want *only and all of* the computations dispatched between `begin_window` and `end_window` to be captured by our time and energy measurement.
That's what the `sync_execution_with` paramter in [`ZeusMonitor`][zeus.monitor.ZeusMonitor] and `sync_execution` paramter in [`begin_window`][zeus.monitor.ZeusMonitor.begin_window] and [`end_window`][zeus.monitor.ZeusMonitor.end_window] are for.
Depending on the Deep Learning framework you're using (currently PyTorch and JAX are supported), [`ZeusMonitor`][zeus.monitor.ZeusMonitor] will automatically synchronize CPU and GPU execution to make sure all and only the computations dispatched between the window are captured.

!!! Tip
Zeus has one function used globally across the codebase for device synchronization: [`sync_execution`][zeus.utils.framework.sync_execution].

!!! Warning
[`ZeusMonitor`][zeus.monitor.ZeusMonitor] covers only the common and simple case of device synchronization, when GPU indices (`gpu_indices`) correspond to one whole physical device.
This is usually what you want, except when using more advanced device partitioning (e.g., using `--xla_force_host_platform_device_count` in JAX to partition CPUs into more pieces).
In such cases, you probably want to opt out from using this function and handle synchronization manually at the appropriate granularity.


## CLI power and energy monitor

The energy monitor measures the total energy consumed by the GPU during the lifetime of the monitor process.
Expand Down
28 changes: 28 additions & 0 deletions examples/jax/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Measuring the energy consumption of JAX

`ZeusMonitor` officially supports JAX:

```python
monitor = ZeusMonitor(sync_execution_with="jax")

monitor.begin_window("computations")
# Run computation
measurement = monitor.end_window("computations")
```

The `sync_execution_with` parameter in `ZeusMonitor` tells the monitor that it should use JAX mechanisms to wait for GPU computations to complete.
GPU computations typically run asynchronously with your Python code (in both PyTorch and JAX), so waiting for GPU computations to complete is important to ensure that we measure the right set of computations.

## Running the example

Install dependencies:

```sh
pip install -r requirements.txt
```

Run the example:

```sh
python measure_energy.py
```
37 changes: 37 additions & 0 deletions examples/jax/measure_energy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import jax
import jax.numpy as jnp

from zeus.monitor import ZeusMonitor

@jax.jit
def mat_prod(B):
A = jnp.ones((1000, 1000))
return A @ B

def main():
# Monitor the GPU with index 0.
# The monitor will use a JAX-specific method to wait for the GPU
# to finish computations when `end_window` is called.
monitor = ZeusMonitor(gpu_indices=[0], sync_execution_with="jax")

# Mark the beginning of a measurement window.
monitor.begin_window("all_computations")

# Actual work
key = jax.random.PRNGKey(0)
B = jax.random.uniform(key, (1000, 1000))
for i in range(50000):
B = mat_prod(B)

# Mark the end of a measurement window and retrieve the measurment result.
measurement = monitor.end_window("all_computations")

# Print the measurement result.
print("Measurement object:", measurement)
print(f"Took {measurement.time} seconds.")
for gpu_idx, gpu_energy in measurement.gpu_energy.items():
print(f"GPU {gpu_idx} consumed {gpu_energy} Joules.")


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions examples/jax/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
zeus-ml
jax[cuda12]==0.4.30
35 changes: 0 additions & 35 deletions examples/jax/simple_monitoring.py

This file was deleted.

2 changes: 1 addition & 1 deletion tests/optimizer/test_power_limit_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def test_power_limit_optimizer(

monitor = ReplayZeusMonitor(
log_file=replay_log.log_file,
ignore_sync_cuda=True,
ignore_sync_execution=True,
match_window_name=False,
)
assert monitor.gpu_indices == replay_log.gpu_indices
Expand Down
62 changes: 31 additions & 31 deletions tests/test_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,79 +278,79 @@ def assert_measurement(
pynvml_mock.nvmlDeviceGetTotalEnergyConsumption.reset_mock()

# Serial non-overlapping windows.
monitor.begin_window("window1", sync_cuda=False)
monitor.begin_window("window1", sync_execution=False)
assert_window_begin("window1", 4)

tick()

# Calling `begin_window` again with the same name should raise an error.
with pytest.raises(ValueError, match="already exists"):
monitor.begin_window("window1", sync_cuda=False)
monitor.begin_window("window1", sync_execution=False)

measurement = monitor.end_window("window1", sync_cuda=False)
measurement = monitor.end_window("window1", sync_execution=False)
assert_measurement("window1", measurement, begin_time=4, elapsed_time=2)

tick()
tick()

monitor.begin_window("window2", sync_cuda=False)
monitor.begin_window("window2", sync_execution=False)
assert_window_begin("window2", 9)

tick()
tick()
tick()

measurement = monitor.end_window("window2", sync_cuda=False)
measurement = monitor.end_window("window2", sync_execution=False)
assert_measurement("window2", measurement, begin_time=9, elapsed_time=4)

# Calling `end_window` again with the same name should raise an error.
with pytest.raises(ValueError, match="does not exist"):
monitor.end_window("window2", sync_cuda=False)
monitor.end_window("window2", sync_execution=False)

# Calling `end_window` with a name that doesn't exist should raise an error.
with pytest.raises(ValueError, match="does not exist"):
monitor.end_window("window3", sync_cuda=False)
monitor.end_window("window3", sync_execution=False)

# Overlapping windows.
monitor.begin_window("window3", sync_cuda=False)
monitor.begin_window("window3", sync_execution=False)
assert_window_begin("window3", 14)

tick()

monitor.begin_window("window4", sync_cuda=False)
monitor.begin_window("window4", sync_execution=False)
assert_window_begin("window4", 16)

tick()
tick()

measurement = monitor.end_window("window3", sync_cuda=False)
measurement = monitor.end_window("window3", sync_execution=False)
assert_measurement("window3", measurement, begin_time=14, elapsed_time=5)

tick()
tick()
tick()

measurement = monitor.end_window("window4", sync_cuda=False)
measurement = monitor.end_window("window4", sync_execution=False)
assert_measurement("window4", measurement, begin_time=16, elapsed_time=7)

# Nested windows.
monitor.begin_window("window5", sync_cuda=False)
monitor.begin_window("window5", sync_execution=False)
assert_window_begin("window5", 24)

monitor.begin_window("window6", sync_cuda=False)
monitor.begin_window("window6", sync_execution=False)
assert_window_begin("window6", 25)

tick()
tick()

measurement = monitor.end_window("window6", sync_cuda=False)
measurement = monitor.end_window("window6", sync_execution=False)
assert_measurement("window6", measurement, begin_time=25, elapsed_time=3)

tick()
tick()
tick()

measurement = monitor.end_window("window5", sync_cuda=False)
measurement = monitor.end_window("window5", sync_execution=False)
assert_measurement("window5", measurement, begin_time=24, elapsed_time=8)

########################################
Expand Down Expand Up @@ -397,58 +397,58 @@ def assert_log_file_row(row: str, name: str, begin_time: int, elapsed_time: int)
if any(is_old_nvml.values()):
return

replay_monitor.begin_window("window1", sync_cuda=False)
replay_monitor.begin_window("window1", sync_execution=False)

# Calling `begin_window` again with the same name should raise an error.
with pytest.raises(RuntimeError, match="is already ongoing"):
replay_monitor.begin_window("window1", sync_cuda=False)
replay_monitor.begin_window("window1", sync_execution=False)

measurement = replay_monitor.end_window("window1", sync_cuda=False)
measurement = replay_monitor.end_window("window1", sync_execution=False)
assert_measurement(
"window1", measurement, begin_time=5, elapsed_time=2, assert_calls=False
)

# Calling `end_window` with a non-existant window name should raise an error.
with pytest.raises(RuntimeError, match="is not ongoing"):
replay_monitor.end_window("window2", sync_cuda=False)
replay_monitor.end_window("window2", sync_execution=False)

replay_monitor.begin_window("window2", sync_cuda=False)
measurement = replay_monitor.end_window("window2", sync_cuda=False)
replay_monitor.begin_window("window2", sync_execution=False)
measurement = replay_monitor.end_window("window2", sync_execution=False)
assert_measurement(
"window2", measurement, begin_time=10, elapsed_time=4, assert_calls=False
)

replay_monitor.begin_window("window3", sync_cuda=False)
replay_monitor.begin_window("window4", sync_cuda=False)
replay_monitor.begin_window("window3", sync_execution=False)
replay_monitor.begin_window("window4", sync_execution=False)

measurement = replay_monitor.end_window("window3", sync_cuda=False)
measurement = replay_monitor.end_window("window3", sync_execution=False)
assert_measurement(
"window3", measurement, begin_time=15, elapsed_time=5, assert_calls=False
)
measurement = replay_monitor.end_window("window4", sync_cuda=False)
measurement = replay_monitor.end_window("window4", sync_execution=False)
assert_measurement(
"window4", measurement, begin_time=17, elapsed_time=7, assert_calls=False
)

replay_monitor.begin_window("window5", sync_cuda=False)
replay_monitor.begin_window("window6", sync_cuda=False)
measurement = replay_monitor.end_window("window6", sync_cuda=False)
replay_monitor.begin_window("window5", sync_execution=False)
replay_monitor.begin_window("window6", sync_execution=False)
measurement = replay_monitor.end_window("window6", sync_execution=False)
assert_measurement(
"window6", measurement, begin_time=26, elapsed_time=3, assert_calls=False
)
measurement = replay_monitor.end_window("window5", sync_cuda=False)
measurement = replay_monitor.end_window("window5", sync_execution=False)
assert_measurement(
"window5", measurement, begin_time=25, elapsed_time=8, assert_calls=False
)

# Calling `end_window` when the energy consumption of one or more GPUs was measured as zero should raise a warning.
pynvml_mock.nvmlDeviceGetTotalEnergyConsumption.side_effect = lambda handle: 0.0

monitor.begin_window("window0", sync_cuda=False)
monitor.begin_window("window0", sync_execution=False)

with pytest.warns(
match="The energy consumption of one or more GPUs was measured as zero. This means that the time duration of the measurement window was shorter than the GPU's energy counter update period. Consider turning on the `approx_instant_energy` option in `ZeusMonitor`, which approximates the energy consumption of a short time window as instant power draw x window duration.",
):
test_measurement = monitor.end_window("window0", sync_cuda=False)
test_measurement = monitor.end_window("window0", sync_execution=False)

assert all(value == 0.0 for value in test_measurement.gpu_energy.values())
8 changes: 4 additions & 4 deletions zeus/device/cpu/rapl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
from __future__ import annotations

import os
import contextlib
from typing import Sequence
import warnings
from glob import glob
from typing import Sequence
from functools import lru_cache

import zeus.device.cpu.common as cpu_common
from zeus.device.cpu.common import CpuDramMeasurement
Expand All @@ -29,6 +29,7 @@
RAPL_DIR = "/sys/class/powercap/intel-rapl"


@lru_cache(maxsize=1)
def rapl_is_available() -> bool:
"""Check if RAPL is available."""
if not os.path.exists(RAPL_DIR):
Expand Down Expand Up @@ -172,5 +173,4 @@ def _init_cpus(self) -> None:

def __del__(self) -> None:
"""Shuts down the Intel CPU monitoring."""
with contextlib.suppress(Exception):
logger.info("Shutting down RAPL CPU monitoring.")
pass
6 changes: 4 additions & 2 deletions zeus/device/gpu/amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import contextlib
from typing import Sequence
from functools import lru_cache

try:
import amdsmi # type: ignore
Expand Down Expand Up @@ -32,6 +33,7 @@ def __getattr__(self, name):
logger = get_logger(name=__name__)


@lru_cache(maxsize=1)
def amdsmi_is_available() -> bool:
"""Check if amdsmi is available."""
try:
Expand All @@ -43,8 +45,8 @@ def amdsmi_is_available() -> bool:
amdsmi.amdsmi_init()
logger.info("amdsmi is available and initialized")
return True
except amdsmi.AmdSmiLibraryException:
logger.info("amdsmi is available but could not initialize.")
except amdsmi.AmdSmiLibraryException as e:
logger.info("amdsmi is available but could not initialize: %s", e)
return False


Expand Down
Loading

0 comments on commit d61b5fb

Please sign in to comment.