Skip to content

Commit

Permalink
👥 Compare data loading between kvikIO and Zarr engine (#7)
Browse files Browse the repository at this point in the history
* 👔 Get mean time by averaging over ten epochs

Benchmark results between the Zarr and kvikIO engine were too close for one epoch, so looping over 10 epochs and reporting the average instead. Not printing the MSE Loss anymore to declutter the console output.

* ➕ Add ipywidgets

Jupyter Interactive Widgets! Repo at https://github.com/jupyter-widgets/ipywidgets

* ♻️ Use tqdm.auto to also work in notebooks

Will be reusing some of this code in a Jupyter Notebook, so refactoring to use tqdm.auto instead of standard tqdm.

* 🔊 Report median time and standard deviation across epochs

Save the time taken to complete each epoch, and compute the median, mean and standard deviation across all epochs. Needed because the time to process one epoch can vary by a few seconds across the ten epochs depending on various factors (e.g. caching), so computing the average time as total_time / num_epochs can lead to misleading results. Also updated main README.md to say be more specific about the reported total/median/mean/std benchmark times and the size of the ERA5 subset dataset.

* 👥 Compare data loading between kvikIO and Zarr engine

Reporting the actual numbers on which is faster - kvikIO or Zarr! Reusing some code from 1_benchmark_kvikIOzarr.py, but now the total/median/mean/std times can be displayed. Final cell calculates the speedup of kvikIO to be ~20% over Zarr, but note that this speedup can actually fluctuate depending on lots of factors (have seen values from 10%-30% over multiple runs).

* ➕ Add seaborn

Statistical data visualization in Python!

* 💫 Plot bar graph showing time taken between kvikio and zarr engine

A bar plot (with error bars) to visually compare kvikio (with GPUDirect Storage) against the zarr (no GPUDirect Storage) xarray backend engines in terms of data loading speed. Speedup results still fluctuates between runs, but are mostly around the 20% mark.

Also did some slight refactoring to use pandas instead of numpy for the mean/median/std calculations. Using ddof=1 for the standard deviation.

* 🐛 Change barplot estimator to median instead of mean

Seaborn plots the mean value by default, but changing to median instead. The kvikIO engine is now reported as 35% faster than the Zarr engine.

* 💬 Report as percentage less time, not speedup

Speed is equal to Distance (or epochs) over time. It makes more sense to report 'less time' (absolute measure) instead of 'faster speed' (inverse measure), so fixing the formulation. Previous calculation of speedup may actually have been incorrect?
  • Loading branch information
weiji14 committed Oct 13, 2023
1 parent 551bdb1 commit cc76120
Show file tree
Hide file tree
Showing 5 changed files with 1,278 additions and 375 deletions.
41 changes: 28 additions & 13 deletions 1_benchmark_kvikIOzarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@

import cupy
import lightning as L
import numpy as np
import torch
import torchdata
import torchdata.dataloader2
import tqdm
import xarray as xr
import zen3geo
from tqdm.auto import tqdm, trange


# %%
Expand Down Expand Up @@ -167,16 +168,30 @@ def train_dataloader(self) -> torchdata.dataloader2.DataLoader2:
datamodule.setup()
train_dataloader = datamodule.train_dataloader()

# Start timing
tic = time.perf_counter()

# Training loop
for i, batch in tqdm.tqdm(iterable=enumerate(train_dataloader), total=23):
input, target, metadata = batch
# Compute Mean Squared Error loss between t=0 and t=1, just for fun
loss: torch.Tensor = torch.functional.F.mse_loss(input=input, target=target)
print(f"Batch {i}, MSE Loss: {loss}")

# Stop timing
toc = time.perf_counter()
print(f"Total: {toc - tic:0.4f} seconds")
num_epochs: int = 10
epoch_timings: list = []
for epoch in trange(num_epochs):
# Start timing
tic: float = time.perf_counter()

# Mini-batch processing
for i, batch in tqdm(iterable=enumerate(train_dataloader), total=23):
input, target, metadata = batch
# Compute Mean Squared Error loss between t=0 and t=1, just for fun
loss: torch.Tensor = torch.functional.F.mse_loss(input=input, target=target)
# print(f"Batch {i}, MSE Loss: {loss}")

# Stop timing
toc: float = time.perf_counter()
epoch_timings.append(toc - tic)

total_time: float = np.sum(a=epoch_timings)
median_time: float = np.median(a=epoch_timings)
mean_time: float = np.mean(a=epoch_timings)
std_time: float = np.std(a=epoch_timings, ddof=1)
print(
f"Total: {total_time:0.4f} seconds, "
f"Median: {median_time:0.4f} seconds/epoch, "
f"Mean: {mean_time:0.4f} ± {std_time:0.4f} seconds/epoch"
)
Loading

0 comments on commit cc76120

Please sign in to comment.