Skip to content
This repository was archived by the owner on Oct 25, 2021. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[tool.nitpick]
style = "https://raw.githubusercontent.com/catalyst-team/codestyle/master/styles/nitpick-style-catalyst.toml"

[tool.black]
line-length = 79
2 changes: 1 addition & 1 deletion requirements/requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
catalyst-codestyle
catalyst-codestyle==20.06.1
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
catalyst[cv]==20.5
catalyst[cv]==20.6
jinja2
safitty>=1.2.3
15 changes: 7 additions & 8 deletions scripts/predictions2labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


def build_args(parser):
"""Constructs the command-line arguments for ``predictions2labels``."""
parser.add_argument("--in-npy", type=Path, required=True)
parser.add_argument("--in-csv-infer", type=Path, required=True)
parser.add_argument("--in-csv-train", type=Path, required=True)
Expand All @@ -19,6 +20,7 @@ def build_args(parser):


def parse_args():
"""Parses the command line arguments for the main method."""
parser = argparse.ArgumentParser()
build_args(parser)
args = parser.parse_args()
Expand All @@ -31,11 +33,8 @@ def softmax(x):
return e_x / e_x.sum(axis=1, keepdims=True)


def path2name(x):
return Path(x).name


def main(args, _=None):
"""Run the ``predictions2labels`` script."""
logits = np.load(args.in_npy, mmap_mode="r")
probs = softmax(logits)
confidence = np.max(probs, axis=1)
Expand All @@ -61,7 +60,7 @@ def main(args, _=None):
"Pseudo Labeling done. Nothing more to label."
)

counter_ = 0
counter = 0
for _, row in df_infer.iterrows():
if row["confidence"] < args.threshold:
continue
Expand All @@ -72,10 +71,10 @@ def main(args, _=None):
filepath_dst.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(filepath_src, filepath_dst)

counter_ += 1
print(f"Predicted: {counter_} ({100 * counter_ / len(df_infer):2.2f}%)")
counter += 1
print(f"Predicted: {counter} ({100 * counter / len(df_infer):2.2f}%)")

if counter_ == 0:
if counter == 0:
raise NotImplementedError(
"Pseudo Labeling done. Nothing more to label."
)
Expand Down
14 changes: 8 additions & 6 deletions scripts/prepare_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#!/usr/bin/env python
# usage:
# python scripts/prepare_config.py \
# --in-template=./configs/templates/focal.yml \
Expand All @@ -18,6 +17,7 @@


def build_args(parser):
"""Constructs the command-line arguments for ``prepare_config``."""
parser.add_argument("--in-template", type=Path, required=True)
parser.add_argument("--out-config", type=Path, required=True)
parser.add_argument("--expdir", type=Path, required=True)
Expand All @@ -33,6 +33,7 @@ def build_args(parser):


def parse_args():
"""Parses the command line arguments for the main method."""
parser = argparse.ArgumentParser()
build_args(parser)
args = parser.parse_args()
Expand All @@ -50,15 +51,15 @@ def render_config(
balance_strategy: str,
criterion: str,
):
_template_path = in_template.absolute().parent

_env = Environment(
loader=FileSystemLoader([str(_template_path)]),
"""Render catalyst config with specified parameters."""
template_path = str(in_template.absolute().parent)
env = Environment(
loader=FileSystemLoader([template_path]),
trim_blocks=True,
lstrip_blocks=True,
)

template = _env.get_template(in_template.name)
template = env.get_template(in_template.name)

tag2class = safitty.load(dataset_path / "tag2class.json")
num_classes = len(tag2class)
Expand All @@ -84,6 +85,7 @@ def render_config(


def main(args, _=None):
"""Run the ``prepare_config`` script."""
args = args.__dict__
render_config(**args)

Expand Down
41 changes: 36 additions & 5 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,15 +1,46 @@
[flake8]
exclude = .git,__pycache__,docs/source/conf.py,build,dist
ignore = C812,C813,C814,C815,C816,D100,D104,D200,D204,D205,D301,D400,D401,D402,D412,D413,DAR003,DAR103,DAR203,E203,E731,E800,E1101,N812,P101,RST201,RST203,RST210,RST213,RST301,RST304,S,W0221,W503,W504,W605,WPS0,WPS100,WPS101,WPS110,WPS111,WPS112,WPS125,WPS2,WPS300,WPS301,WPS305,WPS306,WPS309,WPS317,WPS323,WPS326,WPS331,WPS333,WPS335,WPS336,WPS337,WPS338,WPS342,WPS347,WPS348,WPS349,WPS350,WPS352,WPS402,WPS404,WPS405,WPS408,WPS410,WPS412,WPS414,WPS420,WPS421,WPS425,WPS426,WPS429,WPS430,WPS431,WPS432,WPS433,WPS434,WPS435,WPS440,WPS441,WPS5,WPS6
extend-ignore = WPS120
max-line-length = 79
max-doc-length = 79
inline-quotes = double
multiline-quotes = double
docstring-quotes = double
convention = google

[isort]
force_to_top = typing
skip_glob = **/__init__.py
line_length = 79
multi_line_output = 3
force_grid_wrap = 0
default_section = THIRDPARTY
no_lines_before = STDLIB,LOCALFOLDER
order_by_type = false
lines_between_types = 0
combine_as_imports = true
include_trailing_comma = true
use_parentheses = true
# TODO: check if works fine
filter_files = **/__init__.py
force_sort_within_sections = true
# TODO: check if compatable with black
reverse_relative = true

# catalyst imports order:
# - typing
# - core python libs
# - python libs (known_third_party)
# - dl libs (known_dl)
# - catalyst imports
known_third_party = jinja2,numpy,pandas,safitty
sections = STDLIB,THIRDPARTY,DL,FIRSTPARTY,LOCALFOLDER
known_third_party = imageio,jinja2,numpy,pandas,safitty,skimage
known_dl = albumentations,torch,torchvision
known_first_party = catalyst,src
sections=STDLIB,THIRDPARTY,DL,FIRSTPARTY,LOCALFOLDER

[flake8]
ignore = D100,D101,D102,D103,D107
application-import-names = catalyst,src

[darglint]
docstring_style = google
strictness = short
ignore_regex = ^_(.*)
13 changes: 12 additions & 1 deletion src/criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,16 @@


class EmbeddingsNormLoss(nn.Module):
def forward(self, embeddings, *args):
"""Embeddings loss."""

def forward(self, embeddings, *args) -> torch.Tensor:
"""Forward propagation method for the :class:`EmbeddingsNormLoss` loss.

Args:
embeddings (torch.Tensor): bash of embeddings
*args: other args

Returns:
torch.Tensor: loss
"""
return torch.mean(torch.norm(embeddings, dim=1))
66 changes: 49 additions & 17 deletions src/experiment.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional
import collections

import numpy as np
Expand All @@ -18,37 +19,68 @@


class Experiment(ConfigExperiment):
"""Classification Experiment."""

def _postprocess_model_for_stage(self, stage: str, model: nn.Module):
model_ = model
if isinstance(model, torch.nn.DataParallel):
model_ = model_.module
model = (
model.module if isinstance(model, torch.nn.DataParallel) else model
)

if stage in ["debug", "stage1"]:
for param in model_.encoder_net.parameters():
for param in model.encoder_net.parameters():
param.requires_grad = False
elif stage == "stage2":
for param in model_.encoder_net.parameters():
for param in model.encoder_net.parameters():
param.requires_grad = True
return model_
return model

def get_datasets(
self,
stage: str,
datapath: str = None,
in_csv: str = None,
in_csv_train: str = None,
in_csv_valid: str = None,
in_csv_infer: str = None,
train_folds: str = None,
valid_folds: str = None,
tag2class: str = None,
class_column: str = None,
tag_column: str = None,
datapath: Optional[str] = None,
in_csv: Optional[str] = None,
in_csv_train: Optional[str] = None,
in_csv_valid: Optional[str] = None,
in_csv_infer: Optional[str] = None,
train_folds: Optional[str] = None,
valid_folds: Optional[str] = None,
tag2class: Optional[str] = None,
class_column: Optional[str] = None,
tag_column: Optional[str] = None,
folds_seed: int = 42,
n_folds: int = 5,
one_hot_classes: int = None,
one_hot_classes: Optional[int] = None,
balance_strategy: str = "upsampling",
):
"""Returns the datasets for a given stage and epoch.

Args:
stage (str): stage name of interest,
like "pretrain" / "train" / "finetune" / etc
datapath (str): path to folder with images and masks
in_csv (Optional[str]): path to CSV annotation file. Look at
:func:`catalyst.contrib.utils.pandas.read_csv_data` for details
in_csv_train (Optional[str]): path to CSV annotaion file
with train samples.
in_csv_valid (Optional[str]): path to CSV annotaion file
with the validation samples
in_csv_infer (Optional[str]): path to CSV annotaion file
with test samples
train_folds (Optional[str]): folds to use for training
valid_folds (Optional[str]): folds to use for validation
tag2class (Optional[str]): path to JSON file with mapping from
class name (tag) to index
class_column (Optional[str]): name of class index column in the CSV
tag_column (Optional[str]): name of class name in the CSV file
folds_seed (int): random seed to use
n_folds (int): number of folds on which data will be split
one_hot_classes (int): number of one-hot classes
balance_strategy (str): strategy to handle imbalanced data,
look at :class:`catalyst.data.BalanceClassSampler` for details

Returns:
Dict: dictionary with datasets for current stage.
"""
datasets = collections.OrderedDict()
tag2class = safitty.load(tag2class) if tag2class is not None else None

Expand Down
51 changes: 48 additions & 3 deletions src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,64 @@


class MultiHeadNet(nn.Module):
"""Multi0head network."""

def __init__(
self,
encoder_net: nn.Module,
head_nets: nn.ModuleList,
embedding_net: nn.Module = None,
):
"""Constructor method for the :class:`MultiHeadNet` class.

Args:
encoder_net (nn.Module): encoder network (resnset)
common for all heads
head_nets (nn.ModuleList): list of network heads
embedding_net (nn.Module): network responsible for embeddings
extraction, common for all heads
"""
super().__init__()
self.encoder_net = encoder_net
self.embedding_net = embedding_net or (lambda *args: args)
self.head_nets = head_nets

def forward_embedding(self, x: torch.Tensor):
def forward_embedding(self, x: torch.Tensor) -> torch.Tensor:
"""Forward propagation method for the embedding network.

Args:
x (torch.Tensor): input batch

Returns:
torch.Tensor: batch of embeddings
"""
features = self.encoder_net(x)
embeddings = self.embedding_net(features)
return embeddings

def forward_class(self, x: torch.Tensor):
def forward_class(self, x: torch.Tensor) -> torch.Tensor:
"""Forward propagation method for the encoder and heads.

Args:
x (torch.Tensor): input embeddings

Returns:
torch.Tensor: batch of logits
"""
features = self.encoder_net(x)
embeddings = self.embedding_net(features)
logits = self.head_nets["logits"](embeddings)
return logits

def forward(self, x: torch.Tensor):
def forward(self, x: torch.Tensor) -> dict:
"""Forward propagation method for the network.

Args:
x (torch.Tensor): batch of the images

Returns:
dict: dictionary of predictions
"""
features = self.encoder_net(x)
embeddings = self.embedding_net(features)
result = {"features": features, "embeddings": embeddings}
Expand All @@ -50,7 +85,17 @@ def get_from_params(
embedding_net_params: Dict = None,
heads_params: Dict = None,
) -> "MultiHeadNet":
"""Create neural network from config.

Args:
image_size (int): size of the first conv layer
encoder_params (dict): `ResnetEncoder` constructor params
embedding_net_params (dict): `SequentialNet` constructor params
heads_params (dict): heads params

Returns:
MultiHeadNet: network instance
"""
encoder_params_ = deepcopy(encoder_params)
embedding_net_params_ = deepcopy(embedding_net_params)
heads_params_ = deepcopy(heads_params)
Expand Down