From ba1bc5e3841c9863fdd0a71381806e2027cd524a Mon Sep 17 00:00:00 2001 From: DhruvrajSinhZala24 Date: Mon, 23 Mar 2026 18:13:41 +0530 Subject: [PATCH] Fix classification transformer runtime compatibility issues --- .../README.md | 9 +++++++ .../eval.py | 24 ++++++++++++++++--- .../requirements.txt | 13 ++++++++++ .../train.py | 2 +- .../utils.py | 6 ++++- 5 files changed, 49 insertions(+), 5 deletions(-) create mode 100644 DeepLense_Classification_Transformers_Archil_Srivastava/requirements.txt diff --git a/DeepLense_Classification_Transformers_Archil_Srivastava/README.md b/DeepLense_Classification_Transformers_Archil_Srivastava/README.md index 8ca8a1d..04d058a 100644 --- a/DeepLense_Classification_Transformers_Archil_Srivastava/README.md +++ b/DeepLense_Classification_Transformers_Archil_Srivastava/README.md @@ -30,6 +30,15 @@ ___Note__: Axion files have extra data corresponding to mass of axion used in si
+# __Setup__ + +Install the Python dependencies for this project before running training or evaluation: +```bash +python3 -m pip install -r requirements.txt +``` + +If PyTorch is not installed yet, install a compatible `torch` and `torchvision` build for your platform first from the [official PyTorch install guide](https://pytorch.org/get-started/locally/), then install the remaining project requirements. + # __Training__ Use the train.py script to train a particular model (using timm model name). The script will ask for a WandB login key, hence a WandB account is needed. Example: diff --git a/DeepLense_Classification_Transformers_Archil_Srivastava/eval.py b/DeepLense_Classification_Transformers_Archil_Srivastava/eval.py index a3f12b3..ab1a54f 100644 --- a/DeepLense_Classification_Transformers_Archil_Srivastava/eval.py +++ b/DeepLense_Classification_Transformers_Archil_Srivastava/eval.py @@ -56,9 +56,27 @@ def evaluate(model, data_loader, loss_fn, device): # Concatenate all results logits, y = torch.cat(logits), torch.cat(y) loss.append(loss_fn(logits, y)) - accuracy.append(accuracy_fn(logits, y, num_classes=NUM_CLASSES)) - class_auroc.append(auroc_fn(logits, y, num_classes=NUM_CLASSES, average=None)) - macro_auroc.append(auroc_fn(logits, y, num_classes=NUM_CLASSES, average="macro")) + accuracy.append( + accuracy_fn(logits, y, task="multiclass", num_classes=NUM_CLASSES) + ) + class_auroc.append( + auroc_fn( + logits, + y, + task="multiclass", + num_classes=NUM_CLASSES, + average=None, + ) + ) + macro_auroc.append( + auroc_fn( + logits, + y, + task="multiclass", + num_classes=NUM_CLASSES, + average="macro", + ) + ) result = { "ground_truth": y, diff --git a/DeepLense_Classification_Transformers_Archil_Srivastava/requirements.txt b/DeepLense_Classification_Transformers_Archil_Srivastava/requirements.txt new file mode 100644 index 0000000..4c2b4e7 --- /dev/null +++ b/DeepLense_Classification_Transformers_Archil_Srivastava/requirements.txt @@ -0,0 +1,13 @@ +# Install PyTorch and torchvision for your platform first if needed: +# https://pytorch.org/get-started/locally/ + +einops>=0.6,<1 +matplotlib>=3.7,<4 +numpy>=1.24,<3 +scikit-learn>=1.3,<2 +timm>=0.9,<1 +torch>=2.1,<3 +torchmetrics>=1.3,<2 +torchvision>=0.16,<1 +tqdm>=4.66,<5 +wandb>=0.16,<1 diff --git a/DeepLense_Classification_Transformers_Archil_Srivastava/train.py b/DeepLense_Classification_Transformers_Archil_Srivastava/train.py index a5a6303..4b5e6d3 100644 --- a/DeepLense_Classification_Transformers_Archil_Srivastava/train.py +++ b/DeepLense_Classification_Transformers_Archil_Srivastava/train.py @@ -319,7 +319,7 @@ def train( # Scheduler if run_config.decay_lr: scheduler = CosineAnnealingWarmRestarts( - optimizer, T_0=15, T_mult=1, eta_min=1e-6, verbose=True + optimizer, T_0=15, T_mult=1, eta_min=1e-6 ) else: scheduler = None diff --git a/DeepLense_Classification_Transformers_Archil_Srivastava/utils.py b/DeepLense_Classification_Transformers_Archil_Srivastava/utils.py index 759d3fe..1d5bc8f 100644 --- a/DeepLense_Classification_Transformers_Archil_Srivastava/utils.py +++ b/DeepLense_Classification_Transformers_Archil_Srivastava/utils.py @@ -29,7 +29,11 @@ def get_device(device): return xm.xla_device() if (device == "cuda" or device == "best") and torch.cuda.is_available(): return "cuda" - if (device == "mps" or device == "best") and torch.has_mps: + if ( + (device == "mps" or device == "best") + and hasattr(torch.backends, "mps") + and torch.backends.mps.is_available() + ): return "mps" if device == "cpu" or device == "best": return "cpu"