Skip to content

Commit

Permalink
a few convenience tweaks to make the reproduce script smoother (#34)
Browse files Browse the repository at this point in the history
* reproducing notebook: fix the -g option

before that change, was always skipping gpu runs
also the default is now to run all, use -g to skip the runs that require a GPU

* reproduce: configure jax to enable 64 bits

* cosmetic
  • Loading branch information
parmentelat authored Jun 29, 2024
1 parent 36b10ae commit b27201b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
21 changes: 13 additions & 8 deletions paper/reproducing-results-notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,18 @@
# catch low-hanging fruits first
SKIP_RUNS_LONGER_THAN = 0

# by default, only on CPU
RUN_ON_GPUS = False
# by default, run all runs
SKIP_GPU = False

#
DRY_RUN = False


# %%
# Allow JAX to use 64-bit floating point precision.
import jax
jax.config.update("jax_enable_x64", True)

# %%
# provide a way to choose the output from the command line
# but argparse won't work from Jupyter, so:
Expand All @@ -59,19 +64,19 @@
parser.add_argument("-s", "--skip-runs-longer-than", default=SKIP_RUNS_LONGER_THAN,
action="store", type=int,
help="speed up: skip runs that had taken longer than, in seconds")
parser.add_argument("-g", "--gpu", default=RUN_ON_GPUS,
parser.add_argument("-g", "--skip-gpu", default=SKIP_GPU,
action="store_true",
help="enable runs that require a GPU")
help="skip runs that require a GPU")
parser.add_argument("-n", "--dry-run", default=DRY_RUN,
action="store_true",
help="just show the commands to run, do not actually trigger them")
args = parser.parse_args()
OUR_TIMINGS = args.output
SKIP_RUNS_LONGER_THAN = args.skip_runs_longer_than
RUN_RUN_ON_GPUS = args.gpu
SKIP_GPU = args.skip_gpu
DRY_RUN = args.dry_run

print(f"using {OUR_TIMINGS=} {SKIP_RUNS_LONGER_THAN=} {RUN_ON_GPUS=} {DRY_RUN=}")
print(f"using {OUR_TIMINGS=} {SKIP_RUNS_LONGER_THAN=} {SKIP_GPU=} {DRY_RUN=}")

# %% [markdown]
# ## loading the paper timings
Expand Down Expand Up @@ -157,7 +162,7 @@ def status(message):
# ### isolating lines doable on a CPU (optional)

# %%
if not RUN_ON_GPUS:
if SKIP_GPU:
todo = todo[ ~ todo.model.str.contains('jax')]
status("removed GPU-only runs")

Expand All @@ -169,7 +174,7 @@ def status(message):

if SKIP_RUNS_LONGER_THAN:
todo = todo[todo.time < SKIP_RUNS_LONGER_THAN]
status(f"keeping only runs shorter than {SKIP_RUNS_LONGER_THAN}s")
status(f"skipping runs over {SKIP_RUNS_LONGER_THAN}s")

# %% [markdown]
# ### ignore entries whose estimation cannot be automated
Expand Down
3 changes: 3 additions & 0 deletions paper/time_pls.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
single_fit_gpu_pls,
)

import jax
jax.config.update("jax_enable_x64", True)

def main():
parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down

0 comments on commit b27201b

Please sign in to comment.