Issue of installing and upgrading JAX

Primary informations

Username: zhangzho
Cluster: Baobab

Description

I need to load module JAX to run a large neural network model on Baobab. I noticed that jax/0.3.25-CUDA-11.7.0 is already installed on Baobab. But this version seems to be too old since I always get the error ‘AttributeError: module jax has no attribute typing’ when trying to run the model. Then I created a virtual environment to install newer versions of JAX. However, this does not work either. Whenever I want to install JAX using CUDA with NVIDIA GPU support, I would encounter the error ‘OSError: [Errno 28] No space left on device’.

Steps to Reproduce

  1. load modules with code ‘module load GCC/12.2.0 OpenMPI/4.1.4 CUDA/12.2.2 Python/3.10.8’
  2. create virtual environment
  3. install JAX with code ‘pip install --upgrade “jax[cuda12_pip]” -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Expected Result

JAX version 0.4.19 is successfully installed.

Actual Result

‘OSError: [Errno 28] No space left on device’