-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversation
Hi @eric-haibin-lin @piiswrong , DataParallel and Barrier (ParallelState) are included as discussed. |
from ...ndarray import NDArray | ||
from ..utils import split_and_load | ||
|
||
__all__ = ['DataParallel', 'Barrier', 'parallel_apply', 'split_kwargs'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we don't want to include "Barrier" in __all__
and display it on the API doc. For sparse I'm considering adding & exposing a wrapper around Barrier so that users only need to pass in indices and get weights. Normally I don't think ppl need to see Barrier. SyncBN can still access it by parallel.Barrier
output.wait_to_read() | ||
with lock: | ||
results[i] = output | ||
except Exception as e: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What exception will occur here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The model defined by the user may get some troubles when forwarding, this catch exception avoid printing errors in different threads. We will raise the exception when gathering the results.
""" | ||
return self.out[idx] | ||
|
||
def get(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think Sparse can reuse the pull() interface
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you mean removing get() function?
@zhanghang1989 I am interested in this PR. |
@pengzhao-intel This PR is mainly for Synchronized Cross GPU Batch Norm https://github.com/zhanghang1989/MXNet-Gluon-SyncBN |
I have finished editing this repo. Could you start reviewing? @eric-haibin-lin @piiswrong Please see the deployed docs http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-10536/12/api/python/gluon/contrib.html?highlight=dataparallel#mxnet.gluon.contrib.parallel.DataParallelModel |
docs/api/python/gluon/contrib.md
Outdated
:nosignatures: | ||
|
||
DataParallelModel | ||
DataParallelCriterion |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we only need a DataParallel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This make it easier for the users, because the situation is complicated for network with multiple outputs. We just handle the situation internally.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is too many API's and prone to misuse.
Also we use the term Loss instead of Criterion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will change the name :)
The cross device operation is applying (e.g. AllReduce). | ||
""" | ||
def __init__(self, counter, operation): | ||
self.mutex = threading.Lock() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
internal variables should start with _
return len(self.list) | ||
|
||
def __repr__(self): | ||
return 'ParallelState' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Barrier
self.push_tasks = self.counter | ||
self.reduce_tasks = self.counter | ||
|
||
def __len__(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may want to get the number of devices outside. For example, we want to calculate the global mean by dividing the global sum by global number of elements (local number of elements * N devices)
self.out = self.op(*self.list) | ||
if isinstance(self.out, (list, tuple)): | ||
for xi in self.out: | ||
xi.wait_to_read() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do you need to wait?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we don't wait, the training of BN fails.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the async execution engine not able to handle this?
|
||
Parameters | ||
---------- | ||
module : object |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Block?
return criterion_parallel_apply(self.module, inputs, targets, kwargs, self.sync) | ||
|
||
|
||
def split_load_kwargs(inputs, kwargs, ctx_list, batch_axis=0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
internal functions should start with _
|
||
is_training = autograd.is_training() | ||
is_recording = autograd.is_recording() | ||
threads = [threading.Thread(target=_worker, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
recreating threads at each forward is too slow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mainly followed the PyTorch implementation. Could you show an example of improvement?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use a thread pool at constructor and reuse it for each iteration?
If pytorch do it like this maybe its fine
with autograd.record(is_training): | ||
output = tuple_map(module(*input, **kwargs)) | ||
for out in output: | ||
out.wait_to_read() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why wait?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The training fails without waiting. I think it may because of with autograd.record(is_training):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why? What's the error? That shouldn't happen. Are you using most recent code? There was a recent fix for multithreading
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't tested recently, I will take a look.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, people can use sync=False
mode at most of the cases.
self.op = operation | ||
self._clear() | ||
|
||
def push(self, x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about have a single wait function instead of push/pull.
Is push and pull ever called separately?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a case in 'Sparse' application, which only pull onetime from the master thread. We can change it to single function.
|
||
|
||
class Barrier(object): | ||
"""Shared NDArray for cross device operation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't think "Shared NDArray" is accurate description
"""Shared NDArray for cross device operation. | ||
|
||
A cross device operation that allows synchronized push and pull. It can be used in | ||
Cross-gpu Sycnhronized Batch Normalization and Sparse Blocks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove sparse block from doc for now. Will not be using this for the 1st version
inputs, kwargs = _split_load_kwargs(inputs, kwargs, self.ctx_list) | ||
assert(len(inputs) == len(self.ctx_list)) | ||
if len(self.ctx_list) == 1: | ||
return tuple([tuple_map(self.module(*inputs[0], **kwargs[0]))]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(x,) makes a tuple
net.collect_params().initialize() | ||
criterion = gluon.loss.SoftmaxCELoss(axis=1) | ||
|
||
def test_net_sync(net, criterion, sync, nDevices): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nDevices -> num_devices
def forward(self, x): | ||
idx = self.barrier.push(x) | ||
y = self.barrier.pull(idx) | ||
assert_allclose(y.asnumpy(), x.asnumpy(), rtol=1e-2, atol=1e-4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should x and y be exactly the same? The tolerance seems large
class DataParallelModel(object): | ||
"""Data Parallelism | ||
|
||
Hide the difference of single/multiple GPUs to the user. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hide .. from ..
|
||
Parameters | ||
---------- | ||
module : object |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it better to call it block instead of module?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Block has a general meaning in Gluon. We can probably call it model, if you think that is better.
# evaluation mode | ||
for i in range(iters): | ||
x = mx.random.uniform(shape=(8, 1, 28, 28)) | ||
y = net(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what specifically are you checking here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mainly check the behavior of Barrier:
assert_allclose(y.asnumpy(), x.asnumpy(), rtol=1e-2, atol=1e-4)
self.out = self.op(*self.list) | ||
if isinstance(self.out, (list, tuple)): | ||
for xi in self.out: | ||
xi.wait_to_read() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the async execution engine not able to handle this?
Closing it due to no longer needed in SyncBN #11502 |
Description
(Brief description on what this PR is about)
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments