Skip to content
This repository was archived by the owner on Jul 4, 2023. It is now read-only.

Commit 133a54c

Browse files
authored
Merge pull request #56 from PetrochukM/update
Release 0.3.7 - 5 fixed issues and a new label_encoder
2 parents 2d10f0e + c501355 commit 133a54c

35 files changed

+190
-77
lines changed

.flake8

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
[flake8]
2-
ignore = E402, E722, E731
2+
ignore = E402, E722, E731, W504
33
max-line-length = 100
44
exclude = examples/

examples/snli/train.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@
7777
best_dev_acc = -1
7878
header = ' Time Epoch Iteration Progress (%Epoch) Loss Dev/Loss Accuracy Dev/Accuracy'
7979
dev_log_template = ' '.join(
80-
'{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:8.6f},{:12.4f},{:12.4f}'.
81-
split(','))
80+
'{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:8.6f},{:12.4f},{:12.4f}'
81+
.split(','))
8282
log_template = ' '.join(
8383
'{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{},{:12.4f},{}'.split(','))
8484
makedirs(args.save_path)
@@ -108,8 +108,7 @@
108108
answer = model(premise_batch, hypothesis_batch)
109109

110110
# calculate accuracy of predictions in the current batch
111-
n_correct += (torch.max(answer,
112-
1)[1].view(label_batch.size()) == label_batch).sum()
111+
n_correct += (torch.max(answer, 1)[1].view(label_batch.size()) == label_batch).sum()
113112
n_total += premise_batch.size()[1]
114113
train_acc = 100. * n_correct / n_total
115114

@@ -150,8 +149,8 @@
150149
for dev_batch_idx, (premise_batch, hypothesis_batch,
151150
label_batch) in enumerate(dev_iterator):
152151
answer = model(premise_batch, hypothesis_batch)
153-
n_dev_correct += (torch.max(answer, 1)[1].view(
154-
label_batch.size()) == label_batch).sum()
152+
n_dev_correct += (torch.max(answer,
153+
1)[1].view(label_batch.size()) == label_batch).sum()
155154
dev_loss = criterion(answer, label_batch)
156155
dev_acc = 100. * n_dev_correct / len(dev)
157156

examples/snli/util.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,9 @@ def collate_fn(batch, train=True):
5757
""" list of tensors to a batch tensors """
5858
premise_batch, _ = pad_batch([row['premise'] for row in batch])
5959
hypothesis_batch, _ = pad_batch([row['hypothesis'] for row in batch])
60-
label_batch = [row['label'] for row in batch]
60+
label_batch = torch.stack([row['label'] for row in batch])
6161

6262
# PyTorch RNN requires batches to be transposed for speed and integration with CUDA
63-
transpose = (
64-
lambda b: torch.stack(b).t_().squeeze(0).contiguous())
63+
transpose = (lambda b: b.t_().squeeze(0).contiguous())
6564

6665
return (transpose(premise_batch), transpose(hypothesis_batch), transpose(label_batch))

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@ flake8
1919
# Mocking
2020
mock
2121

22-
# # Optional NLP Utilties
22+
# Optional NLP Utilties
2323
# nltk
2424
# spacy
2525
# sacremoses
2626

27-
# # Optional CUDA Utilties
27+
# Optional CUDA Utilties
2828
# pynvrtc
2929
# cupy
3030

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def find_version(*file_paths):
3636
long_description=long_description,
3737
long_description_content_type='text/markdown',
3838
license='BSD',
39-
install_requires=['numpy', 'pandas', 'tqdm', 'ujson', 'requests'],
39+
install_requires=['numpy', 'pandas', 'tqdm', 'requests'],
4040
classifiers=[
4141
'Development Status :: 4 - Beta',
4242
'Intended Audience :: Developers',

tests/datasets/test_simple_qa.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
import os
22
import shutil
3-
import pytest
43

54
import mock
5+
import pytest
66

77
from torchnlp.datasets import simple_qa_dataset
88
from tests.datasets.utils import urlretrieve_side_effect
99

1010
directory = 'tests/_test_data/'
1111

1212

13-
@pytest.mark.skip(reason="Simple Questions dataset url returns 404.")
13+
@pytest.mark.skip(reason="Simple Questions dataset url sometimes returns 404.")
1414
@mock.patch("urllib.request.urlretrieve")
1515
def test_simple_qa_dataset_row(mock_urlretrieve):
1616
mock_urlretrieve.side_effect = urlretrieve_side_effect

tests/datasets/test_smt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def test_smt_dataset_row(mock_urlretrieve):
3535
" splash even greater than Arnold Schwarzenegger , Jean-Claud Van Damme or Steven" +
3636
" Segal .",
3737
'label':
38-
'positive'
38+
'very positive'
3939
}
4040

4141
# Clean up

tests/nn/test_weight_drop.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def test_weight_drop_lstm():
2727
run2 = [x.sum() for x in wd_lstm(input_)[0].data]
2828

2929
# First time step, not influenced by hidden to hidden weights, should be equal
30-
assert pytest.approx(run1[0]) == pytest.approx(run2[0])
30+
assert pytest.approx(run1[0].item()) == pytest.approx(run2[0].item())
3131
# Second step should not
3232
assert run1[1] != run2[1]
3333

@@ -40,7 +40,7 @@ def test_weight_drop_gru():
4040
run2 = [x.sum() for x in wd_lstm(input_)[0].data]
4141

4242
# First time step, not influenced by hidden to hidden weights, should be equal
43-
assert pytest.approx(run1[0]) == pytest.approx(run2[0])
43+
assert pytest.approx(run1[0].item()) == pytest.approx(run2[0].item())
4444
# Second step should not
4545
assert run1[1] != run2[1]
4646

@@ -53,7 +53,7 @@ def test_weight_drop():
5353
run2 = [x.sum() for x in wd_lstm(input_)[0].data]
5454

5555
# First time step, not influenced by hidden to hidden weights, should be equal
56-
assert pytest.approx(run1[0]) == pytest.approx(run2[0])
56+
assert pytest.approx(run1[0].item()) == pytest.approx(run2[0].item())
5757
# Second step should not
5858
assert run1[1] != run2[1]
5959

@@ -66,6 +66,6 @@ def test_weight_drop_zero():
6666
run2 = [x.sum() for x in wd_lstm(input_)[0].data]
6767

6868
# First time step, not influenced by hidden to hidden weights, should be equal
69-
assert pytest.approx(run1[0]) == pytest.approx(run2[0])
69+
assert pytest.approx(run1[0].item()) == pytest.approx(run2[0].item())
7070
# Second step should not
71-
assert pytest.approx(run1[1]) == pytest.approx(run2[1])
71+
assert pytest.approx(run1[1].item()) == pytest.approx(run2[1].item())

tests/test_label_encoder.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import pickle
2+
3+
import pytest
4+
5+
from torchnlp.label_encoder import LabelEncoder
6+
from torchnlp.label_encoder import UNKNOWN_TOKEN
7+
8+
9+
@pytest.fixture
10+
def encoder():
11+
sample = ['people/deceased_person/place_of_death', 'symbols/name_source/namesakes']
12+
return LabelEncoder(sample)
13+
14+
15+
def test_label_encoder_vocab(encoder):
16+
assert len(encoder.vocab) == 3
17+
assert len(encoder.vocab) == encoder.vocab_size
18+
19+
20+
def test_label_encoder_scalar(encoder):
21+
input_ = 'symbols/namesake/named_after'
22+
output = encoder.encode(input_)[0]
23+
assert encoder.decode(output) == UNKNOWN_TOKEN
24+
25+
26+
def test_label_encoder_unknown(encoder):
27+
input_ = 'symbols/namesake/named_after'
28+
output = encoder.encode(input_)
29+
assert len(output) == 1
30+
assert encoder.decode(output) == UNKNOWN_TOKEN
31+
32+
33+
def test_label_encoder_known():
34+
input_ = 'symbols/namesake/named_after'
35+
sample = ['people/deceased_person/place_of_death', 'symbols/name_source/namesakes']
36+
sample.append(input_)
37+
encoder = LabelEncoder(sample)
38+
output = encoder.encode(input_)
39+
assert len(output) == 1
40+
assert encoder.decode(output) == input_
41+
42+
43+
def test_is_pickleable(encoder):
44+
pickle.dumps(encoder)

tests/text_encoders/test_subword_tokenizer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,10 @@ def test_encode_decode(self):
8989

9090
original = 'This is a coded sentence encoded by the SubwordTextTokenizer.'
9191

92-
encoder = SubwordTextTokenizer.build_to_target_size_from_corpus(
93-
[corpus, original], target_size=100, min_val=2, max_val=10)
92+
encoder = SubwordTextTokenizer.build_to_target_size_from_corpus([corpus, original],
93+
target_size=100,
94+
min_val=2,
95+
max_val=10)
9496

9597
# Encoding should be reversible.
9698
encoded = encoder.encode(original)

0 commit comments

Comments
 (0)