diff --git a/DeepLense_Regression_Zhongchao_Guan/CNNT.py b/DeepLense_Regression_Zhongchao_Guan/CNNT.py index 3cfb8d3..7cbb7bd 100644 --- a/DeepLense_Regression_Zhongchao_Guan/CNNT.py +++ b/DeepLense_Regression_Zhongchao_Guan/CNNT.py @@ -173,10 +173,10 @@ def __init__(self, num_classes, depth, heads, mlp_dim, pool='mean', dim_head=64, self.to_patch_embedding = CNN() # hyper-params - num_patches = 16 * 16 + max_patches = 256 # maximum expected patches, avoids hardcoded assumption dim = 32 - self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) + self.pos_embedding = nn.Parameter(torch.randn(1, max_patches + 1, dim)) self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) self.dropout = nn.Dropout(emb_dropout) @@ -191,17 +191,26 @@ def __init__(self, num_classes, depth, heads, mlp_dim, pool='mean', dim_head=64, ) def forward(self, img): - x = self.to_patch_embedding(img) - - # b: batch size n: patch number _: dim of patch - b, n, _ = x.shape - - cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=b) - x = torch.cat((cls_tokens, x), dim=1) - x += self.pos_embedding[:, :(n + 1)] - x = self.dropout(x) - x = self.transformer(x) - x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0] - - x = self.to_latent(x) - return self.mlp_head(x) + x = self.to_patch_embedding(img) + b, n, _ = x.shape + + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=b) + x = torch.cat((cls_tokens, x), dim=1) + + # dynamically handle any patch size + if (n + 1) > self.pos_embedding.shape[1]: + pos_emb = torch.nn.functional.interpolate( + self.pos_embedding.permute(0, 2, 1), + size=(n + 1), + mode='linear', + align_corners=False + ).permute(0, 2, 1) + else: + pos_emb = self.pos_embedding[:, :(n + 1)] + + x += pos_emb + x = self.dropout(x) + x = self.transformer(x) + x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0] + x = self.to_latent(x) + return self.mlp_head(x) \ No newline at end of file diff --git a/Transformers_Classification_DeepLense_Kartik_Sachdev/finetune.py b/Transformers_Classification_DeepLense_Kartik_Sachdev/finetune.py index cddae46..e0bb427 100644 --- a/Transformers_Classification_DeepLense_Kartik_Sachdev/finetune.py +++ b/Transformers_Classification_DeepLense_Kartik_Sachdev/finetune.py @@ -6,13 +6,18 @@ from models.cnn_zoo import CustomResNet from utils.losses.contrastive_loss import ContrastiveLossEuclidean from utils.train import train_simplistic +import argparse from utils.util import load_model_add_head from torchsummary import summary # Set device device = "cuda" # torch.device("cuda" if torch.cuda.is_available() else "cpu") learning_method = "contrastive_embedding" -saved_model_path = "/home/kartik/git/deepLense_transformer_ssl/output/pretrained_contrastive_embedding.pth" +parser = argparse.ArgumentParser() +parser.add_argument('--saved_model_path', type=str, required=True, + help='Path to saved pretrained model checkpoint') +args = parser.parse_args() +saved_model_path = args.saved_model_path # Set hyperparameters batch_size = 128 diff --git a/Transformers_Classification_DeepLense_Kartik_Sachdev/finetune_byol.py b/Transformers_Classification_DeepLense_Kartik_Sachdev/finetune_byol.py index c90729b..7909432 100644 --- a/Transformers_Classification_DeepLense_Kartik_Sachdev/finetune_byol.py +++ b/Transformers_Classification_DeepLense_Kartik_Sachdev/finetune_byol.py @@ -14,6 +14,7 @@ get_last_layer_features, ) from torchsummary import summary +import argparse from models.byol import BYOLSingleChannel, FinetuneModelByol import torchvision from models.utils.finetune_model import FinetuneModel @@ -22,6 +23,11 @@ device = "cuda" # torch.device("cuda" if torch.cuda.is_available() else "cpu") learning_method = "contrastive_embedding" saved_model_path = "/home/kartik/git/DeepLense/Transformers_Classification_DeepLense_Kartik_Sachdev/logger/2023-07-23-13-30-24/checkpoint/Resnet_finetune_Model_II_2023-07-23-13-30-24.pt" +parser = argparse.ArgumentParser() +parser.add_argument('--saved_model_path', type=str, required=True, + help='Path to saved BYOL model checkpoint') +args = parser.parse_args() +saved_model_path = args.saved_model_path # Set hyperparameters batch_size = 512 diff --git a/Transformers_Classification_DeepLense_Kartik_Sachdev/inference.py b/Transformers_Classification_DeepLense_Kartik_Sachdev/inference.py index 5bda766..15d2a44 100644 --- a/Transformers_Classification_DeepLense_Kartik_Sachdev/inference.py +++ b/Transformers_Classification_DeepLense_Kartik_Sachdev/inference.py @@ -10,6 +10,7 @@ from models.byol import BYOLSingleChannel, FinetuneModelByol import torchvision from torchsummary import summary +import argparse def main(): @@ -19,8 +20,14 @@ def main(): labels_map = {0: "axion", 1: "cdm", 2: "no_sub"} image_size = 224 channels = 1 - log_dir = "/home/kartik/git/DeepLense/Transformers_Classification_DeepLense_Kartik_Sachdev/logger/2023-07-23-13-30-24" - finetune_model_path = "/home/kartik/git/DeepLense/Transformers_Classification_DeepLense_Kartik_Sachdev/logger/2023-07-23-13-30-24/checkpoint/Resnet_finetune_Model_II.pt" + parser = argparse.ArgumentParser() + parser.add_argument('--log_dir', type=str, required=True, + help='Path to log directory') + parser.add_argument('--finetune_model_path', type=str, required=True, + help='Path to finetuned model checkpoint') + args = parser.parse_args() + log_dir = args.log_dir + finetune_model_path = args.finetune_model_path batch_size = 512 num_workers = 8