From aff3a99ff65d95b454d5c40175f68576393682c9 Mon Sep 17 00:00:00 2001 From: Rishi Puri Date: Thu, 16 Nov 2023 11:24:35 -0800 Subject: [PATCH] Add cuda to multigpu (xpu) bench (#8386) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Damian Szwichtenberg --- CHANGELOG.md | 2 +- benchmark/multi_gpu/training/README.md | 17 ++++-- .../{training_benchmark.py => common.py} | 56 +++++-------------- .../training/training_benchmark_cuda.py | 51 +++++++++++++++++ .../training/training_benchmark_xpu.py | 53 ++++++++++++++++++ 5 files changed, 129 insertions(+), 50 deletions(-) rename benchmark/multi_gpu/training/{training_benchmark.py => common.py} (84%) create mode 100644 benchmark/multi_gpu/training/training_benchmark_cuda.py create mode 100644 benchmark/multi_gpu/training/training_benchmark_xpu.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e02c9ad29ed..196f79ba6f69 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for `torch.compile` in `MultiAggregation` ([#8345](https://github.com/pyg-team/pytorch_geometric/pull/8345)) - Added support for `torch.compile` in `HeteroConv` ([#8344](https://github.com/pyg-team/pytorch_geometric/pull/8344)) - Added support for weighted `sparse_cross_entropy` ([#8340](https://github.com/pyg-team/pytorch_geometric/pull/8340)) -- Added a multi GPU training benchmarks for XPU device ([#8288](https://github.com/pyg-team/pytorch_geometric/pull/8288)) +- Added a multi GPU training benchmarks for CUDA and XPU devices ([#8288](https://github.com/pyg-team/pytorch_geometric/pull/8288), [#8386](https://github.com/pyg-team/pytorch_geometric/pull/8386)) - Support MRR computation in `KGEModel.test()` ([#8298](https://github.com/pyg-team/pytorch_geometric/pull/8298)) - Added an example for model parallelism (`examples/multi_gpu/model_parallel.py`) ([#8309](https://github.com/pyg-team/pytorch_geometric/pull/8309)) - Added a tutorial for multi-node multi-GPU training with pure PyTorch ([#8071](https://github.com/pyg-team/pytorch_geometric/pull/8071)) diff --git a/benchmark/multi_gpu/training/README.md b/benchmark/multi_gpu/training/README.md index 43d75b53922f..56d0c58984bb 100644 --- a/benchmark/multi_gpu/training/README.md +++ b/benchmark/multi_gpu/training/README.md @@ -1,16 +1,21 @@ # Training Benchmark -## Environment setup +## Running benchmark on CUDA GPU -Optional, XPU only: +Run benchmark, e.g. assuming you have `n` NVIDIA GPUs: +``` +python training_benchmark_cuda.py --dataset ogbn-products --model edge_cnn --num-epochs 3 --n_gpus +``` + +## Running benchmark on Intel GPU + +## Environment setup ``` install intel_extension_for_pytorch install oneccl_bindings_for_pytorch ``` -## Running benchmark - -Run benchmark, e.g. assuming you have 2 GPUs: +Run benchmark, e.g. assuming you have `n` XPUs: ``` -mpirun -np 2 python training_benchmark.py --dataset ogbn-products --model edge_cnn --num-epochs 3 +mpirun -np python training_benchmark_xpu.py --dataset ogbn-products --model edge_cnn --num-epochs 3 ``` diff --git a/benchmark/multi_gpu/training/training_benchmark.py b/benchmark/multi_gpu/training/common.py similarity index 84% rename from benchmark/multi_gpu/training/training_benchmark.py rename to benchmark/multi_gpu/training/common.py index 391bdfbd9dac..452a27f28bdd 100644 --- a/benchmark/multi_gpu/training/training_benchmark.py +++ b/benchmark/multi_gpu/training/common.py @@ -1,17 +1,15 @@ import argparse import ast -import os from time import perf_counter -from typing import Any, Tuple, Union +from typing import Any, Callable, Tuple, Union -import intel_extension_for_pytorch as ipex -import oneccl_bindings_for_pytorch # noqa import torch import torch.distributed as dist import torch.nn.functional as F from torch.nn.parallel import DistributedDataParallel as DDP -from benchmark.utils import get_dataset, get_model, get_split_masks, test +from benchmark.utils import get_model, get_split_masks, test +from torch_geometric.data import Data, HeteroData from torch_geometric.loader import NeighborLoader from torch_geometric.nn import PNAConv @@ -24,6 +22,7 @@ device_conditions = { 'xpu': (lambda: torch.xpu.is_available()), + 'cuda': (lambda: torch.cuda.is_available()), } @@ -63,6 +62,8 @@ def train_hetero(model: Any, loader: NeighborLoader, def maybe_synchronize(device: str): if device == 'xpu' and torch.xpu.is_available(): torch.xpu.synchronize() + if device == 'cuda' and torch.cuda.is_available(): + torch.cuda.synchronize() def create_mask_per_rank( @@ -83,7 +84,9 @@ def create_mask_per_rank( return mask_per_rank -def run(rank: int, world_size: int, args: argparse.ArgumentParser): +def run(rank: int, world_size: int, args: argparse.ArgumentParser, + num_classes: int, data: Union[Data, HeteroData], + custom_optimizer: Callable[[Any, Any], Tuple[Any, Any]] = None): if not device_conditions[args.device](): raise RuntimeError(f'{args.device.upper()} is not available') @@ -92,13 +95,8 @@ def run(rank: int, world_size: int, args: argparse.ArgumentParser): if rank == 0: print('BENCHMARK STARTS') print(f'Running on {args.device.upper()}') - - assert args.dataset in supported_sets.keys( - ), f"Dataset {args.dataset} isn't supported." - if rank == 0: print(f'Dataset: {args.dataset}') - data, num_classes = get_dataset(args.dataset, args.root) hetero = True if args.dataset == 'ogbn-mag' else False mask, val_mask, test_mask = get_split_masks(data, args.dataset) mask = create_mask_per_rank(mask, rank, world_size, hetero) @@ -192,8 +190,8 @@ def run(rank: int, world_size: int, args: argparse.ArgumentParser): optimizer = torch.optim.Adam(model.parameters(), lr=0.001) - if args.device == 'xpu': - model, optimizer = ipex.optimize(model, optimizer=optimizer) + if custom_optimizer: + model, optimizer = custom_optimizer(model, optimizer) train = train_hetero if hetero else train_homo @@ -248,37 +246,11 @@ def run(rank: int, world_size: int, args: argparse.ArgumentParser): dist.destroy_process_group() -def get_dist_params() -> Tuple[int, int, str]: - master_addr = "127.0.0.1" - master_port = "29500" - os.environ["MASTER_ADDR"] = master_addr - os.environ["MASTER_PORT"] = master_port - - mpi_rank = int(os.environ.get("PMI_RANK", -1)) - mpi_world_size = int(os.environ.get("PMI_SIZE", -1)) - rank = mpi_rank if mpi_world_size > 0 else os.environ.get("RANK", 0) - world_size = (mpi_world_size if mpi_world_size > 0 else os.environ.get( - "WORLD_SIZE", 1)) - - os.environ["RANK"] = str(rank) - os.environ["WORLD_SIZE"] = str(world_size) - - init_method = f"tcp://{master_addr}:{master_port}" - - return rank, world_size, init_method - - -if __name__ == '__main__': - rank, world_size, init_method = get_dist_params() - dist.init_process_group(backend="ccl", init_method=init_method, - world_size=world_size, rank=rank) - +def get_predefined_args() -> argparse.ArgumentParser: argparser = argparse.ArgumentParser( 'GNN distributed (DDP) training benchmark') add = argparser.add_argument - add('--device', choices=['xpu'], default='xpu', - help='Device to run benchmark on') add('--dataset', choices=['ogbn-mag', 'ogbn-products', 'Reddit'], default='Reddit', type=str) add('--model', @@ -297,6 +269,4 @@ def get_dist_params() -> Tuple[int, int, str]: add('--num-epochs', default=1, type=int) add('--evaluate', action='store_true') - args = argparser.parse_args() - - run(rank, world_size, args) + return argparser diff --git a/benchmark/multi_gpu/training/training_benchmark_cuda.py b/benchmark/multi_gpu/training/training_benchmark_cuda.py new file mode 100644 index 000000000000..7c429ecb1363 --- /dev/null +++ b/benchmark/multi_gpu/training/training_benchmark_cuda.py @@ -0,0 +1,51 @@ +import argparse +import os +from typing import Union + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from benchmark.multi_gpu.training.common import ( + get_predefined_args, + run, + supported_sets, +) +from benchmark.utils import get_dataset +from torch_geometric.data import Data, HeteroData + + +def run_cuda(rank: int, world_size: int, args: argparse.ArgumentParser, + num_classes: int, data: Union[Data, HeteroData]): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + dist.init_process_group('nccl', rank=rank, world_size=world_size) + run(rank, world_size, args, num_classes, data) + + +if __name__ == '__main__': + argparser = get_predefined_args() + argparser.add_argument('--n-gpus', default=1, type=int) + args = argparser.parse_args() + setattr(args, 'device', 'cuda') + + assert args.dataset in supported_sets.keys(), \ + f"Dataset {args.dataset} isn't supported." + data, num_classes = get_dataset(args.dataset, args.root) + + max_world_size = torch.cuda.device_count() + chosen_world_size = args.n_gpus + if chosen_world_size <= max_world_size: + world_size = chosen_world_size + else: + print(f'User selected {chosen_world_size} GPUs ' + f'but only {max_world_size} GPUs are available') + world_size = max_world_size + print(f'Let\'s use {world_size} GPUs!') + + mp.spawn( + run_cuda, + args=(world_size, args, num_classes, data), + nprocs=world_size, + join=True, + ) diff --git a/benchmark/multi_gpu/training/training_benchmark_xpu.py b/benchmark/multi_gpu/training/training_benchmark_xpu.py new file mode 100644 index 000000000000..300914429bfa --- /dev/null +++ b/benchmark/multi_gpu/training/training_benchmark_xpu.py @@ -0,0 +1,53 @@ +import os +from typing import Any, Tuple + +import intel_extension_for_pytorch as ipex +import oneccl_bindings_for_pytorch # noqa +import torch.distributed as dist + +from benchmark.multi_gpu.training.common import ( + get_predefined_args, + run, + supported_sets, +) +from benchmark.utils import get_dataset + + +def get_dist_params() -> Tuple[int, int, str]: + master_addr = "127.0.0.1" + master_port = "29500" + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = master_port + + mpi_rank = int(os.environ.get("PMI_RANK", -1)) + mpi_world_size = int(os.environ.get("PMI_SIZE", -1)) + rank = mpi_rank if mpi_world_size > 0 else os.environ.get("RANK", 0) + world_size = (mpi_world_size if mpi_world_size > 0 else os.environ.get( + "WORLD_SIZE", 1)) + + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + + init_method = f"tcp://{master_addr}:{master_port}" + + return rank, world_size, init_method + + +def custom_optimizer(model: Any, optimizer: Any) -> Tuple[Any, Any]: + return ipex.optimize(model, optimizer=optimizer) + + +if __name__ == '__main__': + rank, world_size, init_method = get_dist_params() + dist.init_process_group(backend="ccl", init_method=init_method, + world_size=world_size, rank=rank) + + argparser = get_predefined_args() + args = argparser.parse_args() + setattr(args, 'device', 'xpu') + + assert args.dataset in supported_sets.keys(), \ + f"Dataset {args.dataset} isn't supported." + data, num_classes = get_dataset(args.dataset, args.root) + + run(rank, world_size, args, num_classes, data, custom_optimizer)