This repository contains build scripts for creating a Docker image that enables running JAX with CUDA 12.1 support. The pre-built image is available on Docker Hub with the default tag avirupdas55/jax:v1.
To use the pre-built Docker image with the default tag, simply pull it from Docker Hub:
docker pull avirupdas55/jax:v1If you prefer to build the image locally using the provided build script, follow the steps below.
- Clone this repository:
git clone https://github.com/agent-lab/JAX_Docker.git- Navigate to the repository:
cd JAX_Docker- Run the build script with optional arguments for custom tag and Docker Hub ID:
bash docker_build.sh custom-tag your-docker-idReplace custom-tag with the desired tag name and your-docker-id with your Docker Hub ID. If no arguments are provided, the default tag avirupdas55/jax:v1 will be used.
Once you have the Docker image, you can run JAX scripts inside a container. Use the following command to start a container:
docker run -it --gpus all avirupdas55/jax:v1This command will launch a container with access to all GPUs.
You can also start a Jupyter notebook kernel inside the Docker container. Use the following command:
jupyter-notebook --ip='0.0.0.0' --port=8888 --no-browser --allow-root --NotebookApp.allow_origin='*'You can run JAX scripts within the Docker container. For example:
docker run -it --gpus all avirupdas55/jax:v1 python my_jax_script.pyReplace my_jax_script.py with the name of your JAX script.
To use the Docker image with Singularity, follow these steps:
- Run JAX script in the Singularity container:
singularity exec --nv docker://avirupdas55/jax:v1 python my_jax_script.pyReplace my_jax_script.py with the name of your JAX script.
- This Docker image includes JAX with CUDA 12.1 support.
- The default Docker image tag is
avirupdas55/jax:v1. - For more information about JAX, refer to the official JAX documentation.
- For CUDA documentation, visit the NVIDIA CUDA Toolkit documentation.
Feel free to customise the Dockerfile, build script, and repository to suit your specific requirements.
Happy coding with JAX! 🚀
