From cccb6da06d9e37e27f9b7c00243416fee115cdfe Mon Sep 17 00:00:00 2001 From: Lingfan Yu Date: Mon, 14 Jan 2019 10:00:25 -0500 Subject: [PATCH 01/11] multi-process version of transformer --- .../pytorch/transformer/dataset/__init__.py | 26 +++---- examples/pytorch/transformer/loss/__init__.py | 38 ++++++---- .../pytorch/transformer/translation_train.py | 75 ++++++++++++++----- 3 files changed, 89 insertions(+), 50 deletions(-) diff --git a/examples/pytorch/transformer/dataset/__init__.py b/examples/pytorch/transformer/dataset/__init__.py index 3ac4c667146c..6961373174a4 100644 --- a/examples/pytorch/transformer/dataset/__init__.py +++ b/examples/pytorch/transformer/dataset/__init__.py @@ -90,17 +90,17 @@ def sos_id(self): def eos_id(self): return self.vocab[self.EOS_TOKEN] - def __call__(self, graph_pool, mode='train', batch_size=32, k=1, devices=['cpu']): + def __call__(self, graph_pool, mode='train', batch_size=32, k=1, + device='cpu'): ''' Create a batched graph correspond to the mini-batch of the dataset. args: graph_pool: a GraphPool object for accelerating. mode: train/valid/test batch_size: batch size - devices: ['cpu'] or a list of gpu ids. - k: beam size(only required for test) + device: torch.device + k: beam size(only required for test) ''' - dev_id, gs = 0, [] src_data, tgt_data = self.src[mode], self.tgt[mode] n = len(src_data) order = np.random.permutation(n) if mode == 'train' else range(n) @@ -115,22 +115,16 @@ def __call__(self, graph_pool, mode='train', batch_size=32, k=1, devices=['cpu'] tgt_buf.append(tgt_sample) if len(src_buf) == batch_size: if mode == 'test': - assert len(devices) == 1 # we only allow single gpu for inference - yield graph_pool.beam(src_buf, self.sos_id, self.MAX_LENGTH, k, device=devices[0]) + yield graph_pool.beam(src_buf, self.sos_id, self.MAX_LENGTH, k, device=device) else: - gs.append(graph_pool(src_buf, tgt_buf, device=devices[dev_id])) - dev_id += 1 - if dev_id == len(devices): - yield gs if len(devices) > 1 else gs[0] - dev_id, gs = 0, [] + yield graph_pool(src_buf, tgt_buf, device=device)) src_buf, tgt_buf = [], [] if len(src_buf) != 0: if mode == 'test': - yield graph_pool.beam(src_buf, self.sos_id, self.MAX_LENGTH, k, device=devices[0]) + yield graph_pool.beam(src_buf, self.sos_id, self.MAX_LENGTH, k, device=device) else: - gs.append(graph_pool(src_buf, tgt_buf, device=devices[dev_id])) - yield gs if len(devices) > 1 else gs[0] + yield graph_pool(src_buf, tgt_buf, device=device)) def get_sequence(self, batch): "return a list of sequence from a list of index arrays" @@ -151,8 +145,8 @@ def get_dataset(dataset): raise NotImplementedError elif dataset == 'copy' or dataset == 'sort': return TranslationDataset( - 'data/{}'.format(dataset), - ('in', 'out'), + 'data/{}'.format(dataset), + ('in', 'out'), train='train', valid='valid', test='test', diff --git a/examples/pytorch/transformer/loss/__init__.py b/examples/pytorch/transformer/loss/__init__.py index 35d7603ec9ab..728a01c26141 100644 --- a/examples/pytorch/transformer/loss/__init__.py +++ b/examples/pytorch/transformer/loss/__init__.py @@ -1,6 +1,7 @@ import torch as T import torch.nn as nn import torch.nn.functional as F +import torch.distributed as dist class LabelSmoothing(nn.Module): """ @@ -44,7 +45,10 @@ def __init__(self, criterion, opt=None): super(SimpleLossCompute, self).__init__() self.criterion = criterion self.opt = opt - self.reset() + self.acc_loss = 0 + self.n_correct = 0 + self.norm_term = 0 + self.loss = 0 @property def avg_loss(self): @@ -54,32 +58,34 @@ def avg_loss(self): def accuracy(self): return (self.n_correct + self.eps) / (self.norm_term + self.eps) - def reset(self): - self.acc_loss = 0 - self.n_correct = 0 - self.norm_term = 0 + def backward(self): + self.loss.backward() def __call__(self, y_pred, y, norm): y_pred = y_pred.contiguous().view(-1, y_pred.shape[-1]) y = y.contiguous().view(-1) - loss = self.criterion( + self.loss = self.criterion( y_pred, y ) / norm if self.opt is not None: - loss.backward() + self.backward() self.opt.step() self.opt.optimizer.zero_grad() self.n_correct += ((y_pred.max(dim=-1)[1] == y) & (y != self.criterion.padding_idx)).sum().item() - self.acc_loss += loss.item() * norm + self.acc_loss += self.loss.item() * norm self.norm_term += norm - return loss.item() * norm + return self.loss.item() * norm class MultiGPULossCompute(SimpleLossCompute): - def __init__(self, criterion, devices, opt=None, chunk_size=5): - self.criterion = criterion - self.opt = opt - self.devices = devices - self.chunk_size = chunk_size + def __init__(self, criterion, dev_id, ndev, params, opt=None): + super(MultiGPULossCompute, self).__init__(criterion, opt) + self.dev_id = dev_id + self.ndev = ndev + self.params = params - def __call__(self, y_preds, ys, norms): - pass + def backward(self): + # multi-gpu synchronous backward + self.loss.backward() + for param in self.params: + dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM) + prarm.grad.data /= size diff --git a/examples/pytorch/transformer/translation_train.py b/examples/pytorch/transformer/translation_train.py index 6614c884df42..b421dbcf44cb 100644 --- a/examples/pytorch/transformer/translation_train.py +++ b/examples/pytorch/transformer/translation_train.py @@ -4,18 +4,21 @@ """ from modules import * from parallel import * -from loss import * +from loss import * from optims import * from dataset import * from modules.config import * -from modules.viz import * -from tqdm import tqdm +#from modules.viz import * +#from tqdm import tqdm import numpy as np import argparse +import torch -def run_epoch(data_iter, model, loss_compute, is_train=True): +def run_epoch(data_iter, dev_rank, ndev, model, loss_compute, is_train=True): universal = isinstance(model, UTransformer) - for i, g in tqdm(enumerate(data_iter)): + for i, g in enumerate(data_iter): + if i % ndev != dev_rank: + continue with T.set_grad_enabled(is_train): if isinstance(model, list): model = model[:len(gs)] @@ -51,10 +54,42 @@ def run_epoch(data_iter, model, loss_compute, is_train=True): argparser.add_argument('--viz', action='store_true', help='visualize attention') argparser.add_argument('--universal', action='store_true', help='use universal transformer') args = argparser.parse_args() - args_filter = ['batch', 'gpus', 'viz'] - exp_setting = '-'.join('{}'.format(v) for k, v in vars(args).items() if k not in args_filter) - devices = ['cpu'] if args.gpus == '-1' else [int(gpu_id) for gpu_id in args.gpus.split(',')] + #args_filter = ['batch', 'gpus', 'viz'] + #exp_setting = '-'.join('{}'.format(v) for k, v in vars(args).items() if k not in args_filter) + #devices = ['cpu'] if args.gpus == '-1' else [int(gpu_id) for gpu_id in args.gpus.split(',')] + devices = list(map(int, args.gpus.split(','))) + if len(devices) == 1: + args.ngpu = 0 if devices[0] < 0 else 1 + main(devices[0], args) + else: + args.ngpu = len(devices) + mp = torch.multiprocessing.get_context('spawn') + procs = [] + for dev_id in devices: + procs.append(mp.Proces(target=run, args=(dev_id, args), daemon=True)) + procs[-1].start() + for p in procs: + p.join() +def run(dev_id, args): + # FIXME: make ip and port configurable + ip = "127.0.0.1" + port = "12321" + dist_init_method = 'tcp://{master_ip}:{master_port}'.format(ip, port) + world_size = len(devices) + torch.distributed.init_process_group(backend="nccl", + init_method=dist_init_method, + world_size=world_size + rank=dev_id) + gpu_rank = torch.distributed.get_rank() + assert gpu_rank == dev_id + main(dev_id, args) + +def main(dev_id, args): + if dev_id == -1: + device = torch.device('cpu') + else: + device = torch.device('cuda:{}'.format(dev_id)) dataset = get_dataset(args.dataset) V = dataset.vocab_size @@ -68,25 +103,28 @@ def run_epoch(data_iter, model, loss_compute, is_train=True): model.src_embed.lut.weight = model.tgt_embed.lut.weight model.generator.proj.weight = model.tgt_embed.lut.weight - model, criterion = model.to(devices[0]), criterion.to(devices[0]) + model, criterion = model.to(device), criterion.to(device) model_opt = NoamOpt(dim_model, 1, 400, T.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9)) - if len(devices) > 1: - model, criterion = map(nn.parallel.replicate, [model, criterion], [devices, devices]) - loss_compute = SimpleLossCompute if len(devices) == 1 else MultiGPULossCompute + if args.ngpu > 1: + loss_compute = MultiGPULossCompute(criterion, dev_id, args.ngpu, + model.parameters(), opt=model_opt) + dev_rank = dev_id + ndev = args.ngpu + else: + loss_compute = SimpleLossCompute(criterion, opt=model_opt) for epoch in range(100): - train_iter = dataset(graph_pool, mode='train', batch_size=args.batch, devices=devices) - valid_iter = dataset(graph_pool, mode='valid', batch_size=args.batch, devices=devices) + train_iter = dataset(graph_pool, mode='train', batch_size=args.batch, devices=device) + valid_iter = dataset(graph_pool, mode='valid', batch_size=args.batch, devices=device) print('Epoch: {} Training...'.format(epoch)) model.train(True) - run_epoch(train_iter, model, - loss_compute(criterion, model_opt), is_train=True) + run_epoch(train_iter, dev_rank, ndev, model, loss_compute, is_train=True) print('Epoch: {} Evaluating...'.format(epoch)) model.att_weight_map = None model.eval() - run_epoch(valid_iter, model, - loss_compute(criterion, None), is_train=False) + run_epoch(valid_iter, dev_rank, ndev, model, loss_compute, is_train=False) + """ # Visualize attention if args.viz: src_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='src') @@ -96,4 +134,5 @@ def run_epoch(data_iter, model, loss_compute, is_train=True): print('----------------------------------') with open('checkpoints/{}-{}.pkl'.format(exp_setting, epoch), 'wb') as f: th.save(model.state_dict(), f) + """ From 8107b506193c2bf8a9eb3d1b8b96590eaeec5f94 Mon Sep 17 00:00:00 2001 From: Lingfan Yu Date: Mon, 14 Jan 2019 16:47:29 +0000 Subject: [PATCH 02/11] lots of fix --- .../pytorch/transformer/dataset/__init__.py | 11 ++- examples/pytorch/transformer/loss/__init__.py | 15 ++-- .../pytorch/transformer/translation_train.py | 81 ++++++++++--------- 3 files changed, 59 insertions(+), 48 deletions(-) diff --git a/examples/pytorch/transformer/dataset/__init__.py b/examples/pytorch/transformer/dataset/__init__.py index 6961373174a4..2fefcedd7727 100644 --- a/examples/pytorch/transformer/dataset/__init__.py +++ b/examples/pytorch/transformer/dataset/__init__.py @@ -91,7 +91,7 @@ def eos_id(self): return self.vocab[self.EOS_TOKEN] def __call__(self, graph_pool, mode='train', batch_size=32, k=1, - device='cpu'): + device='cpu', ndev=1): ''' Create a batched graph correspond to the mini-batch of the dataset. args: @@ -103,6 +103,11 @@ def __call__(self, graph_pool, mode='train', batch_size=32, k=1, ''' src_data, tgt_data = self.src[mode], self.tgt[mode] n = len(src_data) + # make sure all devices have the same number of batches + n_ceil = (n + batch_size - 1) // batch_size * batch_size + sample_per_dev = batch_size * ndev + n = min(n, n_ceil // sample_per_dev * sample_per_dev) + order = np.random.permutation(n) if mode == 'train' else range(n) src_buf, tgt_buf = [], [] @@ -117,14 +122,14 @@ def __call__(self, graph_pool, mode='train', batch_size=32, k=1, if mode == 'test': yield graph_pool.beam(src_buf, self.sos_id, self.MAX_LENGTH, k, device=device) else: - yield graph_pool(src_buf, tgt_buf, device=device)) + yield graph_pool(src_buf, tgt_buf, device=device) src_buf, tgt_buf = [], [] if len(src_buf) != 0: if mode == 'test': yield graph_pool.beam(src_buf, self.sos_id, self.MAX_LENGTH, k, device=device) else: - yield graph_pool(src_buf, tgt_buf, device=device)) + yield graph_pool(src_buf, tgt_buf, device=device) def get_sequence(self, batch): "return a list of sequence from a list of index arrays" diff --git a/examples/pytorch/transformer/loss/__init__.py b/examples/pytorch/transformer/loss/__init__.py index 728a01c26141..573efaccc1dd 100644 --- a/examples/pytorch/transformer/loss/__init__.py +++ b/examples/pytorch/transformer/loss/__init__.py @@ -61,13 +61,13 @@ def accuracy(self): def backward(self): self.loss.backward() - def __call__(self, y_pred, y, norm): + def __call__(self, y_pred, y, norm, is_train=True): y_pred = y_pred.contiguous().view(-1, y_pred.shape[-1]) y = y.contiguous().view(-1) self.loss = self.criterion( y_pred, y ) / norm - if self.opt is not None: + if is_train: self.backward() self.opt.step() self.opt.optimizer.zero_grad() @@ -77,15 +77,16 @@ def __call__(self, y_pred, y, norm): return self.loss.item() * norm class MultiGPULossCompute(SimpleLossCompute): - def __init__(self, criterion, dev_id, ndev, params, opt=None): + def __init__(self, criterion, dev_id, ndev, model, opt=None): super(MultiGPULossCompute, self).__init__(criterion, opt) self.dev_id = dev_id self.ndev = ndev - self.params = params + self.model = model def backward(self): # multi-gpu synchronous backward self.loss.backward() - for param in self.params: - dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM) - prarm.grad.data /= size + for param in self.model.parameters(): + if param.requires_grad and param.grad is not None: + dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) + param.grad.data /= self.ndev diff --git a/examples/pytorch/transformer/translation_train.py b/examples/pytorch/transformer/translation_train.py index b421dbcf44cb..56d80edc8b5f 100644 --- a/examples/pytorch/transformer/translation_train.py +++ b/examples/pytorch/transformer/translation_train.py @@ -33,7 +33,7 @@ def run_epoch(data_iter, dev_rank, ndev, model, loss_compute, is_train=True): output = model(g) tgt_y = g.tgt_y n_tokens = g.n_tokens - loss = loss_compute(output, tgt_y, n_tokens) + loss = loss_compute(output, tgt_y, n_tokens, is_train=is_train) if universal: for step in range(1, model.MAX_DEPTH + 1): @@ -42,44 +42,16 @@ def run_epoch(data_iter, dev_rank, ndev, model, loss_compute, is_train=True): print('average loss: {}'.format(loss_compute.avg_loss)) print('accuracy: {}'.format(loss_compute.accuracy)) -if __name__ == '__main__': - if not os.path.exists('checkpoints'): - os.makedirs('checkpoints') - np.random.seed(1111) - argparser = argparse.ArgumentParser('training translation model') - argparser.add_argument('--gpus', default='-1', type=str, help='gpu id') - argparser.add_argument('--N', default=6, type=int, help='enc/dec layers') - argparser.add_argument('--dataset', default='multi30k', help='dataset') - argparser.add_argument('--batch', default=128, type=int, help='batch size') - argparser.add_argument('--viz', action='store_true', help='visualize attention') - argparser.add_argument('--universal', action='store_true', help='use universal transformer') - args = argparser.parse_args() - #args_filter = ['batch', 'gpus', 'viz'] - #exp_setting = '-'.join('{}'.format(v) for k, v in vars(args).items() if k not in args_filter) - #devices = ['cpu'] if args.gpus == '-1' else [int(gpu_id) for gpu_id in args.gpus.split(',')] - devices = list(map(int, args.gpus.split(','))) - if len(devices) == 1: - args.ngpu = 0 if devices[0] < 0 else 1 - main(devices[0], args) - else: - args.ngpu = len(devices) - mp = torch.multiprocessing.get_context('spawn') - procs = [] - for dev_id in devices: - procs.append(mp.Proces(target=run, args=(dev_id, args), daemon=True)) - procs[-1].start() - for p in procs: - p.join() - def run(dev_id, args): # FIXME: make ip and port configurable - ip = "127.0.0.1" - port = "12321" - dist_init_method = 'tcp://{master_ip}:{master_port}'.format(ip, port) - world_size = len(devices) + master_ip = "127.0.0.1" + master_port = "12321" + dist_init_method = 'tcp://{master_ip}:{master_port}'.format( + master_ip=master_ip, master_port=master_port) + world_size = args.ngpu torch.distributed.init_process_group(backend="nccl", init_method=dist_init_method, - world_size=world_size + world_size=world_size, rank=dev_id) gpu_rank = torch.distributed.get_rank() assert gpu_rank == dev_id @@ -109,14 +81,17 @@ def main(dev_id, args): if args.ngpu > 1: loss_compute = MultiGPULossCompute(criterion, dev_id, args.ngpu, - model.parameters(), opt=model_opt) + model, opt=model_opt) dev_rank = dev_id ndev = args.ngpu else: loss_compute = SimpleLossCompute(criterion, opt=model_opt) + dev_rank = 0 + ndev = 0 + for epoch in range(100): - train_iter = dataset(graph_pool, mode='train', batch_size=args.batch, devices=device) - valid_iter = dataset(graph_pool, mode='valid', batch_size=args.batch, devices=device) + train_iter = dataset(graph_pool, mode='train', batch_size=args.batch, device=device, ndev=ndev) + valid_iter = dataset(graph_pool, mode='valid', batch_size=args.batch, device=device, ndev=ndev) print('Epoch: {} Training...'.format(epoch)) model.train(True) run_epoch(train_iter, dev_rank, ndev, model, loss_compute, is_train=True) @@ -136,3 +111,33 @@ def main(dev_id, args): th.save(model.state_dict(), f) """ +if __name__ == '__main__': + if not os.path.exists('checkpoints'): + os.makedirs('checkpoints') + np.random.seed(1111) + argparser = argparse.ArgumentParser('training translation model') + argparser.add_argument('--gpus', default='-1', type=str, help='gpu id') + argparser.add_argument('--N', default=6, type=int, help='enc/dec layers') + argparser.add_argument('--dataset', default='multi30k', help='dataset') + argparser.add_argument('--batch', default=128, type=int, help='batch size') + argparser.add_argument('--viz', action='store_true', help='visualize attention') + argparser.add_argument('--universal', action='store_true', help='use universal transformer') + args = argparser.parse_args() + #args_filter = ['batch', 'gpus', 'viz'] + #exp_setting = '-'.join('{}'.format(v) for k, v in vars(args).items() if k not in args_filter) + #devices = ['cpu'] if args.gpus == '-1' else [int(gpu_id) for gpu_id in args.gpus.split(',')] + devices = list(map(int, args.gpus.split(','))) + if len(devices) == 1: + args.ngpu = 0 if devices[0] < 0 else 1 + main(devices[0], args) + else: + args.ngpu = len(devices) + mp = torch.multiprocessing.get_context('spawn') + procs = [] + for dev_id in devices: + procs.append(mp.Process(target=run, args=(dev_id, args), + daemon=True)) + procs[-1].start() + for p in procs: + p.join() + From 8e44663ed45fd5584549d9becce1e58f28f2e718 Mon Sep 17 00:00:00 2001 From: Lingfan Yu Date: Mon, 14 Jan 2019 20:50:27 +0000 Subject: [PATCH 03/11] fix bugs and accum gradients for multiple batches --- .../pytorch/transformer/dataset/__init__.py | 15 ++--- examples/pytorch/transformer/loss/__init__.py | 32 ++++++---- .../pytorch/transformer/translation_train.py | 61 +++++++++++-------- 3 files changed, 64 insertions(+), 44 deletions(-) diff --git a/examples/pytorch/transformer/dataset/__init__.py b/examples/pytorch/transformer/dataset/__init__.py index 2fefcedd7727..4bc1ca98c2b2 100644 --- a/examples/pytorch/transformer/dataset/__init__.py +++ b/examples/pytorch/transformer/dataset/__init__.py @@ -4,12 +4,12 @@ import os import numpy as np -class ClassificationDataset: +class ClassificationDataset(object): "Dataset class for classification task." def __init__(self): raise NotImplementedError -class TranslationDataset: +class TranslationDataset(object): ''' Dataset class for translation task. By default, the source language shares the same vocabulary with the target language. @@ -91,7 +91,7 @@ def eos_id(self): return self.vocab[self.EOS_TOKEN] def __call__(self, graph_pool, mode='train', batch_size=32, k=1, - device='cpu', ndev=1): + device='cpu', dev_rank=0, ndev=1): ''' Create a batched graph correspond to the mini-batch of the dataset. args: @@ -104,11 +104,12 @@ def __call__(self, graph_pool, mode='train', batch_size=32, k=1, src_data, tgt_data = self.src[mode], self.tgt[mode] n = len(src_data) # make sure all devices have the same number of batches - n_ceil = (n + batch_size - 1) // batch_size * batch_size - sample_per_dev = batch_size * ndev - n = min(n, n_ceil // sample_per_dev * sample_per_dev) + n = n // ndev * ndev + + #order = np.random.permutation(n) if mode == 'train' else range(n) + # FIXME: do not shuffle for mgpu + order = range(dev_rank, n, ndev) - order = np.random.permutation(n) if mode == 'train' else range(n) src_buf, tgt_buf = [], [] for idx in order: diff --git a/examples/pytorch/transformer/loss/__init__.py b/examples/pytorch/transformer/loss/__init__.py index 573efaccc1dd..7a48ab89aa71 100644 --- a/examples/pytorch/transformer/loss/__init__.py +++ b/examples/pytorch/transformer/loss/__init__.py @@ -58,35 +58,43 @@ def avg_loss(self): def accuracy(self): return (self.n_correct + self.eps) / (self.norm_term + self.eps) - def backward(self): + def backward_and_step(self): self.loss.backward() + self.opt.step() + self.opt.optimizer.zero_grad() - def __call__(self, y_pred, y, norm, is_train=True): + def __call__(self, y_pred, y, norm): y_pred = y_pred.contiguous().view(-1, y_pred.shape[-1]) y = y.contiguous().view(-1) self.loss = self.criterion( y_pred, y ) / norm - if is_train: - self.backward() - self.opt.step() - self.opt.optimizer.zero_grad() + if self.opt is not None: + self.backward_and_step() self.n_correct += ((y_pred.max(dim=-1)[1] == y) & (y != self.criterion.padding_idx)).sum().item() self.acc_loss += self.loss.item() * norm self.norm_term += norm return self.loss.item() * norm class MultiGPULossCompute(SimpleLossCompute): - def __init__(self, criterion, dev_id, ndev, model, opt=None): + def __init__(self, criterion, dev_id, ndev, accum_count, model, opt=None): super(MultiGPULossCompute, self).__init__(criterion, opt) self.dev_id = dev_id self.ndev = ndev + self.accum_count = accum_count self.model = model + self.count = 0 - def backward(self): + def backward_and_step(self): # multi-gpu synchronous backward self.loss.backward() - for param in self.model.parameters(): - if param.requires_grad and param.grad is not None: - dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) - param.grad.data /= self.ndev + self.count += 1 + # accumulate self.accum_count times then synchronize and update + if self.count == self.accum_count: + for param in self.model.parameters(): + if param.requires_grad and param.grad is not None: + dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) + param.grad.data /= self.ndev + self.opt.step() + self.opt.optimizer.zero_grad() + self.count = 0 diff --git a/examples/pytorch/transformer/translation_train.py b/examples/pytorch/transformer/translation_train.py index 56d80edc8b5f..5cb0211ffd7d 100644 --- a/examples/pytorch/transformer/translation_train.py +++ b/examples/pytorch/transformer/translation_train.py @@ -13,12 +13,12 @@ import numpy as np import argparse import torch +from functools import partial def run_epoch(data_iter, dev_rank, ndev, model, loss_compute, is_train=True): universal = isinstance(model, UTransformer) for i, g in enumerate(data_iter): - if i % ndev != dev_rank: - continue + #print("Dev {} start batch {}".format(dev_rank, i)) with T.set_grad_enabled(is_train): if isinstance(model, list): model = model[:len(gs)] @@ -33,14 +33,14 @@ def run_epoch(data_iter, dev_rank, ndev, model, loss_compute, is_train=True): output = model(g) tgt_y = g.tgt_y n_tokens = g.n_tokens - loss = loss_compute(output, tgt_y, n_tokens, is_train=is_train) + loss = loss_compute(output, tgt_y, n_tokens) if universal: for step in range(1, model.MAX_DEPTH + 1): print("nodes entering step {}: {:.2f}%".format(step, (1.0 * model.stat[step] / model.stat[0]))) model.reset_stat() - print('average loss: {}'.format(loss_compute.avg_loss)) - print('accuracy: {}'.format(loss_compute.accuracy)) + print('Dev {} average loss: {}'.format(dev_rank, loss_compute.avg_loss)) + print('Dev {} accuracy: {}'.format(dev_rank, loss_compute.accuracy)) def run(dev_id, args): # FIXME: make ip and port configurable @@ -80,36 +80,43 @@ def main(dev_id, args): T.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9)) if args.ngpu > 1: - loss_compute = MultiGPULossCompute(criterion, dev_id, args.ngpu, - model, opt=model_opt) + loss_compute = partial(MultiGPULossCompute, criterion, dev_id, + args.ngpu, args.accum, model) dev_rank = dev_id ndev = args.ngpu else: - loss_compute = SimpleLossCompute(criterion, opt=model_opt) + loss_compute = partial(SimpleLossCompute, criterion) dev_rank = 0 - ndev = 0 + ndev = 1 for epoch in range(100): - train_iter = dataset(graph_pool, mode='train', batch_size=args.batch, device=device, ndev=ndev) - valid_iter = dataset(graph_pool, mode='valid', batch_size=args.batch, device=device, ndev=ndev) - print('Epoch: {} Training...'.format(epoch)) + start = time.time() + train_iter = dataset(graph_pool, mode='train', batch_size=args.batch, + device=device, dev_rank=dev_rank, ndev=ndev) + valid_iter = dataset(graph_pool, mode='valid', batch_size=args.batch, + device=device, dev_rank=dev_rank, ndev=ndev) + print('Dev {} Epoch: {} Training...'.format(dev_rank, epoch)) model.train(True) - run_epoch(train_iter, dev_rank, ndev, model, loss_compute, is_train=True) - print('Epoch: {} Evaluating...'.format(epoch)) + run_epoch(train_iter, dev_rank, ndev, model, loss_compute(opt=model_opt)) + print('Dev {} Epoch: {} Evaluating...'.format(dev_rank, epoch)) model.att_weight_map = None model.eval() - run_epoch(valid_iter, dev_rank, ndev, model, loss_compute, is_train=False) - """ - # Visualize attention - if args.viz: - src_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='src') - tgt_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='tgt')[:-1] - draw_atts(model.att_weight_map, src_seq, tgt_seq, exp_setting, 'epoch_{}'.format(epoch)) + run_epoch(valid_iter, dev_rank, ndev, model, loss_compute(opt=None)) + end = time.time() + if dev_rank == 0: + print("epoch time: {}".format(end - start)) - print('----------------------------------') - with open('checkpoints/{}-{}.pkl'.format(exp_setting, epoch), 'wb') as f: - th.save(model.state_dict(), f) - """ + """ + # Visualize attention + if args.viz: + src_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='src') + tgt_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='tgt')[:-1] + draw_atts(model.att_weight_map, src_seq, tgt_seq, exp_setting, 'epoch_{}'.format(epoch)) + + print('----------------------------------') + with open('checkpoints/{}-{}.pkl'.format(exp_setting, epoch), 'wb') as f: + torch.save(model.state_dict(), f) + """ if __name__ == '__main__': if not os.path.exists('checkpoints'): @@ -122,7 +129,11 @@ def main(dev_id, args): argparser.add_argument('--batch', default=128, type=int, help='batch size') argparser.add_argument('--viz', action='store_true', help='visualize attention') argparser.add_argument('--universal', action='store_true', help='use universal transformer') + argparser.add_argument('--accum', type=int, default=1, + help='accumulate gradients for this many times ' + 'then update weights') args = argparser.parse_args() + print(args) #args_filter = ['batch', 'gpus', 'viz'] #exp_setting = '-'.join('{}'.format(v) for k, v in vars(args).items() if k not in args_filter) #devices = ['cpu'] if args.gpus == '-1' else [int(gpu_id) for gpu_id in args.gpus.split(',')] From 3cd6802b4da914dfe80c5323850a22eb4c8ec142 Mon Sep 17 00:00:00 2001 From: Lingfan Yu Date: Mon, 14 Jan 2019 21:15:42 +0000 Subject: [PATCH 04/11] many fixes --- .../pytorch/transformer/dataset/__init__.py | 9 ++-- examples/pytorch/transformer/loss/__init__.py | 8 ++-- .../pytorch/transformer/translation_train.py | 46 +++++++++++-------- 3 files changed, 35 insertions(+), 28 deletions(-) diff --git a/examples/pytorch/transformer/dataset/__init__.py b/examples/pytorch/transformer/dataset/__init__.py index 4bc1ca98c2b2..d0ea7dd69216 100644 --- a/examples/pytorch/transformer/dataset/__init__.py +++ b/examples/pytorch/transformer/dataset/__init__.py @@ -2,7 +2,7 @@ from .fields import * from .utils import prepare_dataset import os -import numpy as np +import random class ClassificationDataset(object): "Dataset class for classification task." @@ -106,9 +106,10 @@ def __call__(self, graph_pool, mode='train', batch_size=32, k=1, # make sure all devices have the same number of batches n = n // ndev * ndev - #order = np.random.permutation(n) if mode == 'train' else range(n) - # FIXME: do not shuffle for mgpu - order = range(dev_rank, n, ndev) + # XXX: is partition then shuffle equivalent to shuffle then partition? + order = list(range(dev_rank, n, ndev)) + if mode == 'train': + random.shuffle(order) src_buf, tgt_buf = [], [] diff --git a/examples/pytorch/transformer/loss/__init__.py b/examples/pytorch/transformer/loss/__init__.py index 7a48ab89aa71..5c5ae1198e0a 100644 --- a/examples/pytorch/transformer/loss/__init__.py +++ b/examples/pytorch/transformer/loss/__init__.py @@ -77,11 +77,11 @@ def __call__(self, y_pred, y, norm): return self.loss.item() * norm class MultiGPULossCompute(SimpleLossCompute): - def __init__(self, criterion, dev_id, ndev, accum_count, model, opt=None): + def __init__(self, criterion, dev_id, ndev, grad_accum, model, opt=None): super(MultiGPULossCompute, self).__init__(criterion, opt) self.dev_id = dev_id self.ndev = ndev - self.accum_count = accum_count + self.grad_accum = grad_accum self.model = model self.count = 0 @@ -89,8 +89,8 @@ def backward_and_step(self): # multi-gpu synchronous backward self.loss.backward() self.count += 1 - # accumulate self.accum_count times then synchronize and update - if self.count == self.accum_count: + # accumulate self.grad_accum times then synchronize and update + if self.count == self.grad_accum: for param in self.model.parameters(): if param.requires_grad and param.grad is not None: dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) diff --git a/examples/pytorch/transformer/translation_train.py b/examples/pytorch/transformer/translation_train.py index 5cb0211ffd7d..eb5b7ef6bc21 100644 --- a/examples/pytorch/transformer/translation_train.py +++ b/examples/pytorch/transformer/translation_train.py @@ -39,15 +39,13 @@ def run_epoch(data_iter, dev_rank, ndev, model, loss_compute, is_train=True): for step in range(1, model.MAX_DEPTH + 1): print("nodes entering step {}: {:.2f}%".format(step, (1.0 * model.stat[step] / model.stat[0]))) model.reset_stat() - print('Dev {} average loss: {}'.format(dev_rank, loss_compute.avg_loss)) - print('Dev {} accuracy: {}'.format(dev_rank, loss_compute.accuracy)) + print('{}: Dev {} average loss: {}, accuracy {}'.format( + "Training" if is_train else "Evaluting", + dev_rank, loss_compute.avg_loss, loss_compute.accuracy)) def run(dev_id, args): - # FIXME: make ip and port configurable - master_ip = "127.0.0.1" - master_port = "12321" dist_init_method = 'tcp://{master_ip}:{master_port}'.format( - master_ip=master_ip, master_port=master_port) + master_ip=args.master_ip, master_port=args.master_port) world_size = args.ngpu torch.distributed.init_process_group(backend="nccl", init_method=dist_init_method, @@ -69,7 +67,8 @@ def main(dev_id, args): dim_model = 512 graph_pool = GraphPool() - model = make_model(V, V, N=args.N, dim_model=dim_model, universal=args.universal) + model = make_model(V, V, N=args.N, dim_model=dim_model, + universal=args.universal) # Sharing weights between Encoder & Decoder model.src_embed.lut.weight = model.tgt_embed.lut.weight @@ -77,17 +76,18 @@ def main(dev_id, args): model, criterion = model.to(device), criterion.to(device) model_opt = NoamOpt(dim_model, 1, 400, - T.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9)) + T.optim.Adam(model.parameters(), lr=1e-3, + betas=(0.9, 0.98), eps=1e-9)) if args.ngpu > 1: + dev_rank = dev_id # current device id + ndev = args.ngpu # number of devices (including cpu) loss_compute = partial(MultiGPULossCompute, criterion, dev_id, - args.ngpu, args.accum, model) - dev_rank = dev_id - ndev = args.ngpu - else: - loss_compute = partial(SimpleLossCompute, criterion) + args.ngpu, args.grad_accum, model) + else: # cpu or single gpu case dev_rank = 0 ndev = 1 + loss_compute = partial(SimpleLossCompute, criterion) for epoch in range(100): start = time.time() @@ -95,13 +95,13 @@ def main(dev_id, args): device=device, dev_rank=dev_rank, ndev=ndev) valid_iter = dataset(graph_pool, mode='valid', batch_size=args.batch, device=device, dev_rank=dev_rank, ndev=ndev) - print('Dev {} Epoch: {} Training...'.format(dev_rank, epoch)) model.train(True) - run_epoch(train_iter, dev_rank, ndev, model, loss_compute(opt=model_opt)) - print('Dev {} Epoch: {} Evaluating...'.format(dev_rank, epoch)) + run_epoch(train_iter, dev_rank, ndev, model, + loss_compute(opt=model_opt), is_train=True) model.att_weight_map = None model.eval() - run_epoch(valid_iter, dev_rank, ndev, model, loss_compute(opt=None)) + run_epoch(valid_iter, dev_rank, ndev, model, + loss_compute(opt=None), is_train=False) end = time.time() if dev_rank == 0: print("epoch time: {}".format(end - start)) @@ -127,9 +127,15 @@ def main(dev_id, args): argparser.add_argument('--N', default=6, type=int, help='enc/dec layers') argparser.add_argument('--dataset', default='multi30k', help='dataset') argparser.add_argument('--batch', default=128, type=int, help='batch size') - argparser.add_argument('--viz', action='store_true', help='visualize attention') - argparser.add_argument('--universal', action='store_true', help='use universal transformer') - argparser.add_argument('--accum', type=int, default=1, + argparser.add_argument('--viz', action='store_true', + help='visualize attention') + argparser.add_argument('--universal', action='store_true', + help='use universal transformer') + argparser.add_argument('--master-ip', type=str, default='127.0.0.1', + help='master ip address') + argparser.add_argument('--master-port', type=str, default='12345', + help='master port') + argparser.add_argument('--grad-accum', type=int, default=1, help='accumulate gradients for this many times ' 'then update weights') args = argparser.parse_args() From 3f6ff9b768a5d22919c78a2705eea865c0f7442a Mon Sep 17 00:00:00 2001 From: Lingfan Yu Date: Mon, 14 Jan 2019 23:12:01 +0000 Subject: [PATCH 05/11] minor --- examples/pytorch/transformer/loss/__init__.py | 3 +-- examples/pytorch/transformer/translation_train.py | 14 +++++++------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/examples/pytorch/transformer/loss/__init__.py b/examples/pytorch/transformer/loss/__init__.py index 5c5ae1198e0a..e14a80cd2cc5 100644 --- a/examples/pytorch/transformer/loss/__init__.py +++ b/examples/pytorch/transformer/loss/__init__.py @@ -77,9 +77,8 @@ def __call__(self, y_pred, y, norm): return self.loss.item() * norm class MultiGPULossCompute(SimpleLossCompute): - def __init__(self, criterion, dev_id, ndev, grad_accum, model, opt=None): + def __init__(self, criterion, ndev, grad_accum, model, opt=None): super(MultiGPULossCompute, self).__init__(criterion, opt) - self.dev_id = dev_id self.ndev = ndev self.grad_accum = grad_accum self.model = model diff --git a/examples/pytorch/transformer/translation_train.py b/examples/pytorch/transformer/translation_train.py index eb5b7ef6bc21..2db76e65e672 100644 --- a/examples/pytorch/transformer/translation_train.py +++ b/examples/pytorch/transformer/translation_train.py @@ -15,7 +15,7 @@ import torch from functools import partial -def run_epoch(data_iter, dev_rank, ndev, model, loss_compute, is_train=True): +def run_epoch(epoch, data_iter, dev_rank, ndev, model, loss_compute, is_train=True): universal = isinstance(model, UTransformer) for i, g in enumerate(data_iter): #print("Dev {} start batch {}".format(dev_rank, i)) @@ -39,8 +39,8 @@ def run_epoch(data_iter, dev_rank, ndev, model, loss_compute, is_train=True): for step in range(1, model.MAX_DEPTH + 1): print("nodes entering step {}: {:.2f}%".format(step, (1.0 * model.stat[step] / model.stat[0]))) model.reset_stat() - print('{}: Dev {} average loss: {}, accuracy {}'.format( - "Training" if is_train else "Evaluting", + print('Epoch {} {}: Dev {} average loss: {}, accuracy {}'.format( + epoch, "Training" if is_train else "Evaluating", dev_rank, loss_compute.avg_loss, loss_compute.accuracy)) def run(dev_id, args): @@ -82,8 +82,8 @@ def main(dev_id, args): if args.ngpu > 1: dev_rank = dev_id # current device id ndev = args.ngpu # number of devices (including cpu) - loss_compute = partial(MultiGPULossCompute, criterion, dev_id, - args.ngpu, args.grad_accum, model) + loss_compute = partial(MultiGPULossCompute, criterion, args.ngpu, + args.grad_accum, model) else: # cpu or single gpu case dev_rank = 0 ndev = 1 @@ -96,11 +96,11 @@ def main(dev_id, args): valid_iter = dataset(graph_pool, mode='valid', batch_size=args.batch, device=device, dev_rank=dev_rank, ndev=ndev) model.train(True) - run_epoch(train_iter, dev_rank, ndev, model, + run_epoch(epoch, train_iter, dev_rank, ndev, model, loss_compute(opt=model_opt), is_train=True) model.att_weight_map = None model.eval() - run_epoch(valid_iter, dev_rank, ndev, model, + run_epoch(epoch, valid_iter, dev_rank, ndev, model, loss_compute(opt=None), is_train=False) end = time.time() if dev_rank == 0: From 6e40f2996e664cc9f949ed8841dfc2b17180a2b8 Mon Sep 17 00:00:00 2001 From: yzh119 Date: Tue, 15 Jan 2019 09:42:03 +0000 Subject: [PATCH 06/11] upd --- .../pytorch/transformer/dataset/__init__.py | 14 ++--- .../pytorch/transformer/dataset/fields.py | 2 +- examples/pytorch/transformer/parallel.py | 56 ------------------- .../pytorch/transformer/translation_train.py | 32 +++++------ 4 files changed, 21 insertions(+), 83 deletions(-) delete mode 100644 examples/pytorch/transformer/parallel.py diff --git a/examples/pytorch/transformer/dataset/__init__.py b/examples/pytorch/transformer/dataset/__init__.py index d0ea7dd69216..c52f336685e7 100644 --- a/examples/pytorch/transformer/dataset/__init__.py +++ b/examples/pytorch/transformer/dataset/__init__.py @@ -22,17 +22,17 @@ def __init__(self, path, exts, train='train', valid='valid', test='test', vocab= vocab_path = os.path.join(path, vocab) self.src = {} self.tgt = {} - with open(os.path.join(path, train + '.' + exts[0]), 'r') as f: + with open(os.path.join(path, train + '.' + exts[0]), 'r', encoding='utf-8') as f: self.src['train'] = f.readlines() - with open(os.path.join(path, train + '.' + exts[1]), 'r') as f: + with open(os.path.join(path, train + '.' + exts[1]), 'r', encoding='utf-8') as f: self.tgt['train'] = f.readlines() - with open(os.path.join(path, valid + '.' + exts[0]), 'r') as f: + with open(os.path.join(path, valid + '.' + exts[0]), 'r', encoding='utf-8') as f: self.src['valid'] = f.readlines() - with open(os.path.join(path, valid + '.' + exts[1]), 'r') as f: + with open(os.path.join(path, valid + '.' + exts[1]), 'r', encoding='utf-8') as f: self.tgt['valid'] = f.readlines() - with open(os.path.join(path, test + '.' + exts[0]), 'r') as f: + with open(os.path.join(path, test + '.' + exts[0]), 'r', encoding='utf-8') as f: self.src['test'] = f.readlines() - with open(os.path.join(path, test + '.' + exts[1]), 'r') as f: + with open(os.path.join(path, test + '.' + exts[1]), 'r', encoding='utf-8') as f: self.tgt['test'] = f.readlines() if not os.path.exists(vocab_path): @@ -103,7 +103,7 @@ def __call__(self, graph_pool, mode='train', batch_size=32, k=1, ''' src_data, tgt_data = self.src[mode], self.tgt[mode] n = len(src_data) - # make sure all devices have the same number of batches + # make sure all devices have the same number of batch n = n // ndev * ndev # XXX: is partition then shuffle equivalent to shuffle then partition? diff --git a/examples/pytorch/transformer/dataset/fields.py b/examples/pytorch/transformer/dataset/fields.py index d37c395bab73..712bbad1a6ad 100644 --- a/examples/pytorch/transformer/dataset/fields.py +++ b/examples/pytorch/transformer/dataset/fields.py @@ -16,7 +16,7 @@ def load(self, path): self.vocab_lst.append(self.pad_token) if self.unk_token is not None: self.vocab_lst.append(self.unk_token) - with open(path, 'r') as f: + with open(path, 'r', encoding='utf-8') as f: for token in f.readlines(): token = token.strip() self.vocab_lst.append(token) diff --git a/examples/pytorch/transformer/parallel.py b/examples/pytorch/transformer/parallel.py deleted file mode 100644 index 74023923e74c..000000000000 --- a/examples/pytorch/transformer/parallel.py +++ /dev/null @@ -1,56 +0,0 @@ -# Mostly then same with PyTorch -import threading -import torch - -def get_a_var(obj): - if isinstance(obj, torch.Tensor): - return obj - - if isinstance(obj, list) or isinstance(obj, tuple): - for result in map(get_a_var, obj): - if isinstance(result, torch.Tensor): - return result - if isinstance(obj, dict): - for result in map(get_a_var, obj.items()): - if isinstance(result, torch.Tensor): - return result - return None - - -def parallel_apply(modules, inputs): - assert len(modules) == len(inputs) - lock = threading.Lock() - results = {} - grad_enabled = torch.is_grad_enabled() - - def _worker(i, module, input): - torch.set_grad_enabled(grad_enabled) - try: - #with torch.cuda.device(device): - output = module(input) - with lock: - results[i] = output - except Exception as e: - with lock: - results[i] = e - - if len(modules) > 1: - threads = [threading.Thread(target=_worker, - args=(i, module, input)) - for i, (module, input) in - enumerate(zip(modules, inputs))] - - for thread in threads: - thread.start() - for thread in threads: - thread.join() - else: - _worker(0, modules[0], inputs[0]) - - outputs = [] - for i in range(len(inputs)): - output = results[i] - if isinstance(output, Exception): - raise output - outputs.append(output) - return outputs diff --git a/examples/pytorch/transformer/translation_train.py b/examples/pytorch/transformer/translation_train.py index 2db76e65e672..d9f52f80e20f 100644 --- a/examples/pytorch/transformer/translation_train.py +++ b/examples/pytorch/transformer/translation_train.py @@ -3,7 +3,6 @@ Multi-GPU support is required to train the model on WMT14. """ from modules import * -from parallel import * from loss import * from optims import * from dataset import * @@ -20,19 +19,13 @@ def run_epoch(epoch, data_iter, dev_rank, ndev, model, loss_compute, is_train=Tr for i, g in enumerate(data_iter): #print("Dev {} start batch {}".format(dev_rank, i)) with T.set_grad_enabled(is_train): - if isinstance(model, list): - model = model[:len(gs)] - output = parallel_apply(model, g) - tgt_y = [g.tgt_y for g in gs] - n_tokens = [g.n_tokens for g in gs] + if universal: + output, loss_act = model(g) + if is_train: loss_act.backward(retain_graph=True) else: - if universal: - output, loss_act = model(g) - if is_train: loss_act.backward(retain_graph=True) - else: - output = model(g) - tgt_y = g.tgt_y - n_tokens = g.n_tokens + output = model(g) + tgt_y = g.tgt_y + n_tokens = g.n_tokens loss = loss_compute(output, tgt_y, n_tokens) if universal: @@ -75,7 +68,7 @@ def main(dev_id, args): model.generator.proj.weight = model.tgt_embed.lut.weight model, criterion = model.to(device), criterion.to(device) - model_opt = NoamOpt(dim_model, 1, 400, + model_opt = NoamOpt(dim_model, 1, 4000 * 1300 / (args.batch * max(1, args.ngpu)), T.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9)) @@ -98,12 +91,13 @@ def main(dev_id, args): model.train(True) run_epoch(epoch, train_iter, dev_rank, ndev, model, loss_compute(opt=model_opt), is_train=True) - model.att_weight_map = None - model.eval() - run_epoch(epoch, valid_iter, dev_rank, ndev, model, - loss_compute(opt=None), is_train=False) - end = time.time() if dev_rank == 0: + model.att_weight_map = None + model.eval() + run_epoch(epoch, valid_iter, dev_rank, 1, model, + loss_compute(opt=None), is_train=False) + end = time.time() + time.sleep(1) print("epoch time: {}".format(end - start)) """ From 59f1ca749a4f575800f3f151de178e98dd588440 Mon Sep 17 00:00:00 2001 From: yzh119 Date: Tue, 29 Jan 2019 09:03:49 +0000 Subject: [PATCH 07/11] set torch device --- .../pytorch/transformer/modules/attention.py | 2 +- .../pytorch/transformer/modules/models.py | 13 ++++---- .../pytorch/transformer/translation_train.py | 32 ++++++++----------- 3 files changed, 21 insertions(+), 26 deletions(-) diff --git a/examples/pytorch/transformer/modules/attention.py b/examples/pytorch/transformer/modules/attention.py index b4bb1fe79f62..d86fd1dc9a1d 100644 --- a/examples/pytorch/transformer/modules/attention.py +++ b/examples/pytorch/transformer/modules/attention.py @@ -12,7 +12,7 @@ def __init__(self, h, dim_model): self.h = h # W_q, W_k, W_v, W_o self.linears = clones( - nn.Linear(dim_model, dim_model), 4 + nn.Linear(dim_model, dim_model, bias=False), 4 ) def get(self, x, fields='qkv'): diff --git a/examples/pytorch/transformer/modules/models.py b/examples/pytorch/transformer/modules/models.py index 1810fb57d9e4..896e20cfed9f 100644 --- a/examples/pytorch/transformer/modules/models.py +++ b/examples/pytorch/transformer/modules/models.py @@ -46,11 +46,11 @@ def pre_func(self, i, fields='qkv', l=0): layer = self.layers[i] def func(nodes): x = nodes.data['x'] - if fields == 'kv': - norm_x = x # In enc-dec attention, x has already been normalized. + norm_x = layer.sublayer[l].norm(x) if fields.startswith('q') else x + if fields != 'qkv': + return layer.src_attn.get(norm_x, fields) else: - norm_x = layer.sublayer[l].norm(x) - return layer.self_attn.get(norm_x, fields) + return layer.self_attn.get(norm_x, fields) return func def post_func(self, i, l=0): @@ -64,8 +64,6 @@ def func(nodes): return {'x': x if i < self.N - 1 else self.norm(x)} return func -lock = threading.Lock() - class Transformer(nn.Module): def __init__(self, encoder, decoder, src_embed, tgt_embed, pos_enc, generator, h, d_k): super(Transformer, self).__init__() @@ -124,9 +122,10 @@ def forward(self, graph): self.update_graph(g, edges, [(pre_q, nodes), (pre_kv, nodes_e)], [(post_func, nodes)]) # visualize attention - with lock: + """ if self.att_weight_map is None: self._register_att_map(g, graph.nid_arr['enc'][VIZ_IDX], graph.nid_arr['dec'][VIZ_IDX]) + """ return self.generator(g.ndata['x'][nids['dec']]) diff --git a/examples/pytorch/transformer/translation_train.py b/examples/pytorch/transformer/translation_train.py index d9f52f80e20f..8a0b14069296 100644 --- a/examples/pytorch/transformer/translation_train.py +++ b/examples/pytorch/transformer/translation_train.py @@ -1,14 +1,9 @@ -""" -In current version we use multi30k as the default training and validation set. -Multi-GPU support is required to train the model on WMT14. -""" from modules import * from loss import * from optims import * from dataset import * from modules.config import * #from modules.viz import * -#from tqdm import tqdm import numpy as np import argparse import torch @@ -53,25 +48,28 @@ def main(dev_id, args): device = torch.device('cpu') else: device = torch.device('cuda:{}'.format(dev_id)) + # Set current device + th.cuda.set_device(device) + # Prepare dataset dataset = get_dataset(args.dataset) - V = dataset.vocab_size criterion = LabelSmoothing(V, padding_idx=dataset.pad_id, smoothing=0.1) dim_model = 512 - + # Build graph pool graph_pool = GraphPool() + # Create model model = make_model(V, V, N=args.N, dim_model=dim_model, universal=args.universal) - # Sharing weights between Encoder & Decoder model.src_embed.lut.weight = model.tgt_embed.lut.weight model.generator.proj.weight = model.tgt_embed.lut.weight - + # Move model to corresponding device model, criterion = model.to(device), criterion.to(device) - model_opt = NoamOpt(dim_model, 1, 4000 * 1300 / (args.batch * max(1, args.ngpu)), + # Optimizer + model_opt = NoamOpt(dim_model, 1, 4000, T.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9)) - + # Loss function if args.ngpu > 1: dev_rank = dev_id # current device id ndev = args.ngpu # number of devices (including cpu) @@ -82,6 +80,7 @@ def main(dev_id, args): ndev = 1 loss_compute = partial(SimpleLossCompute, criterion) + # Train & evaluate for epoch in range(100): start = time.time() train_iter = dataset(graph_pool, mode='train', batch_size=args.batch, @@ -97,7 +96,6 @@ def main(dev_id, args): run_epoch(epoch, valid_iter, dev_rank, 1, model, loss_compute(opt=None), is_train=False) end = time.time() - time.sleep(1) print("epoch time: {}".format(end - start)) """ @@ -106,11 +104,11 @@ def main(dev_id, args): src_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='src') tgt_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='tgt')[:-1] draw_atts(model.att_weight_map, src_seq, tgt_seq, exp_setting, 'epoch_{}'.format(epoch)) - - print('----------------------------------') + """ + args_filter = ['batch', 'gpus', 'viz', 'master_ip', 'master_port', 'grad_accum', 'ngpu'] + exp_setting = '-'.join('{}'.format(v) for k, v in vars(args).items() if k not in args_filter) with open('checkpoints/{}-{}.pkl'.format(exp_setting, epoch), 'wb') as f: torch.save(model.state_dict(), f) - """ if __name__ == '__main__': if not os.path.exists('checkpoints'): @@ -134,9 +132,7 @@ def main(dev_id, args): 'then update weights') args = argparser.parse_args() print(args) - #args_filter = ['batch', 'gpus', 'viz'] - #exp_setting = '-'.join('{}'.format(v) for k, v in vars(args).items() if k not in args_filter) - #devices = ['cpu'] if args.gpus == '-1' else [int(gpu_id) for gpu_id in args.gpus.split(',')] + devices = list(map(int, args.gpus.split(','))) if len(devices) == 1: args.ngpu = 0 if devices[0] < 0 else 1 From 3f682e543fe37e4c2d2836db97a3c066df7af938 Mon Sep 17 00:00:00 2001 From: Lingfan Yu Date: Wed, 30 Jan 2019 22:26:27 +0000 Subject: [PATCH 08/11] fix bugs --- examples/pytorch/transformer/loss/__init__.py | 49 +++++++++++++------ .../pytorch/transformer/translation_train.py | 42 +++++++++------- 2 files changed, 59 insertions(+), 32 deletions(-) diff --git a/examples/pytorch/transformer/loss/__init__.py b/examples/pytorch/transformer/loss/__init__.py index e14a80cd2cc5..3721f50102b9 100644 --- a/examples/pytorch/transformer/loss/__init__.py +++ b/examples/pytorch/transformer/loss/__init__.py @@ -38,7 +38,7 @@ def forward(self, x, target): class SimpleLossCompute(nn.Module): eps=1e-8 - def __init__(self, criterion, opt=None): + def __init__(self, criterion, grad_accum, opt=None): """ opt is required during training """ @@ -49,6 +49,17 @@ def __init__(self, criterion, opt=None): self.n_correct = 0 self.norm_term = 0 self.loss = 0 + self.batch_count = 0 + self.grad_accum = grad_accum + + def __enter__(self): + self.batch_count = 0 + + def __exit__(self, type, value, traceback): + # if not enough batches accumulated and there are gradients not applied, + # do one more step + if self.batch_count > 0: + self.step() @property def avg_loss(self): @@ -58,11 +69,18 @@ def avg_loss(self): def accuracy(self): return (self.n_correct + self.eps) / (self.norm_term + self.eps) - def backward_and_step(self): - self.loss.backward() + def step(self): self.opt.step() self.opt.optimizer.zero_grad() + + def backward_and_step(self): + self.loss.backward() + self.batch_count += 1 + if self.batch_count == self.grad_accum: + self.step() + self.batch_count = 0 + def __call__(self, y_pred, y, norm): y_pred = y_pred.contiguous().view(-1, y_pred.shape[-1]) y = y.contiguous().view(-1) @@ -78,22 +96,23 @@ def __call__(self, y_pred, y, norm): class MultiGPULossCompute(SimpleLossCompute): def __init__(self, criterion, ndev, grad_accum, model, opt=None): - super(MultiGPULossCompute, self).__init__(criterion, opt) + super(MultiGPULossCompute, self).__init__(criterion, grad_accum, opt=opt) self.ndev = ndev - self.grad_accum = grad_accum self.model = model - self.count = 0 def backward_and_step(self): # multi-gpu synchronous backward self.loss.backward() - self.count += 1 + self.batch_count += 1 # accumulate self.grad_accum times then synchronize and update - if self.count == self.grad_accum: - for param in self.model.parameters(): - if param.requires_grad and param.grad is not None: - dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) - param.grad.data /= self.ndev - self.opt.step() - self.opt.optimizer.zero_grad() - self.count = 0 + if self.batch_count == self.grad_accum: + self.step() + self.batch_count = 0 + + def step(self): + for param in self.model.parameters(): + if param.requires_grad and param.grad is not None: + dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) + param.grad.data /= self.ndev + self.opt.step() + self.opt.optimizer.zero_grad() diff --git a/examples/pytorch/transformer/translation_train.py b/examples/pytorch/transformer/translation_train.py index 8a0b14069296..31e6771653ce 100644 --- a/examples/pytorch/transformer/translation_train.py +++ b/examples/pytorch/transformer/translation_train.py @@ -8,20 +8,22 @@ import argparse import torch from functools import partial +import torch.distributed as dist def run_epoch(epoch, data_iter, dev_rank, ndev, model, loss_compute, is_train=True): universal = isinstance(model, UTransformer) - for i, g in enumerate(data_iter): - #print("Dev {} start batch {}".format(dev_rank, i)) - with T.set_grad_enabled(is_train): - if universal: - output, loss_act = model(g) - if is_train: loss_act.backward(retain_graph=True) - else: - output = model(g) - tgt_y = g.tgt_y - n_tokens = g.n_tokens - loss = loss_compute(output, tgt_y, n_tokens) + with loss_compute: + for i, g in enumerate(data_iter): + #print("Dev {} start batch {}".format(dev_rank, i)) + with T.set_grad_enabled(is_train): + if universal: + output, loss_act = model(g) + if is_train: loss_act.backward(retain_graph=True) + else: + output = model(g) + tgt_y = g.tgt_y + n_tokens = g.n_tokens + loss = loss_compute(output, tgt_y, n_tokens) if universal: for step in range(1, model.MAX_DEPTH + 1): @@ -49,7 +51,7 @@ def main(dev_id, args): else: device = torch.device('cuda:{}'.format(dev_id)) # Set current device - th.cuda.set_device(device) + th.cuda.set_device(device) # Prepare dataset dataset = get_dataset(args.dataset) V = dataset.vocab_size @@ -65,10 +67,6 @@ def main(dev_id, args): model.generator.proj.weight = model.tgt_embed.lut.weight # Move model to corresponding device model, criterion = model.to(device), criterion.to(device) - # Optimizer - model_opt = NoamOpt(dim_model, 1, 4000, - T.optim.Adam(model.parameters(), lr=1e-3, - betas=(0.9, 0.98), eps=1e-9)) # Loss function if args.ngpu > 1: dev_rank = dev_id # current device id @@ -78,7 +76,17 @@ def main(dev_id, args): else: # cpu or single gpu case dev_rank = 0 ndev = 1 - loss_compute = partial(SimpleLossCompute, criterion) + loss_compute = partial(SimpleLossCompute, criterion, args.grad_accum) + + if ndev > 1: + for param in model.parameters(): + dist.all_reduce(param.data, op=dist.ReduceOp.SUM) + param.data /= ndev + + # Optimizer + model_opt = NoamOpt(dim_model, 1, 4000, + T.optim.Adam(model.parameters(), lr=1e-3, + betas=(0.9, 0.98), eps=1e-9)) # Train & evaluate for epoch in range(100): From 91f9465f7f779bcbe66d2c104f135bd35d9c0a7b Mon Sep 17 00:00:00 2001 From: Lingfan Yu Date: Wed, 30 Jan 2019 23:24:15 +0000 Subject: [PATCH 09/11] fix and minor --- examples/pytorch/transformer/dataset/__init__.py | 4 +++- examples/pytorch/transformer/translation_train.py | 5 ++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/pytorch/transformer/dataset/__init__.py b/examples/pytorch/transformer/dataset/__init__.py index c52f336685e7..9cc86e957b7e 100644 --- a/examples/pytorch/transformer/dataset/__init__.py +++ b/examples/pytorch/transformer/dataset/__init__.py @@ -98,8 +98,10 @@ def __call__(self, graph_pool, mode='train', batch_size=32, k=1, graph_pool: a GraphPool object for accelerating. mode: train/valid/test batch_size: batch size - device: torch.device k: beam size(only required for test) + device: torch.device + dev_rank: rank (id) of current device + ndev: number of devices ''' src_data, tgt_data = self.src[mode], self.tgt[mode] n = len(src_data) diff --git a/examples/pytorch/transformer/translation_train.py b/examples/pytorch/transformer/translation_train.py index 31e6771653ce..3697c0afe03e 100644 --- a/examples/pytorch/transformer/translation_train.py +++ b/examples/pytorch/transformer/translation_train.py @@ -14,7 +14,6 @@ def run_epoch(epoch, data_iter, dev_rank, ndev, model, loss_compute, is_train=Tr universal = isinstance(model, UTransformer) with loss_compute: for i, g in enumerate(data_iter): - #print("Dev {} start batch {}".format(dev_rank, i)) with T.set_grad_enabled(is_train): if universal: output, loss_act = model(g) @@ -93,14 +92,14 @@ def main(dev_id, args): start = time.time() train_iter = dataset(graph_pool, mode='train', batch_size=args.batch, device=device, dev_rank=dev_rank, ndev=ndev) - valid_iter = dataset(graph_pool, mode='valid', batch_size=args.batch, - device=device, dev_rank=dev_rank, ndev=ndev) model.train(True) run_epoch(epoch, train_iter, dev_rank, ndev, model, loss_compute(opt=model_opt), is_train=True) if dev_rank == 0: model.att_weight_map = None model.eval() + valid_iter = dataset(graph_pool, mode='valid', batch_size=args.batch, + device=device, dev_rank=dev_rank, ndev=1) run_epoch(epoch, valid_iter, dev_rank, 1, model, loss_compute(opt=None), is_train=False) end = time.time() From c11c4bfb0c09d675350752dd7d7b0a39bffe4ebf Mon Sep 17 00:00:00 2001 From: Lingfan Yu Date: Sat, 2 Feb 2019 03:46:57 +0000 Subject: [PATCH 10/11] comments and clean up --- examples/pytorch/transformer/.gitmodules | 0 .../pytorch/transformer/dataset/__init__.py | 5 ++- examples/pytorch/transformer/loss/__init__.py | 45 ++++++++++++------- 3 files changed, 33 insertions(+), 17 deletions(-) delete mode 100644 examples/pytorch/transformer/.gitmodules diff --git a/examples/pytorch/transformer/.gitmodules b/examples/pytorch/transformer/.gitmodules deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/examples/pytorch/transformer/dataset/__init__.py b/examples/pytorch/transformer/dataset/__init__.py index 9cc86e957b7e..6a63044298de 100644 --- a/examples/pytorch/transformer/dataset/__init__.py +++ b/examples/pytorch/transformer/dataset/__init__.py @@ -99,7 +99,7 @@ def __call__(self, graph_pool, mode='train', batch_size=32, k=1, mode: train/valid/test batch_size: batch size k: beam size(only required for test) - device: torch.device + device: str or torch.device dev_rank: rank (id) of current device ndev: number of devices ''' @@ -108,7 +108,8 @@ def __call__(self, graph_pool, mode='train', batch_size=32, k=1, # make sure all devices have the same number of batch n = n // ndev * ndev - # XXX: is partition then shuffle equivalent to shuffle then partition? + # XXX: partition then shuffle may not be equivalent to shuffle then + # partition order = list(range(dev_rank, n, ndev)) if mode == 'train': random.shuffle(order) diff --git a/examples/pytorch/transformer/loss/__init__.py b/examples/pytorch/transformer/loss/__init__.py index 3721f50102b9..b0ede9a716b6 100644 --- a/examples/pytorch/transformer/loss/__init__.py +++ b/examples/pytorch/transformer/loss/__init__.py @@ -39,8 +39,17 @@ def forward(self, x, target): class SimpleLossCompute(nn.Module): eps=1e-8 def __init__(self, criterion, grad_accum, opt=None): - """ - opt is required during training + """Loss function and optimizer for single device + + Parameters + ---------- + criterion: torch.nn.Module + criterion to compute loss + grad_accum: int + number of batches to accumulate gradients + opt: Optimizer + Model optimizer to use. If None, then no backward and update will be + performed """ super(SimpleLossCompute, self).__init__() self.criterion = criterion @@ -73,10 +82,10 @@ def step(self): self.opt.step() self.opt.optimizer.zero_grad() - def backward_and_step(self): self.loss.backward() self.batch_count += 1 + # accumulate self.grad_accum times then synchronize and update if self.batch_count == self.grad_accum: self.step() self.batch_count = 0 @@ -84,9 +93,7 @@ def backward_and_step(self): def __call__(self, y_pred, y, norm): y_pred = y_pred.contiguous().view(-1, y_pred.shape[-1]) y = y.contiguous().view(-1) - self.loss = self.criterion( - y_pred, y - ) / norm + self.loss = self.criterion(y_pred, y) / norm if self.opt is not None: self.backward_and_step() self.n_correct += ((y_pred.max(dim=-1)[1] == y) & (y != self.criterion.padding_idx)).sum().item() @@ -96,20 +103,28 @@ def __call__(self, y_pred, y, norm): class MultiGPULossCompute(SimpleLossCompute): def __init__(self, criterion, ndev, grad_accum, model, opt=None): + """Loss function and optimizer for multiple devices + + Parameters + ---------- + criterion: torch.nn.Module + criterion to compute loss + ndev: int + number of devices used + grad_accum: int + number of batches to accumulate gradients + model: torch.nn.Module + model to optimizer (needed to iterate and synchronize all parameters) + opt: Optimizer + Model optimizer to use. If None, then no backward and update will be + performed + """ super(MultiGPULossCompute, self).__init__(criterion, grad_accum, opt=opt) self.ndev = ndev self.model = model - def backward_and_step(self): - # multi-gpu synchronous backward - self.loss.backward() - self.batch_count += 1 - # accumulate self.grad_accum times then synchronize and update - if self.batch_count == self.grad_accum: - self.step() - self.batch_count = 0 - def step(self): + # multi-gpu synchronize gradients for param in self.model.parameters(): if param.requires_grad and param.grad is not None: dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) From e5a1299cc76063a5aae544186c40461189b4efe5 Mon Sep 17 00:00:00 2001 From: Lingfan Yu Date: Mon, 11 Feb 2019 20:10:31 -0500 Subject: [PATCH 11/11] uncomment viz code --- examples/pytorch/transformer/translation_train.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/pytorch/transformer/translation_train.py b/examples/pytorch/transformer/translation_train.py index 3697c0afe03e..a4cdcb63063c 100644 --- a/examples/pytorch/transformer/translation_train.py +++ b/examples/pytorch/transformer/translation_train.py @@ -105,13 +105,11 @@ def main(dev_id, args): end = time.time() print("epoch time: {}".format(end - start)) - """ # Visualize attention if args.viz: src_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='src') tgt_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='tgt')[:-1] draw_atts(model.att_weight_map, src_seq, tgt_seq, exp_setting, 'epoch_{}'.format(epoch)) - """ args_filter = ['batch', 'gpus', 'viz', 'master_ip', 'master_port', 'grad_accum', 'ngpu'] exp_setting = '-'.join('{}'.format(v) for k, v in vars(args).items() if k not in args_filter) with open('checkpoints/{}-{}.pkl'.format(exp_setting, epoch), 'wb') as f: