Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Backport PRs in v1.7.x missing from v1.x to v1.8.x (#19262)
Browse files Browse the repository at this point in the history
* * Fix einsum gradient (#18482)

* [v1.7.x] Backport PRs of numpy features (#18653)

* add zero grad for npi_unique (#18080)

* fix np.clip scalar input case (#17788)

* fix true_divide (#18393)

Co-authored-by: Hao Jin <hjjn.amzn@gmail.com>
Co-authored-by: Xi Wang <xidulu@gmail.com>

* [v1.7.x] backport mixed type binary ops to v1.7.x (#18649)

* Fix Windows GPU CI (#17962)

Update Windows CI to use VS 2019 and enable x64 bit toolchain. Previously we are using an older 32 bit toolchain causing OOM errors during linking. Switching to x64 bit toolchain on the older VS version previously used by the CI was attempted in #17912 and did not work. Update to Cuda 10.2 as it is required by VS 2019. Switch to ninja-build on Windows to speed up build as ninja-build is now preinstalled. Remove logic to install cmake 3.16 on every PR as cmake 3.17 is now preinstalled. Add build retrials due to cuda thrust + VS2019 flakyness.

Co-authored-by: vexilligera <vexilligera@gmail.com>

* backport mixed type

Co-authored-by: Leonard Lausen <lausen@amazon.com>
Co-authored-by: vexilligera <vexilligera@gmail.com>

* revise activations (#18700)

* [v1.6] Fix the monitor_callback invalid issue during calibration with variable input shapes (#18632) (#18703)

* Fix the monitor_callback invalid issue during calibration with variable input shapes

* retrigger CI

* Add UT for monitor check and disable codecov

Co-authored-by: Tao Lv <tao.a.lv@intel.com>

* Fail build_windows.py if all retries failed (#18177)

* Update to thrust 1.9.8 on Windows (#18218)

* Update to thrust 1.9.8 on Windows

* Remove debug logic

* Re-enable build retries on MSVC (#18230)

Updating thrust alone did not help. Similar issues (though less often) still
occur with updated thrust, and also with nvidia cub. Tracked upstream at
NVIDIA/thrust#1090

Co-authored-by: Ke Han <38852697+hanke580@users.noreply.github.com>
Co-authored-by: Xingjian Shi <xshiab@connect.ust.hk>
Co-authored-by: Hao Jin <hjjn.amzn@gmail.com>
Co-authored-by: Xi Wang <xidulu@gmail.com>
Co-authored-by: Yijun Chen <chenyijun0902@gmail.com>
Co-authored-by: vexilligera <vexilligera@gmail.com>
Co-authored-by: ciyong <ciyong.chen@intel.com>
Co-authored-by: Tao Lv <tao.a.lv@intel.com>
  • Loading branch information
9 people committed Oct 2, 2020
1 parent 51cc0af commit 371b312
Show file tree
Hide file tree
Showing 32 changed files with 700 additions and 520 deletions.
3 changes: 3 additions & 0 deletions .codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ codecov:
require_ci_to_pass: yes

coverage:
status:
project: off
patch: off
precision: 2
round: down
range: "70...100"
Expand Down
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ cmake_install.cmake
# Mac OS X
.DS_Store

# Windows
windows_package.7z
windows_package

#Notebook Automated Test
!tests/nightly/test_tutorial_config.txt
!tests/nightly/TestNotebook
Expand Down
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ if(MSVC)
add_definitions(-DDMLC_STRICT_CXX11)
add_definitions(-DNOMINMAX)
set(CMAKE_C_FLAGS "/MP")
set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} /bigobj")
set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} ${CMAKE_CXX_FLAGS} /bigobj")
else()
include(CheckCXXCompilerFlag)
if(USE_CXX14_IF_AVAILABLE)
Expand Down
93 changes: 66 additions & 27 deletions ci/build_windows.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,18 @@
import tempfile
import time
import zipfile
import requests
from distutils.dir_util import copy_tree
from enum import Enum
from subprocess import check_call
from subprocess import check_call, call

from util import *

KNOWN_VCVARS = {
# https://gitlab.kitware.com/cmake/cmake/issues/18920
'VS 2015': r'C:\Program Files (x86)\Microsoft Visual Studio 14.0\VC\bin\x86_amd64\vcvarsx86_amd64.bat',
'VS 2017': r'C:\Program Files (x86)\Microsoft Visual Studio\2017\Community\VC\Auxiliary\Build\vcvarsx86_amd64.bat'
'VS 2017': r'C:\Program Files (x86)\Microsoft Visual Studio\2017\Community\VC\Auxiliary\Build\vcvarsx86_amd64.bat',
'VS 2019': r'C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Auxiliary\Build\vcvars64.bat',
}


Expand All @@ -54,6 +57,8 @@ class BuildFlavour(Enum):

CMAKE_FLAGS = {
'WIN_CPU': (
'-DCMAKE_C_COMPILER=cl '
'-DCMAKE_CXX_COMPILER=cl '
'-DUSE_CUDA=OFF '
'-DUSE_CUDNN=OFF '
'-DENABLE_CUDA_RTC=OFF '
Expand All @@ -67,6 +72,8 @@ class BuildFlavour(Enum):
'-DCMAKE_BUILD_TYPE=Release')

, 'WIN_CPU_MKLDNN': (
'-DCMAKE_C_COMPILER=cl '
'-DCMAKE_CXX_COMPILER=cl '
'-DUSE_CUDA=OFF '
'-DUSE_CUDNN=OFF '
'-DENABLE_CUDA_RTC=OFF '
Expand All @@ -80,6 +87,8 @@ class BuildFlavour(Enum):
'-DCMAKE_BUILD_TYPE=Release')

, 'WIN_CPU_MKLDNN_MKL': (
'-DCMAKE_C_COMPILER=cl '
'-DCMAKE_CXX_COMPILER=cl '
'-DUSE_CUDA=OFF '
'-DUSE_CUDNN=OFF '
'-DENABLE_CUDA_RTC=OFF '
Expand All @@ -93,6 +102,8 @@ class BuildFlavour(Enum):
'-DCMAKE_BUILD_TYPE=Release')

, 'WIN_CPU_MKL': (
'-DCMAKE_C_COMPILER=cl '
'-DCMAKE_CXX_COMPILER=cl '
'-DUSE_CUDA=OFF '
'-DUSE_CUDNN=OFF '
'-DENABLE_CUDA_RTC=OFF '
Expand All @@ -106,6 +117,8 @@ class BuildFlavour(Enum):
'-DCMAKE_BUILD_TYPE=Release')

, 'WIN_GPU': (
'-DCMAKE_C_COMPILER=cl '
'-DCMAKE_CXX_COMPILER=cl '
'-DUSE_CUDA=ON '
'-DUSE_CUDNN=ON '
'-DENABLE_CUDA_RTC=ON '
Expand All @@ -115,11 +128,12 @@ class BuildFlavour(Enum):
'-DUSE_LAPACK=ON '
'-DUSE_DIST_KVSTORE=OFF '
'-DMXNET_CUDA_ARCH="5.2" '
'-DCMAKE_CXX_FLAGS="/FS /MD /O2 /Ob2" '
'-DUSE_MKL_IF_AVAILABLE=OFF '
'-DCMAKE_BUILD_TYPE=Release')

, 'WIN_GPU_MKLDNN': (
'-DCMAKE_C_COMPILER=cl '
'-DCMAKE_CXX_COMPILER=cl '
'-DUSE_CUDA=ON '
'-DUSE_CUDNN=ON '
'-DENABLE_CUDA_RTC=ON '
Expand All @@ -130,7 +144,6 @@ class BuildFlavour(Enum):
'-DUSE_DIST_KVSTORE=OFF '
'-DMXNET_CUDA_ARCH="5.2" '
'-DUSE_MKLDNN=ON '
'-DCMAKE_CXX_FLAGS="/FS /MD /O2 /Ob2" '
'-DCMAKE_BUILD_TYPE=Release')

}
Expand All @@ -140,39 +153,65 @@ def windows_build(args):
logging.info("Using vcvars environment:\n{}".format(args.vcvars))

path = args.output
os.makedirs(path, exist_ok=True)

mxnet_root = get_mxnet_root()
logging.info("Found MXNet root: {}".format(mxnet_root))

url = 'https://github.com/Kitware/CMake/releases/download/v3.16.1/cmake-3.16.1-win64-x64.zip'
with tempfile.TemporaryDirectory() as tmpdir:
cmake_file_path = download_file(url, tmpdir)
with zipfile.ZipFile(cmake_file_path, 'r') as zip_ref:
# Create $tmpdir\cmake-3.16.1-win64-x64\bin\cmake.exe
zip_ref.extractall(tmpdir)
if 'GPU' in args.flavour:
# Get Thrust version to be shipped in Cuda 11, due to flakyness of
# older Thrust versions with MSVC 19 compiler
with remember_cwd():
tmpdirname = tempfile.mkdtemp()
os.chdir(tmpdirname)
r = requests.get('https://github.com/thrust/thrust/archive/1.9.8.zip', allow_redirects=True)
with open('thrust.zip', 'wb') as f:
f.write(r.content)
with zipfile.ZipFile('thrust.zip', 'r') as zip_ref:
zip_ref.extractall('.')
thrust_path = os.path.join(tmpdirname, "thrust-1.9.8")


# cuda thrust / CUB + VS 2019 is flaky: try multiple times if fail
MAXIMUM_TRY = 5
build_try = 0

while build_try < MAXIMUM_TRY:
if os.path.exists(path):
shutil.rmtree(path)
os.makedirs(path, exist_ok=True)

with remember_cwd():
os.chdir(path)
cmd = "\"{}\" && {} -G \"NMake Makefiles JOM\" {} {}".format(
args.vcvars,
os.path.join(tmpdir, 'cmake-3.16.1-win64-x64', 'bin', 'cmake.exe'),
CMAKE_FLAGS[args.flavour], mxnet_root)
env = os.environ.copy()
if 'GPU' in args.flavour:
env["CXXFLAGS"] = '/FS /MD /O2 /Ob2 /I {}'.format(thrust_path)
env["CUDAFLAGS"] = '-I {}'.format(thrust_path)
cmd = "\"{}\" && cmake -GNinja {} {}".format(args.vcvars,
CMAKE_FLAGS[args.flavour],
mxnet_root)
logging.info("Generating project with CMake:\n{}".format(cmd))
check_call(cmd, shell=True)
check_call(cmd, shell=True, env=env)

cmd = "\"{}\" && jom".format(args.vcvars)
logging.info("Building with jom:\n{}".format(cmd))
cmd = "\"{}\" && ninja".format(args.vcvars)
logging.info("Building:\n{}".format(cmd))

t0 = int(time.time())
check_call(cmd, shell=True)
ret = call(cmd, shell=True)


logging.info(
"Build flavour: {} complete in directory: \"{}\"".format(
args.flavour, os.path.abspath(path)))
logging.info("Build took {}".format(
datetime.timedelta(seconds=int(time.time() - t0))))
windows_package(args)
if ret != 0:
build_try += 1
logging.info("{} build(s) have failed".format(build_try))
else:
logging.info("Build flavour: {} complete in directory: \"{}\"".format(args.flavour, os.path.abspath(path)))
logging.info("Build took {}".format(datetime.timedelta(seconds=int(time.time() - t0))))
break

if ret == 0:
windows_package(args)
else:
logging.info("Build failed")
sys.exit(1)


def windows_package(args):
Expand Down Expand Up @@ -233,7 +272,7 @@ def main():

parser.add_argument("--vcvars",
help="vcvars batch file location, typically inside vs studio install dir",
default=KNOWN_VCVARS['VS 2015'],
default=KNOWN_VCVARS['VS 2019'],
type=str)

parser.add_argument("--arch",
Expand All @@ -258,7 +297,7 @@ def main():
if 'OpenCV_DIR' not in os.environ:
os.environ["OpenCV_DIR"] = "C:\\Program Files\\OpenCV-v3.4.1\\build"
if 'CUDA_PATH' not in os.environ:
os.environ["CUDA_PATH"] = "C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v9.2"
os.environ["CUDA_PATH"] = "C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v10.2"
if 'MKL_ROOT' not in os.environ:
os.environ["MKL_ROOT"] = "C:\\Program Files (x86)\\IntelSWTools\\compilers_and_libraries\\windows\\mkl"
windows_build(args)
Expand Down
11 changes: 11 additions & 0 deletions include/mxnet/imperative.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,16 @@ class Imperative {
}
return is_np_shape_thread_local_ ? 1 : 0;
}

/*! \brief return current numpy default dtype compatibility status.
* */
bool is_np_default_dtype() const {
if (is_np_default_dtype_global_) {
return true;
}
return false;
}

/*! \brief specify numpy compatibility off, thread local on or global on. */
bool set_is_np_shape(int is_np_shape) {
NumpyShape flag = static_cast<NumpyShape>(is_np_shape);
Expand Down Expand Up @@ -215,6 +225,7 @@ class Imperative {
static MX_THREAD_LOCAL bool is_np_shape_thread_local_;
#endif
bool is_np_shape_global_{false};
bool is_np_default_dtype_global_{false};
/*! \brief node count used for naming */
std::atomic<uint64_t> node_count_{0};
/*! \brief variable count used for naming */
Expand Down
18 changes: 13 additions & 5 deletions python/mxnet/gluon/nn/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def __init__(self, alpha_initializer=initializer.Constant(0.25),
init=alpha_initializer)

def hybrid_forward(self, F, x, alpha):
return F.LeakyReLU(x, gamma=alpha, act_type='prelu', name='fwd')
leaky_relu = F.npx.leaky_relu if is_np_array() else F.LeakyReLU
return leaky_relu(x, gamma=alpha, act_type='prelu', name='fwd')


class ELU(HybridBlock):
Expand Down Expand Up @@ -167,7 +168,8 @@ def __init__(self, alpha=1.0, **kwargs):
self._alpha = alpha

def hybrid_forward(self, F, x):
return F.LeakyReLU(x, act_type='elu', slope=self._alpha)
leaky_relu = F.npx.leaky_relu if is_np_array() else F.LeakyReLU
return leaky_relu(x, act_type='elu', slope=self._alpha)


class SELU(HybridBlock):
Expand All @@ -187,7 +189,9 @@ def __init__(self, **kwargs):
super(SELU, self).__init__(**kwargs)

def hybrid_forward(self, F, x):
return F.LeakyReLU(x, act_type='selu', name='fwd')
leaky_relu = F.npx.leaky_relu if is_np_array() else F.LeakyReLU
return leaky_relu(x, act_type='selu', name='fwd')


class GELU(HybridBlock):
r"""
Expand All @@ -206,7 +210,8 @@ def __init__(self, **kwargs):
super(GELU, self).__init__(**kwargs)

def hybrid_forward(self, F, x):
return F.LeakyReLU(x, act_type='gelu', name='fwd')
leaky_relu = F.npx.leaky_relu if is_np_array() else F.LeakyReLU
return leaky_relu(x, act_type='gelu', name='fwd')


class Swish(HybridBlock):
Expand All @@ -232,4 +237,7 @@ def __init__(self, beta=1.0, **kwargs):
self._beta = beta

def hybrid_forward(self, F, x):
return x * F.sigmoid(self._beta * x, name='fwd')
if is_np_array():
return x * F.npx.sigmoid(self._beta * x)
else:
return x * F.sigmoid(self._beta * x, name='fwd')
5 changes: 5 additions & 0 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6174,6 +6174,11 @@ def clip(a, a_min, a_max, out=None):
>>> np.clip(a, 3, 6, out=a)
array([3., 3., 3., 3., 4., 5., 6., 6., 6., 6.], dtype=float32)
"""
from numbers import Number
if isinstance(a, Number):
# In case input is a scalar, the computation would fall back to native numpy.
# The value returned would be a python scalar.
return _np.clip(a, a_min, a_max, out=None)
return _mx_nd_np.clip(a, a_min, a_max, out=out)


Expand Down
8 changes: 5 additions & 3 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1529,13 +1529,15 @@ def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, ou
if isinstance(rhs, numeric_types):
return fn_scalar(lhs, rhs, out=out)
else:
is_int = isinstance(rhs, integer_types)
if rfn_scalar is None:
# commutative function
return lfn_scalar(rhs, float(lhs), out=out)
return lfn_scalar(rhs, scalar=float(lhs), is_int=is_int, out=out)
else:
return rfn_scalar(rhs, float(lhs), out=out)
return rfn_scalar(rhs, scalar=float(lhs), is_int=is_int, out=out)
elif isinstance(rhs, numeric_types):
return lfn_scalar(lhs, float(rhs), out=out)
is_int = isinstance(rhs, integer_types)
return lfn_scalar(lhs, scalar=float(rhs), is_int=is_int, out=out)
elif isinstance(rhs, Symbol):
return fn_array(lhs, rhs, out=out)
else:
Expand Down
19 changes: 19 additions & 0 deletions src/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <nnvm/node.h>
#include <mxnet/engine.h>
#include <mxnet/ndarray.h>
#include <mxnet/imperative.h>
#include <mxnet/op_attr_types.h>
#include <mxnet/graph_attr_types.h>
#include <nnvm/graph_attr_types.h>
Expand Down Expand Up @@ -874,6 +875,11 @@ inline bool is_float(const int dtype) {
return dtype == mshadow::kFloat32 || dtype == mshadow::kFloat64 || dtype == mshadow::kFloat16;
}

inline bool is_int(const int dtype) {
return dtype == mshadow::kUint8 || dtype == mshadow::kInt8 ||
dtype == mshadow::kInt32 || dtype == mshadow::kInt64;
}

inline int get_more_precise_type(const int type1, const int type2) {
if (type1 == type2) return type1;
if (is_float(type1) && is_float(type2)) {
Expand Down Expand Up @@ -910,6 +916,19 @@ inline int np_binary_out_infer_type(const int type1, const int type2) {
return get_more_precise_type(type1, type2);
}

inline int GetDefaultDtype() {
return Imperative::Get()->is_np_default_dtype() ?
mshadow::kFloat64 :
mshadow::kFloat32;
}

inline int GetDefaultDtype(int dtype) {
if (dtype != -1) return dtype;
return Imperative::Get()->is_np_default_dtype() ?
mshadow::kFloat64 :
mshadow::kFloat32;
}

} // namespace common
} // namespace mxnet
#endif // MXNET_COMMON_UTILS_H_
6 changes: 2 additions & 4 deletions src/operator/contrib/gradient_multiplier_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,7 @@ In forward pass it acts as an identity transform. During backpropagation it
multiplies the gradient from the subsequent level by a scalar factor lambda and passes it to
the preceding layer.
)code" ADD_FILELINE)
.set_attr_parser([](NodeAttrs* attrs) {
attrs->parsed = dmlc::stod(attrs->dict["scalar"]);
})
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>)
.set_attr<FInferStorageType>("FInferStorageType", ElemwiseStorageType<1, 1, false, true, true>)
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", UnaryOp::IdentityComputeEx<cpu>)
Expand All @@ -88,7 +86,7 @@ the preceding layer.
[](const NodeAttrs& attrs){
return std::vector<bool>{true};
})
.add_argument("scalar", "float", "lambda multiplier");
.add_arguments(NumpyBinaryScalarParam::__FIELDS__());

MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_contrib_backward_gradientmultiplier)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
Expand Down
Loading

0 comments on commit 371b312

Please sign in to comment.