Skip to content

Commit

Permalink
1bit: don't do wd mom for uncompressed gradients (#60)
Browse files Browse the repository at this point in the history
* 1bit: update

* 1bit: test

* 1bit: register wdmom
  • Loading branch information
jasperzhong authored Jul 30, 2020
1 parent 7dc8d7f commit 11a7ec0
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
14 changes: 8 additions & 6 deletions byteps/common/compressor/impl/onebit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

#include <cstring>

#include "onebit.h"
#include "../compressor_registry.h"
#include "onebit.h"

namespace byteps {
namespace common {
Expand Down Expand Up @@ -51,10 +51,11 @@ tensor_t OnebitCompressor::CompressImpl(index_t* dst, const scalar_t* src,

#pragma omp parallel for simd
for (size_t i = 0; i < chunk_len; ++i) {
index_t x = src[i * PACKING_SIZE] < 0;
size_t idx = i * PACKING_SIZE;
index_t x = src[idx] < 0;
for (size_t j = 1; j < PACKING_SIZE; ++j) {
x <<= 1;
x |= src[i * PACKING_SIZE + j] < 0;
x |= src[idx + j] < 0;
}
dst[i] = x;
}
Expand Down Expand Up @@ -90,9 +91,10 @@ tensor_t OnebitCompressor::DecompressImpl(scalar_t* dst, const index_t* src,
#pragma omp parallel for simd
for (int i = chunk_len - 1; i >= 0; --i) {
index_t x = ptr[i];
size_t idx = i * PACKING_SIZE;
for (int j = PACKING_SIZE - 1; j >= 0; --j) {
int sign = 1 - ((x & 0x01) << 1);
dst[i * PACKING_SIZE + j] = sign * scale;
dst[idx + j] = sign * scale;
x >>= 1;
}
}
Expand Down Expand Up @@ -123,10 +125,10 @@ void OnebitCompressor::FastUpdateErrorImpl(scalar_t* error, scalar_t* corrected,
#pragma omp parallel for simd
for (int i = chunk_len - 1; i >= 0; --i) {
index_t x = compressed[i];
size_t idx = i * PACKING_SIZE;
for (int j = PACKING_SIZE - 1; j >= 0; --j) {
int sign = ((x & 0x01) << 1) - 1;
error[i * PACKING_SIZE + j] =
corrected[i * PACKING_SIZE + j] + sign * scale;
error[idx + j] = corrected[idx + j] + sign * scale;
x >>= 1;
}
}
Expand Down
15 changes: 12 additions & 3 deletions byteps/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import os
import struct
import warnings
from functools import reduce

import mxnet as mx
import mxnet.ndarray as nd
Expand Down Expand Up @@ -218,8 +219,6 @@ def __init__(self, params, optimizer, optimizer_params=None, root_rank=0, compre
self._intra_compressors = {}
for i, param in enumerate(self._params):
byteps_declare_tensor("parameter_" + str(i))
self._intra_compressors[param.name] = type(self._intra_compressor)(
**self._intra_compressor.__dict__)
if param.grad_req != 'null':
byteps_params = dict(
filter(lambda attr: attr[0].startswith(
Expand Down Expand Up @@ -280,7 +279,7 @@ def _register_compressor(self, params, optimizer_params, compression_params):
if compression_params.get("momentum"):
# 1bit compressor use an additional momentum for weight decay
if compressor == "onebit" and "wd" in optimizer_params:
intra_compressor = Compression.wdmom(
Compression.wdmom = Compression.wdmom(
intra_compressor, optimizer_params["momentum"], optimizer_params["wd"])
del optimizer_params["wd"]

Expand Down Expand Up @@ -316,6 +315,7 @@ def _allreduce_grads(self):

def _init_params(self):
tensors = []
threshold = int(os.environ.get("BYTEPS_MIN_COMPRESS_BYTES", 65536))
for param in self._params_to_init:
if param._deferred_init:
tensors.append(param)
Expand All @@ -326,6 +326,15 @@ def _init_params(self):
if rank() != self.root_rank:
param_arrays[0].__imul__(0)

# register intra-node compressor
size = reduce(lambda x, y: x*y, param_arrays[0].shape)
if size >= threshold:
self._intra_compressors[param.name] = type(
Compression.wdmom)(**Compression.wdmom.__dict__)
else:
self._intra_compressors[param.name] = type(
self._intra_compressor)(**self._intra_compressor.__dict__)

compressed, ctx = self._intra_compressors[param.name].compress(
param_arrays[0])
byteps_push_pull(compressed, version=0, priority=0,
Expand Down

0 comments on commit 11a7ec0

Please sign in to comment.