Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions prompter/config/lucchi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
prompter = dict(
backbone=dict(
model_name='convnext_small',
pretrained=True,
num_classes=0,
global_pool=''
),
neck=dict(
in_channels=[96, 192, 384, 768],
out_channels=256,
num_outs=4,
add_extra_convs='on_input',
),
dropout=0.1,
space=16,
hidden_dim=256
)

data = dict(
name='lucchi',
num_classes=1,
batch_size_per_gpu=8,
num_workers=8,
train=dict(transform=[
dict(type='RandomCrop', height=768, width=1024, p=1),
dict(type='RandomGridShuffle', grid=(4, 4), p=0.5),
dict(type='ColorJitter', brightness=0.25, contrast=0.25, saturation=0.1, hue=0.05, p=0.2),
#dict(type='RandomRotate90', p=0.5),
dict(type='Downscale', scale_max=0.5, scale_min=0.5, p=0.15),
dict(type='Blur', blur_limit=10, p=0.2),
dict(type='GaussNoise', var_limit=50, p=0.25),
dict(type='ZoomBlur', p=0.1, max_factor=1.05),
dict(type='HorizontalFlip', p=0.5),
dict(type='VerticalFlip', p=0.5),
dict(type='ShiftScaleRotate', shift_limit=0.3, scale_limit=0.1, rotate_limit=0, border_mode=0, value=0, p=0.5),
dict(type='PadIfNeeded', min_height=None, min_width=None, pad_height_divisor=prompter["space"],
pad_width_divisor=prompter["space"], position="top_left", p=1),
dict(type='Normalize'),
]),
val=dict(transform=[
dict(type='PadIfNeeded', min_height=None, min_width=None, pad_height_divisor=prompter["space"],
pad_width_divisor=prompter["space"], position="top_left", p=1),
dict(type='Normalize'),
]),
test=dict(transform=[
dict(type='PadIfNeeded', min_height=None, min_width=None, pad_height_divisor=prompter["space"],
pad_width_divisor=prompter["space"], position="top_left", p=1),
dict(type='Normalize'),
]),
)

optimizer = dict(
type='Adam',
lr=1e-4,
weight_decay=1e-4
)

criterion = dict(
matcher=dict(type='HungarianMatcher', dis_type='l2', set_cost_point=0.1, set_cost_class=1),
eos_coef=0.25,
reg_loss_coef=5e-3,
cls_loss_coef=1.0,
mask_loss_coef=1.0
)

test = dict(nms_thr=12, match_dis=12, filtering=True)
27 changes: 24 additions & 3 deletions prompter/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler

from PIL import Image
import re
import cv2 as cv

def read_from_json(json_path):
with open(json_path, 'r', encoding='utf-8') as f:
Expand All @@ -29,7 +31,6 @@ def __init__(
self.data = anno_json
self.img_paths = list(anno_json.keys())
self.keys = ['image', 'keypoints'] + [f'keypoints{i}' for i in range(1, cfg.data.num_classes)] + ['mask']

self.phase = mode
self.dataset = cfg.data.name

Expand All @@ -53,7 +54,23 @@ def __getitem__(self, index: int):

img_path = self.img_paths[index]

values = ([io.imread(f'../segmentor/{img_path}')[..., :3]] +
# Define the regex pattern to match the 'mask' part
pattern = re.compile(r'(/mask/)')
# Replace 'mask' with 'raw' in the path
raw_path = re.sub(pattern, r'/raw/', img_path)

if raw_path[-3:] == 'mat':
raw_path = raw_path[:-3] + 'png'

np_img = []
if self.dataset == 'lucchi' or self.dataset == 'Lucchipp':
np_img = io.imread(f'../segmentor/{raw_path}')
np_img = [cv.merge((np_img, np_img, np_img))]
else:
np_img = [io.imread(f'../segmentor/{raw_path}')[..., :3]]


values = (np_img +
[np.array(point).reshape(-1, 2) for point in self.data[img_path]])

if self.dataset == 'kumar':
Expand All @@ -64,6 +81,8 @@ def __getitem__(self, index: int):
mask = np.load(f'../segmentor/{mask_path}')
elif self.dataset == 'cpm17':
mask = scipy.io.loadmat(f'../segmentor/{img_path[:-4].replace("Images", "Labels")}.mat')['inst_map']
elif self.dataset == 'lucchi' or self.dataset == 'Lucchipp':
mask = np.asarray(Image.open(f'../segmentor/{img_path}'))
else:
mask = np.load(f'../segmentor/{img_path.replace("Images", "Masks")[:-4]}.npy', allow_pickle=True)[()][
'inst_map']
Expand All @@ -73,6 +92,8 @@ def __getitem__(self, index: int):

ori_shape = values[0].shape[:2]
sample = dict(zip(self.keys, values))
sample['keypoints'] = [tuple(k) for k in sample['keypoints']]

res = self.transform(**sample)
res = list(res.values())

Expand Down
4 changes: 3 additions & 1 deletion prompter/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from engine import train_one_epoch, evaluate
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import os


os.environ["WANDB__SERVICE_WAIT"] = "300"
def parse_args():
parser = argparse.ArgumentParser('Cell prompter')
parser.add_argument('--config', default='pannuke123.py', type=str)
Expand Down Expand Up @@ -235,7 +237,7 @@ def main():
checkpoint,
f"checkpoint/{args.output_dir}/best.pth",
)
except NameError:
except AttributeError:
pass

if is_main_process() and args.use_wandb:
Expand Down
14 changes: 13 additions & 1 deletion prompter/predict_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from main import parse_args
from mmengine.config import Config
import re
import cv2 as cv

args = parse_args()
cfg = Config.fromfile(f'config/{args.config}')
Expand Down Expand Up @@ -38,7 +40,17 @@

def process_files(files):
for file in tqdm(files):
img = io.imread(f'../segmentor/{file}')[..., :3]
pattern = re.compile(r'(/mask/)')
raw_path = re.sub(pattern, r'/raw/', file)

if raw_path[-3:] == 'mat':
raw_path = raw_path[:-3] + 'png'

if dataset == 'lucchi' or dataset == 'Lucchipp':
img = io.imread(f'../segmentor/{raw_path}')
img = cv.merge((img, img, img))
else:
img = io.imread(f'../segmentor/{raw_path}')[..., :3]

image = transform(image=img)['image'].unsqueeze(0).to(device)

Expand Down
Binary file removed prompter/timm/__pycache__/__init__.cpython-37.pyc
Binary file not shown.
Binary file removed prompter/timm/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Binary file removed prompter/timm/__pycache__/version.cpython-37.pyc
Binary file not shown.
Binary file removed prompter/timm/__pycache__/version.cpython-39.pyc
Binary file not shown.
Binary file removed prompter/timm/data/__pycache__/__init__.cpython-37.pyc
Binary file not shown.
Binary file removed prompter/timm/data/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed prompter/timm/data/__pycache__/config.cpython-37.pyc
Binary file not shown.
Binary file removed prompter/timm/data/__pycache__/config.cpython-39.pyc
Binary file not shown.
Binary file removed prompter/timm/data/__pycache__/constants.cpython-37.pyc
Binary file not shown.
Binary file removed prompter/timm/data/__pycache__/constants.cpython-39.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed prompter/timm/data/__pycache__/loader.cpython-37.pyc
Binary file not shown.
Binary file removed prompter/timm/data/__pycache__/mixup.cpython-37.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed prompter/timm/models/__pycache__/beit.cpython-37.pyc
Binary file not shown.
Binary file removed prompter/timm/models/__pycache__/beit.cpython-39.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed prompter/timm/models/__pycache__/cait.cpython-37.pyc
Binary file not shown.
Binary file removed prompter/timm/models/__pycache__/coat.cpython-37.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed prompter/timm/models/__pycache__/deit.cpython-37.pyc
Binary file not shown.
Binary file not shown.
Binary file removed prompter/timm/models/__pycache__/dla.cpython-37.pyc
Binary file not shown.
Binary file removed prompter/timm/models/__pycache__/dpn.cpython-37.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed prompter/timm/models/__pycache__/hub.cpython-37.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed prompter/timm/models/__pycache__/nest.cpython-37.pyc
Binary file not shown.
Binary file not shown.
Binary file removed prompter/timm/models/__pycache__/pit.cpython-37.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed prompter/timm/models/__pycache__/tnt.cpython-37.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed prompter/timm/models/__pycache__/vgg.cpython-37.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed prompter/timm/models/__pycache__/volo.cpython-37.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed prompter/timm/utils/__pycache__/agc.cpython-37.pyc
Binary file not shown.
Binary file removed prompter/timm/utils/__pycache__/agc.cpython-39.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed prompter/timm/utils/__pycache__/cuda.cpython-37.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed prompter/timm/utils/__pycache__/jit.cpython-37.pyc
Binary file not shown.
Binary file removed prompter/timm/utils/__pycache__/log.cpython-37.pyc
Binary file not shown.
Binary file not shown.
Binary file removed prompter/timm/utils/__pycache__/misc.cpython-37.pyc
Binary file not shown.
Binary file removed prompter/timm/utils/__pycache__/model.cpython-37.pyc
Binary file not shown.
Binary file removed prompter/timm/utils/__pycache__/model.cpython-39.pyc
Binary file not shown.
Binary file not shown.
Binary file removed prompter/timm/utils/__pycache__/random.cpython-37.pyc
Binary file not shown.
Binary file not shown.
59 changes: 59 additions & 0 deletions segmentor/config/lucchi_b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
segmentor = dict(
type='PromptNucSeg-B',
img_size=512,
patch_size=16,
multimask=False
)

input_shape = segmentor['img_size']
data = dict(
name='lucchi',
num_classes=1,
num_mask_per_img=25,
batch_size_per_gpu=8,
num_workers=8,
num_neg_prompt=0,
train=dict(transform=[
dict(type='RandomCrop', height=512, width=512, p=1),
dict(type='RandomRotate90', p=0.5),
dict(type='HorizontalFlip', p=0.5),
dict(type='VerticalFlip', p=0.5),
dict(type='Downscale', scale_max=0.5, scale_min=0.5, p=0.15),
dict(type='Blur', blur_limit=10, p=0.2),
dict(type='GaussNoise', var_limit=50, p=0.25),
dict(type='ColorJitter', brightness=0.25, contrast=0.25, saturation=0.1, hue=0.05, p=0.2),
dict(type='Superpixels', p=0.1, p_replace=0.1, n_segments=200, max_size=int(input_shape / 2)),
dict(type='ZoomBlur', p=0.1, max_factor=1.05),
dict(type='RandomSizedCrop', min_max_height=(int(input_shape / 2), input_shape),
height=input_shape,
width=input_shape,
p=0.1),
dict(type='ElasticTransform', p=0.2, sigma=25, alpha=0.5, alpha_affine=15),
dict(type='Normalize')
]),
val=dict(transform=[
dict(type='Normalize'),
]),
test=dict(transform=[
dict(type='Normalize'),
]),
post=dict(iou_threshold=0.5)
)

optimizer = dict(
type='Adam',
lr=1e-4,
weight_decay=1e-4
)

scheduler = dict(
type='MultiStepLR',
milestones=[300],
gamma=0.1
)

criterion = dict(
loss_focal=20,
loss_dice=1,
loss_iou=1
)
59 changes: 59 additions & 0 deletions segmentor/config/lucchi_l.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
segmentor = dict(
type='PromptNucSeg-L',
img_size=256,
patch_size=16,
multimask=False
)

input_shape = segmentor['img_size']
data = dict(
name='lucchi',
num_classes=1,
num_mask_per_img=25,
batch_size_per_gpu=8,
num_workers=8,
num_neg_prompt=0,
train=dict(transform=[
dict(type='RandomCrop', height=256, width=256, p=1),
dict(type='RandomRotate90', p=0.5),
dict(type='HorizontalFlip', p=0.5),
dict(type='VerticalFlip', p=0.5),
dict(type='Downscale', scale_max=0.5, scale_min=0.5, p=0.15),
dict(type='Blur', blur_limit=10, p=0.2),
dict(type='GaussNoise', var_limit=50, p=0.25),
dict(type='ColorJitter', brightness=0.25, contrast=0.25, saturation=0.1, hue=0.05, p=0.2),
dict(type='Superpixels', p=0.1, p_replace=0.1, n_segments=200, max_size=int(input_shape / 2)),
dict(type='ZoomBlur', p=0.1, max_factor=1.05),
dict(type='RandomSizedCrop', min_max_height=(int(input_shape / 2), input_shape),
height=input_shape,
width=input_shape,
p=0.1),
dict(type='ElasticTransform', p=0.2, sigma=25, alpha=0.5, alpha_affine=15),
dict(type='Normalize')
]),
val=dict(transform=[
dict(type='Normalize'),
]),
test=dict(transform=[
dict(type='Normalize'),
]),
post=dict(iou_threshold=0.5)
)

optimizer = dict(
type='Adam',
lr=1e-4,
weight_decay=1e-4
)

scheduler = dict(
type='MultiStepLR',
milestones=[300],
gamma=0.1
)

criterion = dict(
loss_focal=20,
loss_dice=1,
loss_iou=1
)
19 changes: 17 additions & 2 deletions segmentor/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from skimage import io
from torch.utils.data import Dataset
from albumentations.pytorch import ToTensorV2
import re
import cv2 as cv
from PIL import Image


class DataFolder(Dataset):
Expand Down Expand Up @@ -56,10 +59,19 @@ def __getitem__(self, idx):
mask_path = '/'.join(sub_paths)
elif self.dataset == 'cpm17':
mask_path = f'{img_path[:-4].replace("Images", "Labels")}.mat'
elif self.dataset == 'lucchi' or self.dataset == 'Lucchipp':
mask_path = img_path
else:
mask_path = f'{img_path[:-4].replace("Images", "Masks")}.npy'

img, mask = io.imread(img_path)[..., :3], load_maskfile(mask_path)
if self.dataset == 'lucchi' or self.dataset == 'Lucchipp':
pattern = re.compile(r'(/mask/)')
raw_path = re.sub(pattern, r'/raw/', img_path)
img = io.imread(raw_path)
img = cv.merge((img, img, img))
mask = load_maskfile(mask_path)
else:
img, mask = io.imread(img_path)[..., :3], load_maskfile(mask_path)

if self.mode != 'train':
res = self.transform(image=img)
Expand Down Expand Up @@ -141,7 +153,7 @@ def __getitem__(self, idx):
prompt_points = torch.empty(0, (self.num_neg_prompt + 1), 2)
prompt_labels = torch.empty(0, (self.num_neg_prompt + 1))
all_points = torch.empty(0, 2)
inst_map = torch.empty(0, 256, 256)
inst_map = torch.empty(0, 512, 512)
cell_types = torch.empty(0)

return img, inst_map.long(), prompt_points, prompt_labels, cell_types, all_points
Expand All @@ -157,6 +169,9 @@ def load_maskfile(mask_path: str):
inst_map = scipy.io.loadmat(mask_path)['inst_map']
type_map = (inst_map.copy() > 0).astype(float)

elif 'lucchi' or 'Lucchipp' in mask_path:
inst_map = np.asarray(Image.open(mask_path))
type_map = (inst_map.copy() > 0).astype(float)
else:
inst_map = np.load(mask_path)
type_map = (inst_map.copy() > 0).astype(float)
Expand Down
Loading