Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
316 changes: 177 additions & 139 deletions DeepLense_Classification_Transformers_Archil_Srivastava/README.md
Original file line number Diff line number Diff line change
@@ -1,198 +1,236 @@
# __DeepLense Classification Using Vision Transformers__

PyTorch-based library for performing image classification of the strong lensing images to predict the type of dark matter substructure. The code contains implementation and benchmarking of various versions of Vision Transformers (especially hybrid ones) from [pytorch-image-models](https://github.com/rwightman/pytorch-image-models) and logging metrics like loss and AUROC (class-wise and overall) scores on [Weights and Biases](https://wandb.ai/site).
# DeepLense: Classification Using Vision Transformers

This was a __Google Summer of Code (GSoC) 2022__ project. For more info on the project [Click Here](https://summerofcode.withgoogle.com/programs/2022/projects/iFKJMj0t) <br>
<br>
A PyTorch-based framework for classifying **strong gravitational lensing images** to identify **dark matter substructures** using modern deep learning architectures, especially **Vision Transformers (ViTs)**.

# __Datasets__
The models are trained on 3 datasets -- namely Model I, Model II, and Model III -- consisting of ~30,000 train and ~5000 test images in each dataset. All images are single channel images with the size being 150x150 for Model I and 64x64 for Model II and Model III both. All dataasets consis of 3 classes, namely:
- Axion (Vortex substructure)
- CDM (Cold Dark Matter, point mass subhalos)
- No substructure (this doesn’t occur observationally as there is always a substructure in reality, but we use this simulated class as a baseline)
This project benchmarks multiple transformer-based models (including hybrid architectures) using the [timm](https://github.com/rwightman/pytorch-image-models) library and tracks performance metrics (Loss, AUROC, etc.) using [Weights & Biases](https://wandb.ai/site).

___Note__: Axion files have extra data corresponding to mass of axion used in simulation._
> Originally developed as part of **Google Summer of Code (GSoC) 2022**
> [Project Link](https://summerofcode.withgoogle.com/programs/2022/projects/iFKJMj0t)

## __Model_I__
- Images are 150 x 150 pixels
- Modeled with a Gaussian point spread function
- Added background and noise for SNR of around 25
---

## __Model_II__
- Images are 64 x 64 pixels
- Modeled after Euclid observation characteristics as done by default in lenstronomy
- Modeled with simple Sersic light profile
## Problem Statement

## __Model_III__
- Images are 64 x 64 pixels
- Modeled after HST observation characteristics as done by default in lenstronomy.
- Modeled with simple Sersic light profile
Gravitational lensing causes distortion in galaxy images due to massive celestial objects.
This project applies deep learning to:

<br>
* Classify lensing images into dark matter categories
* Analyze substructure patterns
* Benchmark modern transformer architectures in astrophysical ML tasks

# __Training__
---

## Key Features

* Image classification using Vision Transformers
* Supports multiple model families (ViT, Swin, CoAtNet, EfficientNet, etc.)
* Experiment tracking with Weights & Biases
* Benchmarking across multiple datasets
* Modular and extensible pipeline

---

## Getting Started

### 1. Clone the Repository

Use the train.py script to train a particular model (using timm model name). The script will ask for a WandB login key, hence a WandB account is needed. Example:
```bash
python3 train.py \
git clone https://github.com/ML4SCI/DeepLense.git
cd DeepLense/DeepLense_Classification_Transformers_Archil_Srivastava
```

---

### 2. Install Dependencies

```bash
pip install torch torchvision timm wandb numpy matplotlib
```

---

### 3. Train a Model

```bash
python train.py \
--dataset Model_I \
--model_name coatnet_nano_rw_224 \
--pretrained \
--tune \
--no-complex \
--device best \
--project ml4sci_deeplense_final
--device best
```

---

### 4. Evaluate the Model

```bash
python eval.py \
--run_id <wandb_run_id> \
--device cpu
```
| Arguments | Description |
| :--- | :--- |
| dataset | Name of dataset i.e. Model_I, Model_II or Model_III |
| model_name | Name of the model from pytorch-image-models |
| complex | 0 if use model from pytorch-image-models directly, 1 if add some additional layers at the end of the model |
| pretrained | Picked pretrained weights or train from scratch |
| tune | Whether to further tune (1) pretrained model (if any) or freeze the pretrained weights (0) |
| batch_size | Batch Size |
| lr | Learning Rate |
| dropout | Dropout Rate |
| optimizer | Optimizer name |
| decay_lr | 0 if use constant LR, 1 if use CosineAnnealingWarmRestarts |
| epochs | Number of epochs |
| random_zoom | Random zoom for augmentation |
| random_rotation | Random rotation for augmentation (in degreees) |
| log_interval | Log interval for logging to weights and biases |
| project | Project name in Weight and Biases
| device | Device: cuda or mps or cpu or best |
| seed | Random seed |

# __Evaluation__

Run evaluation of trained model on test sets using eval.py script. Pass the run_id of the train run from WandB to pick the proper configuration. Example:

---

## Project Structure

```bash
python3 eval.py \
--run_id 1g9hi3n6 \
--device cuda \
-- project ml4sci_deeplense_final
.
├── data/ # Dataset loading & preprocessing
├── models/ # Model architectures (ViT, etc.)
├── train.py # Training script
├── eval.py # Evaluation script
├── utils.py # Helper utilities
├── constants.py # Configurations
├── results/ # Output logs and results
```

<br>
---

# __Results__
## Datasets

So far, around 9 model families (including EfficientNet as baseline and 8 transformer families). Different variants of models from the same families were tested and the results are shown below. For any further analysis, the run data can be found on my Weights & Biases [project](https://wandb.ai/_archil/ml4sci_deeplense_final).
The model is trained on **3 simulated datasets**:

## __[EfficientNet](https://arxiv.org/abs/1905.11946)__
| Dataset | Image Size | Characteristics |
| --------- | ---------- | ----------------------------- |
| Model I | 150×150 | Gaussian PSF + noise |
| Model II | 64×64 | Euclid observation simulation |
| Model III | 64×64 | HST observation simulation |

### Model I
![Alt text](https://github.com/archilk/ml4sci-gsoc22/blob/main/deeplense/results/Model_I/efficientnet_b1__complex.png?raw=true)
### Classes:

### Model II
![Alt text](https://github.com/archilk/ml4sci-gsoc22/blob/main/deeplense/results/Model_II/efficientnet_b1__complex.png?raw=true)
* Axion (Vortex substructure)
* CDM (Cold Dark Matter)
* No Substructure (baseline class)

### Model III
![Alt text](https://github.com/archilk/ml4sci-gsoc22/blob/main/deeplense/results/Model_III/efficientnet_b1__complex.png?raw=true)
> Note: Axion datasets include additional mass-related information.

## __[ViT](https://arxiv.org/abs/2010.11929)__
---

### Model I
![Alt text](https://github.com/archilk/ml4sci-gsoc22/blob/main/deeplense/results/Model_I/vit_tiny_r_s16_p8_224.png?raw=true)
## Training Configuration

### Model II
Example:

### Model III
![Alt text](https://github.com/archilk/ml4sci-gsoc22/blob/main/deeplense/results/Model_III/vit_tiny_r_s16_p8_224__complex.png?raw=true)
```bash
python train.py \
--dataset Model_I \
--model_name coatnet_nano_rw_224 \
--batch_size 32 \
--lr 0.001 \
--epochs 20 \
--optimizer adam \
--device best
```

### Important Parameters:

| Parameter | Description |
| ---------- | ------------------------------ |
| dataset | Model_I / Model_II / Model_III |
| model_name | Model from timm library |
| pretrained | Use pretrained weights |
| tune | Fine-tune model |
| batch_size | Batch size |
| lr | Learning rate |
| epochs | Training epochs |
| device | cpu / cuda / best |

---

## __[ConViT](https://arxiv.org/abs/2103.10697)__
## Evaluation

### Model I
![Alt text](https://github.com/archilk/ml4sci-gsoc22/blob/main/deeplense/results/Model_I/convit_tiny__complex.png?raw=true)
Evaluate trained models using:

### Model II
![Alt text](https://github.com/archilk/ml4sci-gsoc22/blob/main/deeplense/results/Model_II/convit_tiny.png?raw=true)
```bash
python eval.py --run_id <wandb_run_id>
```

### Model III
![Alt text](https://github.com/archilk/ml4sci-gsoc22/blob/main/deeplense/results/Model_III/convit_tiny__complex.png?raw=true)
Metrics:

## __[CrossViT](https://arxiv.org/abs/2103.14899)__
* Accuracy
* AUROC (class-wise & overall)
* Loss

### Model I
![Alt text](https://github.com/archilk/ml4sci-gsoc22/blob/main/deeplense/results/Model_I/crossvit_small_240.png?raw=true)
---

### Model II
![Alt text](https://github.com/archilk/ml4sci-gsoc22/blob/main/deeplense/results/Model_II/crossvit_small_240__complex.png?raw=true)
## Results & Benchmarks

### Model III
![Alt text](https://github.com/archilk/ml4sci-gsoc22/blob/main/deeplense/results/Model_III/crossvit_small_240.png?raw=true)
Models evaluated include:

## __[Bottleneck Transformers](https://arxiv.org/abs/2101.11605)__
* EfficientNet
* Vision Transformers (ViT)
* Swin Transformer
* CoAtNet
* CrossViT
* ConViT

### Model I
![Alt text](https://github.com/archilk/ml4sci-gsoc22/blob/main/deeplense/results/Model_I/botnet_26t_256.png?raw=true)
Full experiment logs:
https://wandb.ai/_archil/ml4sci_deeplense_final

### Model II
![Alt text](https://github.com/archilk/ml4sci-gsoc22/blob/main/deeplense/results/Model_II/botnet_26t_256.png?raw=true)
---

### Model III
![Alt text](https://github.com/archilk/ml4sci-gsoc22/blob/main/deeplense/results/Model_III/botnet_26t_256.png?raw=true)
## Suggested Improvements (For Contributors)

## __[EfficientFormer](https://arxiv.org/abs/2206.01191)__
This project can be extended by:

### Model I
![Alt text](https://github.com/archilk/ml4sci-gsoc22/blob/main/deeplense/results/Model_I/efficientformer_l3.png?raw=true)
* Adding new transformer architectures
* Improving dataset preprocessing
* Optimizing training pipelines
* Adding visualization dashboards
* Improving documentation & usability

### Model II
![Alt text](https://github.com/archilk/ml4sci-gsoc22/blob/main/deeplense/results/Model_II/efficientformer_l3.png?raw=true)
---

### Model III
![Alt text](https://github.com/archilk/ml4sci-gsoc22/blob/main/deeplense/results/Model_III/efficientformer_l3.png?raw=true)
## Contributing

## __[CoaT](https://arxiv.org/abs/2104.06399)__
We welcome contributions!

### Model I
![Alt text](https://github.com/archilk/ml4sci-gsoc22/blob/main/deeplense/results/Model_I/coat_lite_small__complex.png?raw=true)
### How to contribute:

### Model II
![Alt text](https://github.com/archilk/ml4sci-gsoc22/blob/main/deeplense/results/Model_II/coat_lite_small.png?raw=true)
1. Fork the repository
2. Create a new branch
3. Make changes
4. Submit a Pull Request

### Model III
![Alt text](https://github.com/archilk/ml4sci-gsoc22/blob/main/deeplense/results/Model_III/coat_lite_small.png?raw=true)
### Good first contributions:

## __[CoAtNet](https://arxiv.org/abs/2106.04803)__
* Improve documentation
* Add comments to code
* Fix bugs
* Add visualization

### Model I
![Alt text](https://github.com/archilk/ml4sci-gsoc22/blob/main/deeplense/results/Model_I/coatnet_nano_rw_224.png?raw=true)
---

### Model II
![Alt text](https://github.com/archilk/ml4sci-gsoc22/blob/main/deeplense/results/Model_II/coatnet_nano_rw_224.png?raw=true)
## Limitations

### Model III
![Alt text](https://github.com/archilk/ml4sci-gsoc22/blob/main/deeplense/results/Model_III/coatnet_nano_rw_224.png?raw=true)
* Requires WandB account for full tracking
* Dataset setup not fully automated
* Limited documentation for beginners (can be improved)

## __[Swin](https://arxiv.org/abs/2103.14030)__
---

### Model I
![Alt text](https://github.com/archilk/ml4sci-gsoc22/blob/main/deeplense/results/Model_I/swinv2_tiny_window8_256%20.png?raw=true)
## References

### Model II
![Alt text](https://github.com/archilk/ml4sci-gsoc22/blob/main/deeplense/results/Model_II/swinv2_tiny_window8_256.png?raw=true)
* [PyTorch Image Models (timm)](https://github.com/rwightman/pytorch-image-models)
* DeepLense GitHub Repository
* GSoC 2022 Project

### Model III
![Alt text](https://github.com/archilk/ml4sci-gsoc22/blob/main/deeplense/results/Model_III/swinv2_tiny_window8_256%20.png?raw=true)
---

<br>
## Summary

## __Citation__
This project demonstrates the application of **state-of-the-art transformer architectures** to astrophysical data, specifically for understanding **dark matter substructures through gravitational lensing images**.

* [pytorch-image-models](https://github.com/rwightman/pytorch-image-models)
---

```bibtex
@misc{rw2019timm,
author = {Ross Wightman},
title = {PyTorch Image Models},
year = {2019},
publisher = {GitHub},
journal = {GitHub repository},
doi = {10.5281/zenodo.4414861},
howpublished = {\url{https://github.com/rwightman/pytorch-image-models}}
}
```

* Apoorva Singh, Yurii Halychanskyi, Marcos Tidball, DeepLense, (2021), GitHub repository, https://github.com/ML4SCI/DeepLense
## Citation

```bibtex
@misc{rw2019timm,
author = {Ross Wightman},
title = {PyTorch Image Models},
year = {2019},
publisher = {GitHub},
journal = {GitHub repository},
doi = {10.5281/zenodo.4414861}
}
```