Skip to content

Commit

Permalink
Fix some issues with the speed tests in custom kernels (#117)
Browse files Browse the repository at this point in the history
* fixes for custom kernel speed tests

* update printing for speed tests
  • Loading branch information
lubbersnick authored Oct 29, 2024
1 parent 957de00 commit a0b5ca7
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
3 changes: 2 additions & 1 deletion hippynn/custom_kernels/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,5 @@ def get_available_implementations(self, hidden=False):
:param hidden: Show all implementations, even those which have no improved performance characteristics.
:return:
"""
return [k for k in self._registered_implementations.keys() if not k.startswith("_")]

return [k for k in self._registered_implementations.keys() if hidden or not k.startswith("_")]
22 changes: 16 additions & 6 deletions hippynn/custom_kernels/test_speed_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def parse_args():
parser.add_argument("--all-hidden", action="store_true", default=False, help="Use all implementations, even with _ beginning.")
parser.add_argument("--all-impl", action="store_true", default=False, help="Use all non-hidden implementations.")
parser.add_argument("--all-gpu", action="store_true", default=False, help="Use low-mem implementations suitable for GPU.")
parser.add_argument("--all-gpu", action="store_true", default=False, help="CPU-capable implementaitons.")
parser.add_argument("--all-cpu", action="store_true", default=False, help="CPU-capable implementaitons.")

for param_type in TEST_PARAMS.keys():
parser.add_argument(f"--{param_type}", type=int, default=0, help=f"Count for param type {param_type}")
Expand All @@ -43,7 +43,8 @@ def main(args=None):
setattr(args, k, default)

test_spec = {k: count for k in TEST_PARAMS if (count := getattr(args, k, 0)) > 0}
print(TEST_PARAMS.keys())

print("Testing specification:")
print(test_spec)
results = {}

Expand All @@ -60,6 +61,11 @@ def main(args=None):
implementations = MessagePassingKernels.get_available_implementations()
if args.all_hidden:
implementations = MessagePassingKernels.get_available_implementations(hidden=True)


print("Testing implementations:")
print(implementations)


# Error if implementation does not exist.
for impl in implementations:
Expand All @@ -82,10 +88,14 @@ def main(args=None):
for k, count in test_spec.items():
print(f"Testing {k} {count} times:")
np.random.seed(args.seed)
out0, out1 = tester.check_speed(
n_repetitions=count, device=torch.device(args.accelerator), data_size=TEST_PARAMS[k], compare_against=impl
)
impl_results[k] = dict(tested=out0, comparison=out1)
try:
out0, out1 = tester.check_speed(
n_repetitions=count, device=torch.device(args.accelerator), data_size=TEST_PARAMS[k], compare_against=impl)
impl_results[k] = dict(tested=out0, comparison=out1)
except (torch.OutOfMemoryError, RuntimeError) as toom:
print(toom)
print("Got out of memory for this test! Attempting to continue.")
impl_results[k] = "OUT OF MEMORY"

with open(path, "wt") as f:
json.dump(results, f)
Expand Down

0 comments on commit a0b5ca7

Please sign in to comment.