forked from EESSI/test-suite
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial commit for an EESSI PyTorch test that uses torchvision models
- Loading branch information
Caspar van Leeuwen
committed
Mar 27, 2024
1 parent
ede05e4
commit 2e7bf95
Showing
2 changed files
with
366 additions
and
0 deletions.
There are no files selected for viewing
153 changes: 153 additions & 0 deletions
153
eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
import reframe as rfm | ||
import reframe.utility.sanity as sn | ||
|
||
from eessi.testsuite import hooks | ||
from eessi.testsuite.constants import SCALES, TAGS, DEVICE_TYPES, COMPUTE_UNIT, CPU, CPU_SOCKET, GPU | ||
from eessi.testsuite.utils import find_modules, log | ||
|
||
class PyTorch_torchvision(rfm.RunOnlyRegressionTest): | ||
nn_model = parameter(['vgg16', 'resnet50', 'resnet152', 'densenet121', 'mobilenet_v3_large']) | ||
### SHOULD BE DETERMINED BY SCALE | ||
#n_processes = parameter([1, 2, 4, 8, 16]) | ||
scale = parameter(SCALES.keys()) | ||
# Not sure how we would ensure the horovod module is _also_ loaded... | ||
# parallel_strategy = parameter([None, 'horovod', 'ddp']) | ||
parallel_strategy = parameter([None, 'ddp']) | ||
compute_device = variable(str) | ||
# module_name = parameter(find_modules('PyTorch-bundle')) | ||
module_name = parameter(find_modules('torchvision')) | ||
|
||
descr = 'Benchmark that runs a selected torchvision model on synthetic data' | ||
|
||
executable = 'python' | ||
|
||
valid_prog_environs = ['default'] | ||
valid_systems = ['*'] | ||
|
||
time_limit = '30m' | ||
|
||
@run_after('init') | ||
def prepare_test(self): | ||
|
||
# Set nn_model as executable option | ||
self.executable_opts = ['pytorch_synthetic_benchmark.py --model %s' % self.nn_model] | ||
|
||
# If not a GPU run, disable CUDA | ||
if self.compute_device != DEVICE_TYPES[GPU]: | ||
self.executable_opts += ['--no-cuda'] | ||
|
||
|
||
|
||
@run_after('init') | ||
def apply_init_hooks(self): | ||
# Filter on which scales are supported by the partitions defined in the ReFrame configuration | ||
hooks.filter_supported_scales(self) | ||
|
||
# Make sure that GPU tests run in partitions that support running on a GPU, | ||
# and that CPU-only tests run in partitions that support running CPU-only. | ||
# Also support setting valid_systems on the cmd line. | ||
hooks.filter_valid_systems_by_device_type(self, required_device_type=self.compute_device) | ||
|
||
# Support selecting modules on the cmd line. | ||
hooks.set_modules(self) | ||
|
||
# Support selecting scales on the cmd line via tags. | ||
hooks.set_tag_scale(self) | ||
|
||
@run_after('init') | ||
def set_tag_ci(self): | ||
if self.nn_model == 'resnet50': | ||
self.tags.add(TAGS['CI']) | ||
|
||
@run_after('setup') | ||
def apply_setup_hooks(self): | ||
if self.compute_device==DEVICE_TYPES[GPU]: | ||
hooks.assign_tasks_per_compute_unit(test=self, compute_unit=COMPUTE_UNIT[GPU]) | ||
else: | ||
# Hybrid code, so launch 1 rank per socket. | ||
# Probably, launching 1 task per NUMA domain is even better, but the current hook doesn't support it | ||
hooks.assign_tasks_per_compute_unit(test=self, compute_unit=COMPUTE_UNIT[CPU_SOCKET]) | ||
|
||
# This is a hybrid test, binding is important for performance | ||
hooks.set_compact_process_binding(self) | ||
|
||
@run_after('setup') | ||
def set_ddp_env_vars(self): | ||
# Set environment variables for PyTorch DDP | ||
### TODO: THIS WILL ONLY WORK WITH SLURM, WE SHOULD MAKE A SKIP_IF BASED ON THE SCHEDULER | ||
if self.parallel_strategy == 'ddp': | ||
self.prerun_cmds = [ | ||
'export MASTER_PORT=$(expr 10000 + $(echo -n $SLURM_JOBID | tail -c 4))', | ||
'export WORLD_SIZE=%s' % self.num_tasks, | ||
'echo "WORLD_SIZE="${WORLD_SIZE}', | ||
'master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)', | ||
'export MASTER_ADDR=${master_addr}', | ||
'echo "MASTER_ADDR"=${master_addr}', | ||
] | ||
|
||
|
||
@run_after('setup') | ||
def filter_invalid_parameter_combinations(self): | ||
# We cannot detect this situation before the setup phase, because it requires self.num_tasks. | ||
# Thus, the core count of the node needs to be known, which is only the case after the setup phase. | ||
msg=f"Skipping test: parallel strategy is 'None', but requested process count is larger than one ({self.num_tasks})" | ||
self.skip_if(self.num_tasks > 1 and self.parallel_strategy is None, msg) | ||
msg=f"Skipping test: parallel strategy is {self.parallel_strategy}, but only one process is requested" | ||
self.skip_if(self.num_tasks == 1 and not self.parallel_strategy is None, msg) | ||
|
||
@run_after('setup') | ||
def pass_parallel_strategy(self): | ||
# Set parallelization strategy when using more than one process | ||
if self.num_tasks != 1: | ||
self.executable_opts += ['--use-%s' % self.parallel_strategy] | ||
|
||
@run_after('setup') | ||
def avoid_horovod_cpu_contention(self): | ||
# Horovod had issues with CPU performance, see https://github.com/horovod/horovod/issues/2804 | ||
# The root cause is Horovod having two threads with very high utilization, which interferes with | ||
# the compute threads. It was fixed, but seems to be broken again in Horovod 0.28.1 | ||
# The easiest workaround is to reduce the number of compute threads by 2 | ||
if self.compute_device == DEVICE_TYPES[CPU] and self.parallel_strategy == 'horovod': | ||
self.env_vars['OMP_NUM_THREADS'] = max(self.num_cpus_per_task-2, 2) # Never go below 2 compute threads | ||
|
||
@sanity_function | ||
def assert_num_ranks(self): | ||
'''Assert that the number of reported CPUs/GPUs used is correct''' | ||
return sn.assert_found(r'Total img/sec on %s .PU\(s\):.*' % self.num_tasks, self.stdout) | ||
|
||
|
||
@performance_function('img/sec') | ||
def total_throughput(self): | ||
'''Total training throughput, aggregated over all CPUs/GPUs''' | ||
return sn.extractsingle(r'Total img/sec on [0-9]+ .PU\(s\):\s+(?P<perf>\S+)', self.stdout, 'perf', float) | ||
|
||
@performance_function('img/sec') | ||
def througput_per_CPU(self): | ||
'''Training througput per CPU''' | ||
if self.compute_device == DEVICE_TYPES[CPU]: | ||
return sn.extractsingle(r'Img/sec per CPU:\s+(?P<perf_per_cpu>\S+)', self.stdout, 'perf_per_cpu', float) | ||
else: | ||
return sn.extractsingle(r'Img/sec per GPU:\s+(?P<perf_per_gpu>\S+)', self.stdout, 'perf_per_gpu', float) | ||
|
||
@rfm.simple_test | ||
class PyTorch_torchvision_CPU(PyTorch_torchvision): | ||
compute_device = DEVICE_TYPES[CPU] | ||
|
||
|
||
@rfm.simple_test | ||
class PyTorch_torchvision_GPU(PyTorch_torchvision): | ||
compute_device = DEVICE_TYPES[GPU] | ||
precision = parameter(['default', 'mixed']) | ||
|
||
@run_after('init') | ||
def prepare_gpu_test(self): | ||
# Set precision | ||
if self.precision == 'mixed': | ||
self.executable_opts += ['--use-amp'] | ||
|
||
@run_after('init') | ||
def skip_hvd_plus_amp(self): | ||
'''Skip combination of horovod and AMP, it does not work see https://github.com/horovod/horovod/issues/1417''' | ||
if self.parallel_strategy == 'horovod' and self.precision == 'mixed': | ||
self.valid_systems = [] | ||
|
213 changes: 213 additions & 0 deletions
213
eessi/testsuite/tests/apps/PyTorch/src/pytorch_synthetic_benchmark.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
from __future__ import print_function | ||
|
||
import argparse | ||
import torch.backends.cudnn as cudnn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
import torch.utils.data.distributed | ||
from torchvision import models | ||
import timeit | ||
import numpy as np | ||
import os | ||
|
||
# Benchmark settings | ||
parser = argparse.ArgumentParser(description='PyTorch Synthetic Benchmark', | ||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||
parser.add_argument('--fp16-allreduce', action='store_true', default=False, | ||
help='use fp16 compression during allreduce') | ||
|
||
parser.add_argument('--model', type=str, default='resnet50', | ||
help='model to benchmark') | ||
parser.add_argument('--batch-size', type=int, default=32, | ||
help='input batch size') | ||
|
||
parser.add_argument('--num-warmup-batches', type=int, default=10, | ||
help='number of warm-up batches that don\'t count towards benchmark') | ||
parser.add_argument('--num-batches-per-iter', type=int, default=10, | ||
help='number of batches per benchmark iteration') | ||
parser.add_argument('--num-iters', type=int, default=10, | ||
help='number of benchmark iterations') | ||
|
||
parser.add_argument('--no-cuda', action='store_true', default=False, | ||
help='disables CUDA training') | ||
|
||
parser.add_argument('--use-adasum', action='store_true', default=False, | ||
help='use adasum algorithm to do reduction') | ||
parser.add_argument('--use-horovod', action='store_true', default=False) | ||
parser.add_argument('--use-ddp', action='store_true', default=False) | ||
|
||
parser.add_argument('--use-amp', action='store_true', default=False, | ||
help='Use PyTorch Automatic Mixed Precision (AMP)') | ||
|
||
args = parser.parse_args() | ||
args.cuda = not args.no_cuda and torch.cuda.is_available() | ||
|
||
if args.use_horovod and args.use_ddp: | ||
print("You can't specify to use both Horovod and Pytorch DDP, exiting...") | ||
exit(1) | ||
|
||
# Set a default rank and world size, also for when ddp and horovod are not used | ||
rank = 0 | ||
world_size=1 | ||
if args.use_horovod: | ||
import horovod.torch as hvd | ||
hvd.init() | ||
rank = hvd.local_rank() | ||
world_size = hvd.size() | ||
|
||
if args.cuda: | ||
# If launched with srun, you are in a CGROUP with only 1 GPU, so you don't need to set it. | ||
# If launched with mpirun, you see ALL local GPUs on the node, and you need to set which one | ||
# this rank should use. | ||
visible_gpus = torch.cuda.device_count() | ||
# Horovod: pin GPU to local rank. | ||
if visible_gpus > 1: | ||
torch.cuda.set_device(hvd.local_rank()) | ||
|
||
# Should only be uncommented for debugging | ||
# In ReFrame tests, a print from each rank can mess up the output file, causing | ||
# performance and sanity patterns to not be found | ||
# print(f"hvd.local_rank: {rank}", flush=True) | ||
|
||
|
||
if args.use_ddp: | ||
import torch.distributed as dist | ||
from torch.nn.parallel import DistributedDataParallel as DDP | ||
from socket import gethostname | ||
|
||
def setup(rank, world_size): | ||
# initialize the process group | ||
if args.cuda: | ||
dist.init_process_group("nccl", rank=rank, world_size=world_size) | ||
else: | ||
dist.init_process_group("gloo", rank=rank, world_size=world_size) | ||
|
||
def cleanup(): | ||
# clean up the distributed environment | ||
dist.destroy_process_group() | ||
|
||
world_size = int(os.environ["SLURM_NTASKS"]) | ||
# If launched with mpirun, get rank from this | ||
rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", -1)) | ||
if rank == -1: | ||
# Else it's launched with srun, get rank from this | ||
rank = int(os.environ["SLURM_PROCID"]) | ||
|
||
setup(rank, world_size) | ||
# log(f"Group initialized? {dist.is_initialized()}", rank) | ||
if rank == 0: print(f"Group initialized? {dist.is_initialized()}", flush=True) | ||
|
||
# If launched with srun, you are in a CGROUP with only 1 GPU, so you don't need to set it. | ||
# If launched with mpirun, you see ALL local GPUs on the node, and you need to set which one | ||
# this rank should use. | ||
visible_gpus = torch.cuda.device_count() | ||
if visible_gpus > 1: | ||
local_rank = rank - visible_gpus * (rank // visible_gpus) | ||
torch.cuda.set_device(local_rank) | ||
print(f"host: {gethostname()}, rank: {rank}, local_rank: {local_rank}") | ||
else: | ||
print(f"host: {gethostname()}, rank: {rank}") | ||
|
||
# This relies on the 'rank' set in the if args.use_horovod or args.use_ddp sections | ||
def log(s, nl=True): | ||
if (args.use_horovod or args.use_ddp) and rank != 0: | ||
return | ||
print(s, end='\n' if nl else '', flush=True) | ||
|
||
log(f"World size: {world_size}") | ||
|
||
# Used to be needed, but now seems that different SLURM tasks run within their own cgroup | ||
# Each cgroup only contains a single GPU, which has GPU ID 0. So no longer needed to set | ||
# one of the ranks to GPU 0 and one to GPU 1 | ||
#if args.cuda and args.use_horovod: | ||
# # Horovod: pin GPU to local rank. | ||
# torch.cuda.set_device(hvd.local_rank()) | ||
|
||
torch.set_num_threads(int(os.environ['OMP_NUM_THREADS'])) | ||
torch.set_num_interop_threads(2) | ||
|
||
cudnn.benchmark = True | ||
|
||
# Set up standard model. | ||
model = getattr(models, args.model)() | ||
|
||
# By default, Adasum doesn't need scaling up learning rate. | ||
lr_scaler = hvd.size() if not args.use_adasum and args.use_horovod else 1 | ||
|
||
if args.cuda: | ||
# Move model to GPU. | ||
model.cuda() | ||
# If using GPU Adasum allreduce, scale learning rate by local_size. | ||
if args.use_horovod and args.use_adasum and hvd.nccl_built(): | ||
lr_scaler = hvd.local_size() | ||
|
||
# If using DDP, wrap model | ||
if args.use_ddp: | ||
model = DDP(model) | ||
|
||
optimizer = optim.SGD(model.parameters(), lr=0.01 * lr_scaler) | ||
|
||
# Horovod: (optional) compression algorithm. | ||
if args.use_horovod: | ||
compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none | ||
|
||
# Horovod: wrap optimizer with DistributedOptimizer. | ||
if args.use_horovod: | ||
optimizer = hvd.DistributedOptimizer(optimizer, | ||
named_parameters=model.named_parameters(), | ||
compression=compression, | ||
op=hvd.Adasum if args.use_adasum else hvd.Average) | ||
|
||
# Horovod: broadcast parameters & optimizer state. | ||
hvd.broadcast_parameters(model.state_dict(), root_rank=0) | ||
hvd.broadcast_optimizer_state(optimizer, root_rank=0) | ||
|
||
# Set up fixed fake data | ||
data = torch.randn(args.batch_size, 3, 224, 224) | ||
target = torch.LongTensor(args.batch_size).random_() % 1000 | ||
if args.cuda: | ||
data, target = data.cuda(), target.cuda() | ||
|
||
# Create GradScaler for automatic mixed precision | ||
scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp) | ||
|
||
# Set device_type for AMP | ||
if args.cuda: | ||
device_type="cuda" | ||
else: | ||
device_type="cpu" | ||
|
||
def benchmark_step(): | ||
optimizer.zero_grad() | ||
with torch.autocast(device_type=device_type, enabled=args.use_amp): | ||
output = model(data) | ||
loss = F.cross_entropy(output, target) | ||
scaler.scale(loss).backward() | ||
scaler.step(optimizer) | ||
scaler.update() | ||
|
||
log('Model: %s' % args.model) | ||
log('Batch size: %d' % args.batch_size) | ||
device = 'GPU' if args.cuda else 'CPU' | ||
if args.use_horovod: | ||
log('Number of %ss: %d' % (device, hvd.size())) | ||
|
||
# Warm-up | ||
log('Running warmup...') | ||
timeit.timeit(benchmark_step, number=args.num_warmup_batches) | ||
|
||
# Benchmark | ||
log('Running benchmark...') | ||
img_secs = [] | ||
for x in range(args.num_iters): | ||
time = timeit.timeit(benchmark_step, number=args.num_batches_per_iter) | ||
img_sec = args.batch_size * args.num_batches_per_iter / time | ||
log('Iter #%d: %.1f img/sec per %s' % (x, img_sec, device)) | ||
img_secs.append(img_sec) | ||
|
||
# Results | ||
img_sec_mean = np.mean(img_secs) | ||
img_sec_conf = 1.96 * np.std(img_secs) | ||
log('Img/sec per %s: %.1f +-%.1f' % (device, img_sec_mean, img_sec_conf)) | ||
log('Total img/sec on %d %s(s): %.1f +-%.1f' % | ||
(world_size, device, world_size * img_sec_mean, world_size * img_sec_conf)) |