From ed2d1709651dd0d8c0cb871a18b64c216155415c Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Sat, 8 Dec 2018 02:27:22 -0500 Subject: [PATCH] add glove download --- examples/mxnet/tree_lstm/train.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/examples/mxnet/tree_lstm/train.py b/examples/mxnet/tree_lstm/train.py index b2346f2a0406..365762a5259b 100644 --- a/examples/mxnet/tree_lstm/train.py +++ b/examples/mxnet/tree_lstm/train.py @@ -1,5 +1,7 @@ import argparse import time +import warnings +import zipfile import numpy as np import mxnet as mx from mxnet import gluon @@ -18,6 +20,18 @@ def batcher_dev(batch): label=batch_trees.ndata['y'].as_in_context(ctx)) return batcher_dev +def prepare_glove(): + if not data.utils.check_sha1('glove.840B.300d.txt', + sha1_hash='294b9f37fa64cce31f9ebb409c266fc379527708'): + zip_path = data.utils.download('http://nlp.stanford.edu/data/glove.840B.300d.zip', + sha1_hash='8084fbacc2dee3b1fd1ca4cc534cbfff3519ed0d') + with zipfile.ZipFile(zip_path, 'r') as zf: + zf.extractall() + if not data.utils.check_sha1('glove.840B.300d.txt', + sha1_hash=TODO1): + warnings.warn('The downloaded glove embedding file checksum mismatch. File content ' + 'may be corrupted.') + def main(args): np.random.seed(args.seed) mx.random.seed(args.seed) @@ -28,6 +42,9 @@ def main(args): cuda = args.gpu >= 0 ctx = mx.gpu(args.gpu) if cuda else mx.cpu() + if args.use_glove: + prepare_glove() + trainset = data.SST() train_loader = gluon.data.DataLoader(dataset=trainset, batch_size=args.batch_size, @@ -173,6 +190,7 @@ def main(args): parser.add_argument('--lr', type=float, default=0.05) parser.add_argument('--weight-decay', type=float, default=1e-4) parser.add_argument('--dropout', type=float, default=0.5) + parser.add_argument('--use-glove', action='store_true') args = parser.parse_args() print(args) main(args)