Skip to content

Commit

Permalink
[numpy] [DO NOT MERGE] Fix d2l chapters 9 and 13 (apache#15246)
Browse files Browse the repository at this point in the history
* Add npx batch_dot and topk

* Text embedding uses numpy

* Fix SoftmaxCrossEntropyLoss with np

* Fix sentiment cnn

* Fix pylint

* Fix dot attention

* Fix seq2seq attention

* Add np.tile

* Fix transformer

* Fix ci

* Fix ci and rebase
  • Loading branch information
reminisce authored and haojin2 committed Jul 22, 2019
1 parent 092bdf9 commit 0138b04
Show file tree
Hide file tree
Showing 19 changed files with 273 additions and 56 deletions.
23 changes: 23 additions & 0 deletions python/mxnet/_numpy_op_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,26 @@ def _np_zeros_like(a):
Array of zeros with the same shape and type as `a`.
"""
pass


def _np_repeat(a, repeats, axis=None):
"""Repeat elements of an array.
Parameters
----------
a : ndarray
Input array.
repeats : int or array of ints
The number of repetitions for each element. `repeats` is broadcasted
to fit the shape of the given axis.
axis : int, optional
The axis along which to repeat values. By default, use the
flattened input array, and return a flat output array.
Returns
-------
repeated_array : ndarray
Output array which has the same shape as `a`, except along
the given axis.
"""
pass
29 changes: 22 additions & 7 deletions python/mxnet/contrib/text/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
from ... import ndarray as nd
from ... import registry
from ... import base
from ...util import is_np_array
from ... import numpy as _mx_np
from ... import numpy_extension as _mx_npx


def register(embedding_cls):
Expand Down Expand Up @@ -295,12 +298,15 @@ def _load_embedding(self, pretrained_file_path, elem_delim, init_unknown_vec, en
tokens.add(token)

self._vec_len = vec_len
self._idx_to_vec = nd.array(all_elems).reshape((-1, self.vec_len))
array_fn = _mx_np.array if is_np_array() else nd.array
self._idx_to_vec = array_fn(all_elems).reshape((-1, self.vec_len))

if loaded_unknown_vec is None:
self._idx_to_vec[C.UNKNOWN_IDX] = init_unknown_vec(shape=self.vec_len)
init_val = init_unknown_vec(shape=self.vec_len)
self._idx_to_vec[C.UNKNOWN_IDX] =\
init_val.as_np_ndarray() if is_np_array() else init_val
else:
self._idx_to_vec[C.UNKNOWN_IDX] = nd.array(loaded_unknown_vec)
self._idx_to_vec[C.UNKNOWN_IDX] = array_fn(loaded_unknown_vec)

def _index_tokens_from_vocabulary(self, vocabulary):
self._token_to_idx = vocabulary.token_to_idx.copy() \
Expand Down Expand Up @@ -328,7 +334,8 @@ def _set_idx_to_vec_by_embeddings(self, token_embeddings, vocab_len, vocab_idx_t
"""

new_vec_len = sum(embed.vec_len for embed in token_embeddings)
new_idx_to_vec = nd.zeros(shape=(vocab_len, new_vec_len))
zeros_fn = _mx_np.zeros if is_np_array() else nd.zeros
new_idx_to_vec = zeros_fn(shape=(vocab_len, new_vec_len))

col_start = 0
# Concatenate all the embedding vectors in token_embeddings.
Expand Down Expand Up @@ -397,7 +404,13 @@ def get_vecs_by_tokens(self, tokens, lower_case_backup=False):
else self.token_to_idx.get(token.lower(), C.UNKNOWN_IDX)
for token in tokens]

vecs = nd.Embedding(nd.array(indices), self.idx_to_vec, self.idx_to_vec.shape[0],
if is_np_array():
embedding_fn = _mx_npx.Embedding
array_fn = _mx_np.array
else:
embedding_fn = nd.Embedding
array_fn = nd.array
vecs = embedding_fn(array_fn(indices), self.idx_to_vec, self.idx_to_vec.shape[0],
self.idx_to_vec.shape[1])

return vecs[0] if to_reduce else vecs
Expand Down Expand Up @@ -425,7 +438,8 @@ def update_token_vectors(self, tokens, new_vectors):
if not isinstance(tokens, list):
tokens = [tokens]
if len(new_vectors.shape) == 1:
new_vectors = new_vectors.expand_dims(0)
expand_dims_fn = _mx_np.expand_dims if is_np_array() else nd.expand_dims
new_vectors = expand_dims_fn(new_vectors, axis=0)

else:
assert isinstance(new_vectors, nd.NDArray) and len(new_vectors.shape) == 2, \
Expand All @@ -444,7 +458,8 @@ def update_token_vectors(self, tokens, new_vectors):
'`unknown_token` %s in `tokens`. This is to avoid unintended '
'updates.' % (token, self.idx_to_token[C.UNKNOWN_IDX]))

self._idx_to_vec[nd.array(indices)] = new_vectors
array_fn = _mx_np.array if is_np_array() else nd.array
self._idx_to_vec[array_fn(indices)] = new_vectors

@classmethod
def _check_pretrained_file_names(cls, pretrained_file_name):
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def __call__(self, *args):
for hook in self._forward_hooks.values():
hook(self, args, out)
if _mx_npx.is_np_array():
_check_all_np_ndarrays(_flatten(out, "output")[0])
_check_all_np_ndarrays(out)
return out

def forward(self, *args):
Expand Down
21 changes: 16 additions & 5 deletions python/mxnet/gluon/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,17 +357,28 @@ def __init__(self, axis=-1, sparse_label=True, from_logits=False, weight=None,
self._sparse_label = sparse_label
self._from_logits = from_logits

@_adapt_np_array
def hybrid_forward(self, F, pred, label, sample_weight=None):
if is_np_array():
log_softmax = F.npx.log_softmax
pick = F.npx.pick
else:
log_softmax = F.log_softmax
pick = F.pick
if not self._from_logits:
pred = F.log_softmax(pred, self._axis)
pred = log_softmax(pred, self._axis)
if self._sparse_label:
loss = -F.pick(pred, label, axis=self._axis, keepdims=True)
loss = -pick(pred, label, axis=self._axis, keepdims=True)
else:
label = _reshape_like(F, label, pred)
loss = -F.sum(pred * label, axis=self._axis, keepdims=True)
loss = -(pred * label).sum(axis=self._axis, keepdims=True)
loss = _apply_weighting(F, loss, self._weight, sample_weight)
return F.mean(loss, axis=self._batch_axis, exclude=True)
if is_np_array():
if F is ndarray:
return loss.mean(axis=tuple(range(1, loss.ndim)))
else:
return F.npx.batch_flatten(loss).mean(axis=1)
else:
return loss.mean(axis=self._batch_axis, exclude=True)


SoftmaxCELoss = SoftmaxCrossEntropyLoss
Expand Down
16 changes: 9 additions & 7 deletions python/mxnet/gluon/nn/basic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,12 +265,13 @@ def __init__(self, rate, axes=(), **kwargs):
self._rate = rate
self._axes = axes

@_adapt_np_array
def hybrid_forward(self, F, x):
if self._rate > 0:
return F.Dropout(x, p=self._rate, axes=self._axes, name='fwd', cudnn_off=False)
dropout = F.npx.Dropout if is_np_array() else F.Dropout
return dropout(x, p=self._rate, axes=self._axes, name='fwd', cudnn_off=False)
else:
return F.identity(x)
copy = F.np.copy if is_np_array() else F.identity
return copy(x)

def __repr__(self):
s = '{name}(p = {_rate}, axes={_axes})'
Expand Down Expand Up @@ -360,8 +361,9 @@ def cast(self, dtype):
dtype = 'float32'
super(BatchNorm, self).cast(dtype)

@_adapt_np_array
def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var):
if is_np_array():
F = F.npx
return F.BatchNorm(x, gamma, beta, running_mean, running_var,
name='fwd', **self._kwargs)

Expand Down Expand Up @@ -612,10 +614,10 @@ def __init__(self, axis=-1, epsilon=1e-5, center=True, scale=True,
shape=(in_channels,), init=beta_initializer,
allow_deferred_init=True)

@_adapt_np_array
def hybrid_forward(self, F, data, gamma, beta):
norm_data = F.LayerNorm(data, gamma=gamma, beta=beta, axis=self._axis, eps=self._epsilon)
return norm_data
if is_np_array():
F = F.npx
return F.LayerNorm(data, gamma=gamma, beta=beta, axis=self._axis, eps=self._epsilon)

def __repr__(self):
s = '{name}({content}'
Expand Down
5 changes: 4 additions & 1 deletion python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,10 @@ def _reduce(self):
ctx = context.cpu()
if self._stype == 'default':
block = self.list_data()
data = ndarray.add_n(*(w.copyto(ctx).as_nd_ndarray() for w in block)) / len(block)
if is_np_array():
data = sum([w.copyto(ctx) for w in block]) / len(block)
else:
data = ndarray.add_n(*(w.copyto(ctx) for w in block)) / len(block)
else:
# fetch all rows for 'row_sparse' param
all_row_ids = ndarray.arange(0, self.shape[0], dtype='int64', ctx=ctx)
Expand Down
35 changes: 22 additions & 13 deletions python/mxnet/gluon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
# coding: utf-8
# pylint: disable=
"""Parallelization utility optimizer."""
from __future__ import absolute_import

__all__ = ['split_data', 'split_and_load', 'clip_global_norm',
'check_sha1', 'download']

Expand All @@ -39,6 +41,7 @@ class requests_failed_to_import(object):

from .. import ndarray
from ..util import is_np_shape, is_np_array, wraps_safely
from .. import numpy as _mx_np # pylint: disable=reimported


def split_data(data, num_slice, batch_axis=0, even_split=True):
Expand Down Expand Up @@ -112,15 +115,14 @@ def split_and_load(data, ctx_list, batch_axis=0, even_split=True):
list of NDArray
Each corresponds to a context in `ctx_list`.
"""
# TODO(junwu): temp solution for supporting np.ndarray
# rewrite this using np ops
array_fn = _mx_np.array if is_np_array() else ndarray.array
if not isinstance(data, ndarray.NDArray):
data = ndarray.array(data, ctx=ctx_list[0])
data = array_fn(data, ctx=ctx_list[0])
if len(ctx_list) == 1:
if is_np_array():
data = data.as_np_ndarray()
return [data.as_in_context(ctx_list[0])]

# TODO(junwu): temp solution for supporting np.ndarray
# rewrite this using np ops
slices = split_data(data, len(ctx_list), batch_axis, even_split)
if is_np_array():
slices = [i.as_np_ndarray() for i in slices]
Expand Down Expand Up @@ -445,7 +447,7 @@ def _check_same_symbol_type(symbols):
Raise type error if the types are different. Return the class of
the symbols."""
from ..symbol.numpy import _Symbol as np_symbol
from ..symbol import Symbol as classic_symbol
from ..symbol import Symbol as nd_symbol
is_np_sym = bool(isinstance(symbols[0], np_symbol))
for s in symbols[1:]:
if is_np_sym != isinstance(s, np_symbol):
Expand All @@ -460,18 +462,25 @@ def _check_same_symbol_type(symbols):
'on each of them; if you want classic ndarray output(s) from the '
'computation graph, please convert all the numpy symbols in the list '
'to classic symbols by calling `as_nd_ndarray()` on each of them.')
return np_symbol if is_np_sym else classic_symbol
return np_symbol if is_np_sym else nd_symbol


def _check_all_np_ndarrays(out):
"""Check if ndarrays in out are all np.ndarray"""
"""Check if ndarrays/symbols in out are all np.ndarray/np._Symbol."""
from ..numpy import ndarray as np_ndarray
from ..symbol.numpy import _Symbol as np_symbol
assert isinstance(out, (list, tuple))
for array in out:
if not isinstance(array, (np_ndarray, np_symbol)):
raise TypeError('Expected np.ndarray or np._Symbol type in output, while received type '
'{}'.format(str(type(array))))
from ..symbol import Symbol as nd_symbol
from ..ndarray import NDArray as nd_ndarray

# pylint: disable=no-else-raise
if isinstance(out, (nd_ndarray, nd_symbol)) and not isinstance(out, (np_ndarray, np_symbol)):
raise TypeError("Block's output ndarrays/symbols must be of type `mxnet.numpy.ndarray`"
" or `mxnet.symbol.numpy._Symbol`, while got output type {}"
.format(str(type(out))))
elif isinstance(out, (list, tuple)):
for i in out:
_check_all_np_ndarrays(i)
# pylint: enable=no-else-raise


def _to_classic_arrays(*args, **kwargs):
Expand Down
38 changes: 37 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange', 'argmax',
'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'concatenate',
'clip', 'split', 'swapaxes', 'expand_dims']
'clip', 'split', 'swapaxes', 'expand_dims', 'tile']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -593,3 +593,39 @@ def split(ary, indices_or_sections, axis=0):
if not isinstance(ret, list):
raise NotImplementedError('single output from split is not supported yet...')
return ret


@set_module('mxnet.ndarray.numpy')
def tile(A, reps):
"""
Construct an array by repeating A the number of times given by reps.
If `reps` has length ``d``, the result will have dimension of
``max(d, A.ndim)``.
If ``A.ndim < d``, `A` is promoted to be d-dimensional by prepending new
axes. So a shape (3,) array is promoted to (1, 3) for 2-D replication,
or shape (1, 1, 3) for 3-D replication. If this is not the desired
behavior, promote `A` to d-dimensions manually before calling this
function.
If ``A.ndim > d``, `reps` is promoted to `A`.ndim by pre-pending 1's to it.
Thus for an `A` of shape (2, 3, 4, 5), a `reps` of (2, 2) is treated as
(1, 1, 2, 2).
Note : Although tile may be used for broadcasting, it is strongly
recommended to use numpy's broadcasting operations and functions.
Parameters
----------
A : ndarray
The input array.
reps : tuple of integers
The number of repetitions of `A` along each axis.
Returns
-------
c : ndarray
The tiled output array.
"""
return _npi.tile(A, reps)
Loading

0 comments on commit 0138b04

Please sign in to comment.