Skip to content

Commit

Permalink
Revert "Update to latest CachedLoader design"
Browse files Browse the repository at this point in the history
This reverts commit 4cc5de3.
  • Loading branch information
DamianSzwichtenberg committed Aug 23, 2023
1 parent 4652255 commit e437576
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
9 changes: 4 additions & 5 deletions benchmark/inference/inference_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
write_to_csv,
)
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import PNAConv
from torch_geometric.nn import PNAConv, make_batches_cacheable
from torch_geometric.profile import (
rename_profile_file,
timeit,
Expand Down Expand Up @@ -180,9 +180,11 @@ def run(args: argparse.ArgumentParser):
print(f'Calculated degree for {dataset_name}.')
params['degree'] = degree

model_decorator = make_batches_cacheable if args.cached_loader else None
model = get_model(
model_name, params,
metadata=data.metadata() if hetero else None)
metadata=data.metadata() if hetero else None,
model_decorator=model_decorator)
model = model.to(device)
# TODO: Migrate to ModelHubMixin.
if args.ckpt_path:
Expand Down Expand Up @@ -215,9 +217,6 @@ def run(args: argparse.ArgumentParser):
inference_kwargs = {}
if args.reuse_device_for_embeddings and not hetero:
inference_kwargs['embedding_device'] = device
if args.cached_loader:
inference_kwargs[
'maybe_with_cached_loader'] = True
for _ in range(args.warmup):
if args.full_batch:
full_batch_inference(model, data)
Expand Down
4 changes: 3 additions & 1 deletion benchmark/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,10 @@ def get_dataset(name, root, use_sparse_tensor=False, bf16=False):
return data, num_classes


def get_model(name, params, metadata=None):
def get_model(name, params, metadata=None, model_decorator=None):
Model = models_dict.get(name, None)
if model_decorator is not None:
Model = model_decorator(Model)
assert Model is not None, f'Model {name} not supported!'

if name == 'rgat':
Expand Down

0 comments on commit e437576

Please sign in to comment.