diff --git a/examples/mxnet/tree_lstm/README.md b/examples/mxnet/tree_lstm/README.md new file mode 100644 index 000000000000..49dfb9e2643b --- /dev/null +++ b/examples/mxnet/tree_lstm/README.md @@ -0,0 +1,19 @@ +# Tree-LSTM +This is a re-implementation of the following paper: + +> [**Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks**](http://arxiv.org/abs/1503.00075) +> *Kai Sheng Tai, Richard Socher, and Christopher Manning*. + +The provided implementation can achieve a test accuracy of 51.72 which is comparable with the result reported in the original paper: 51.0(±0.5). + +## Data +The script will download the [SST dataset] (http://nlp.stanford.edu/sentiment/index.html) and the GloVe 840B.300d embedding automatically if `--use-glove` is specified (note: download may take a while). + +## Usage +``` +python train.py --gpu 0 +``` + +## Speed Test + +See https://docs.google.com/spreadsheets/d/1eCQrVn7g0uWriz63EbEDdes2ksMdKdlbWMyT8PSU4rc . diff --git a/examples/mxnet/tree_lstm/train.py b/examples/mxnet/tree_lstm/train.py new file mode 100644 index 000000000000..02ca1322b391 --- /dev/null +++ b/examples/mxnet/tree_lstm/train.py @@ -0,0 +1,208 @@ +import argparse +import time +import warnings +import zipfile +import os + +os.environ['DGLBACKEND'] = 'mxnet' +os.environ['MXNET_GPU_MEM_POOL_TYPE'] = 'Round' + +import numpy as np +import mxnet as mx +from mxnet import gluon + +import dgl +import dgl.data as data + +from tree_lstm import TreeLSTM + +def batcher(ctx): + def batcher_dev(batch): + batch_trees = dgl.batch(batch) + return data.SSTBatch(graph=batch_trees, + mask=batch_trees.ndata['mask'].as_in_context(ctx), + wordid=batch_trees.ndata['x'].as_in_context(ctx), + label=batch_trees.ndata['y'].as_in_context(ctx)) + return batcher_dev + +def prepare_glove(): + if not (os.path.exists('glove.840B.300d.txt') + and data.utils.check_sha1('glove.840B.300d.txt', + sha1_hash='294b9f37fa64cce31f9ebb409c266fc379527708')): + zip_path = data.utils.download('http://nlp.stanford.edu/data/glove.840B.300d.zip', + sha1_hash='8084fbacc2dee3b1fd1ca4cc534cbfff3519ed0d') + with zipfile.ZipFile(zip_path, 'r') as zf: + zf.extractall() + if not data.utils.check_sha1('glove.840B.300d.txt', + sha1_hash='294b9f37fa64cce31f9ebb409c266fc379527708'): + warnings.warn('The downloaded glove embedding file checksum mismatch. File content ' + 'may be corrupted.') + +def main(args): + np.random.seed(args.seed) + mx.random.seed(args.seed) + + best_epoch = -1 + best_dev_acc = 0 + + cuda = args.gpu >= 0 + if cuda: + if args.gpu in mx.test_utils.list_gpus(): + ctx = mx.gpu(args.gpu) + else: + print('Requested GPU id {} was not found. Defaulting to CPU implementation'.format(args.gpu)) + ctx = mx.cpu() + + if args.use_glove: + prepare_glove() + + trainset = data.SST() + train_loader = gluon.data.DataLoader(dataset=trainset, + batch_size=args.batch_size, + batchify_fn=batcher(ctx), + shuffle=True, + num_workers=0) + devset = data.SST(mode='dev') + dev_loader = gluon.data.DataLoader(dataset=devset, + batch_size=100, + batchify_fn=batcher(ctx), + shuffle=True, + num_workers=0) + + testset = data.SST(mode='test') + test_loader = gluon.data.DataLoader(dataset=testset, + batch_size=100, + batchify_fn=batcher(ctx), + shuffle=False, num_workers=0) + + model = TreeLSTM(trainset.num_vocabs, + args.x_size, + args.h_size, + trainset.num_classes, + args.dropout, + cell_type='childsum' if args.child_sum else 'nary', + pretrained_emb = trainset.pretrained_emb, + ctx=ctx) + print(model) + params_ex_emb =[x for x in model.collect_params().values() + if x.grad_req != 'null' and x.shape[0] != trainset.num_vocabs] + params_emb = list(model.embedding.collect_params().values()) + for p in params_emb: + p.lr_mult = 0.1 + + model.initialize(mx.init.Xavier(magnitude=1), ctx=ctx) + model.hybridize() + trainer = gluon.Trainer(model.collect_params('^(?!embedding).*$'), 'adagrad', + {'learning_rate': args.lr, 'wd': args.weight_decay}) + trainer_emb = gluon.Trainer(model.collect_params('^embedding.*$'), 'adagrad', + {'learning_rate': args.lr}) + + dur = [] + L = gluon.loss.SoftmaxCrossEntropyLoss(axis=1) + for epoch in range(args.epochs): + t_epoch = time.time() + for step, batch in enumerate(train_loader): + g = batch.graph + n = g.number_of_nodes() + + # TODO begin_states function? + h = mx.nd.zeros((n, args.h_size), ctx=ctx) + c = mx.nd.zeros((n, args.h_size), ctx=ctx) + if step >= 3: + t0 = time.time() # tik + with mx.autograd.record(): + pred = model(batch, h, c) + loss = L(pred, batch.label) + + loss.backward() + trainer.step(args.batch_size) + trainer_emb.step(args.batch_size) + + if step >= 3: + dur.append(time.time() - t0) # tok + + if step > 0 and step % args.log_every == 0: + pred = pred.argmax(axis=1) + acc = (batch.label == pred).sum() + root_ids = [i for i in range(batch.graph.number_of_nodes()) if batch.graph.out_degree(i)==0] + root_acc = np.sum(batch.label.asnumpy()[root_ids] == pred.asnumpy()[root_ids]) + + print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} | Root Acc {:.4f} | Time(s) {:.4f}".format( + epoch, step, loss.sum().asscalar(), 1.0*acc.asscalar()/len(batch.label), 1.0*root_acc/len(root_ids), np.mean(dur))) + print('Epoch {:05d} training time {:.4f}s'.format(epoch, time.time() - t_epoch)) + + # eval on dev set + accs = [] + root_accs = [] + for step, batch in enumerate(dev_loader): + g = batch.graph + n = g.number_of_nodes() + h = mx.nd.zeros((n, args.h_size), ctx=ctx) + c = mx.nd.zeros((n, args.h_size), ctx=ctx) + pred = model(batch, h, c).argmax(1) + + acc = (batch.label == pred).sum().asscalar() + accs.append([acc, len(batch.label)]) + root_ids = [i for i in range(batch.graph.number_of_nodes()) if batch.graph.out_degree(i)==0] + root_acc = np.sum(batch.label.asnumpy()[root_ids] == pred.asnumpy()[root_ids]) + root_accs.append([root_acc, len(root_ids)]) + + dev_acc = 1.0*np.sum([x[0] for x in accs])/np.sum([x[1] for x in accs]) + dev_root_acc = 1.0*np.sum([x[0] for x in root_accs])/np.sum([x[1] for x in root_accs]) + print("Epoch {:05d} | Dev Acc {:.4f} | Root Acc {:.4f}".format( + epoch, dev_acc, dev_root_acc)) + + if dev_root_acc > best_dev_acc: + best_dev_acc = dev_root_acc + best_epoch = epoch + model.save_parameters('best_{}.params'.format(args.seed)) + else: + if best_epoch <= epoch - 10: + break + + # lr decay + trainer.set_learning_rate(max(1e-5, trainer.learning_rate*0.99)) + print(trainer.learning_rate) + trainer_emb.set_learning_rate(max(1e-5, trainer_emb.learning_rate*0.99)) + print(trainer_emb.learning_rate) + + # test + model.load_parameters('best_{}.params'.format(args.seed)) + accs = [] + root_accs = [] + for step, batch in enumerate(test_loader): + g = batch.graph + n = g.number_of_nodes() + h = mx.nd.zeros((n, args.h_size), ctx=ctx) + c = mx.nd.zeros((n, args.h_size), ctx=ctx) + pred = model(batch, h, c).argmax(axis=1) + + acc = (batch.label == pred).sum().asscalar() + accs.append([acc, len(batch.label)]) + root_ids = [i for i in range(batch.graph.number_of_nodes()) if batch.graph.out_degree(i)==0] + root_acc = np.sum(batch.label.asnumpy()[root_ids] == pred.asnumpy()[root_ids]) + root_accs.append([root_acc, len(root_ids)]) + + test_acc = 1.0*np.sum([x[0] for x in accs])/np.sum([x[1] for x in accs]) + test_root_acc = 1.0*np.sum([x[0] for x in root_accs])/np.sum([x[1] for x in root_accs]) + print('------------------------------------------------------------------------------------') + print("Epoch {:05d} | Test Acc {:.4f} | Root Acc {:.4f}".format( + best_epoch, test_acc, test_root_acc)) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--gpu', type=int, default=0) + parser.add_argument('--seed', type=int, default=41) + parser.add_argument('--batch-size', type=int, default=256) + parser.add_argument('--child-sum', action='store_true') + parser.add_argument('--x-size', type=int, default=300) + parser.add_argument('--h-size', type=int, default=150) + parser.add_argument('--epochs', type=int, default=100) + parser.add_argument('--log-every', type=int, default=5) + parser.add_argument('--lr', type=float, default=0.05) + parser.add_argument('--weight-decay', type=float, default=1e-4) + parser.add_argument('--dropout', type=float, default=0.5) + parser.add_argument('--use-glove', action='store_true') + args = parser.parse_args() + print(args) + main(args) diff --git a/examples/mxnet/tree_lstm/tree_lstm.py b/examples/mxnet/tree_lstm/tree_lstm.py new file mode 100644 index 000000000000..f4c78d6bac71 --- /dev/null +++ b/examples/mxnet/tree_lstm/tree_lstm.py @@ -0,0 +1,129 @@ +""" +Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks +https://arxiv.org/abs/1503.00075 +""" +import time +import itertools +import networkx as nx +import numpy as np +import mxnet as mx +from mxnet import gluon +import dgl + +class _TreeLSTMCellNodeFunc(gluon.HybridBlock): + def hybrid_forward(self, F, iou, b_iou, c): + iou = F.broadcast_add(iou, b_iou) + i, o, u = iou.split(num_outputs=3, axis=1) + i, o, u = i.sigmoid(), o.sigmoid(), u.tanh() + c = i * u + c + h = o * c.tanh() + + return h, c + +class _TreeLSTMCellReduceFunc(gluon.HybridBlock): + def __init__(self, U_iou, U_f): + super(_TreeLSTMCellReduceFunc, self).__init__() + self.U_iou = U_iou + self.U_f = U_f + + def hybrid_forward(self, F, h, c): + h_cat = h.reshape((0, -1)) + f = self.U_f(h_cat).sigmoid().reshape_like(h) + c = (f * c).sum(axis=1) + iou = self.U_iou(h_cat) + return iou, c + +class _TreeLSTMCell(gluon.HybridBlock): + def __init__(self, h_size): + super(_TreeLSTMCell, self).__init__() + self._apply_node_func = _TreeLSTMCellNodeFunc() + self.b_iou = self.params.get('bias', shape=(1, 3 * h_size), + init='zeros') + + def message_func(self, edges): + return {'h': edges.src['h'], 'c': edges.src['c']} + + def apply_node_func(self, nodes): + iou = nodes.data['iou'] + b_iou, c = self.b_iou.data(iou.context), nodes.data['c'] + h, c = self._apply_node_func(iou, b_iou, c) + return {'h' : h, 'c' : c} + +class TreeLSTMCell(_TreeLSTMCell): + def __init__(self, x_size, h_size): + super(TreeLSTMCell, self).__init__(h_size) + self._reduce_func = _TreeLSTMCellReduceFunc( + gluon.nn.Dense(3 * h_size, use_bias=False), + gluon.nn.Dense(2 * h_size)) + self.W_iou = gluon.nn.Dense(3 * h_size, use_bias=False) + + def reduce_func(self, nodes): + h, c = nodes.mailbox['h'], nodes.mailbox['c'] + iou, c = self._reduce_func(h, c) + return {'iou': iou, 'c': c} + +class ChildSumTreeLSTMCell(_TreeLSTMCell): + def __init__(self, x_size, h_size): + super(ChildSumTreeLSTMCell, self).__init__() + self.W_iou = gluon.nn.Dense(3 * h_size, use_bias=False) + self.U_iou = gluon.nn.Dense(3 * h_size, use_bias=False) + self.U_f = gluon.nn.Dense(h_size) + + def reduce_func(self, nodes): + h_tild = nodes.mailbox['h'].sum(axis=1) + f = self.U_f(nodes.mailbox['h']).sigmoid() + c = (f * nodes.mailbox['c']).sum(axis=1) + return {'iou': self.U_iou(h_tild), 'c': c} + +class TreeLSTM(gluon.nn.Block): + def __init__(self, + num_vocabs, + x_size, + h_size, + num_classes, + dropout, + cell_type='nary', + pretrained_emb=None, + ctx=None): + super(TreeLSTM, self).__init__() + self.x_size = x_size + self.embedding = gluon.nn.Embedding(num_vocabs, x_size) + if pretrained_emb is not None: + print('Using glove') + self.embedding.initialize(ctx=ctx) + self.embedding.weight.set_data(pretrained_emb) + self.dropout = gluon.nn.Dropout(dropout) + self.linear = gluon.nn.Dense(num_classes) + cell = TreeLSTMCell if cell_type == 'nary' else ChildSumTreeLSTMCell + self.cell = cell(x_size, h_size) + + def forward(self, batch, h, c): + """Compute tree-lstm prediction given a batch. + Parameters + ---------- + batch : dgl.data.SSTBatch + The data batch. + h : Tensor + Initial hidden state. + c : Tensor + Initial cell state. + Returns + ------- + logits : Tensor + The prediction of each node. + """ + g = batch.graph + g.register_message_func(self.cell.message_func) + g.register_reduce_func(self.cell.reduce_func) + g.register_apply_node_func(self.cell.apply_node_func) + # feed embedding + embeds = self.embedding(batch.wordid * batch.mask) + g.ndata['iou'] = self.cell.W_iou(self.dropout(embeds)) * batch.mask.expand_dims(-1) + g.ndata['h'] = h + g.ndata['c'] = c + # propagate + dgl.prop_nodes_topo(g) + # compute logits + h = self.dropout(g.ndata.pop('h')) + logits = self.linear(h) + return logits diff --git a/examples/pytorch/tree_lstm/train.py b/examples/pytorch/tree_lstm/train.py index e1c013b88530..b8f9bcc73e59 100644 --- a/examples/pytorch/tree_lstm/train.py +++ b/examples/pytorch/tree_lstm/train.py @@ -9,7 +9,7 @@ from torch.utils.data import DataLoader import dgl -from dgl.data.tree import SST +from dgl.data.tree import SST, SSTBatch from tree_lstm import TreeLSTM