99import pickle
1010
1111from punctuator import Punctuator
12+ import gdown
13+
1214
1315CACHE_PATH = Path .home () / "cache"
1416CACHE_PATH .mkdir (exist_ok = True )
1921 {
2022 "url" : "http://ltdata1.informatik.uni-hamburg.de/subtitle2go/Model_subs_norm1_filt_5M_tageschau_euparl_h256_lr0.02.pcl" ,
2123 "tests" : [{"input" : "hallo ich bin ein testsatz" , "expected" : "Hallo, ich bin ein testsatz." }]
24+ },
25+ # Rehosted from https://drive.google.com/drive/folders/0B7BsN5f2F1fZQnFsbzJ3TWxxMms?resourcekey=0-6yhuY9FOeITBBWWNdyG2aw
26+ {
27+ "url" : "gdrive://1CZ_Os38LjBwyd-jgDMsfpqiWPB6wwVKA" ,
28+ "name" : "Demo-EUROPARL-EN.zip" ,
29+ "pickle_encoding" : "latin-1" ,
30+ "tests" : [
31+ {
32+ "input" : "hello this is an example sentence" ,
33+ "expected" : "Hello, this is an example sentence." ,
34+ }
35+ ],
2236 }
2337]
2438
25- for model in PUNCTUATOR_MODELS :
26- url_path = urllib .parse .urlparse (model ['url' ]).path
39+
40+ def download_model (model ):
41+ if model ["url" ].startswith ("gdrive://" ):
42+ return download_gdrive_model (model )
43+ else :
44+ return download_http_model (model )
45+
46+
47+ def download_gdrive_model (model ):
48+ url_path = urllib .parse .urlparse (model ["url" ]).netloc
49+ output_model_file_path = MODEL_PATH / model ["name" ]
50+ input_model_file_path = CACHE_PATH / url_path
51+ if not input_model_file_path .exists ():
52+ print ("Downloading" , url_path )
53+ gdown .download (id = url_path , output = str (input_model_file_path ), fuzzy = True )
54+ return input_model_file_path , output_model_file_path
55+
56+
57+ def download_http_model (model ):
58+ url_path = urllib .parse .urlparse (model ["url" ]).path
2759 name = Path (url_path ).name
2860 input_model_file_path = CACHE_PATH / name
2961 output_name = Path (url_path ).with_suffix (".zip" ).name
3062 output_model_file_path = MODEL_PATH / output_name
3163 if not input_model_file_path .exists ():
32- req = requests .get (model ['url' ])
64+ print ("Downloading" , model ["url" ])
65+ req = requests .get (model ["url" ])
3366 with open (input_model_file_path , "wb" ) as f :
3467 f .write (req .content )
3568
69+ return input_model_file_path , output_model_file_path
70+
71+
72+ for model in PUNCTUATOR_MODELS :
73+ input_model_file_path , output_model_file_path = download_model (model )
3674 with open (input_model_file_path , "rb" ) as f :
37- state = pickle .load (f )
75+ if 'pickle_encoding' in model :
76+ u = pickle ._Unpickler (f )
77+ u .encoding = model ['pickle_encoding' ]
78+ state = u .load ()
79+ else :
80+ state = pickle .load (f )
3881
3982 with zipfile .ZipFile (output_model_file_path , "w" ) as model_zip :
4083 for k , v in state .items ():
5093 with model_zip .open (f"{ k } .json" , "w" ) as f :
5194 f .write (json .dumps (v ).encode ())
5295
53- if ' tests' in model and model [' tests' ]:
96+ if " tests" in model and model [" tests" ]:
5497 punctuation_model = Punctuator (output_model_file_path )
55- for test in model ['tests' ]:
56- actual = punctuation_model .punctuate (test ['input' ])
57- assert actual == test ['expected' ]
98+ for test in model ["tests" ]:
99+ actual = punctuation_model .punctuate (test ["input" ])
100+ assert (
101+ actual == test ["expected" ]
102+ ), f"'{ test ['expected' ]} ' expected, got { actual } "
0 commit comments