diff --git a/encexp/__init__.py b/encexp/__init__.py index 0499652..8f77ddf 100644 --- a/encexp/__init__.py +++ b/encexp/__init__.py @@ -17,4 +17,4 @@ if not '-m' in sys.argv: from encexp.text_repr import EncExpT, SeqTM, TextModel -__version__ = "0.1.8" +__version__ = "0.1.9" diff --git a/encexp/tests/test_build_encexp.py b/encexp/tests/test_build_encexp.py index 4bbcfac..0563003 100644 --- a/encexp/tests/test_build_encexp.py +++ b/encexp/tests/test_build_encexp.py @@ -32,8 +32,7 @@ def test_Dataset_output_filename(): def test_Dataset_process(): """Test Dataset process""" - dataset = load_dataset('mx') - iter = list(tweet_iterator(dataset)) + iter = load_dataset(dataset='dev')[:2048] for x in iter: x['klass'] = 'mx' seq = SeqTM(lang='es', token_max_filter=2**13) @@ -42,7 +41,7 @@ def test_Dataset_process(): data = open(ds.output_filename, encoding='utf-8').readlines() assert data[0][:2] == 'mx' - iter = list(tweet_iterator(dataset)) + # iter = list(tweet_iterator(dataset)) seq = SeqTM(lang='es', token_max_filter=2**13) words = [str(x) for x in seq.names if x[:2] != 'q:' and x[:2] != 'e:'] @@ -76,10 +75,9 @@ def test_Dataset_self_supervise(): def test_EncExpDataset(): """Test EncExpDataset""" - dataset = load_dataset('mx') + iter = load_dataset(dataset='dev')[:2048] seq = SeqTM(lang='es', token_max_filter=2**13) ds = EncExpDataset(text_model=seq) - iter = list(tweet_iterator(dataset)) ds.process(iter) data = open(ds.output_filename, encoding='utf-8').readlines() assert len(data) <= len(iter) @@ -89,13 +87,13 @@ def test_EncExpDataset(): def test_Train_labels(): """Test labels""" - dataset = load_dataset('mx') + dataset = load_dataset(dataset='dev')[:2048] seq = SeqTM(lang='es', token_max_filter=2**13) ds = EncExpDataset(text_model=clone(seq)) - ds.process(tweet_iterator(dataset)) + ds.process(dataset) train = Train(text_model=seq, min_pos=32, filename=ds.output_filename) - assert len(train.labels) == 90 + assert len(train.labels) == 91 X, y = load_dataset(['mx', 'ar', 'es'], return_X_y=True) D = [dict(text=text, klass=label) for text, label in zip(X, y)] ds = EncExpDataset(text_model=clone(seq), self_supervised=False) @@ -110,11 +108,11 @@ def test_Train_labels(): def test_Train_training_set(): """Test Train""" - dataset = load_dataset('mx') + dataset = load_dataset(dataset='dev')[:2048] seq = SeqTM(lang='es', token_max_filter=2**13) ds = EncExpDataset(text_model=clone(seq)) # if not isfile(ds.output_filename): - ds.process(tweet_iterator(dataset)) + ds.process(dataset) train = Train(text_model=seq, min_pos=32, filename=ds.output_filename) labels = train.labels @@ -141,11 +139,11 @@ def test_Train_training_set(): def test_Train_parameters(): """Test Train""" - dataset = load_dataset('mx') + dataset = load_dataset(dataset='dev')[:2048] seq = SeqTM(lang='es', token_max_filter=2**13) ds = EncExpDataset(text_model=clone(seq)) if not isfile(ds.output_filename): - ds.process(tweet_iterator(dataset)) + ds.process(dataset) train = Train(text_model=seq, min_pos=32, filename=ds.output_filename) labels = train.labels @@ -157,14 +155,14 @@ def test_Train_parameters(): def test_Train_store_model(): """Test Train""" - dataset = load_dataset('mx') + dataset = load_dataset(dataset='dev')[:2048] enc = EncExpT(lang='es', token_max_filter=2**13, pretrained=False) enc.pretrained = True ds = EncExpDataset(text_model=clone(enc.seqTM)) ds.identifier = enc.identifier if not isfile(ds.output_filename): - ds.process(tweet_iterator(dataset)) + ds.process(dataset) train = Train(text_model=enc.seqTM, min_pos=32, filename=ds.output_filename) train.identifier = enc.identifier @@ -201,13 +199,15 @@ def test_Train_2cl(): def test_seqtm_build(): """Test SeqTM CLI""" + from encexp.utils import MODELS class A: """Dummy""" - dataset = load_dataset('mx') + load_dataset(dataset='dev') + filename = join(MODELS, 'dialectid_es_dev.json') A.lang = 'es' - A.file = [dataset] + A.file = [filename] A.voc_size_exponent = 13 A.n_jobs = -1 main(A) diff --git a/encexp/tests/test_build_voc.py b/encexp/tests/test_build_voc.py index 734b923..17e293a 100644 --- a/encexp/tests/test_build_voc.py +++ b/encexp/tests/test_build_voc.py @@ -11,20 +11,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from os.path import join from microtc.utils import tweet_iterator # from encexp.tests.test_utils import samples from encexp.utils import load_dataset from encexp.build_voc import compute_TextModel_vocabulary, compute_SeqTM_vocabulary +from encexp.utils import MODELS def test_compute_TextModel_vocabulary(): """Compute vocabulary""" def iterator(): """iterator""" - return tweet_iterator(dataset) + for x in dataset: + yield x - dataset = load_dataset('mx') - data = compute_TextModel_vocabulary(dataset, + + dataset = load_dataset(dataset='dev')[:2048] + filename = join(MODELS, 'dialectid_es_dev.json') + data = compute_TextModel_vocabulary(filename, pretrained=False, token_max_filter=20) assert len(data['vocabulary']['dict']) == 20 @@ -37,10 +42,15 @@ def iterator(): def test_compute_SeqTM_vocabulary(): """test SeqTM vocabulary""" - dataset = load_dataset('mx') - params = compute_TextModel_vocabulary(dataset, + def iterator(): + """iterator""" + for x in dataset: + yield x + + dataset = load_dataset(dataset='dev')[:2048] + params = compute_TextModel_vocabulary(iterator, pretrained=False) - data = compute_SeqTM_vocabulary(dataset, + data = compute_SeqTM_vocabulary(iterator, params, pretrained=False, token_max_filter=2**13) diff --git a/encexp/tests/test_text_repr.py b/encexp/tests/test_text_repr.py index d54c5a6..e113dd7 100644 --- a/encexp/tests/test_text_repr.py +++ b/encexp/tests/test_text_repr.py @@ -134,42 +134,40 @@ def test_EncExpT_identifier(): def test_EncExpT_tailored(): """Test EncExpT tailored""" - dataset = load_dataset('mx') - D = list(tweet_iterator(dataset)) + D = load_dataset(dataset='dev')[:2048] enc = EncExpT(lang='es', pretrained=False) enc.tailored(D, tsv_filename='tailored.tsv', min_pos=32, filename='tailored.json.gz') assert enc.weights.shape[0] == 2**14 - assert enc.weights.shape[1] == 90 + assert enc.weights.shape[1] == 91 W = enc.encode('buenos dias') - assert W.shape == (1, 90) + assert W.shape == (1, 91) X = enc.transform(D) - assert X.shape == (2048, 90) + assert X.shape == (2048, 91) def test_EncExpT_pretrained(): """Test EncExpT pretrained""" enc = EncExpT(lang='es', token_max_filter=2**13) X = enc.transform(['buenos dias']) - assert X.shape == (1, 4985) - assert len(enc.names) == 4985 + assert X.shape == (1, 4977) + assert len(enc.names) == 4977 def test_EncExpT_tailored_intercept(): """Test EncExpT tailored""" - dataset = load_dataset('mx') - D = list(tweet_iterator(dataset)) + D = load_dataset(dataset='dev')[:2048] enc = EncExpT(lang='es', with_intercept=True, pretrained=False) enc.tailored(D, tsv_filename='tailored.tsv', min_pos=32, filename='tailored_intercept.json.gz') assert enc.weights.shape[0] == 2**14 - assert enc.weights.shape[1] == 90 - assert enc.intercept.shape[0] == 90 + assert enc.weights.shape[1] == 91 + assert enc.intercept.shape[0] == 91 X = enc.transform(['buenos dias']) - assert X.shape[1] == 90 + assert X.shape[1] == 91 enc.with_intercept = False assert np.fabs(X - enc.transform(['buenos dias'])).sum() != 0 enc.with_intercept = True @@ -184,16 +182,15 @@ def test_EncExpT_tailored_intercept(): def test_EncExpT_tailored_add(): """Test EncExpT tailored""" - dataset = load_dataset('mx') - D = list(tweet_iterator(dataset)) + D = load_dataset(dataset='dev')[:2048] enc = EncExpT(lang='es', token_max_filter=2**13) enc.tailored(D, min_pos=32) def test_EncExpT_tailored_no_neg(): """Test EncExpT tailored""" - dataset = load_dataset('mx') - D = [f'{text} de' for text in tweet_iterator(dataset)] + dataset = load_dataset(dataset='dev')[:2048] + D = [f'{text} de' for text in dataset] enc = EncExpT(lang='es', token_max_filter=2**13) enc.tailored(D, min_pos=32) @@ -224,8 +221,7 @@ def test_EncExpT_norm(): def test_TextModel_diac(): """Test TextModel diac""" from unicodedata import normalize - dataset = load_dataset('mx') - D = list(tweet_iterator(dataset)) + D = load_dataset(dataset='dev')[:2048] tm = TextModel(del_diac=False, pretrained=False).fit(D) cdn = normalize('NFD', 'ñ') lst = [x for x in tm.names if cdn in x] @@ -240,4 +236,19 @@ def test_EncExpT_transform_dtype(): enc = EncExpT(lang='es', token_max_filter=2**13) X = enc.transform(['buenos dias']) - assert X.dtype == enc.precision \ No newline at end of file + assert X.dtype == enc.precision + + +def test_EncExpT_encode(): + """Test EncExpT transform type""" + enc = EncExpT(lang='es', merge_encode=False, + token_max_filter=2**13) + text = 'el infarto tiene que ver con el organo' + index = [k for k, v in enumerate(enc.seqTM.tokenize(text)) if v == 'el'] + X = enc.encode(text) + for otro in index[1:]: + assert_almost_equal(X[index[0]], X[otro]) + enc.merge_encode = True + X2 = enc.encode(text) + assert_almost_equal(X.sum(axis=0), X.sum(axis=0)) + \ No newline at end of file diff --git a/encexp/utils.py b/encexp/utils.py index fa75bf1..652118d 100644 --- a/encexp/utils.py +++ b/encexp/utils.py @@ -179,31 +179,34 @@ def inner(texts): return inner -def load_dataset(country: Union[str, list], +def load_dataset(country: Union[str, list]=None, lang: str='es', + dataset='train', return_X_y:bool=False): """Country identification dataset""" + def filter_func(ele): + if country is None: + return True + elif isinstance(country, str): + return ele['country'] == country + return ele['country'] in country + + if not isdir(MODELS): - os.mkdir(MODELS) - if isinstance(country, str): - country = [country] - for cntr in country: - url = f'{DialectID_URL}/{lang}-{cntr}-sample.json.zip' - filename=join(MODELS, f'{lang}-{cntr}-sample.json.zip') - if isfile(filename): - continue + os.mkdir(MODELS) + url = f'{DialectID_URL}/dialectid_{lang}_{dataset}.json.zip' + filename=join(MODELS, f'dialectid_{lang}_{dataset}.json.zip') + json_filename = filename[:-4] + if not isfile(json_filename): Download(url, filename) with ZipFile(filename, "r") as fpt: fpt.extractall(path=MODELS, pwd="ingeotec".encode("utf-8")) - if len(country) == 1 and return_X_y is False: - return join(MODELS, f'{lang}-{country[0]}-sample.json') - assert return_X_y - X = [] - y = [] - for cntr in country: - _ = join(MODELS, f'{lang}-{cntr}-sample.json') - _ = list(tweet_iterator(_)) - X.extend(_) - y.extend([cntr] * len(_)) - return X, y \ No newline at end of file + os.unlink(filename) + if country is not None: + assert dataset == 'train' + data = list(tweet_iterator(json_filename)) + if not return_X_y: + return [x for x in data if filter_func(x)] + _ = [x for x in data if filter_func(x)] + return [i['text'] for i in _], [i['country'] for i in _] \ No newline at end of file