Skip to content

Commit

Permalink
add glove download
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Dec 13, 2018
1 parent 6945f3e commit ed2d170
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions examples/mxnet/tree_lstm/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import argparse
import time
import warnings
import zipfile
import numpy as np
import mxnet as mx
from mxnet import gluon
Expand All @@ -18,6 +20,18 @@ def batcher_dev(batch):
label=batch_trees.ndata['y'].as_in_context(ctx))
return batcher_dev

def prepare_glove():
if not data.utils.check_sha1('glove.840B.300d.txt',
sha1_hash='294b9f37fa64cce31f9ebb409c266fc379527708'):
zip_path = data.utils.download('http://nlp.stanford.edu/data/glove.840B.300d.zip',
sha1_hash='8084fbacc2dee3b1fd1ca4cc534cbfff3519ed0d')
with zipfile.ZipFile(zip_path, 'r') as zf:
zf.extractall()
if not data.utils.check_sha1('glove.840B.300d.txt',
sha1_hash=TODO1):
warnings.warn('The downloaded glove embedding file checksum mismatch. File content '
'may be corrupted.')

def main(args):
np.random.seed(args.seed)
mx.random.seed(args.seed)
Expand All @@ -28,6 +42,9 @@ def main(args):
cuda = args.gpu >= 0
ctx = mx.gpu(args.gpu) if cuda else mx.cpu()

if args.use_glove:
prepare_glove()

trainset = data.SST()
train_loader = gluon.data.DataLoader(dataset=trainset,
batch_size=args.batch_size,
Expand Down Expand Up @@ -173,6 +190,7 @@ def main(args):
parser.add_argument('--lr', type=float, default=0.05)
parser.add_argument('--weight-decay', type=float, default=1e-4)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--use-glove', action='store_true')
args = parser.parse_args()
print(args)
main(args)

0 comments on commit ed2d170

Please sign in to comment.