Skip to content

jax on SLURM with CUDA 11.6 #25373

Closed Answered by ashok-arora
ashok-arora asked this question in Q&A
Discussion options

You must be logged in to vote

Finally worked, heres my setup for future reference:

  1. in the slurm script:
module load cuda/11.6
module load spack
source /home/apps/spack/share/spack/setup-env.sh
spack load cudnn@8.7.0.84-11.8
  1. code to check:
import jax
import logging

  logger = logging.getLogger(__name__)
  logger.setLevel(logging.INFO)

  logger.info("Starting Job for Config:\n"+str(OmegaConf.to_yaml(config)))
  logger.info("Available Backends:"+str(jax.devices()))
  jax.print_environment_info()
  1. Output of the code:
[2024-12-10 16:44:32,071][jax._src.xla_bridge][INFO] - Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
[2024-12-10 16:44:34,106][jax._src.xla_br…

Replies: 2 comments

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Answer selected by ashok-arora
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant