-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model][MXNet] MXNet Tree LSTM example #279
Changes from all commits
cd29a85
5341b21
7eb2271
2c14e3b
d117b4d
08d4153
77a2ccf
0c90b6f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Tree-LSTM | ||
This is a re-implementation of the following paper: | ||
|
||
> [**Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks**](http://arxiv.org/abs/1503.00075) | ||
> *Kai Sheng Tai, Richard Socher, and Christopher Manning*. | ||
|
||
The provided implementation can achieve a test accuracy of 51.72 which is comparable with the result reported in the original paper: 51.0(±0.5). | ||
|
||
## Data | ||
The script will download the [SST dataset] (http://nlp.stanford.edu/sentiment/index.html) and the GloVe 840B.300d embedding automatically if `--use-glove` is specified (note: download may take a while). | ||
|
||
## Usage | ||
``` | ||
python train.py --gpu 0 | ||
``` | ||
|
||
## Speed Test | ||
|
||
See https://docs.google.com/spreadsheets/d/1eCQrVn7g0uWriz63EbEDdes2ksMdKdlbWMyT8PSU4rc . |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,208 @@ | ||
import argparse | ||
import time | ||
import warnings | ||
import zipfile | ||
import os | ||
|
||
os.environ['DGLBACKEND'] = 'mxnet' | ||
os.environ['MXNET_GPU_MEM_POOL_TYPE'] = 'Round' | ||
|
||
import numpy as np | ||
import mxnet as mx | ||
from mxnet import gluon | ||
|
||
import dgl | ||
import dgl.data as data | ||
|
||
from tree_lstm import TreeLSTM | ||
|
||
def batcher(ctx): | ||
def batcher_dev(batch): | ||
batch_trees = dgl.batch(batch) | ||
return data.SSTBatch(graph=batch_trees, | ||
mask=batch_trees.ndata['mask'].as_in_context(ctx), | ||
wordid=batch_trees.ndata['x'].as_in_context(ctx), | ||
label=batch_trees.ndata['y'].as_in_context(ctx)) | ||
return batcher_dev | ||
|
||
def prepare_glove(): | ||
if not (os.path.exists('glove.840B.300d.txt') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think PyTorch Tree LSTM should prepare glove inside training script too. |
||
and data.utils.check_sha1('glove.840B.300d.txt', | ||
sha1_hash='294b9f37fa64cce31f9ebb409c266fc379527708')): | ||
zip_path = data.utils.download('http://nlp.stanford.edu/data/glove.840B.300d.zip', | ||
sha1_hash='8084fbacc2dee3b1fd1ca4cc534cbfff3519ed0d') | ||
with zipfile.ZipFile(zip_path, 'r') as zf: | ||
zf.extractall() | ||
if not data.utils.check_sha1('glove.840B.300d.txt', | ||
sha1_hash='294b9f37fa64cce31f9ebb409c266fc379527708'): | ||
warnings.warn('The downloaded glove embedding file checksum mismatch. File content ' | ||
'may be corrupted.') | ||
|
||
def main(args): | ||
np.random.seed(args.seed) | ||
mx.random.seed(args.seed) | ||
|
||
best_epoch = -1 | ||
best_dev_acc = 0 | ||
|
||
cuda = args.gpu >= 0 | ||
if cuda: | ||
if args.gpu in mx.test_utils.list_gpus(): | ||
ctx = mx.gpu(args.gpu) | ||
else: | ||
print('Requested GPU id {} was not found. Defaulting to CPU implementation'.format(args.gpu)) | ||
ctx = mx.cpu() | ||
|
||
if args.use_glove: | ||
prepare_glove() | ||
|
||
trainset = data.SST() | ||
train_loader = gluon.data.DataLoader(dataset=trainset, | ||
batch_size=args.batch_size, | ||
batchify_fn=batcher(ctx), | ||
shuffle=True, | ||
num_workers=0) | ||
devset = data.SST(mode='dev') | ||
dev_loader = gluon.data.DataLoader(dataset=devset, | ||
batch_size=100, | ||
batchify_fn=batcher(ctx), | ||
shuffle=True, | ||
num_workers=0) | ||
|
||
testset = data.SST(mode='test') | ||
test_loader = gluon.data.DataLoader(dataset=testset, | ||
batch_size=100, | ||
batchify_fn=batcher(ctx), | ||
shuffle=False, num_workers=0) | ||
|
||
model = TreeLSTM(trainset.num_vocabs, | ||
args.x_size, | ||
args.h_size, | ||
trainset.num_classes, | ||
args.dropout, | ||
cell_type='childsum' if args.child_sum else 'nary', | ||
pretrained_emb = trainset.pretrained_emb, | ||
ctx=ctx) | ||
print(model) | ||
params_ex_emb =[x for x in model.collect_params().values() | ||
if x.grad_req != 'null' and x.shape[0] != trainset.num_vocabs] | ||
params_emb = list(model.embedding.collect_params().values()) | ||
for p in params_emb: | ||
p.lr_mult = 0.1 | ||
|
||
model.initialize(mx.init.Xavier(magnitude=1), ctx=ctx) | ||
model.hybridize() | ||
trainer = gluon.Trainer(model.collect_params('^(?!embedding).*$'), 'adagrad', | ||
{'learning_rate': args.lr, 'wd': args.weight_decay}) | ||
trainer_emb = gluon.Trainer(model.collect_params('^embedding.*$'), 'adagrad', | ||
{'learning_rate': args.lr}) | ||
|
||
dur = [] | ||
L = gluon.loss.SoftmaxCrossEntropyLoss(axis=1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In DyNet implementation, they use reduction= |
||
for epoch in range(args.epochs): | ||
t_epoch = time.time() | ||
for step, batch in enumerate(train_loader): | ||
g = batch.graph | ||
n = g.number_of_nodes() | ||
|
||
# TODO begin_states function? | ||
h = mx.nd.zeros((n, args.h_size), ctx=ctx) | ||
c = mx.nd.zeros((n, args.h_size), ctx=ctx) | ||
if step >= 3: | ||
t0 = time.time() # tik | ||
with mx.autograd.record(): | ||
pred = model(batch, h, c) | ||
loss = L(pred, batch.label) | ||
|
||
loss.backward() | ||
trainer.step(args.batch_size) | ||
trainer_emb.step(args.batch_size) | ||
|
||
if step >= 3: | ||
dur.append(time.time() - t0) # tok | ||
|
||
if step > 0 and step % args.log_every == 0: | ||
pred = pred.argmax(axis=1) | ||
acc = (batch.label == pred).sum() | ||
root_ids = [i for i in range(batch.graph.number_of_nodes()) if batch.graph.out_degree(i)==0] | ||
root_acc = np.sum(batch.label.asnumpy()[root_ids] == pred.asnumpy()[root_ids]) | ||
|
||
print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} | Root Acc {:.4f} | Time(s) {:.4f}".format( | ||
epoch, step, loss.sum().asscalar(), 1.0*acc.asscalar()/len(batch.label), 1.0*root_acc/len(root_ids), np.mean(dur))) | ||
print('Epoch {:05d} training time {:.4f}s'.format(epoch, time.time() - t_epoch)) | ||
|
||
# eval on dev set | ||
accs = [] | ||
root_accs = [] | ||
for step, batch in enumerate(dev_loader): | ||
g = batch.graph | ||
n = g.number_of_nodes() | ||
h = mx.nd.zeros((n, args.h_size), ctx=ctx) | ||
c = mx.nd.zeros((n, args.h_size), ctx=ctx) | ||
pred = model(batch, h, c).argmax(1) | ||
|
||
acc = (batch.label == pred).sum().asscalar() | ||
accs.append([acc, len(batch.label)]) | ||
root_ids = [i for i in range(batch.graph.number_of_nodes()) if batch.graph.out_degree(i)==0] | ||
root_acc = np.sum(batch.label.asnumpy()[root_ids] == pred.asnumpy()[root_ids]) | ||
root_accs.append([root_acc, len(root_ids)]) | ||
|
||
dev_acc = 1.0*np.sum([x[0] for x in accs])/np.sum([x[1] for x in accs]) | ||
dev_root_acc = 1.0*np.sum([x[0] for x in root_accs])/np.sum([x[1] for x in root_accs]) | ||
print("Epoch {:05d} | Dev Acc {:.4f} | Root Acc {:.4f}".format( | ||
epoch, dev_acc, dev_root_acc)) | ||
|
||
if dev_root_acc > best_dev_acc: | ||
best_dev_acc = dev_root_acc | ||
best_epoch = epoch | ||
model.save_parameters('best_{}.params'.format(args.seed)) | ||
else: | ||
if best_epoch <= epoch - 10: | ||
break | ||
|
||
# lr decay | ||
trainer.set_learning_rate(max(1e-5, trainer.learning_rate*0.99)) | ||
print(trainer.learning_rate) | ||
trainer_emb.set_learning_rate(max(1e-5, trainer_emb.learning_rate*0.99)) | ||
print(trainer_emb.learning_rate) | ||
|
||
# test | ||
model.load_parameters('best_{}.params'.format(args.seed)) | ||
accs = [] | ||
root_accs = [] | ||
for step, batch in enumerate(test_loader): | ||
g = batch.graph | ||
n = g.number_of_nodes() | ||
h = mx.nd.zeros((n, args.h_size), ctx=ctx) | ||
c = mx.nd.zeros((n, args.h_size), ctx=ctx) | ||
pred = model(batch, h, c).argmax(axis=1) | ||
|
||
acc = (batch.label == pred).sum().asscalar() | ||
accs.append([acc, len(batch.label)]) | ||
root_ids = [i for i in range(batch.graph.number_of_nodes()) if batch.graph.out_degree(i)==0] | ||
root_acc = np.sum(batch.label.asnumpy()[root_ids] == pred.asnumpy()[root_ids]) | ||
root_accs.append([root_acc, len(root_ids)]) | ||
|
||
test_acc = 1.0*np.sum([x[0] for x in accs])/np.sum([x[1] for x in accs]) | ||
test_root_acc = 1.0*np.sum([x[0] for x in root_accs])/np.sum([x[1] for x in root_accs]) | ||
print('------------------------------------------------------------------------------------') | ||
print("Epoch {:05d} | Test Acc {:.4f} | Root Acc {:.4f}".format( | ||
best_epoch, test_acc, test_root_acc)) | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--gpu', type=int, default=0) | ||
parser.add_argument('--seed', type=int, default=41) | ||
parser.add_argument('--batch-size', type=int, default=256) | ||
parser.add_argument('--child-sum', action='store_true') | ||
parser.add_argument('--x-size', type=int, default=300) | ||
parser.add_argument('--h-size', type=int, default=150) | ||
parser.add_argument('--epochs', type=int, default=100) | ||
parser.add_argument('--log-every', type=int, default=5) | ||
parser.add_argument('--lr', type=float, default=0.05) | ||
parser.add_argument('--weight-decay', type=float, default=1e-4) | ||
parser.add_argument('--dropout', type=float, default=0.5) | ||
parser.add_argument('--use-glove', action='store_true') | ||
args = parser.parse_args() | ||
print(args) | ||
main(args) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
""" | ||
Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks | ||
https://arxiv.org/abs/1503.00075 | ||
""" | ||
import time | ||
import itertools | ||
import networkx as nx | ||
import numpy as np | ||
import mxnet as mx | ||
from mxnet import gluon | ||
import dgl | ||
|
||
class _TreeLSTMCellNodeFunc(gluon.HybridBlock): | ||
def hybrid_forward(self, F, iou, b_iou, c): | ||
iou = F.broadcast_add(iou, b_iou) | ||
i, o, u = iou.split(num_outputs=3, axis=1) | ||
i, o, u = i.sigmoid(), o.sigmoid(), u.tanh() | ||
c = i * u + c | ||
h = o * c.tanh() | ||
|
||
return h, c | ||
|
||
class _TreeLSTMCellReduceFunc(gluon.HybridBlock): | ||
def __init__(self, U_iou, U_f): | ||
super(_TreeLSTMCellReduceFunc, self).__init__() | ||
self.U_iou = U_iou | ||
self.U_f = U_f | ||
|
||
def hybrid_forward(self, F, h, c): | ||
h_cat = h.reshape((0, -1)) | ||
f = self.U_f(h_cat).sigmoid().reshape_like(h) | ||
c = (f * c).sum(axis=1) | ||
iou = self.U_iou(h_cat) | ||
return iou, c | ||
|
||
class _TreeLSTMCell(gluon.HybridBlock): | ||
def __init__(self, h_size): | ||
super(_TreeLSTMCell, self).__init__() | ||
self._apply_node_func = _TreeLSTMCellNodeFunc() | ||
self.b_iou = self.params.get('bias', shape=(1, 3 * h_size), | ||
init='zeros') | ||
|
||
def message_func(self, edges): | ||
return {'h': edges.src['h'], 'c': edges.src['c']} | ||
|
||
def apply_node_func(self, nodes): | ||
iou = nodes.data['iou'] | ||
b_iou, c = self.b_iou.data(iou.context), nodes.data['c'] | ||
h, c = self._apply_node_func(iou, b_iou, c) | ||
return {'h' : h, 'c' : c} | ||
|
||
class TreeLSTMCell(_TreeLSTMCell): | ||
def __init__(self, x_size, h_size): | ||
super(TreeLSTMCell, self).__init__(h_size) | ||
self._reduce_func = _TreeLSTMCellReduceFunc( | ||
gluon.nn.Dense(3 * h_size, use_bias=False), | ||
gluon.nn.Dense(2 * h_size)) | ||
self.W_iou = gluon.nn.Dense(3 * h_size, use_bias=False) | ||
|
||
def reduce_func(self, nodes): | ||
h, c = nodes.mailbox['h'], nodes.mailbox['c'] | ||
iou, c = self._reduce_func(h, c) | ||
return {'iou': iou, 'c': c} | ||
|
||
class ChildSumTreeLSTMCell(_TreeLSTMCell): | ||
def __init__(self, x_size, h_size): | ||
super(ChildSumTreeLSTMCell, self).__init__() | ||
self.W_iou = gluon.nn.Dense(3 * h_size, use_bias=False) | ||
self.U_iou = gluon.nn.Dense(3 * h_size, use_bias=False) | ||
self.U_f = gluon.nn.Dense(h_size) | ||
|
||
def reduce_func(self, nodes): | ||
h_tild = nodes.mailbox['h'].sum(axis=1) | ||
f = self.U_f(nodes.mailbox['h']).sigmoid() | ||
c = (f * nodes.mailbox['c']).sum(axis=1) | ||
return {'iou': self.U_iou(h_tild), 'c': c} | ||
|
||
class TreeLSTM(gluon.nn.Block): | ||
def __init__(self, | ||
num_vocabs, | ||
x_size, | ||
h_size, | ||
num_classes, | ||
dropout, | ||
cell_type='nary', | ||
pretrained_emb=None, | ||
ctx=None): | ||
super(TreeLSTM, self).__init__() | ||
self.x_size = x_size | ||
self.embedding = gluon.nn.Embedding(num_vocabs, x_size) | ||
if pretrained_emb is not None: | ||
print('Using glove') | ||
self.embedding.initialize(ctx=ctx) | ||
self.embedding.weight.set_data(pretrained_emb) | ||
self.dropout = gluon.nn.Dropout(dropout) | ||
self.linear = gluon.nn.Dense(num_classes) | ||
cell = TreeLSTMCell if cell_type == 'nary' else ChildSumTreeLSTMCell | ||
self.cell = cell(x_size, h_size) | ||
|
||
def forward(self, batch, h, c): | ||
"""Compute tree-lstm prediction given a batch. | ||
Parameters | ||
---------- | ||
batch : dgl.data.SSTBatch | ||
The data batch. | ||
h : Tensor | ||
Initial hidden state. | ||
c : Tensor | ||
Initial cell state. | ||
Returns | ||
------- | ||
logits : Tensor | ||
The prediction of each node. | ||
""" | ||
g = batch.graph | ||
g.register_message_func(self.cell.message_func) | ||
g.register_reduce_func(self.cell.reduce_func) | ||
g.register_apply_node_func(self.cell.apply_node_func) | ||
# feed embedding | ||
embeds = self.embedding(batch.wordid * batch.mask) | ||
g.ndata['iou'] = self.cell.W_iou(self.dropout(embeds)) * batch.mask.expand_dims(-1) | ||
g.ndata['h'] = h | ||
g.ndata['c'] = c | ||
# propagate | ||
dgl.prop_nodes_topo(g) | ||
# compute logits | ||
h = self.dropout(g.ndata.pop('h')) | ||
logits = self.linear(h) | ||
return logits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does MXNet Tree-LSTM produce the same result as PyTorch? That's interesting