Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Support Multi-GPU for Transformer model #356

Merged
merged 16 commits into from
Feb 12, 2019
41 changes: 21 additions & 20 deletions examples/pytorch/transformer/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
from .fields import *
from .utils import prepare_dataset
import os
import numpy as np
import random

class ClassificationDataset:
class ClassificationDataset(object):
"Dataset class for classification task."
def __init__(self):
raise NotImplementedError

class TranslationDataset:
class TranslationDataset(object):
'''
Dataset class for translation task.
By default, the source language shares the same vocabulary with the target language.
Expand Down Expand Up @@ -90,20 +90,27 @@ def sos_id(self):
def eos_id(self):
return self.vocab[self.EOS_TOKEN]

def __call__(self, graph_pool, mode='train', batch_size=32, k=1, devices=['cpu']):
def __call__(self, graph_pool, mode='train', batch_size=32, k=1,
device='cpu', dev_rank=0, ndev=1):
'''
Create a batched graph correspond to the mini-batch of the dataset.
args:
graph_pool: a GraphPool object for accelerating.
mode: train/valid/test
batch_size: batch size
devices: ['cpu'] or a list of gpu ids.
k: beam size(only required for test)
device: torch.device
k: beam size(only required for test)
'''
dev_id, gs = 0, []
src_data, tgt_data = self.src[mode], self.tgt[mode]
n = len(src_data)
order = np.random.permutation(n) if mode == 'train' else range(n)
# make sure all devices have the same number of batches
n = n // ndev * ndev

# XXX: is partition then shuffle equivalent to shuffle then partition?
yzh119 marked this conversation as resolved.
Show resolved Hide resolved
order = list(range(dev_rank, n, ndev))
if mode == 'train':
random.shuffle(order)

src_buf, tgt_buf = [], []

for idx in order:
Expand All @@ -115,22 +122,16 @@ def __call__(self, graph_pool, mode='train', batch_size=32, k=1, devices=['cpu']
tgt_buf.append(tgt_sample)
if len(src_buf) == batch_size:
if mode == 'test':
assert len(devices) == 1 # we only allow single gpu for inference
yield graph_pool.beam(src_buf, self.sos_id, self.MAX_LENGTH, k, device=devices[0])
yield graph_pool.beam(src_buf, self.sos_id, self.MAX_LENGTH, k, device=device)
else:
gs.append(graph_pool(src_buf, tgt_buf, device=devices[dev_id]))
dev_id += 1
if dev_id == len(devices):
yield gs if len(devices) > 1 else gs[0]
dev_id, gs = 0, []
yield graph_pool(src_buf, tgt_buf, device=device)
src_buf, tgt_buf = [], []

if len(src_buf) != 0:
if mode == 'test':
yield graph_pool.beam(src_buf, self.sos_id, self.MAX_LENGTH, k, device=devices[0])
yield graph_pool.beam(src_buf, self.sos_id, self.MAX_LENGTH, k, device=device)
else:
gs.append(graph_pool(src_buf, tgt_buf, device=devices[dev_id]))
yield gs if len(devices) > 1 else gs[0]
yield graph_pool(src_buf, tgt_buf, device=device)

def get_sequence(self, batch):
"return a list of sequence from a list of index arrays"
Expand All @@ -151,8 +152,8 @@ def get_dataset(dataset):
raise NotImplementedError
elif dataset == 'copy' or dataset == 'sort':
return TranslationDataset(
'data/{}'.format(dataset),
('in', 'out'),
'data/{}'.format(dataset),
('in', 'out'),
train='train',
valid='valid',
test='test',
Expand Down
50 changes: 32 additions & 18 deletions examples/pytorch/transformer/loss/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist

class LabelSmoothing(nn.Module):
"""
Expand Down Expand Up @@ -44,7 +45,10 @@ def __init__(self, criterion, opt=None):
super(SimpleLossCompute, self).__init__()
self.criterion = criterion
self.opt = opt
self.reset()
self.acc_loss = 0
self.n_correct = 0
self.norm_term = 0
self.loss = 0

@property
def avg_loss(self):
Expand All @@ -54,32 +58,42 @@ def avg_loss(self):
def accuracy(self):
return (self.n_correct + self.eps) / (self.norm_term + self.eps)

def reset(self):
self.acc_loss = 0
self.n_correct = 0
self.norm_term = 0
def backward_and_step(self):
self.loss.backward()
self.opt.step()
self.opt.optimizer.zero_grad()

def __call__(self, y_pred, y, norm):
y_pred = y_pred.contiguous().view(-1, y_pred.shape[-1])
y = y.contiguous().view(-1)
loss = self.criterion(
self.loss = self.criterion(
y_pred, y
) / norm
if self.opt is not None:
loss.backward()
self.opt.step()
self.opt.optimizer.zero_grad()
self.backward_and_step()
self.n_correct += ((y_pred.max(dim=-1)[1] == y) & (y != self.criterion.padding_idx)).sum().item()
self.acc_loss += loss.item() * norm
self.acc_loss += self.loss.item() * norm
self.norm_term += norm
return loss.item() * norm
return self.loss.item() * norm

class MultiGPULossCompute(SimpleLossCompute):
def __init__(self, criterion, devices, opt=None, chunk_size=5):
self.criterion = criterion
self.opt = opt
self.devices = devices
self.chunk_size = chunk_size
def __init__(self, criterion, ndev, grad_accum, model, opt=None):
super(MultiGPULossCompute, self).__init__(criterion, opt)
self.ndev = ndev
self.grad_accum = grad_accum
self.model = model
self.count = 0

def __call__(self, y_preds, ys, norms):
pass
def backward_and_step(self):
# multi-gpu synchronous backward
self.loss.backward()
self.count += 1
# accumulate self.grad_accum times then synchronize and update
if self.count == self.grad_accum:
for param in self.model.parameters():
if param.requires_grad and param.grad is not None:
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
param.grad.data /= self.ndev
self.opt.step()
self.opt.optimizer.zero_grad()
self.count = 0
149 changes: 105 additions & 44 deletions examples/pytorch/transformer/translation_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,21 @@
"""
from modules import *
from parallel import *
from loss import *
from loss import *
from optims import *
from dataset import *
from modules.config import *
from modules.viz import *
from tqdm import tqdm
#from modules.viz import *
#from tqdm import tqdm
import numpy as np
import argparse
import torch
from functools import partial

def run_epoch(data_iter, model, loss_compute, is_train=True):
def run_epoch(epoch, data_iter, dev_rank, ndev, model, loss_compute, is_train=True):
universal = isinstance(model, UTransformer)
for i, g in tqdm(enumerate(data_iter)):
for i, g in enumerate(data_iter):
#print("Dev {} start batch {}".format(dev_rank, i))
with T.set_grad_enabled(is_train):
if isinstance(model, list):
model = model[:len(gs)]
Expand All @@ -36,64 +39,122 @@ def run_epoch(data_iter, model, loss_compute, is_train=True):
for step in range(1, model.MAX_DEPTH + 1):
print("nodes entering step {}: {:.2f}%".format(step, (1.0 * model.stat[step] / model.stat[0])))
model.reset_stat()
print('average loss: {}'.format(loss_compute.avg_loss))
print('accuracy: {}'.format(loss_compute.accuracy))
print('Epoch {} {}: Dev {} average loss: {}, accuracy {}'.format(
epoch, "Training" if is_train else "Evaluating",
dev_rank, loss_compute.avg_loss, loss_compute.accuracy))

if __name__ == '__main__':
if not os.path.exists('checkpoints'):
os.makedirs('checkpoints')
np.random.seed(1111)
argparser = argparse.ArgumentParser('training translation model')
argparser.add_argument('--gpus', default='-1', type=str, help='gpu id')
argparser.add_argument('--N', default=6, type=int, help='enc/dec layers')
argparser.add_argument('--dataset', default='multi30k', help='dataset')
argparser.add_argument('--batch', default=128, type=int, help='batch size')
argparser.add_argument('--viz', action='store_true', help='visualize attention')
argparser.add_argument('--universal', action='store_true', help='use universal transformer')
args = argparser.parse_args()
args_filter = ['batch', 'gpus', 'viz']
exp_setting = '-'.join('{}'.format(v) for k, v in vars(args).items() if k not in args_filter)
devices = ['cpu'] if args.gpus == '-1' else [int(gpu_id) for gpu_id in args.gpus.split(',')]
def run(dev_id, args):
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
master_ip=args.master_ip, master_port=args.master_port)
world_size = args.ngpu
torch.distributed.init_process_group(backend="nccl",
init_method=dist_init_method,
world_size=world_size,
rank=dev_id)
gpu_rank = torch.distributed.get_rank()
assert gpu_rank == dev_id
main(dev_id, args)

def main(dev_id, args):
if dev_id == -1:
device = torch.device('cpu')
else:
device = torch.device('cuda:{}'.format(dev_id))
dataset = get_dataset(args.dataset)

V = dataset.vocab_size
criterion = LabelSmoothing(V, padding_idx=dataset.pad_id, smoothing=0.1)
dim_model = 512

graph_pool = GraphPool()
model = make_model(V, V, N=args.N, dim_model=dim_model, universal=args.universal)
model = make_model(V, V, N=args.N, dim_model=dim_model,
universal=args.universal)

# Sharing weights between Encoder & Decoder
model.src_embed.lut.weight = model.tgt_embed.lut.weight
model.generator.proj.weight = model.tgt_embed.lut.weight

model, criterion = model.to(devices[0]), criterion.to(devices[0])
model, criterion = model.to(device), criterion.to(device)
model_opt = NoamOpt(dim_model, 1, 400,
T.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9))
if len(devices) > 1:
model, criterion = map(nn.parallel.replicate, [model, criterion], [devices, devices])
loss_compute = SimpleLossCompute if len(devices) == 1 else MultiGPULossCompute
T.optim.Adam(model.parameters(), lr=1e-3,
betas=(0.9, 0.98), eps=1e-9))

if args.ngpu > 1:
dev_rank = dev_id # current device id
ndev = args.ngpu # number of devices (including cpu)
loss_compute = partial(MultiGPULossCompute, criterion, args.ngpu,
args.grad_accum, model)
else: # cpu or single gpu case
dev_rank = 0
ndev = 1
loss_compute = partial(SimpleLossCompute, criterion)

for epoch in range(100):
train_iter = dataset(graph_pool, mode='train', batch_size=args.batch, devices=devices)
valid_iter = dataset(graph_pool, mode='valid', batch_size=args.batch, devices=devices)
print('Epoch: {} Training...'.format(epoch))
start = time.time()
train_iter = dataset(graph_pool, mode='train', batch_size=args.batch,
device=device, dev_rank=dev_rank, ndev=ndev)
valid_iter = dataset(graph_pool, mode='valid', batch_size=args.batch,
device=device, dev_rank=dev_rank, ndev=ndev)
model.train(True)
run_epoch(train_iter, model,
loss_compute(criterion, model_opt), is_train=True)
print('Epoch: {} Evaluating...'.format(epoch))
run_epoch(epoch, train_iter, dev_rank, ndev, model,
loss_compute(opt=model_opt), is_train=True)
model.att_weight_map = None
model.eval()
run_epoch(valid_iter, model,
loss_compute(criterion, None), is_train=False)
# Visualize attention
if args.viz:
src_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='src')
tgt_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='tgt')[:-1]
draw_atts(model.att_weight_map, src_seq, tgt_seq, exp_setting, 'epoch_{}'.format(epoch))
run_epoch(epoch, valid_iter, dev_rank, ndev, model,
loss_compute(opt=None), is_train=False)
end = time.time()
if dev_rank == 0:
print("epoch time: {}".format(end - start))

"""
# Visualize attention
if args.viz:
src_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='src')
tgt_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='tgt')[:-1]
draw_atts(model.att_weight_map, src_seq, tgt_seq, exp_setting, 'epoch_{}'.format(epoch))

print('----------------------------------')
with open('checkpoints/{}-{}.pkl'.format(exp_setting, epoch), 'wb') as f:
th.save(model.state_dict(), f)
print('----------------------------------')
with open('checkpoints/{}-{}.pkl'.format(exp_setting, epoch), 'wb') as f:
torch.save(model.state_dict(), f)
"""

if __name__ == '__main__':
if not os.path.exists('checkpoints'):
os.makedirs('checkpoints')
np.random.seed(1111)
argparser = argparse.ArgumentParser('training translation model')
argparser.add_argument('--gpus', default='-1', type=str, help='gpu id')
argparser.add_argument('--N', default=6, type=int, help='enc/dec layers')
argparser.add_argument('--dataset', default='multi30k', help='dataset')
argparser.add_argument('--batch', default=128, type=int, help='batch size')
argparser.add_argument('--viz', action='store_true',
help='visualize attention')
argparser.add_argument('--universal', action='store_true',
help='use universal transformer')
argparser.add_argument('--master-ip', type=str, default='127.0.0.1',
help='master ip address')
argparser.add_argument('--master-port', type=str, default='12345',
help='master port')
argparser.add_argument('--grad-accum', type=int, default=1,
help='accumulate gradients for this many times '
'then update weights')
args = argparser.parse_args()
print(args)
#args_filter = ['batch', 'gpus', 'viz']
#exp_setting = '-'.join('{}'.format(v) for k, v in vars(args).items() if k not in args_filter)
#devices = ['cpu'] if args.gpus == '-1' else [int(gpu_id) for gpu_id in args.gpus.split(',')]
devices = list(map(int, args.gpus.split(',')))
if len(devices) == 1:
args.ngpu = 0 if devices[0] < 0 else 1
main(devices[0], args)
else:
args.ngpu = len(devices)
mp = torch.multiprocessing.get_context('spawn')
procs = []
for dev_id in devices:
procs.append(mp.Process(target=run, args=(dev_id, args),
daemon=True))
procs[-1].start()
for p in procs:
p.join()