Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
device: "cuda"
learning_method: "contrastive_embedding"
dataset_name: "Model_II"
batch_size: 512
epochs: 10
learning_rate: 0.0001
margin: 1.0
num_channels: 1
temperature: 0.5
num_classes: 3
split_ratio: 0.25
num_workers: 8

# REQUIRED - set this to your actual path
saved_model_path: ""
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
device: "cuda"
learning_method: "contrastive_embedding"
batch_size: 128
epochs: 10
learning_rate: 0.001
margin: 1.0
num_channels: 1
temperature: 0.5
num_classes: 2
num_workers: 4

# REQUIRED - set this to your actual path
saved_model_path: ""
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
device: "cuda"
dataset_name: "Model_II"
num_classes: 3
image_size: 224
channels: 1
batch_size: 512
num_workers: 8
labels_map:
0: "axion"
1: "cdm"
2: "no_sub"

# REQUIRED - set these to your actual paths
log_dir: ""
finetune_model_path: ""
69 changes: 23 additions & 46 deletions Transformers_Classification_DeepLense_Kartik_Sachdev/finetune.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import os
import argparse
import yaml

import torch
import torch.nn as nn
import torch.optim as optim
Expand All @@ -9,51 +13,24 @@
from utils.util import load_model_add_head
from torchsummary import summary

# Set device
device = "cuda" # torch.device("cuda" if torch.cuda.is_available() else "cpu")
learning_method = "contrastive_embedding"
saved_model_path = "/home/kartik/git/deepLense_transformer_ssl/output/pretrained_contrastive_embedding.pth"

# Set hyperparameters
batch_size = 128
epochs = 10
learning_rate = 0.001
margin = 1.0
num_channels = 1
temperature = 0.5
num_classes = 2

# Load dataset
default_dataset_setup = DefaultDatasetSetup()
train_dataset = default_dataset_setup.get_default_trainset_ssl()
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

# Load pretrained model and add head
pretrain_head = nn.Identity()
pretrain_model = CustomResNet(num_channels, head=pretrain_head)
pretrain_model.summarize()

in_features = pretrain_model.get_last_layer_features()
finetune_head = nn.Sequential(
nn.Linear(in_features, 256), nn.ReLU(), nn.Linear(256, num_classes)
)

model = load_model_add_head(
pretrain_model=pretrain_model,
saved_model_path=saved_model_path,
head=finetune_head,
freeze_pretrain_layers=True,
)

model.to(device)
summary(model=model, input_size=(1, 1, 224, 224), device=device)


# Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

# Training loop
train_simplistic(
epochs, model, device, train_loader, criterion, optimizer, saved_model_path
# ── Argument parsing ───────────────────────────────────────────────────────────
parser = argparse.ArgumentParser(description="Contrastive Finetuning for DeepLense")
parser.add_argument(
"--config",
type=str,
default="config/yaml/finetune_config.yaml",
help="Path to YAML config file"
)
parser.add_argument("--saved_model_path", type=str, default=None, help="Path to pretrained model .pth file")
parser.add_argument("--device", type=str, default=None, help="cuda or cpu")
args = parser.parse_args()

# ── Load YAML config ───────────────────────────────────────────────────────────
config = {}
if os.path.exists(args.config):
with open(args.config, "r") as f:
config = yaml.safe_load(f) or {}

# ── Settings (CLI overrides YAML) ──────────────────────────────────────────────
device = args.device or config.get("dev
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
import argparse
import yaml

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from utils.dataset import DeepLenseDatasetSSL, DefaultDatasetSetupSSL
from torch.utils.data import DataLoader, random_split

from utils.dataset import DeepLenseDatasetSSL, DefaultDatasetSetupSSL
from models.cnn_zoo import CustomResNet
from utils.losses.contrastive_loss import ContrastiveLossEuclidean
from utils.train import train_simplistic
Expand All @@ -18,78 +20,23 @@
import torchvision
from models.utils.finetune_model import FinetuneModel

# Set device
device = "cuda" # torch.device("cuda" if torch.cuda.is_available() else "cpu")
learning_method = "contrastive_embedding"
saved_model_path = "/home/kartik/git/DeepLense/Transformers_Classification_DeepLense_Kartik_Sachdev/logger/2023-07-23-13-30-24/checkpoint/Resnet_finetune_Model_II_2023-07-23-13-30-24.pt"

# Set hyperparameters
batch_size = 512
epochs = 10
learning_rate = 0.0001
margin = 1.0
num_channels = 1
temperature = 0.5
num_classes = 3

# Load dataset
default_dataset_setup = DefaultDatasetSetupSSL(dataset_name="Model_II")
train_dataset = default_dataset_setup.get_dataset()

# split in train and valid set
split_ratio = 0.25 # 0.25
valid_len = int(split_ratio * len(train_dataset))
train_len = len(train_dataset) - valid_len

train_dataset, val_set = random_split(train_dataset, [train_len, valid_len])

train_loader = DataLoader(
dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=8
# ── Argument parsing ───────────────────────────────────────────────────────────
parser = argparse.ArgumentParser(description="BYOL Finetuning for DeepLense")
parser.add_argument(
"--config",
type=str,
default="config/yaml/finetune_byol_config.yaml",
help="Path to YAML config file"
)
parser.add_argument("--saved_model_path", type=str, default=None, help="Path to pretrained BYOL .pt file")
parser.add_argument("--device", type=str, default=None, help="cuda or cpu")
args = parser.parse_args()

val_loader = DataLoader(
dataset=val_set, batch_size=batch_size, shuffle=True, num_workers=8
)
# ── Load YAML config ───────────────────────────────────────────────────────────
config = {}
if os.path.exists(args.config):
with open(args.config, "r") as f:
config = yaml.safe_load(f) or {}


# Load pretrained model and add head
resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = BYOLSingleChannel(backbone, num_ftrs=512)
model.load_state_dict(torch.load(saved_model_path))
print(">>>>>> Keys matched. Model loaded")
model.to(device)
# last_layer_num = get_last_layer_features(model=model, num_input=2)
# print(last_layer_num)

input_feature = 256
finetune_head = nn.Sequential(
nn.Linear(input_feature, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Linear(512, input_feature),
nn.BatchNorm1d(input_feature),
nn.ReLU(),
nn.Linear(input_feature, num_classes),
)

finetune_model = FinetuneModelByol(backbone=model, head=finetune_head)
finetune_model.to(device=device)
summary(finetune_model, input_size=(10, 1, 224, 224), device=device)


# Define optimizer and loss function
optimizer = optim.Adam(finetune_model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

# Training loop
train_simplistic(
epochs,
finetune_model,
device,
train_loader,
criterion,
optimizer,
saved_model_path,
valid_loader=val_loader,
)
# ── Settings (CLI overrides
77 changes: 62 additions & 15 deletions Transformers_Classification_DeepLense_Kartik_Sachdev/inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import print_function

from turtle import down
import os
import argparse
import yaml

from utils.dataset import DefaultDatasetSetupSSL
from utils.inference import InferenceSSL
Expand All @@ -12,17 +14,63 @@
from torchsummary import summary


def parse_args():
parser = argparse.ArgumentParser(description="BYOL Inference for DeepLense")
parser.add_argument(
"--config",
type=str,
default="config/yaml/inference_config.yaml",
help="Path to YAML config file"
)
parser.add_argument("--log_dir", type=str, default=None, help="Path to log directory")
parser.add_argument("--finetune_model_path", type=str, default=None, help="Path to finetuned model .pt file")
parser.add_argument("--dataset_name", type=str, default=None, help="Dataset name e.g. Model_II")
parser.add_argument("--device", type=str, default=None, help="cuda or cpu")
return parser.parse_args()


def main():
device = "cuda"
num_classes = 3
dataset_name = "Model_II"
labels_map = {0: "axion", 1: "cdm", 2: "no_sub"}
image_size = 224
channels = 1
log_dir = "/home/kartik/git/DeepLense/Transformers_Classification_DeepLense_Kartik_Sachdev/logger/2023-07-23-13-30-24"
finetune_model_path = "/home/kartik/git/DeepLense/Transformers_Classification_DeepLense_Kartik_Sachdev/logger/2023-07-23-13-30-24/checkpoint/Resnet_finetune_Model_II.pt"
batch_size = 512
num_workers = 8
args = parse_args()

# Load YAML config
config = {}
if os.path.exists(args.config):
with open(args.config, "r") as f:
config = yaml.safe_load(f) or {}

# CLI args override YAML values
device = args.device or config.get("device", "cuda")
num_classes = config.get("num_classes", 3)
dataset_name = args.dataset_name or config.get("dataset_name", "Model_II")
labels_map = config.get("labels_map", {0: "axion", 1: "cdm", 2: "no_sub"})
image_size = config.get("image_size", 224)
channels = config.get("channels", 1)
batch_size = config.get("batch_size", 512)
num_workers = config.get("num_workers", 8)
log_dir = args.log_dir or config.get("log_dir", None)
finetune_model_path = args.finetune_model_path or config.get("finetune_model_path", None)

# Validate required paths
if not log_dir:
raise ValueError(
"log_dir is required.\n"
"Set it in config/yaml/inference_config.yaml or pass --log_dir <path>"
)
if not finetune_model_path:
raise ValueError(
"finetune_model_path is required.\n"
"Set it in config/yaml/inference_config.yaml or pass --finetune_model_path <path>"
)
if not os.path.exists(log_dir):
raise FileNotFoundError(
f"log_dir not found: {log_dir}\n"
"Please check your config/yaml/inference_config.yaml or --log_dir argument."
)
if not os.path.exists(finetune_model_path):
raise FileNotFoundError(
f"finetune_model_path not found: {finetune_model_path}\n"
"Please check your config/yaml/inference_config.yaml or --finetune_model_path argument."
)

# Load pretrained model and add head
resnet = torchvision.models.resnet18()
Expand All @@ -45,7 +93,6 @@ def main():
print(">>>> Keys matched")
summary(finetune_model, input_size=(10, 1, 224, 224), device=device)

# testset
# setup default dataset
default_dataset_setup = DefaultDatasetSetupSSL()
default_dataset_setup.setup(dataset_name=dataset_name)
Expand All @@ -56,7 +103,7 @@ def main():
default_dataset_setup.visualize_dataset(train_dataset)

# split in train and valid set
split_ratio = 0.05 # 0.25
split_ratio = 0.05
valid_len = int(split_ratio * len(train_dataset))
train_len = len(train_dataset) - valid_len

Expand Down Expand Up @@ -84,12 +131,12 @@ def main():
image_size=image_size,
channels=channels,
destination_dir="data",
log_dir=log_dir, # log_dir
log_dir=log_dir,
)

infer_obj.infer_plot_roc()
infer_obj.generate_plot_confusion_matrix()


if __name__ == "__main__":
main()
main()