Skip to content

Commit

Permalink
mxnet: enable async training (#114)
Browse files Browse the repository at this point in the history
  • Loading branch information
ymjiang authored Sep 22, 2019
1 parent 7fa12a3 commit 33a7f91
Showing 1 changed file with 35 additions and 4 deletions.
39 changes: 35 additions & 4 deletions byteps/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import warnings
import mxnet as mx
import os

from byteps.mxnet.ops import byteps_push_pull, byteps_declare_tensor
from byteps.mxnet.ops import init, shutdown
Expand All @@ -32,6 +33,11 @@ class DistributedOptimizer(mx.optimizer.Optimizer):
"""This is where BytePS's DistributedOptimizer wrapper for MXNet goes"""
def __init__(self, optimizer):
self._optimizer = optimizer
self._enable_async = (int(os.getenv('BYTEPS_ENABLE_ASYNC', 0)) != 0)
if self._enable_async:
assert int(os.getenv('DMLC_NUM_WORKER'))>1, \
"Async is only valid for distributed training"
print('BytePS: enable asynchronous training')

def __getattr__(self, item):
return getattr(self._optimizer, item)
Expand All @@ -50,13 +56,38 @@ def _do_push_pull(self, index, grad):
byteps_push_pull(grad, version=0, priority=-index,
name="gradient_" + str(index), is_average=True)

def _do_push_pull_param(self, index, delta_weight):
if isinstance(index, (tuple, list)):
for i in range(len(index)):
byteps_declare_tensor(delta_weight[i], "weight_" + str(index[i]))
byteps_push_pull(delta_weight[i], version=0, priority=-index[i],
name="weight_" + str(index[i]), is_average=False)
else:
byteps_declare_tensor(delta_weight, "weight_" + str(index))
byteps_push_pull(delta_weight, version=0, priority=-index,
name="weight_" + str(index), is_average=False)

def update(self, index, weight, grad, state):
self._do_push_pull(index, grad)
self._optimizer.update(index, weight, grad, state)
if self._enable_async:
temp_weight = weight.copy()
self._optimizer.update(index, weight, grad, state)
# push delta weight, and pull weight back to the same tensor
weight.__isub__(temp_weight)
self._do_push_pull_param(index, weight)
else:
self._do_push_pull(index, grad)
self._optimizer.update(index, weight, grad, state)

def update_multi_precision(self, index, weight, grad, state):
self._do_push_pull(index, grad)
self._optimizer.update_multi_precision(index, weight, grad, state)
if self._enable_async:
temp_weight = weight.copy()
self._optimizer.update_multi_precision(index, weight, grad, state)
# push delta weight, and pull weight back to the same tensor
weight.__isub__(temp_weight)
self._do_push_pull_param(index, weight)
else:
self._do_push_pull(index, grad)
self._optimizer.update_multi_precision(index, weight, grad, state)

def set_learning_rate(self, lr):
self._optimizer.set_learning_rate(lr)
Expand Down

0 comments on commit 33a7f91

Please sign in to comment.