Skip to content

Commit 9cb828f

Browse files
committed
Implement vocabulary expansion
1 parent 0b528c1 commit 9cb828f

File tree

3 files changed

+15
-4
lines changed

3 files changed

+15
-4
lines changed

anago/preprocess.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,8 @@ def dense_to_one_hot(labels_dense, num_classes, nlevels=1):
237237
raise ValueError('nlevels can take 1 or 2, not take {}.'.format(nlevels))
238238

239239

240-
def prepare_preprocessor(X, y, use_char=True):
241-
p = WordPreprocessor()
240+
def prepare_preprocessor(X, y, use_char=True, vocab_init=None):
241+
p = WordPreprocessor(vocab_init=vocab_init)
242242
p.fit(X, y)
243243

244244
return p

anago/wrapper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ def __init__(self, char_emb_size=25, word_emb_size=100, char_lstm_units=25,
3434
self.log_dir = log_dir
3535
self.embeddings = embeddings
3636

37-
def train(self, x_train, y_train, x_valid=None, y_valid=None):
38-
self.p = prepare_preprocessor(x_train, y_train)
37+
def train(self, x_train, y_train, x_valid=None, y_valid=None, vocab_init=None):
38+
self.p = prepare_preprocessor(x_train, y_train, vocab_init=vocab_init)
3939
embeddings = filter_embeddings(self.embeddings, self.p.vocab_word,
4040
self.model_config.word_embedding_size)
4141
self.model_config.vocab_size = len(self.p.vocab_word)

tests/wrapper_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import unittest
33
from pprint import pprint
44

5+
import numpy as np
6+
57
import anago
68
from anago.reader import load_data_and_labels, load_glove
79

@@ -80,3 +82,12 @@ def test_load(self):
8082

8183
model = anago.Sequence.load(self.dir_path)
8284
model.eval(self.x_test, self.y_test)
85+
86+
def test_train_vocab_init(self):
87+
vocab = set()
88+
for words in np.r_[self.x_train, self.x_valid, self.x_test]:
89+
for word in words:
90+
vocab.add(word)
91+
model = anago.Sequence(max_epoch=15, embeddings=self.embeddings, log_dir='logs')
92+
model.train(self.x_train, self.y_train, self.x_test, self.y_test, vocab_init=vocab)
93+
model.save(dir_path=self.dir_path)

0 commit comments

Comments
 (0)