diff --git a/examples/pytorch/transformer/.gitmodules b/examples/pytorch/transformer/.gitmodules deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/examples/pytorch/transformer/dataset/__init__.py b/examples/pytorch/transformer/dataset/__init__.py index 3ac4c667146c..6a63044298de 100644 --- a/examples/pytorch/transformer/dataset/__init__.py +++ b/examples/pytorch/transformer/dataset/__init__.py @@ -2,14 +2,14 @@ from .fields import * from .utils import prepare_dataset import os -import numpy as np +import random -class ClassificationDataset: +class ClassificationDataset(object): "Dataset class for classification task." def __init__(self): raise NotImplementedError -class TranslationDataset: +class TranslationDataset(object): ''' Dataset class for translation task. By default, the source language shares the same vocabulary with the target language. @@ -22,17 +22,17 @@ def __init__(self, path, exts, train='train', valid='valid', test='test', vocab= vocab_path = os.path.join(path, vocab) self.src = {} self.tgt = {} - with open(os.path.join(path, train + '.' + exts[0]), 'r') as f: + with open(os.path.join(path, train + '.' + exts[0]), 'r', encoding='utf-8') as f: self.src['train'] = f.readlines() - with open(os.path.join(path, train + '.' + exts[1]), 'r') as f: + with open(os.path.join(path, train + '.' + exts[1]), 'r', encoding='utf-8') as f: self.tgt['train'] = f.readlines() - with open(os.path.join(path, valid + '.' + exts[0]), 'r') as f: + with open(os.path.join(path, valid + '.' + exts[0]), 'r', encoding='utf-8') as f: self.src['valid'] = f.readlines() - with open(os.path.join(path, valid + '.' + exts[1]), 'r') as f: + with open(os.path.join(path, valid + '.' + exts[1]), 'r', encoding='utf-8') as f: self.tgt['valid'] = f.readlines() - with open(os.path.join(path, test + '.' + exts[0]), 'r') as f: + with open(os.path.join(path, test + '.' + exts[0]), 'r', encoding='utf-8') as f: self.src['test'] = f.readlines() - with open(os.path.join(path, test + '.' + exts[1]), 'r') as f: + with open(os.path.join(path, test + '.' + exts[1]), 'r', encoding='utf-8') as f: self.tgt['test'] = f.readlines() if not os.path.exists(vocab_path): @@ -90,20 +90,30 @@ def sos_id(self): def eos_id(self): return self.vocab[self.EOS_TOKEN] - def __call__(self, graph_pool, mode='train', batch_size=32, k=1, devices=['cpu']): + def __call__(self, graph_pool, mode='train', batch_size=32, k=1, + device='cpu', dev_rank=0, ndev=1): ''' Create a batched graph correspond to the mini-batch of the dataset. args: graph_pool: a GraphPool object for accelerating. mode: train/valid/test batch_size: batch size - devices: ['cpu'] or a list of gpu ids. - k: beam size(only required for test) + k: beam size(only required for test) + device: str or torch.device + dev_rank: rank (id) of current device + ndev: number of devices ''' - dev_id, gs = 0, [] src_data, tgt_data = self.src[mode], self.tgt[mode] n = len(src_data) - order = np.random.permutation(n) if mode == 'train' else range(n) + # make sure all devices have the same number of batch + n = n // ndev * ndev + + # XXX: partition then shuffle may not be equivalent to shuffle then + # partition + order = list(range(dev_rank, n, ndev)) + if mode == 'train': + random.shuffle(order) + src_buf, tgt_buf = [], [] for idx in order: @@ -115,22 +125,16 @@ def __call__(self, graph_pool, mode='train', batch_size=32, k=1, devices=['cpu'] tgt_buf.append(tgt_sample) if len(src_buf) == batch_size: if mode == 'test': - assert len(devices) == 1 # we only allow single gpu for inference - yield graph_pool.beam(src_buf, self.sos_id, self.MAX_LENGTH, k, device=devices[0]) + yield graph_pool.beam(src_buf, self.sos_id, self.MAX_LENGTH, k, device=device) else: - gs.append(graph_pool(src_buf, tgt_buf, device=devices[dev_id])) - dev_id += 1 - if dev_id == len(devices): - yield gs if len(devices) > 1 else gs[0] - dev_id, gs = 0, [] + yield graph_pool(src_buf, tgt_buf, device=device) src_buf, tgt_buf = [], [] if len(src_buf) != 0: if mode == 'test': - yield graph_pool.beam(src_buf, self.sos_id, self.MAX_LENGTH, k, device=devices[0]) + yield graph_pool.beam(src_buf, self.sos_id, self.MAX_LENGTH, k, device=device) else: - gs.append(graph_pool(src_buf, tgt_buf, device=devices[dev_id])) - yield gs if len(devices) > 1 else gs[0] + yield graph_pool(src_buf, tgt_buf, device=device) def get_sequence(self, batch): "return a list of sequence from a list of index arrays" @@ -151,8 +155,8 @@ def get_dataset(dataset): raise NotImplementedError elif dataset == 'copy' or dataset == 'sort': return TranslationDataset( - 'data/{}'.format(dataset), - ('in', 'out'), + 'data/{}'.format(dataset), + ('in', 'out'), train='train', valid='valid', test='test', diff --git a/examples/pytorch/transformer/dataset/fields.py b/examples/pytorch/transformer/dataset/fields.py index d37c395bab73..712bbad1a6ad 100644 --- a/examples/pytorch/transformer/dataset/fields.py +++ b/examples/pytorch/transformer/dataset/fields.py @@ -16,7 +16,7 @@ def load(self, path): self.vocab_lst.append(self.pad_token) if self.unk_token is not None: self.vocab_lst.append(self.unk_token) - with open(path, 'r') as f: + with open(path, 'r', encoding='utf-8') as f: for token in f.readlines(): token = token.strip() self.vocab_lst.append(token) diff --git a/examples/pytorch/transformer/loss/__init__.py b/examples/pytorch/transformer/loss/__init__.py index 35d7603ec9ab..b0ede9a716b6 100644 --- a/examples/pytorch/transformer/loss/__init__.py +++ b/examples/pytorch/transformer/loss/__init__.py @@ -1,6 +1,7 @@ import torch as T import torch.nn as nn import torch.nn.functional as F +import torch.distributed as dist class LabelSmoothing(nn.Module): """ @@ -37,14 +38,37 @@ def forward(self, x, target): class SimpleLossCompute(nn.Module): eps=1e-8 - def __init__(self, criterion, opt=None): - """ - opt is required during training + def __init__(self, criterion, grad_accum, opt=None): + """Loss function and optimizer for single device + + Parameters + ---------- + criterion: torch.nn.Module + criterion to compute loss + grad_accum: int + number of batches to accumulate gradients + opt: Optimizer + Model optimizer to use. If None, then no backward and update will be + performed """ super(SimpleLossCompute, self).__init__() self.criterion = criterion self.opt = opt - self.reset() + self.acc_loss = 0 + self.n_correct = 0 + self.norm_term = 0 + self.loss = 0 + self.batch_count = 0 + self.grad_accum = grad_accum + + def __enter__(self): + self.batch_count = 0 + + def __exit__(self, type, value, traceback): + # if not enough batches accumulated and there are gradients not applied, + # do one more step + if self.batch_count > 0: + self.step() @property def avg_loss(self): @@ -54,32 +78,56 @@ def avg_loss(self): def accuracy(self): return (self.n_correct + self.eps) / (self.norm_term + self.eps) - def reset(self): - self.acc_loss = 0 - self.n_correct = 0 - self.norm_term = 0 + def step(self): + self.opt.step() + self.opt.optimizer.zero_grad() + + def backward_and_step(self): + self.loss.backward() + self.batch_count += 1 + # accumulate self.grad_accum times then synchronize and update + if self.batch_count == self.grad_accum: + self.step() + self.batch_count = 0 def __call__(self, y_pred, y, norm): y_pred = y_pred.contiguous().view(-1, y_pred.shape[-1]) y = y.contiguous().view(-1) - loss = self.criterion( - y_pred, y - ) / norm + self.loss = self.criterion(y_pred, y) / norm if self.opt is not None: - loss.backward() - self.opt.step() - self.opt.optimizer.zero_grad() + self.backward_and_step() self.n_correct += ((y_pred.max(dim=-1)[1] == y) & (y != self.criterion.padding_idx)).sum().item() - self.acc_loss += loss.item() * norm + self.acc_loss += self.loss.item() * norm self.norm_term += norm - return loss.item() * norm + return self.loss.item() * norm class MultiGPULossCompute(SimpleLossCompute): - def __init__(self, criterion, devices, opt=None, chunk_size=5): - self.criterion = criterion - self.opt = opt - self.devices = devices - self.chunk_size = chunk_size + def __init__(self, criterion, ndev, grad_accum, model, opt=None): + """Loss function and optimizer for multiple devices + + Parameters + ---------- + criterion: torch.nn.Module + criterion to compute loss + ndev: int + number of devices used + grad_accum: int + number of batches to accumulate gradients + model: torch.nn.Module + model to optimizer (needed to iterate and synchronize all parameters) + opt: Optimizer + Model optimizer to use. If None, then no backward and update will be + performed + """ + super(MultiGPULossCompute, self).__init__(criterion, grad_accum, opt=opt) + self.ndev = ndev + self.model = model - def __call__(self, y_preds, ys, norms): - pass + def step(self): + # multi-gpu synchronize gradients + for param in self.model.parameters(): + if param.requires_grad and param.grad is not None: + dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) + param.grad.data /= self.ndev + self.opt.step() + self.opt.optimizer.zero_grad() diff --git a/examples/pytorch/transformer/modules/attention.py b/examples/pytorch/transformer/modules/attention.py index b4bb1fe79f62..d86fd1dc9a1d 100644 --- a/examples/pytorch/transformer/modules/attention.py +++ b/examples/pytorch/transformer/modules/attention.py @@ -12,7 +12,7 @@ def __init__(self, h, dim_model): self.h = h # W_q, W_k, W_v, W_o self.linears = clones( - nn.Linear(dim_model, dim_model), 4 + nn.Linear(dim_model, dim_model, bias=False), 4 ) def get(self, x, fields='qkv'): diff --git a/examples/pytorch/transformer/modules/models.py b/examples/pytorch/transformer/modules/models.py index 1810fb57d9e4..896e20cfed9f 100644 --- a/examples/pytorch/transformer/modules/models.py +++ b/examples/pytorch/transformer/modules/models.py @@ -46,11 +46,11 @@ def pre_func(self, i, fields='qkv', l=0): layer = self.layers[i] def func(nodes): x = nodes.data['x'] - if fields == 'kv': - norm_x = x # In enc-dec attention, x has already been normalized. + norm_x = layer.sublayer[l].norm(x) if fields.startswith('q') else x + if fields != 'qkv': + return layer.src_attn.get(norm_x, fields) else: - norm_x = layer.sublayer[l].norm(x) - return layer.self_attn.get(norm_x, fields) + return layer.self_attn.get(norm_x, fields) return func def post_func(self, i, l=0): @@ -64,8 +64,6 @@ def func(nodes): return {'x': x if i < self.N - 1 else self.norm(x)} return func -lock = threading.Lock() - class Transformer(nn.Module): def __init__(self, encoder, decoder, src_embed, tgt_embed, pos_enc, generator, h, d_k): super(Transformer, self).__init__() @@ -124,9 +122,10 @@ def forward(self, graph): self.update_graph(g, edges, [(pre_q, nodes), (pre_kv, nodes_e)], [(post_func, nodes)]) # visualize attention - with lock: + """ if self.att_weight_map is None: self._register_att_map(g, graph.nid_arr['enc'][VIZ_IDX], graph.nid_arr['dec'][VIZ_IDX]) + """ return self.generator(g.ndata['x'][nids['dec']]) diff --git a/examples/pytorch/transformer/parallel.py b/examples/pytorch/transformer/parallel.py deleted file mode 100644 index 74023923e74c..000000000000 --- a/examples/pytorch/transformer/parallel.py +++ /dev/null @@ -1,56 +0,0 @@ -# Mostly then same with PyTorch -import threading -import torch - -def get_a_var(obj): - if isinstance(obj, torch.Tensor): - return obj - - if isinstance(obj, list) or isinstance(obj, tuple): - for result in map(get_a_var, obj): - if isinstance(result, torch.Tensor): - return result - if isinstance(obj, dict): - for result in map(get_a_var, obj.items()): - if isinstance(result, torch.Tensor): - return result - return None - - -def parallel_apply(modules, inputs): - assert len(modules) == len(inputs) - lock = threading.Lock() - results = {} - grad_enabled = torch.is_grad_enabled() - - def _worker(i, module, input): - torch.set_grad_enabled(grad_enabled) - try: - #with torch.cuda.device(device): - output = module(input) - with lock: - results[i] = output - except Exception as e: - with lock: - results[i] = e - - if len(modules) > 1: - threads = [threading.Thread(target=_worker, - args=(i, module, input)) - for i, (module, input) in - enumerate(zip(modules, inputs))] - - for thread in threads: - thread.start() - for thread in threads: - thread.join() - else: - _worker(0, modules[0], inputs[0]) - - outputs = [] - for i in range(len(inputs)): - output = results[i] - if isinstance(output, Exception): - raise output - outputs.append(output) - return outputs diff --git a/examples/pytorch/transformer/translation_train.py b/examples/pytorch/transformer/translation_train.py index 6614c884df42..a4cdcb63063c 100644 --- a/examples/pytorch/transformer/translation_train.py +++ b/examples/pytorch/transformer/translation_train.py @@ -1,28 +1,20 @@ -""" -In current version we use multi30k as the default training and validation set. -Multi-GPU support is required to train the model on WMT14. -""" from modules import * -from parallel import * -from loss import * +from loss import * from optims import * from dataset import * from modules.config import * -from modules.viz import * -from tqdm import tqdm +#from modules.viz import * import numpy as np import argparse +import torch +from functools import partial +import torch.distributed as dist -def run_epoch(data_iter, model, loss_compute, is_train=True): +def run_epoch(epoch, data_iter, dev_rank, ndev, model, loss_compute, is_train=True): universal = isinstance(model, UTransformer) - for i, g in tqdm(enumerate(data_iter)): - with T.set_grad_enabled(is_train): - if isinstance(model, list): - model = model[:len(gs)] - output = parallel_apply(model, g) - tgt_y = [g.tgt_y for g in gs] - n_tokens = [g.n_tokens for g in gs] - else: + with loss_compute: + for i, g in enumerate(data_iter): + with T.set_grad_enabled(is_train): if universal: output, loss_act = model(g) if is_train: loss_act.backward(retain_graph=True) @@ -30,70 +22,134 @@ def run_epoch(data_iter, model, loss_compute, is_train=True): output = model(g) tgt_y = g.tgt_y n_tokens = g.n_tokens - loss = loss_compute(output, tgt_y, n_tokens) + loss = loss_compute(output, tgt_y, n_tokens) if universal: for step in range(1, model.MAX_DEPTH + 1): print("nodes entering step {}: {:.2f}%".format(step, (1.0 * model.stat[step] / model.stat[0]))) model.reset_stat() - print('average loss: {}'.format(loss_compute.avg_loss)) - print('accuracy: {}'.format(loss_compute.accuracy)) + print('Epoch {} {}: Dev {} average loss: {}, accuracy {}'.format( + epoch, "Training" if is_train else "Evaluating", + dev_rank, loss_compute.avg_loss, loss_compute.accuracy)) -if __name__ == '__main__': - if not os.path.exists('checkpoints'): - os.makedirs('checkpoints') - np.random.seed(1111) - argparser = argparse.ArgumentParser('training translation model') - argparser.add_argument('--gpus', default='-1', type=str, help='gpu id') - argparser.add_argument('--N', default=6, type=int, help='enc/dec layers') - argparser.add_argument('--dataset', default='multi30k', help='dataset') - argparser.add_argument('--batch', default=128, type=int, help='batch size') - argparser.add_argument('--viz', action='store_true', help='visualize attention') - argparser.add_argument('--universal', action='store_true', help='use universal transformer') - args = argparser.parse_args() - args_filter = ['batch', 'gpus', 'viz'] - exp_setting = '-'.join('{}'.format(v) for k, v in vars(args).items() if k not in args_filter) - devices = ['cpu'] if args.gpus == '-1' else [int(gpu_id) for gpu_id in args.gpus.split(',')] +def run(dev_id, args): + dist_init_method = 'tcp://{master_ip}:{master_port}'.format( + master_ip=args.master_ip, master_port=args.master_port) + world_size = args.ngpu + torch.distributed.init_process_group(backend="nccl", + init_method=dist_init_method, + world_size=world_size, + rank=dev_id) + gpu_rank = torch.distributed.get_rank() + assert gpu_rank == dev_id + main(dev_id, args) +def main(dev_id, args): + if dev_id == -1: + device = torch.device('cpu') + else: + device = torch.device('cuda:{}'.format(dev_id)) + # Set current device + th.cuda.set_device(device) + # Prepare dataset dataset = get_dataset(args.dataset) - V = dataset.vocab_size criterion = LabelSmoothing(V, padding_idx=dataset.pad_id, smoothing=0.1) dim_model = 512 - + # Build graph pool graph_pool = GraphPool() - model = make_model(V, V, N=args.N, dim_model=dim_model, universal=args.universal) - + # Create model + model = make_model(V, V, N=args.N, dim_model=dim_model, + universal=args.universal) # Sharing weights between Encoder & Decoder model.src_embed.lut.weight = model.tgt_embed.lut.weight model.generator.proj.weight = model.tgt_embed.lut.weight + # Move model to corresponding device + model, criterion = model.to(device), criterion.to(device) + # Loss function + if args.ngpu > 1: + dev_rank = dev_id # current device id + ndev = args.ngpu # number of devices (including cpu) + loss_compute = partial(MultiGPULossCompute, criterion, args.ngpu, + args.grad_accum, model) + else: # cpu or single gpu case + dev_rank = 0 + ndev = 1 + loss_compute = partial(SimpleLossCompute, criterion, args.grad_accum) - model, criterion = model.to(devices[0]), criterion.to(devices[0]) - model_opt = NoamOpt(dim_model, 1, 400, - T.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9)) - if len(devices) > 1: - model, criterion = map(nn.parallel.replicate, [model, criterion], [devices, devices]) - loss_compute = SimpleLossCompute if len(devices) == 1 else MultiGPULossCompute + if ndev > 1: + for param in model.parameters(): + dist.all_reduce(param.data, op=dist.ReduceOp.SUM) + param.data /= ndev + # Optimizer + model_opt = NoamOpt(dim_model, 1, 4000, + T.optim.Adam(model.parameters(), lr=1e-3, + betas=(0.9, 0.98), eps=1e-9)) + + # Train & evaluate for epoch in range(100): - train_iter = dataset(graph_pool, mode='train', batch_size=args.batch, devices=devices) - valid_iter = dataset(graph_pool, mode='valid', batch_size=args.batch, devices=devices) - print('Epoch: {} Training...'.format(epoch)) + start = time.time() + train_iter = dataset(graph_pool, mode='train', batch_size=args.batch, + device=device, dev_rank=dev_rank, ndev=ndev) model.train(True) - run_epoch(train_iter, model, - loss_compute(criterion, model_opt), is_train=True) - print('Epoch: {} Evaluating...'.format(epoch)) - model.att_weight_map = None - model.eval() - run_epoch(valid_iter, model, - loss_compute(criterion, None), is_train=False) - # Visualize attention - if args.viz: - src_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='src') - tgt_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='tgt')[:-1] - draw_atts(model.att_weight_map, src_seq, tgt_seq, exp_setting, 'epoch_{}'.format(epoch)) + run_epoch(epoch, train_iter, dev_rank, ndev, model, + loss_compute(opt=model_opt), is_train=True) + if dev_rank == 0: + model.att_weight_map = None + model.eval() + valid_iter = dataset(graph_pool, mode='valid', batch_size=args.batch, + device=device, dev_rank=dev_rank, ndev=1) + run_epoch(epoch, valid_iter, dev_rank, 1, model, + loss_compute(opt=None), is_train=False) + end = time.time() + print("epoch time: {}".format(end - start)) + + # Visualize attention + if args.viz: + src_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='src') + tgt_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='tgt')[:-1] + draw_atts(model.att_weight_map, src_seq, tgt_seq, exp_setting, 'epoch_{}'.format(epoch)) + args_filter = ['batch', 'gpus', 'viz', 'master_ip', 'master_port', 'grad_accum', 'ngpu'] + exp_setting = '-'.join('{}'.format(v) for k, v in vars(args).items() if k not in args_filter) + with open('checkpoints/{}-{}.pkl'.format(exp_setting, epoch), 'wb') as f: + torch.save(model.state_dict(), f) + +if __name__ == '__main__': + if not os.path.exists('checkpoints'): + os.makedirs('checkpoints') + np.random.seed(1111) + argparser = argparse.ArgumentParser('training translation model') + argparser.add_argument('--gpus', default='-1', type=str, help='gpu id') + argparser.add_argument('--N', default=6, type=int, help='enc/dec layers') + argparser.add_argument('--dataset', default='multi30k', help='dataset') + argparser.add_argument('--batch', default=128, type=int, help='batch size') + argparser.add_argument('--viz', action='store_true', + help='visualize attention') + argparser.add_argument('--universal', action='store_true', + help='use universal transformer') + argparser.add_argument('--master-ip', type=str, default='127.0.0.1', + help='master ip address') + argparser.add_argument('--master-port', type=str, default='12345', + help='master port') + argparser.add_argument('--grad-accum', type=int, default=1, + help='accumulate gradients for this many times ' + 'then update weights') + args = argparser.parse_args() + print(args) - print('----------------------------------') - with open('checkpoints/{}-{}.pkl'.format(exp_setting, epoch), 'wb') as f: - th.save(model.state_dict(), f) + devices = list(map(int, args.gpus.split(','))) + if len(devices) == 1: + args.ngpu = 0 if devices[0] < 0 else 1 + main(devices[0], args) + else: + args.ngpu = len(devices) + mp = torch.multiprocessing.get_context('spawn') + procs = [] + for dev_id in devices: + procs.append(mp.Process(target=run, args=(dev_id, args), + daemon=True)) + procs[-1].start() + for p in procs: + p.join()