Replies: 1 comment 8 replies
-
Hi @kdlamb! Installing these ML libraries in a way that actually works with the GPU is, to put it mildly, a nightmare. 🧛 We have solved this within Pangeo community by carefully maintaining a set docker images that are tested and updated regularly: https://github.com/pangeo-data/pangeo-docker-images/ Jax, coming from Google, is much more compatible with TensorFlow than PyTorch. Jax is installed already in the TensorFlow notebook image (https://github.com/pangeo-data/pangeo-docker-images/blob/master/ml-notebook/environment.yml). Is it possible for you to use that image instead (select it when you start up your notebook)? Or do you need to use both Jax and PyTorch in the same project? |
Beta Was this translation helpful? Give feedback.
-
Hi, I'm trying to install jax within the Pytorch ML Notebook GPU docker image environment. I tried installing jax from the wheel for both cuda11 and cuda12 within a jupyterlab notebook using pip install:
!pip install --upgrade pip
!pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
The installation worked but every time I tried to call any functions from jax I got an error. For example, trying to create an array:
a = jnp.ones((3,))
gives me the following error:
`2023-08-16 21:08:41.215198: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:413] There was an error before creating cudnn handle (35): cudaErrorInsufficientDriver : CUDA driver version is insufficient for CUDA runtime version
XlaRuntimeError Traceback (most recent call last)
Cell In[24], line 1
----> 1 a = jnp.ones((3,))
File /srv/conda/envs/notebook/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:2180, in ones(shape, dtype)
2178 shape = canonicalize_shape(shape)
2179 dtypes.check_user_dtype_supported(dtype, "ones")
-> 2180 return lax.full(shape, 1, _jnp_dtype(dtype))
File /srv/conda/envs/notebook/lib/python3.10/site-packages/jax/_src/lax/lax.py:1210, in full(shape, fill_value, dtype)
1208 dtype = dtypes.canonicalize_dtype(dtype or _dtype(fill_value))
1209 fill_value = _convert_element_type(fill_value, dtype, weak_type)
-> 1210 return broadcast(fill_value, shape)
File /srv/conda/envs/notebook/lib/python3.10/site-packages/jax/_src/lax/lax.py:772, in broadcast(operand, sizes)
758 """Broadcasts an array, adding new leading dimensions
759
760 Args:
(...)
769 jax.lax.broadcast_in_dim : add new dimensions at any location in the array shape.
770 """
771 dims = tuple(range(len(sizes), len(sizes) + np.ndim(operand)))
--> 772 return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims)
File /srv/conda/envs/notebook/lib/python3.10/site-packages/jax/_src/lax/lax.py:801, in broadcast_in_dim(operand, shape, broadcast_dimensions)
799 else:
800 dyn_shape, static_shape = [], shape # type: ignore
--> 801 return broadcast_in_dim_p.bind(
802 operand, *dyn_shape, shape=tuple(static_shape),
803 broadcast_dimensions=tuple(broadcast_dimensions))
File /srv/conda/envs/notebook/lib/python3.10/site-packages/jax/_src/core.py:386, in Primitive.bind(self, *args, **params)
383 def bind(self, *args, **params):
384 assert (not config.jax_enable_checks or
385 all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 386 return self.bind_with_trace(find_top_trace(args), args, params)
File /srv/conda/envs/notebook/lib/python3.10/site-packages/jax/_src/core.py:389, in Primitive.bind_with_trace(self, trace, args, params)
388 def bind_with_trace(self, trace, args, params):
--> 389 out = trace.process_primitive(self, map(trace.full_raise, args), params)
390 return map(full_lower, out) if self.multiple_results else full_lower(out)
File /srv/conda/envs/notebook/lib/python3.10/site-packages/jax/_src/core.py:821, in EvalTrace.process_primitive(self, primitive, tracers, params)
820 def process_primitive(self, primitive, tracers, params):
--> 821 return primitive.impl(*tracers, **params)
File /srv/conda/envs/notebook/lib/python3.10/site-packages/jax/_src/dispatch.py:131, in apply_primitive(prim, *args, **params)
129 try:
130 in_avals, in_shardings = util.unzip2([arg_spec(a) for a in args])
--> 131 compiled_fun = xla_primitive_callable(
132 prim, in_avals, OrigShardings(in_shardings), **params)
133 except pxla.DeviceAssignmentMismatchError as e:
134 fails, = e.args
File /srv/conda/envs/notebook/lib/python3.10/site-packages/jax/_src/util.py:263, in cache..wrap..wrapper(*args, **kwargs)
261 return f(*args, **kwargs)
262 else:
--> 263 return cached(config._trace_context(), *args, **kwargs)
File /srv/conda/envs/notebook/lib/python3.10/site-packages/jax/src/util.py:256, in cache..wrap..cached(, *args, **kwargs)
254 @functools.lru_cache(max_size)
255 def cached(_, *args, **kwargs):
--> 256 return f(*args, **kwargs)
File /srv/conda/envs/notebook/lib/python3.10/site-packages/jax/_src/dispatch.py:222, in xla_primitive_callable(prim, in_avals, orig_in_shardings, **params)
220 return out,
221 donated_invars = (False,) * len(in_avals)
--> 222 compiled = _xla_callable_uncached(
223 lu.wrap_init(prim_fun), prim.name, donated_invars, False, in_avals,
224 orig_in_shardings)
225 if not prim.multiple_results:
226 return lambda *args, **kw: compiled(*args, **kw)[0]
File /srv/conda/envs/notebook/lib/python3.10/site-packages/jax/_src/dispatch.py:252, in _xla_callable_uncached(fun, name, donated_invars, keep_unused, in_avals, orig_in_shardings)
247 def _xla_callable_uncached(fun: lu.WrappedFun, name, donated_invars,
248 keep_unused, in_avals, orig_in_shardings):
249 computation = sharded_lowering(
250 fun, name, donated_invars, keep_unused, True, in_avals, orig_in_shardings,
251 lowering_platform=None)
--> 252 return computation.compile().unsafe_call
File /srv/conda/envs/notebook/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:2206, in MeshComputation.compile(self, compiler_options)
2203 executable = MeshExecutable.from_trivial_jaxpr(
2204 **self.compile_args)
2205 else:
-> 2206 executable = UnloadedMeshExecutable.from_hlo(
2207 self._name,
2208 self._hlo,
2209 **self.compile_args,
2210 compiler_options=compiler_options)
2211 if compiler_options is None:
2212 self._executable = executable
File /srv/conda/envs/notebook/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:2544, in UnloadedMeshExecutable.from_hlo(failed resolving arguments)
2541 mesh = i.mesh # type: ignore
2542 break
-> 2544 xla_executable, compile_options = _cached_compilation(
2545 hlo, name, mesh, spmd_lowering,
2546 tuple_args, auto_spmd_lowering, allow_prop_to_outputs,
2547 tuple(host_callbacks), backend, da, pmap_nreps,
2548 compiler_options_keys, compiler_options_values)
2550 if hasattr(backend, "compile_replicated"):
2551 semantics_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore
File /srv/conda/envs/notebook/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:2454, in _cached_compilation(computation, name, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, _allow_propagation_to_outputs, host_callbacks, backend, da, pmap_nreps, compiler_options_keys, compiler_options_values)
2449 return None, compile_options
2451 with dispatch.log_elapsed_time(
2452 "Finished XLA compilation of {fun_name} in {elapsed_time} sec",
2453 fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT):
-> 2454 xla_executable = dispatch.compile_or_get_cached(
2455 backend, computation, dev, compile_options, host_callbacks)
2456 return xla_executable, compile_options
File /srv/conda/envs/notebook/lib/python3.10/site-packages/jax/_src/dispatch.py:496, in compile_or_get_cached(backend, computation, devices, compile_options, host_callbacks)
492 use_compilation_cache = (compilation_cache.is_initialized() and
493 backend.platform in supported_platforms)
495 if not use_compilation_cache:
--> 496 return backend_compile(backend, computation, compile_options,
497 host_callbacks)
499 cache_key = compilation_cache.get_cache_key(
500 computation, devices, compile_options, backend)
502 executable, compile_time_retrieved = _cache_read(
503 module_name, cache_key, compile_options, backend)
File /srv/conda/envs/notebook/lib/python3.10/site-packages/jax/_src/profiler.py:314, in annotate_function..wrapper(*args, **kwargs)
311 @wraps(func)
312 def wrapper(*args, **kwargs):
313 with TraceAnnotation(name, **decorator_kwargs):
--> 314 return func(*args, **kwargs)
315 return wrapper
File /srv/conda/envs/notebook/lib/python3.10/site-packages/jax/_src/dispatch.py:464, in backend_compile(backend, module, options, host_callbacks)
459 return backend.compile(built_c, compile_options=options,
460 host_callbacks=host_callbacks)
461 # Some backends don't have
host_callbacks
option yet462 # TODO(sharadmv): remove this fallback when all backends allow
compile
463 # to take in
host_callbacks
--> 464 return backend.compile(built_c, compile_options=options)
XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.`
So I'm guessing the issue is the wrong/incompatible cuda version? Any ideas how to fix this?
Beta Was this translation helpful? Give feedback.
All reactions