-
Notifications
You must be signed in to change notification settings - Fork 35
Open
Description
As the title mentioned, TRAK is experiencing numerical issues on the cuda version 11.8.
- Minimum reproducible code
import torch
from trak.projectors import ProjectionType, AbstractProjector, CudaProjector
grad_dim = int(1e6)
projector = CudaProjector(
grad_dim=grad_dim,
proj_dim=32768,
seed=42,
proj_type=ProjectionType.normal,
device='cuda:0',
max_batch_size=8,
)
grad = torch.randn(8, grad_dim, device='cuda:0')
proj = projector.project(grad, model_id=0)
print(proj)
>>> tensor([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]], device='cuda:0')- Q: How I installed TRAK?
conda create -n trak python=3.10.16
conda activate trak
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
conda install cuda-nvcc=11.8 -c nvidia -y
conda install -c nvidia cuda-toolkit=11.8 -y
export CUDA_HOME=$CONDA_PREFIX
export PYTHONPATH=$CONDA_PREFIX/lib/python3.x/site-packages:$PYTHONPATH
pip install traker[fast]==0.3.2
Enviroment setting
### Environment Info
Python 3.8.20
Package Version
------------------------ ------------
fast_jl 0.1.3
filelock 3.13.1
fsspec 2024.6.1
Jinja2 3.1.4
MarkupSafe 2.1.5
mpmath 1.3.0
networkx 3.0
numpy 1.24.1
nvidia-cublas-cu11 11.11.3.6
nvidia-cuda-cupti-cu11 11.8.87
nvidia-cuda-nvrtc-cu11 11.8.89
nvidia-cuda-runtime-cu11 11.8.89
nvidia-cudnn-cu11 9.1.0.70
nvidia-cufft-cu11 10.9.0.58
nvidia-curand-cu11 10.3.0.86
nvidia-cusolver-cu11 11.4.1.48
nvidia-cusparse-cu11 11.7.5.86
nvidia-nccl-cu11 2.20.5
nvidia-nvtx-cu11 11.8.86
pillow 10.2.0
pip 24.2
setuptools 75.1.0
sympy 1.13.1
torch 2.4.1+cu118
torchvision 0.19.1+cu118
tqdm 4.67.1
traker 0.3.2
triton 3.0.0
typing_extensions 4.12.2
wheel 0.44.0
# packages in environment at /data/yonghyun/anaconda3/envs/trak_118:
#
# Name Version Build Channel
_libgcc_mutex 0.1 main
_openmp_mutex 5.1 1_gnu
bzip2 1.0.8 h5eee18b_6
ca-certificates 2024.12.31 h06a4308_0
cuda-cccl_linux-64 12.8.55 0 nvidia
cuda-command-line-tools 12.8.0 0 nvidia
cuda-compiler 12.6.2 0 nvidia
cuda-cudart 12.8.57 0 nvidia
cuda-cudart-dev 12.8.57 0 nvidia
cuda-cudart-dev_linux-64 12.8.57 0 nvidia
cuda-cudart-static 12.8.57 0 nvidia
cuda-cudart-static_linux-64 12.8.57 0 nvidia
cuda-cudart_linux-64 12.8.57 0 nvidia
cuda-cuobjdump 12.8.55 0 nvidia
cuda-cupti 12.8.57 0 nvidia
cuda-cupti-dev 12.8.57 0 nvidia
cuda-cuxxfilt 12.8.55 0 nvidia
cuda-documentation 12.4.127 0 nvidia
cuda-driver-dev 12.8.57 0 nvidia
cuda-driver-dev_linux-64 12.8.57 0 nvidia
cuda-gdb 12.8.55 0 nvidia
cuda-libraries 12.8.0 0 nvidia
cuda-libraries-dev 12.8.0 0 nvidia
cuda-nsight 12.8.55 0 nvidia
cuda-nvcc 11.8.89 0 nvidia
cuda-nvdisasm 12.8.55 0 nvidia
cuda-nvml-dev 12.8.55 0 nvidia
cuda-nvprof 12.8.57 0 nvidia
cuda-nvprune 12.8.55 0 nvidia
cuda-nvrtc 12.8.61 0 nvidia
cuda-nvrtc-dev 12.8.61 0 nvidia
cuda-nvtx 12.8.55 0 nvidia
cuda-nvvp 12.8.57 0 nvidia
cuda-opencl 12.8.55 0 nvidia
cuda-opencl-dev 12.8.55 0 nvidia
cuda-profiler-api 12.8.55 0 nvidia
cuda-sanitizer-api 12.8.55 0 nvidia
cuda-toolkit 11.8.0 0 nvidia
cuda-tools 12.8.0 0 nvidia
cuda-version 12.8 3 nvidia
cuda-visual-tools 12.8.0 0 nvidia
dbus 1.13.18 hb2f20db_0
expat 2.6.4 h6a678d5_0
fast-jl 0.1.3 pypi_0 pypi
filelock 3.13.1 pypi_0 pypi
fontconfig 2.14.1 h55d465d_3
freetype 2.12.1 h4a9f257_0
fsspec 2024.6.1 pypi_0 pypi
gds-tools 1.13.0.11 0 nvidia
glib 2.78.4 h6a678d5_0
glib-tools 2.78.4 h6a678d5_0
gmp 6.3.0 h6a678d5_0
icu 73.1 h6a678d5_0
jinja2 3.1.4 pypi_0 pypi
ld_impl_linux-64 2.40 h12ee557_0
libcublas 12.8.3.14 0 nvidia
libcublas-dev 12.8.3.14 0 nvidia
libcufft 11.3.3.41 0 nvidia
libcufft-dev 11.3.3.41 0 nvidia
libcufile 1.13.0.11 0 nvidia
libcufile-dev 1.13.0.11 0 nvidia
libcurand 10.3.9.55 0 nvidia
libcurand-dev 10.3.9.55 0 nvidia
libcusolver 11.7.2.55 0 nvidia
libcusolver-dev 11.7.2.55 0 nvidia
libcusparse 12.5.7.53 0 nvidia
libcusparse-dev 12.5.7.53 0 nvidia
libffi 3.4.4 h6a678d5_1
libgcc-ng 11.2.0 h1234567_1
libglib 2.78.4 hdc74915_0
libgomp 11.2.0 h1234567_1
libiconv 1.16 h5eee18b_3
libnpp 12.3.3.65 0 nvidia
libnpp-dev 12.3.3.65 0 nvidia
libnvfatbin 12.8.55 0 nvidia
libnvfatbin-dev 12.8.55 0 nvidia
libnvjitlink 12.8.61 1 nvidia
libnvjitlink-dev 12.8.61 1 nvidia
libnvjpeg 12.3.5.57 0 nvidia
libnvjpeg-dev 12.3.5.57 0 nvidia
libpng 1.6.39 h5eee18b_0
libstdcxx-ng 11.2.0 h1234567_1
libuuid 1.41.5 h5eee18b_0
libxcb 1.15 h7f8727e_0
libxkbcommon 1.0.1 h097e994_2
libxml2 2.13.5 hfdd30dd_0
markupsafe 2.1.5 pypi_0 pypi
mpmath 1.3.0 pypi_0 pypi
ncurses 6.4 h6a678d5_0
networkx 3.0 pypi_0 pypi
nsight-compute 2025.1.0.14 0 nvidia
nspr 4.35 h6a678d5_0
nss 3.89.1 h6a678d5_0
numpy 1.24.1 pypi_0 pypi
nvidia-cublas-cu11 11.11.3.6 pypi_0 pypi
nvidia-cuda-cupti-cu11 11.8.87 pypi_0 pypi
nvidia-cuda-nvrtc-cu11 11.8.89 pypi_0 pypi
nvidia-cuda-runtime-cu11 11.8.89 pypi_0 pypi
nvidia-cudnn-cu11 9.1.0.70 pypi_0 pypi
nvidia-cufft-cu11 10.9.0.58 pypi_0 pypi
nvidia-curand-cu11 10.3.0.86 pypi_0 pypi
nvidia-cusolver-cu11 11.4.1.48 pypi_0 pypi
nvidia-cusparse-cu11 11.7.5.86 pypi_0 pypi
nvidia-nccl-cu11 2.20.5 pypi_0 pypi
nvidia-nvtx-cu11 11.8.86 pypi_0 pypi
ocl-icd 2.3.2 h5eee18b_1
openssl 3.0.15 h5eee18b_0
pcre2 10.42 hebb0a14_1
pillow 10.2.0 pypi_0 pypi
pip 24.2 py38h06a4308_0
python 3.8.20 he870216_0
readline 8.2 h5eee18b_0
setuptools 75.1.0 py38h06a4308_0
sqlite 3.45.3 h5eee18b_0
sympy 1.13.1 pypi_0 pypi
tk 8.6.14 h39e8969_0
torch 2.4.1+cu118 pypi_0 pypi
torchvision 0.19.1+cu118 pypi_0 pypi
tqdm 4.67.1 pypi_0 pypi
traker 0.3.2 pypi_0 pypi
triton 3.0.0 pypi_0 pypi
typing-extensions 4.12.2 pypi_0 pypi
wheel 0.44.0 py38h06a4308_0
xz 5.4.6 h5eee18b_1
zlib 1.2.13 h5eee18b_1
### NVIDIA Info
Wed Feb 5 15:12:37 2025
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.216.01 Driver Version: 535.216.01 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA H100 PCIe Off | 00000000:2D:00.0 Off | 0 |
| N/A 40C P0 52W / 350W | 17MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
Haruka1307
Metadata
Metadata
Assignees
Labels
No labels