Skip to content

Commit ff0e625

Browse files
committed
✨ English model
1 parent f028e6c commit ff0e625

File tree

3 files changed

+55
-9
lines changed

3 files changed

+55
-9
lines changed

generate_models.py

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import pickle
1010

1111
from punctuator import Punctuator
12+
import gdown
13+
1214

1315
CACHE_PATH = Path.home() / "cache"
1416
CACHE_PATH.mkdir(exist_ok=True)
@@ -19,22 +21,63 @@
1921
{
2022
"url": "http://ltdata1.informatik.uni-hamburg.de/subtitle2go/Model_subs_norm1_filt_5M_tageschau_euparl_h256_lr0.02.pcl",
2123
"tests": [{"input": "hallo ich bin ein testsatz", "expected": "Hallo, ich bin ein testsatz."}]
24+
},
25+
# Rehosted from https://drive.google.com/drive/folders/0B7BsN5f2F1fZQnFsbzJ3TWxxMms?resourcekey=0-6yhuY9FOeITBBWWNdyG2aw
26+
{
27+
"url": "gdrive://1CZ_Os38LjBwyd-jgDMsfpqiWPB6wwVKA",
28+
"name": "Demo-EUROPARL-EN.zip",
29+
"pickle_encoding": "latin-1",
30+
"tests": [
31+
{
32+
"input": "hello this is an example sentence",
33+
"expected": "Hello, this is an example sentence.",
34+
}
35+
],
2236
}
2337
]
2438

25-
for model in PUNCTUATOR_MODELS:
26-
url_path = urllib.parse.urlparse(model['url']).path
39+
40+
def download_model(model):
41+
if model["url"].startswith("gdrive://"):
42+
return download_gdrive_model(model)
43+
else:
44+
return download_http_model(model)
45+
46+
47+
def download_gdrive_model(model):
48+
url_path = urllib.parse.urlparse(model["url"]).netloc
49+
output_model_file_path = MODEL_PATH / model["name"]
50+
input_model_file_path = CACHE_PATH / url_path
51+
if not input_model_file_path.exists():
52+
print("Downloading", url_path)
53+
gdown.download(id=url_path, output=str(input_model_file_path), fuzzy=True)
54+
return input_model_file_path, output_model_file_path
55+
56+
57+
def download_http_model(model):
58+
url_path = urllib.parse.urlparse(model["url"]).path
2759
name = Path(url_path).name
2860
input_model_file_path = CACHE_PATH / name
2961
output_name = Path(url_path).with_suffix(".zip").name
3062
output_model_file_path = MODEL_PATH / output_name
3163
if not input_model_file_path.exists():
32-
req = requests.get(model['url'])
64+
print("Downloading", model["url"])
65+
req = requests.get(model["url"])
3366
with open(input_model_file_path, "wb") as f:
3467
f.write(req.content)
3568

69+
return input_model_file_path, output_model_file_path
70+
71+
72+
for model in PUNCTUATOR_MODELS:
73+
input_model_file_path, output_model_file_path = download_model(model)
3674
with open(input_model_file_path, "rb") as f:
37-
state = pickle.load(f)
75+
if 'pickle_encoding' in model:
76+
u = pickle._Unpickler(f)
77+
u.encoding = model['pickle_encoding']
78+
state = u.load()
79+
else:
80+
state = pickle.load(f)
3881

3982
with zipfile.ZipFile(output_model_file_path, "w") as model_zip:
4083
for k, v in state.items():
@@ -50,8 +93,10 @@
5093
with model_zip.open(f"{k}.json", "w") as f:
5194
f.write(json.dumps(v).encode())
5295

53-
if 'tests' in model and model['tests']:
96+
if "tests" in model and model["tests"]:
5497
punctuation_model = Punctuator(output_model_file_path)
55-
for test in model['tests']:
56-
actual = punctuation_model.punctuate(test['input'])
57-
assert actual == test['expected']
98+
for test in model["tests"]:
99+
actual = punctuation_model.punctuate(test["input"])
100+
assert (
101+
actual == test["expected"]
102+
), f"'{test['expected']}' expected, got {actual}"

poetry.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ numpy = ">1.22"
1010
scipy = ">=1.7.0"
1111
punctuator = {git = "https://github.com/audapolis/punctuator2", rev = "51eeb14cd30a4162e6a54e5234805651963f767f"}
1212
requests = "^2.27.1"
13+
gdown = "^4.4.0"
1314

1415
[tool.poetry.dev-dependencies]
1516

0 commit comments

Comments
 (0)