Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add option to control use of MKL-DNN in jaxlib easyblock #2619

Merged
merged 5 commits into from
Dec 8, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions easybuild/easyblocks/j/jaxlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,17 @@

@author: Denis Kristak (INUITS)
@author: Alexander Grund (TU Dresden)
@author: Alex Domingo (Vrije Universiteit Brussel)
"""

import os
import tempfile

from distutils.version import LooseVersion

import easybuild.tools.environment as env
from easybuild.easyblocks.generic.pythonpackage import PythonPackage
from easybuild.framework.easyconfig import CUSTOM
from easybuild.tools.build_log import EasyBuildError
from easybuild.tools.filetools import apply_regex_substitutions, which
from easybuild.tools.modules import get_software_root, get_software_version
Expand All @@ -52,6 +56,11 @@ def extra_options():
extra_vars['buildcmd'][0] = '%(python)s build/build.py'
extra_vars['install_src'][0] = 'dist/*.whl'

# Custom parameters
extra_vars.update({
'use_mkl_dnn': [True, "Enable support for Intel MKL-DNN", CUSTOM],
})

return extra_vars

def configure_step(self):
Expand All @@ -69,13 +78,15 @@ def configure_step(self):

# Collect options for the build script
# Used only by the build script
options = [
'--target_cpu_features=default', # Using copt for optimizations
]

# C++ flags are set through copt below
options = ['--target_cpu_features=default']

# Passed directly to bazel
bazel_startup_options = [
'--output_user_root=%s' % tempfile.mkdtemp(suffix='-bazel', dir=self.builddir),
]

# Passed to the build command of bazel
bazel_options = [
'--jobs=%s' % self.cfg['parallel'],
Expand Down Expand Up @@ -107,16 +118,22 @@ def configure_step(self):
'--cudnn_version=' + cudnn_version,
])

nccl_root = get_software_root('NCCL')
if nccl_root:
options.append('--enable_nccl')
else:
options.append('--noenable_nccl')
if LooseVersion(self.version) >= LooseVersion('0.1.70'):
nccl_root = get_software_root('NCCL')
if nccl_root:
options.append('--enable_nccl')
else:
options.append('--noenable_nccl')

config_env_vars['GCC_HOST_COMPILER_PATH'] = which(os.getenv('CC'))
else:
options.append('--noenable_cuda')

if self.cfg['use_mkl_dnn']:
options.append('--enable_mkl_dnn')
else:
options.append('--noenable_mkl_dnn')

# Prepend to buildopts so users can overwrite this
self.cfg['buildopts'] = ' '.join(
options +
Expand Down