-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refine JAX integration, example, and docs (#99)
- Loading branch information
1 parent
5526da2
commit d61b5fb
Showing
16 changed files
with
218 additions
and
144 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# Measuring the energy consumption of JAX | ||
|
||
`ZeusMonitor` officially supports JAX: | ||
|
||
```python | ||
monitor = ZeusMonitor(sync_execution_with="jax") | ||
|
||
monitor.begin_window("computations") | ||
# Run computation | ||
measurement = monitor.end_window("computations") | ||
``` | ||
|
||
The `sync_execution_with` parameter in `ZeusMonitor` tells the monitor that it should use JAX mechanisms to wait for GPU computations to complete. | ||
GPU computations typically run asynchronously with your Python code (in both PyTorch and JAX), so waiting for GPU computations to complete is important to ensure that we measure the right set of computations. | ||
|
||
## Running the example | ||
|
||
Install dependencies: | ||
|
||
```sh | ||
pip install -r requirements.txt | ||
``` | ||
|
||
Run the example: | ||
|
||
```sh | ||
python measure_energy.py | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
from zeus.monitor import ZeusMonitor | ||
|
||
@jax.jit | ||
def mat_prod(B): | ||
A = jnp.ones((1000, 1000)) | ||
return A @ B | ||
|
||
def main(): | ||
# Monitor the GPU with index 0. | ||
# The monitor will use a JAX-specific method to wait for the GPU | ||
# to finish computations when `end_window` is called. | ||
monitor = ZeusMonitor(gpu_indices=[0], sync_execution_with="jax") | ||
|
||
# Mark the beginning of a measurement window. | ||
monitor.begin_window("all_computations") | ||
|
||
# Actual work | ||
key = jax.random.PRNGKey(0) | ||
B = jax.random.uniform(key, (1000, 1000)) | ||
for i in range(50000): | ||
B = mat_prod(B) | ||
|
||
# Mark the end of a measurement window and retrieve the measurment result. | ||
measurement = monitor.end_window("all_computations") | ||
|
||
# Print the measurement result. | ||
print("Measurement object:", measurement) | ||
print(f"Took {measurement.time} seconds.") | ||
for gpu_idx, gpu_energy in measurement.gpu_energy.items(): | ||
print(f"GPU {gpu_idx} consumed {gpu_energy} Joules.") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
zeus-ml | ||
jax[cuda12]==0.4.30 |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.