Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 25 additions & 16 deletions DeepLense_Regression_Zhongchao_Guan/CNNT.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,10 @@ def __init__(self, num_classes, depth, heads, mlp_dim, pool='mean', dim_head=64,
self.to_patch_embedding = CNN()

# hyper-params
num_patches = 16 * 16
max_patches = 256 # maximum expected patches, avoids hardcoded assumption
dim = 32

self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.pos_embedding = nn.Parameter(torch.randn(1, max_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)

Expand All @@ -191,17 +191,26 @@ def __init__(self, num_classes, depth, heads, mlp_dim, pool='mean', dim_head=64,
)

def forward(self, img):
x = self.to_patch_embedding(img)

# b: batch size n: patch number _: dim of patch
b, n, _ = x.shape

cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]

x = self.to_latent(x)
return self.mlp_head(x)
x = self.to_patch_embedding(img)
b, n, _ = x.shape

cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=b)
x = torch.cat((cls_tokens, x), dim=1)

# dynamically handle any patch size
if (n + 1) > self.pos_embedding.shape[1]:
pos_emb = torch.nn.functional.interpolate(
self.pos_embedding.permute(0, 2, 1),
size=(n + 1),
mode='linear',
align_corners=False
).permute(0, 2, 1)
else:
pos_emb = self.pos_embedding[:, :(n + 1)]

x += pos_emb
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@
from models.cnn_zoo import CustomResNet
from utils.losses.contrastive_loss import ContrastiveLossEuclidean
from utils.train import train_simplistic
import argparse
from utils.util import load_model_add_head
from torchsummary import summary

# Set device
device = "cuda" # torch.device("cuda" if torch.cuda.is_available() else "cpu")
learning_method = "contrastive_embedding"
saved_model_path = "/home/kartik/git/deepLense_transformer_ssl/output/pretrained_contrastive_embedding.pth"
parser = argparse.ArgumentParser()
parser.add_argument('--saved_model_path', type=str, required=True,
help='Path to saved pretrained model checkpoint')
args = parser.parse_args()
saved_model_path = args.saved_model_path

# Set hyperparameters
batch_size = 128
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
get_last_layer_features,
)
from torchsummary import summary
import argparse
from models.byol import BYOLSingleChannel, FinetuneModelByol
import torchvision
from models.utils.finetune_model import FinetuneModel
Expand All @@ -22,6 +23,11 @@
device = "cuda" # torch.device("cuda" if torch.cuda.is_available() else "cpu")
learning_method = "contrastive_embedding"
saved_model_path = "/home/kartik/git/DeepLense/Transformers_Classification_DeepLense_Kartik_Sachdev/logger/2023-07-23-13-30-24/checkpoint/Resnet_finetune_Model_II_2023-07-23-13-30-24.pt"
parser = argparse.ArgumentParser()
parser.add_argument('--saved_model_path', type=str, required=True,
help='Path to saved BYOL model checkpoint')
args = parser.parse_args()
saved_model_path = args.saved_model_path

# Set hyperparameters
batch_size = 512
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from models.byol import BYOLSingleChannel, FinetuneModelByol
import torchvision
from torchsummary import summary
import argparse


def main():
Expand All @@ -19,8 +20,14 @@ def main():
labels_map = {0: "axion", 1: "cdm", 2: "no_sub"}
image_size = 224
channels = 1
log_dir = "/home/kartik/git/DeepLense/Transformers_Classification_DeepLense_Kartik_Sachdev/logger/2023-07-23-13-30-24"
finetune_model_path = "/home/kartik/git/DeepLense/Transformers_Classification_DeepLense_Kartik_Sachdev/logger/2023-07-23-13-30-24/checkpoint/Resnet_finetune_Model_II.pt"
parser = argparse.ArgumentParser()
parser.add_argument('--log_dir', type=str, required=True,
help='Path to log directory')
parser.add_argument('--finetune_model_path', type=str, required=True,
help='Path to finetuned model checkpoint')
args = parser.parse_args()
log_dir = args.log_dir
finetune_model_path = args.finetune_model_path
batch_size = 512
num_workers = 8

Expand Down