diff --git a/encexp/__init__.py b/encexp/__init__.py index 3f3f4b6..7bfdcf6 100644 --- a/encexp/__init__.py +++ b/encexp/__init__.py @@ -17,4 +17,4 @@ if not '-m' in sys.argv: from encexp.text_repr import EncExpT, SeqTM, TextModel -__version__ = "0.1.4" +__version__ = "0.1.5" diff --git a/encexp/build_encexp.py b/encexp/build_encexp.py index 4342ff2..76b63a0 100644 --- a/encexp/build_encexp.py +++ b/encexp/build_encexp.py @@ -26,7 +26,7 @@ from microtc.utils import tweet_iterator, Counter import encexp from encexp.text_repr import SeqTM, EncExpT -from encexp.utils import progress_bar +from encexp.utils import progress_bar, uniform_sample from encexp.download import download @@ -118,7 +118,7 @@ class Train: """Train""" text_model: SeqTM=None min_pos: int=512 - max_pos: int=int(2**15) + max_pos: int=int(2**14) min_neg: int=int(2**14) filename: str=None use_tqdm: bool=True @@ -159,17 +159,38 @@ def labels(self): """Labels""" if hasattr(self, '_labels'): return self._labels - cnt = Counter() + labels_freq = Counter() with open(self.filename, encoding='utf-8') as fpt: for line in fpt: line = line.strip() labels, text = line.split('\t') labels = labels.split() - cnt.update(labels) - labels = sorted([k for k, v in cnt.items() if v >= self.min_pos]) + labels_freq.update(labels) + labels = sorted([k for k, v in labels_freq.items() if v >= self.min_pos]) self.labels = labels - self.labels_freq = cnt + self.labels_freq = labels_freq + if self.keep_unfreq and self.self_supervised: + cnt = Counter() + with open(self.filename, encoding='utf-8') as fpt: + for line in fpt: + line = line.strip() + labels, text = line.split('\t') + labels = labels.split() + _labels_freq = [(k, labels_freq[k]) + for k in labels] + klass, _ = min(_labels_freq, key=lambda x: x[1]) + cnt.update([klass]) + self.neg_freq = cnt return labels + + @property + def neg_freq(self): + """Frequency in the negative label""" + return self._neg_freq + + @neg_freq.setter + def neg_freq(self, value): + self._neg_freq = value @labels.setter def labels(self, value): @@ -195,20 +216,25 @@ def filter_tokens(self, tokens, label): if not self.self_supervised: return tokens return [x for x in tokens if x != label] - - def training_set(self, label): - """Training set""" + + def training_set_texts(self, label): + """Training set texts""" self.text_model.disable_text_transformations = True tokenize = self.text_model.tokenize max_pos = min(self.max_pos, self.labels_freq[label]) num_neg = max(max_pos, self.min_neg) POS = [] - NEG = [] - labels_freq = [(k, v) for k, v in self.labels_freq.items() if k != label] + labels_freq = self.labels_freq + if self.keep_unfreq and self.self_supervised: + labels_freq = self.neg_freq + labels_freq = {k: v for k, v in labels_freq.items() if k != label} + if not self.keep_unfreq: + labels_freq = {None: num_neg} + NEG = NegDataset(num_neg, labels_freq) with open(self.filename, encoding='utf-8') as fpt: for line in fpt: - if len(POS) >= max_pos and len(NEG) >= num_neg: + if len(POS) >= max_pos and NEG.full: break line = line.strip() labels, text = line.split('\t') @@ -219,20 +245,20 @@ def training_set(self, label): _ = self.filter_tokens(tokens, label) POS.append(_) continue - klass, _ = min(labels_freq, key=lambda x: x[1]) - neg = dict(tokens=tokens, label=klass) - if len(NEG) < num_neg: - NEG.append(neg) - continue - k = randint(0, len(NEG) - 1) - if not self.keep_unfreq: - NEG[k] = neg - continue - if self.labels_freq[NEG[k]['label']] > self.labels_freq[neg['label']]: - NEG[k] = neg + if self.keep_unfreq: + labels_freq = [(k, self.labels_freq[k]) + for k in labels] + klass, _ = min(labels_freq, key=lambda x: x[1]) + else: + klass = None + NEG.add(tokens, klass) + return NEG.dataset(), POS + + def training_set(self, label): + """Training set""" + NEG, POS = self.training_set_texts(label) if len(NEG) == 0 or len(POS) == 0: return None - NEG = [x['tokens'] for x in NEG] X = self.transform(POS + NEG) y = [1] * len(POS) + [-1] * len(NEG) return X, np.array(y) @@ -307,6 +333,39 @@ def delete_tmps(self, args): os.rmdir(self.identifier) +class NegDataset: + """Uniform sample of the negatives""" + def __init__(self, N: int, freq: dict): + keys = list(freq) + cnt = uniform_sample(N, + np.array([freq[x] for x in keys])) + self.cnt = {k: v for k, v in zip(keys, cnt)} + self.elements = {k: list() for k in keys} + self.tot = N + self.size = 0 + + def add(self, data: str, label: str): + """Add element""" + cnt = self.cnt[label] + dataset = self.elements[label] + if len(dataset) < cnt: + self.size += 1 + dataset.append(data) + + def dataset(self): + """Dataset""" + values = [] + for v in self.elements.values(): + values.extend(v) + shuffle(values) + return values + + @property + def full(self): + """Indicate whether the dataset has all the elements required""" + return self.tot - self.size <= 0 + + def main(args): """CLI""" filename = args.file[0] diff --git a/encexp/tests/test_build_encexp.py b/encexp/tests/test_build_encexp.py index 5443651..4bbcfac 100644 --- a/encexp/tests/test_build_encexp.py +++ b/encexp/tests/test_build_encexp.py @@ -19,7 +19,7 @@ # from encexp.tests.test_utils import samples from encexp.utils import load_dataset from encexp.text_repr import SeqTM, EncExpT -from encexp.build_encexp import Dataset, EncExpDataset, Train, main +from encexp.build_encexp import Dataset, EncExpDataset, Train, main, NegDataset def test_Dataset_output_filename(): @@ -109,7 +109,7 @@ def test_Train_labels(): def test_Train_training_set(): """Test Train""" - + dataset = load_dataset('mx') seq = SeqTM(lang='es', token_max_filter=2**13) ds = EncExpDataset(text_model=clone(seq)) @@ -121,14 +121,17 @@ def test_Train_training_set(): X, y = train.training_set(labels[0]) assert X.shape[0] == len(y) and X.shape[1] == len(seq.names) # cnt = np.where((X > 0).sum(axis=0).A1)[0].shape - train.keep_unfreq = True + train = Train(text_model=seq, min_pos=32, + keep_unfreq=True, + filename=ds.output_filename) + labels = train.labels X, y = train.training_set(labels[0]) _, freq = np.unique(y, return_counts=True) assert freq[0] > freq[1] - train.min_neg = 0 - X, y = train.training_set(labels[0]) - _, freq = np.unique(y, return_counts=True) - assert freq[0] == freq[1] + # train.min_neg = 0 + # X, y = train.training_set(labels[0]) + # _, freq = np.unique(y, return_counts=True) + # assert freq[0] == freq[1] os.unlink(ds.output_filename) # cnt2 = np.where((X > 0).sum(axis=0).A1)[0].shape @@ -207,4 +210,24 @@ class A: A.file = [dataset] A.voc_size_exponent = 13 A.n_jobs = -1 - main(A) \ No newline at end of file + main(A) + + +def test_NegDataset(): + """Test NegDataset""" + freq = {'mx': 1000, 'ar': 100, 'es': 10} + neg = NegDataset(500, freq) + for k in range(510): + neg.add(f'mx {k}', 'mx') + assert len(neg.elements['mx']) == 390 + for k in range(110): + neg.add(f'ar {k}', 'ar') + assert len(neg.elements['ar']) == 100 + for k in range(20): + neg.add(f'es {k}', 'es') + assert neg.full + assert len(neg.dataset()) == 500 + neg = NegDataset(500, {None: 500}) + for k in range(510): + neg.add(f'unico {k}', None) + assert neg.full \ No newline at end of file diff --git a/encexp/tests/test_text_repr.py b/encexp/tests/test_text_repr.py index d056bc1..1b9061f 100644 --- a/encexp/tests/test_text_repr.py +++ b/encexp/tests/test_text_repr.py @@ -138,6 +138,7 @@ def test_EncExpT_tailored(): D = list(tweet_iterator(dataset)) enc = EncExpT(lang='es', pretrained=False) enc.tailored(D, tsv_filename='tailored.tsv', + min_pos=32, filename='tailored.json.gz') assert enc.weights.shape[0] == 2**14 assert enc.weights.shape[1] == 90 @@ -162,6 +163,7 @@ def test_EncExpT_tailored_intercept(): enc = EncExpT(lang='es', with_intercept=True, pretrained=False) enc.tailored(D, tsv_filename='tailored.tsv', + min_pos=32, filename='tailored_intercept.json.gz') assert enc.weights.shape[0] == 2**14 assert enc.weights.shape[1] == 90 @@ -185,7 +187,7 @@ def test_EncExpT_tailored_add(): dataset = load_dataset('mx') D = list(tweet_iterator(dataset)) enc = EncExpT(lang='es', token_max_filter=2**13) - enc.tailored(D) + enc.tailored(D, min_pos=32) def test_EncExpT_tailored_no_neg(): @@ -193,7 +195,7 @@ def test_EncExpT_tailored_no_neg(): dataset = load_dataset('mx') D = [f'{text} de' for text in tweet_iterator(dataset)] enc = EncExpT(lang='es', token_max_filter=2**13) - enc.tailored(D) + enc.tailored(D, min_pos=32) def test_EncExpT_tailored_2cl(): @@ -203,7 +205,7 @@ def test_EncExpT_tailored_2cl(): enc = EncExpT(lang='es', pretrained=False, with_intercept=True, token_max_filter=2**13) - enc.tailored(D, self_supervised=False) + enc.tailored(D, self_supervised=False, min_pos=32) assert enc.names.tolist() == ['ar', 'mx'] @@ -538,4 +540,4 @@ def test_TextModel_diac(): # cv=sss, # n_jobs=1, # scoring='f1_macro').fit(mx + ar, y) -# assert grid.best_score_ > 0.7 \ No newline at end of file +# assert grid.best_score_ > 0.7 diff --git a/encexp/text_repr.py b/encexp/text_repr.py index 3873ad6..ce2568f 100644 --- a/encexp/text_repr.py +++ b/encexp/text_repr.py @@ -441,7 +441,7 @@ class EncExpT(Identifier): with_intercept: bool=False merge_encode: bool=True distance: bool=False - keep_unfreq: bool=True + keep_unfreq: bool=False @property def seqTM(self): @@ -609,8 +609,9 @@ def add(self, data: Iterable): def tailored(self, D: Iterable=None, filename: str=None, tsv_filename: str=None, - min_pos: int=32, - max_pos: int=int(2**15), + min_pos: int=512, + min_neg: int=int(2**14), + max_pos: int=int(2**14), n_jobs: int=-1, self_supervised: bool=True, ds: object=None, @@ -653,6 +654,7 @@ def set_weights(data): filename=ds.output_filename, use_tqdm=self.use_tqdm, min_pos=min_pos, + min_neg=min_neg, max_pos=max_pos, n_jobs=n_jobs, with_intercept=self.with_intercept, diff --git a/encexp/utils.py b/encexp/utils.py index bc402b8..77a621e 100644 --- a/encexp/utils.py +++ b/encexp/utils.py @@ -181,6 +181,7 @@ def inner(texts): def load_dataset(country: Union[str, list], + lang: str='es', return_X_y:bool=False): """Country identification dataset""" if not isdir(MODELS): @@ -188,8 +189,8 @@ def load_dataset(country: Union[str, list], if isinstance(country, str): country = [country] for cntr in country: - url = f'{DialectID_URL}/es-{cntr}-sample.json.zip' - filename=join(MODELS, f'es-{cntr}-sample.json.zip') + url = f'{DialectID_URL}/{lang}-{cntr}-sample.json.zip' + filename=join(MODELS, f'{lang}-{cntr}-sample.json.zip') if isfile(filename): continue Download(url, filename) @@ -197,12 +198,12 @@ def load_dataset(country: Union[str, list], fpt.extractall(path=MODELS, pwd="ingeotec".encode("utf-8")) if len(country) == 1 and return_X_y is False: - return join(MODELS, f'es-{country[0]}-sample.json') + return join(MODELS, f'{lang}-{country[0]}-sample.json') assert return_X_y X = [] y = [] for cntr in country: - _ = join(MODELS, f'es-{cntr}-sample.json') + _ = join(MODELS, f'{lang}-{cntr}-sample.json') _ = list(tweet_iterator(_)) X.extend(_) y.extend([cntr] * len(_))