Skip to content

Commit

Permalink
Initial commit for an EESSI PyTorch test that uses torchvision models
Browse files Browse the repository at this point in the history
  • Loading branch information
Caspar van Leeuwen committed Mar 27, 2024
1 parent ede05e4 commit 2e7bf95
Show file tree
Hide file tree
Showing 2 changed files with 366 additions and 0 deletions.
153 changes: 153 additions & 0 deletions eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import reframe as rfm
import reframe.utility.sanity as sn

from eessi.testsuite import hooks
from eessi.testsuite.constants import SCALES, TAGS, DEVICE_TYPES, COMPUTE_UNIT, CPU, CPU_SOCKET, GPU
from eessi.testsuite.utils import find_modules, log

class PyTorch_torchvision(rfm.RunOnlyRegressionTest):
nn_model = parameter(['vgg16', 'resnet50', 'resnet152', 'densenet121', 'mobilenet_v3_large'])
### SHOULD BE DETERMINED BY SCALE
#n_processes = parameter([1, 2, 4, 8, 16])
scale = parameter(SCALES.keys())
# Not sure how we would ensure the horovod module is _also_ loaded...
# parallel_strategy = parameter([None, 'horovod', 'ddp'])
parallel_strategy = parameter([None, 'ddp'])
compute_device = variable(str)
# module_name = parameter(find_modules('PyTorch-bundle'))
module_name = parameter(find_modules('torchvision'))

descr = 'Benchmark that runs a selected torchvision model on synthetic data'

executable = 'python'

valid_prog_environs = ['default']
valid_systems = ['*']

time_limit = '30m'

@run_after('init')
def prepare_test(self):

# Set nn_model as executable option
self.executable_opts = ['pytorch_synthetic_benchmark.py --model %s' % self.nn_model]

# If not a GPU run, disable CUDA
if self.compute_device != DEVICE_TYPES[GPU]:
self.executable_opts += ['--no-cuda']



@run_after('init')
def apply_init_hooks(self):
# Filter on which scales are supported by the partitions defined in the ReFrame configuration
hooks.filter_supported_scales(self)

# Make sure that GPU tests run in partitions that support running on a GPU,
# and that CPU-only tests run in partitions that support running CPU-only.
# Also support setting valid_systems on the cmd line.
hooks.filter_valid_systems_by_device_type(self, required_device_type=self.compute_device)

# Support selecting modules on the cmd line.
hooks.set_modules(self)

# Support selecting scales on the cmd line via tags.
hooks.set_tag_scale(self)

@run_after('init')
def set_tag_ci(self):
if self.nn_model == 'resnet50':
self.tags.add(TAGS['CI'])

@run_after('setup')
def apply_setup_hooks(self):
if self.compute_device==DEVICE_TYPES[GPU]:
hooks.assign_tasks_per_compute_unit(test=self, compute_unit=COMPUTE_UNIT[GPU])
else:
# Hybrid code, so launch 1 rank per socket.
# Probably, launching 1 task per NUMA domain is even better, but the current hook doesn't support it
hooks.assign_tasks_per_compute_unit(test=self, compute_unit=COMPUTE_UNIT[CPU_SOCKET])

# This is a hybrid test, binding is important for performance
hooks.set_compact_process_binding(self)

@run_after('setup')
def set_ddp_env_vars(self):
# Set environment variables for PyTorch DDP
### TODO: THIS WILL ONLY WORK WITH SLURM, WE SHOULD MAKE A SKIP_IF BASED ON THE SCHEDULER
if self.parallel_strategy == 'ddp':
self.prerun_cmds = [
'export MASTER_PORT=$(expr 10000 + $(echo -n $SLURM_JOBID | tail -c 4))',
'export WORLD_SIZE=%s' % self.num_tasks,
'echo "WORLD_SIZE="${WORLD_SIZE}',
'master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)',
'export MASTER_ADDR=${master_addr}',
'echo "MASTER_ADDR"=${master_addr}',
]


@run_after('setup')
def filter_invalid_parameter_combinations(self):
# We cannot detect this situation before the setup phase, because it requires self.num_tasks.
# Thus, the core count of the node needs to be known, which is only the case after the setup phase.
msg=f"Skipping test: parallel strategy is 'None', but requested process count is larger than one ({self.num_tasks})"
self.skip_if(self.num_tasks > 1 and self.parallel_strategy is None, msg)
msg=f"Skipping test: parallel strategy is {self.parallel_strategy}, but only one process is requested"
self.skip_if(self.num_tasks == 1 and not self.parallel_strategy is None, msg)

@run_after('setup')
def pass_parallel_strategy(self):
# Set parallelization strategy when using more than one process
if self.num_tasks != 1:
self.executable_opts += ['--use-%s' % self.parallel_strategy]

@run_after('setup')
def avoid_horovod_cpu_contention(self):
# Horovod had issues with CPU performance, see https://github.com/horovod/horovod/issues/2804
# The root cause is Horovod having two threads with very high utilization, which interferes with
# the compute threads. It was fixed, but seems to be broken again in Horovod 0.28.1
# The easiest workaround is to reduce the number of compute threads by 2
if self.compute_device == DEVICE_TYPES[CPU] and self.parallel_strategy == 'horovod':
self.env_vars['OMP_NUM_THREADS'] = max(self.num_cpus_per_task-2, 2) # Never go below 2 compute threads

@sanity_function
def assert_num_ranks(self):
'''Assert that the number of reported CPUs/GPUs used is correct'''
return sn.assert_found(r'Total img/sec on %s .PU\(s\):.*' % self.num_tasks, self.stdout)


@performance_function('img/sec')
def total_throughput(self):
'''Total training throughput, aggregated over all CPUs/GPUs'''
return sn.extractsingle(r'Total img/sec on [0-9]+ .PU\(s\):\s+(?P<perf>\S+)', self.stdout, 'perf', float)

@performance_function('img/sec')
def througput_per_CPU(self):
'''Training througput per CPU'''
if self.compute_device == DEVICE_TYPES[CPU]:
return sn.extractsingle(r'Img/sec per CPU:\s+(?P<perf_per_cpu>\S+)', self.stdout, 'perf_per_cpu', float)
else:
return sn.extractsingle(r'Img/sec per GPU:\s+(?P<perf_per_gpu>\S+)', self.stdout, 'perf_per_gpu', float)

@rfm.simple_test
class PyTorch_torchvision_CPU(PyTorch_torchvision):
compute_device = DEVICE_TYPES[CPU]


@rfm.simple_test
class PyTorch_torchvision_GPU(PyTorch_torchvision):
compute_device = DEVICE_TYPES[GPU]
precision = parameter(['default', 'mixed'])

@run_after('init')
def prepare_gpu_test(self):
# Set precision
if self.precision == 'mixed':
self.executable_opts += ['--use-amp']

@run_after('init')
def skip_hvd_plus_amp(self):
'''Skip combination of horovod and AMP, it does not work see https://github.com/horovod/horovod/issues/1417'''
if self.parallel_strategy == 'horovod' and self.precision == 'mixed':
self.valid_systems = []

213 changes: 213 additions & 0 deletions eessi/testsuite/tests/apps/PyTorch/src/pytorch_synthetic_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
from __future__ import print_function

import argparse
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data.distributed
from torchvision import models
import timeit
import numpy as np
import os

# Benchmark settings
parser = argparse.ArgumentParser(description='PyTorch Synthetic Benchmark',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--fp16-allreduce', action='store_true', default=False,
help='use fp16 compression during allreduce')

parser.add_argument('--model', type=str, default='resnet50',
help='model to benchmark')
parser.add_argument('--batch-size', type=int, default=32,
help='input batch size')

parser.add_argument('--num-warmup-batches', type=int, default=10,
help='number of warm-up batches that don\'t count towards benchmark')
parser.add_argument('--num-batches-per-iter', type=int, default=10,
help='number of batches per benchmark iteration')
parser.add_argument('--num-iters', type=int, default=10,
help='number of benchmark iterations')

parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')

parser.add_argument('--use-adasum', action='store_true', default=False,
help='use adasum algorithm to do reduction')
parser.add_argument('--use-horovod', action='store_true', default=False)
parser.add_argument('--use-ddp', action='store_true', default=False)

parser.add_argument('--use-amp', action='store_true', default=False,
help='Use PyTorch Automatic Mixed Precision (AMP)')

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

if args.use_horovod and args.use_ddp:
print("You can't specify to use both Horovod and Pytorch DDP, exiting...")
exit(1)

# Set a default rank and world size, also for when ddp and horovod are not used
rank = 0
world_size=1
if args.use_horovod:
import horovod.torch as hvd
hvd.init()
rank = hvd.local_rank()
world_size = hvd.size()

if args.cuda:
# If launched with srun, you are in a CGROUP with only 1 GPU, so you don't need to set it.
# If launched with mpirun, you see ALL local GPUs on the node, and you need to set which one
# this rank should use.
visible_gpus = torch.cuda.device_count()
# Horovod: pin GPU to local rank.
if visible_gpus > 1:
torch.cuda.set_device(hvd.local_rank())

# Should only be uncommented for debugging
# In ReFrame tests, a print from each rank can mess up the output file, causing
# performance and sanity patterns to not be found
# print(f"hvd.local_rank: {rank}", flush=True)


if args.use_ddp:
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from socket import gethostname

def setup(rank, world_size):
# initialize the process group
if args.cuda:
dist.init_process_group("nccl", rank=rank, world_size=world_size)
else:
dist.init_process_group("gloo", rank=rank, world_size=world_size)

def cleanup():
# clean up the distributed environment
dist.destroy_process_group()

world_size = int(os.environ["SLURM_NTASKS"])
# If launched with mpirun, get rank from this
rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", -1))
if rank == -1:
# Else it's launched with srun, get rank from this
rank = int(os.environ["SLURM_PROCID"])

setup(rank, world_size)
# log(f"Group initialized? {dist.is_initialized()}", rank)
if rank == 0: print(f"Group initialized? {dist.is_initialized()}", flush=True)

# If launched with srun, you are in a CGROUP with only 1 GPU, so you don't need to set it.
# If launched with mpirun, you see ALL local GPUs on the node, and you need to set which one
# this rank should use.
visible_gpus = torch.cuda.device_count()
if visible_gpus > 1:
local_rank = rank - visible_gpus * (rank // visible_gpus)
torch.cuda.set_device(local_rank)
print(f"host: {gethostname()}, rank: {rank}, local_rank: {local_rank}")
else:
print(f"host: {gethostname()}, rank: {rank}")

# This relies on the 'rank' set in the if args.use_horovod or args.use_ddp sections
def log(s, nl=True):
if (args.use_horovod or args.use_ddp) and rank != 0:
return
print(s, end='\n' if nl else '', flush=True)

log(f"World size: {world_size}")

# Used to be needed, but now seems that different SLURM tasks run within their own cgroup
# Each cgroup only contains a single GPU, which has GPU ID 0. So no longer needed to set
# one of the ranks to GPU 0 and one to GPU 1
#if args.cuda and args.use_horovod:
# # Horovod: pin GPU to local rank.
# torch.cuda.set_device(hvd.local_rank())

torch.set_num_threads(int(os.environ['OMP_NUM_THREADS']))
torch.set_num_interop_threads(2)

cudnn.benchmark = True

# Set up standard model.
model = getattr(models, args.model)()

# By default, Adasum doesn't need scaling up learning rate.
lr_scaler = hvd.size() if not args.use_adasum and args.use_horovod else 1

if args.cuda:
# Move model to GPU.
model.cuda()
# If using GPU Adasum allreduce, scale learning rate by local_size.
if args.use_horovod and args.use_adasum and hvd.nccl_built():
lr_scaler = hvd.local_size()

# If using DDP, wrap model
if args.use_ddp:
model = DDP(model)

optimizer = optim.SGD(model.parameters(), lr=0.01 * lr_scaler)

# Horovod: (optional) compression algorithm.
if args.use_horovod:
compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none

# Horovod: wrap optimizer with DistributedOptimizer.
if args.use_horovod:
optimizer = hvd.DistributedOptimizer(optimizer,
named_parameters=model.named_parameters(),
compression=compression,
op=hvd.Adasum if args.use_adasum else hvd.Average)

# Horovod: broadcast parameters & optimizer state.
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)

# Set up fixed fake data
data = torch.randn(args.batch_size, 3, 224, 224)
target = torch.LongTensor(args.batch_size).random_() % 1000
if args.cuda:
data, target = data.cuda(), target.cuda()

# Create GradScaler for automatic mixed precision
scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp)

# Set device_type for AMP
if args.cuda:
device_type="cuda"
else:
device_type="cpu"

def benchmark_step():
optimizer.zero_grad()
with torch.autocast(device_type=device_type, enabled=args.use_amp):
output = model(data)
loss = F.cross_entropy(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

log('Model: %s' % args.model)
log('Batch size: %d' % args.batch_size)
device = 'GPU' if args.cuda else 'CPU'
if args.use_horovod:
log('Number of %ss: %d' % (device, hvd.size()))

# Warm-up
log('Running warmup...')
timeit.timeit(benchmark_step, number=args.num_warmup_batches)

# Benchmark
log('Running benchmark...')
img_secs = []
for x in range(args.num_iters):
time = timeit.timeit(benchmark_step, number=args.num_batches_per_iter)
img_sec = args.batch_size * args.num_batches_per_iter / time
log('Iter #%d: %.1f img/sec per %s' % (x, img_sec, device))
img_secs.append(img_sec)

# Results
img_sec_mean = np.mean(img_secs)
img_sec_conf = 1.96 * np.std(img_secs)
log('Img/sec per %s: %.1f +-%.1f' % (device, img_sec_mean, img_sec_conf))
log('Total img/sec on %d %s(s): %.1f +-%.1f' %
(world_size, device, world_size * img_sec_mean, world_size * img_sec_conf))

0 comments on commit 2e7bf95

Please sign in to comment.