Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,26 @@
"OTHER": 2
}

GLC_POS_LABEL2ID = {
"PRON": 0,
"VERB": 1,
"ADJ": 2,
"ADP": 3,
"AUX": 4,
"CCONJ": 5,
"DET": 6,
"PUNCT": 7,
"SYM": 8,
"NOUN": 9,
"PART": 10,
"PROPN": 11,
"ADV": 12,
"SCONJ": 13,
"INTJ": 14,
"NUM": 15,
"X": 16
}

K_CROSSFOLD_VALIDATION_SPLITS = 10

# PROJECT CONFIGURATION
Expand Down
38 changes: 32 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import torch
import wandb

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
Expand Down Expand Up @@ -31,7 +32,8 @@
NUM_WORKERS,
AVAIL_GPUS,
GLC_NER_LABEL2ID,
GLC_LID_LABEL2ID
GLC_LID_LABEL2ID,
GLC_POS_LABEL2ID
)

def test_dm(args):
Expand All @@ -49,6 +51,19 @@ def test_dm(args):
dm.setup()
print(next(iter(dm.train_dataloader())))

sweep_configuration = {
'method': 'random',
'name': 'sweep',
'metric': {'goal': 'maximize', 'name': 'val_acc'},
'parameters':
{
'batch_size': {'values': [16, 32, 64]},
'epochs': {'values': [5, 10, 15]},
'lr': {'max': 0.1, 'min': 0.01}
}
}

sweep_id = wandb.sweep(sweep=sweep_configuration, project='my-first-sweep')

def main(args):

Expand Down Expand Up @@ -203,12 +218,14 @@ def kcrossfold(args):
trainer.fit(model, datamodule=dm)
trainer.test(model, datamodule=dm)



def multidataset(args):
seed_everything(42)

# important to keep the order of label2ids, tasknames and tasks same.
label2ids = [ GLC_NER_LABEL2ID, GLC_LID_LABEL2ID ]
tasknames = ['NER', 'LID']
label2ids = [ GLC_NER_LABEL2ID, GLC_LID_LABEL2ID, GLC_POS_LABEL2ID ]
tasknames = ['NER', 'LID', 'POS']
tasks = [
Task(GLC_NER_LABEL2ID,
'NER',
Expand All @@ -217,7 +234,11 @@ def multidataset(args):
Task(GLC_LID_LABEL2ID,
'LID',
'data/GLUECoS/LID/Romanized/train.txt',
'data/GLUECoS/LID/Romanized/validation.txt')
'data/GLUECoS/LID/Romanized/validation.txt'),
Task(GLC_POS_LABEL2ID,
'POS',
'data/GLUECoS/POS_EN_HI_UD/Romanized/train.txt',
'data/GLUECoS/POS_EN_HI_UD/Romanized/validation.txt')
]

isFreezed = args.freeze
Expand All @@ -244,7 +265,10 @@ def multidataset(args):
args.base_model,
args.padding,
args.lr,
args.weight_decay
args.weight_decay,
ner_learning_rate = 3e-4,
lid_learning_rate = 3e-5,
pos_learning_rate = 3e-4
)

logger = WandbLogger(
Expand Down Expand Up @@ -307,4 +331,6 @@ def multidataset(args):

# main(args)
# kcrossfold(args)
multidataset(args)
multidataset(args)

wandb.agent(sweep_id, function=main, count=4)
11 changes: 7 additions & 4 deletions src/datamodules/gluecos/GLUECoSSequenceLabelDataModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,13 @@ def prepare_data(self):
def setup(self, stage):
self.train_data = []
for task in self.tasks:
self.train_data.append(self._read_gluecos_(task.train_path))
self.train_data.append(self._read_gluecos_(task.train_path, task.name=='POS'))
self.train_data[-1] = self._mtokenize_(self.train_data[-1], task.label2id)
self.training_dataset = TaskDataset(self.train_data)

self.val_data = []
for task in self.tasks:
self.val_data.append(self._read_gluecos_(task.val_path))
self.val_data.append(self._read_gluecos_(task.val_path, task.name=='POS'))
self.val_data[-1] = self._mtokenize_(self.val_data[-1], task.label2id)
self.validation_dataset = TaskDataset(self.val_data)

Expand All @@ -108,7 +108,7 @@ def _mtokenize_(self, datapoints, word2id):
word2id
]

def _read_gluecos_(self, file_path):
def _read_gluecos_(self, file_path, is_pos=False):
toret = []
with open(file_path, encoding='utf-8') as f:
datapoint = [[], []]
Expand All @@ -123,7 +123,10 @@ def _read_gluecos_(self, file_path):
try:
split_line = line.split('\t')
datapoint[0].append(split_line[0])
datapoint[1].append(split_line[1])
if is_pos:
datapoint[1].append(split_line[2])
else:
datapoint[1].append(split_line[1])
except:
datapoint = [[], []] #just reset the datapoint.

Expand Down
149 changes: 106 additions & 43 deletions src/models/multidataset/sequencemultitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,21 @@
import pytorch_lightning as pl
from torchcrf import CRF
from torchmetrics.functional import accuracy, precision, recall, f1_score
from torch.optim import AdamW

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

from config import (
LABEL2ID,
LEARNING_RATE,
LID2ID,
WARM_RESTARTS,
WEIGHT_DECAY,
DROPOUT_RATE,
MAX_SEQUENCE_LENGTH,
PADDING
)

class TaskHead(nn.Module):
def __init__(self, n_labels):
super(TaskHead, self).__init__()
Expand Down Expand Up @@ -38,7 +50,15 @@ def __init__(
model_name_or_path,
padding,
learning_rate,
weight_decay,
weight_decay: float = WEIGHT_DECAY,
ner_learning_rate: float = LEARNING_RATE,
lid_learning_rate: float = LEARNING_RATE,
pos_learning_rate: float = LEARNING_RATE,
warm_restart_epochs: int = WARM_RESTARTS,
ner_wd: float = WEIGHT_DECAY,
lid_wd: float = WEIGHT_DECAY,
pos_wd: float = WEIGHT_DECAY,
dropout_rate: float = DROPOUT_RATE
) -> None:
super(SequenceMultiTaskModel, self).__init__()

Expand Down Expand Up @@ -90,6 +110,13 @@ def __init__(
TaskHead( len(self.label2ids[i]) )
)

if task_names[i] == 'NER':
self.ner_net = self.task_heads[-1]
elif task_names[i] == 'LID':
self.lid_net = self.task_heads[-1]
else:
self.pos_net = self.task_heads[-1]

self.log_vars.append(nn.Parameter(torch.zeros(1)))

self.add_module(f"Task {task_names[task_id]} TaskHead", self.task_heads[task_id])
Expand Down Expand Up @@ -169,56 +196,92 @@ def _compute_metrics(self, preds: torch.Tensor, labels: torch.Tensor, mode: str,

metrics = {}
# metrics[f"acc/{mode}"] = accuracy(preds, labels, num_classes=len(self.label2id) + 1, ignore_index=self.special_tag_id)
metrics[f"prec/{mode}"] = precision(preds, labels, num_classes=len(self.label2ids[task_id]) + 1, ignore_index=self.special_tag_ids[task_id], average="macro")
metrics[f"rec/{mode}"] = recall(preds, labels, num_classes=len(self.label2ids[task_id]) + 1, ignore_index=self.special_tag_ids[task_id], average="macro")
metrics[f"f1/{mode}"] = f1_score(preds, labels, num_classes=len(self.label2ids[task_id]) + 1, ignore_index=self.special_tag_ids[task_id], average="macro")
metrics[f"prec/{mode}"] = precision(preds, labels, num_classes=len(self.label2ids[task_id]) + 1, ignore_index=self.special_tag_ids[task_id], average="macro", task='multiclass')
metrics[f"rec/{mode}"] = recall(preds, labels, num_classes=len(self.label2ids[task_id]) + 1, ignore_index=self.special_tag_ids[task_id], average="macro", task='multiclass')
metrics[f"f1/{mode}"] = f1_score(preds, labels, num_classes=len(self.label2ids[task_id]) + 1, ignore_index=self.special_tag_ids[task_id], average="macro", task='multiclass')

return metrics

def configure_optimizers(self):

parameters = [
{
'params': self.baseModel.parameters()
},
{
'params': self.bi_lstm.parameters(),
'lr': 1e-5
},
{
'params': self.linear.parameters(),
'lr': 1e-6
}
]
# Same LR for shared params and different LR for different tasks params
# Same weight decay for shared params and different weight decay for different tasks params
# TODO: Experiment with Different LRs

for log_var in self.log_vars:
parameters.append(
no_decay = ["bias", "LayerNorm.weight"]

# * The params for which there is no lr or weight_decay key will use global lr and weight_decay
# * [ i.e. lr and weight_decay args in AdamW ]
optimizer_grouped_parameters = [
{
'params': [
p
for n, p in self.bi_lstm.named_parameters()
if not any(nd in n for nd in no_decay)
],

},
{
'params': [
p
for n, p in self.linear.named_parameters()
if not any(nd in n for nd in no_decay)
],
},
{
'params': [
p
for n, p in self.ner_net.named_parameters()
if not any(nd in n for nd in no_decay)
],
'lr': self.hparams.ner_learning_rate,
'weight_decay': self.hparams.ner_wd
},
{
'params': [
p
for n, p in self.lid_net.named_parameters()
if not any(nd in n for nd in no_decay)
],
'lr': self.hparams.lid_learning_rate,
'weight_decay': self.hparams.lid_wd
},
{
'params': log_var
'params': [
p
for n, p in self.pos_net.named_parameters()
if not any(nd in n for nd in no_decay)
],
'lr': self.hparams.pos_learning_rate,
'weight_decay': self.hparams.pos_wd
},
{
'params': [
p
for n, p in self.named_parameters()
if any(nd in n for nd in no_decay)
],
'weight_decay': 0.0
}
)
]

optimizer_grouped_parameters.append({
'params': [
p
for n, p in self.baseModel.named_parameters()
if not any(nd in n for nd in no_decay)
],
})

for name, module in self.named_modules():
if 'Task NER TaskHead' == name:
parameters.append(
{
'params': module.parameters(),
'lr': 2e-6
}
)

elif 'Task LID TaskHead' == name:
parameters.append(
{
'params': module.parameters(),
'lr': 5e-8
}
)

optimizer = torch.optim.AdamW(
params=parameters,
lr=self.learning_rate,
weight_decay=self.weight_decay
optimizer = AdamW(
params=optimizer_grouped_parameters,
lr=self.hparams.learning_rate,
weight_decay=self.hparams.weight_decay
)

lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer=optimizer,
T_0=self.hparams.warm_restart_epochs, # First restart after T_0 epochs [50 Initial value, 20 ]
)

return optimizer
return [optimizer], [lr_scheduler]