Skip to content

Commit fa026e7

Browse files
authored
Merge pull request #59 from neurosynth/parallel
Add parallel processing to add articles
2 parents f652fbc + 8a6a530 commit fa026e7

File tree

3 files changed

+125
-44
lines changed

3 files changed

+125
-44
lines changed

ace/ingest.py

Lines changed: 115 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from os import path
22
import logging
3+
import re
34
from . import sources, config
45
from .scrape import _validate_scrape
56
import multiprocessing as mp
@@ -23,8 +24,59 @@ def _process_file(f):
2324
return f, html
2425

2526

27+
def _process_file_with_source(args):
28+
"""Helper function to read, validate, and identify source for a single file."""
29+
f, source_configs = args
30+
logger.info("Processing article %s..." % f)
31+
try:
32+
html = open(f).read()
33+
except Exception as e:
34+
logger.warning("Failed to read file %s: %s" % (f, str(e)))
35+
return f, None, None
36+
37+
if not _validate_scrape(html):
38+
logger.warning("Invalid HTML for %s" % f)
39+
return f, None, None
40+
41+
# Identify source from HTML using regex patterns
42+
source_name = None
43+
for name, identifiers in source_configs.items():
44+
for patt in identifiers:
45+
if re.search(patt, html):
46+
logger.debug('Matched article to Source: %s' % name)
47+
source_name = name
48+
break
49+
if source_name:
50+
break
51+
52+
return f, html, source_name
53+
54+
55+
def _parse_article(args):
56+
"""Helper function to parse an article from HTML content."""
57+
# Unpack arguments
58+
f, html, source_name, pmid, manager, metadata_dir, force_ingest, kwargs = args
59+
60+
try:
61+
# Get the actual source object
62+
if source_name:
63+
source = manager.sources[source_name]
64+
else:
65+
# Fallback to original source identification
66+
source = manager.identify_source(html)
67+
if source is None:
68+
logger.warning("Could not identify source for %s" % f)
69+
return f, None
70+
71+
article = source.parse_article(html, pmid, metadata_dir=metadata_dir, **kwargs)
72+
return f, article
73+
except Exception as e:
74+
logger.warning("Error parsing article %s: %s" % (f, str(e)))
75+
return f, None
76+
77+
2678
def add_articles(db, files, commit=True, table_dir=None, limit=None,
27-
pmid_filenames=False, metadata_dir=None, force_ingest=True, parallel=True, num_workers=None, **kwargs):
79+
pmid_filenames=False, metadata_dir=None, force_ingest=True, num_workers=None, **kwargs):
2880
''' Process articles and add their data to the DB.
2981
Args:
3082
files: The path to the article(s) to process. Can be a single
@@ -46,13 +98,15 @@ def add_articles(db, files, commit=True, table_dir=None, limit=None,
4698
and will save the result of the query if it doesn't already
4799
exist.
48100
force_ingest: Ingest even if no source is identified.
49-
parallel: Whether to process articles in parallel (default: True).
50101
num_workers: Number of worker processes to use when processing in parallel.
51102
If None (default), uses the number of CPUs available on the system.
52103
kwargs: Additional keyword arguments to pass to parse_article.
53104
'''
54105

55-
manager = sources.SourceManager(db, table_dir)
106+
manager = sources.SourceManager(table_dir)
107+
108+
# Prepare source configurations for parallel processing
109+
source_configs = {name: source.identifiers for name, source in manager.sources.items()}
56110

57111
if isinstance(files, str):
58112
from glob import glob
@@ -64,38 +118,74 @@ def add_articles(db, files, commit=True, table_dir=None, limit=None,
64118

65119
missing_sources = []
66120

67-
if parallel:
68-
# Process files in parallel to extract HTML content
121+
# Step 1: Process files in parallel to extract HTML content and identify sources
122+
if num_workers is not None and num_workers != 1:
123+
# Process files in parallel to extract HTML content and identify sources
124+
process_args = [(f, source_configs) for f in files]
69125
with mp.Pool(processes=num_workers) as pool:
70-
file_html_pairs = pool.map(_process_file, files)
126+
file_html_source_tuples = pool.map(_process_file_with_source, process_args)
71127
else:
72128
# Process files sequentially
73-
file_html_pairs = []
129+
file_html_source_tuples = []
74130
for f in files:
75-
file_html_pairs.append(_process_file(f))
131+
result = _process_file_with_source((f, source_configs))
132+
file_html_source_tuples.append(result)
133+
134+
# Step 2: In serial mode, use the db object to skip articles that have been already added
135+
# Filter out files with reading/validation errors
136+
valid_files = []
137+
for f, html, source_name in file_html_source_tuples:
138+
if html is not None:
139+
valid_files.append((f, html, source_name))
140+
# We'll handle missing sources later when we actually parse the articles
141+
142+
# Filter out articles that already exist in the database
143+
files_to_process = []
144+
missing_sources = []
76145

77-
# Process each file's HTML content
78-
for i, (f, html) in enumerate(file_html_pairs):
79-
if html is None:
80-
# File reading or validation failed
81-
missing_sources.append(f)
146+
for f, html, source_name in valid_files:
147+
pmid = path.splitext(path.basename(f))[0] if pmid_filenames else None
148+
149+
# Check if article already exists
150+
if pmid is not None and db.article_exists(pmid) and not config.OVERWRITE_EXISTING_ROWS:
82151
continue
152+
153+
files_to_process.append((f, html, source_name, pmid))
83154

84-
source = manager.identify_source(html)
85-
if source is None:
86-
logger.warning("Could not identify source for %s" % f)
87-
missing_sources.append(f)
88-
if not force_ingest:
89-
continue
90-
else:
91-
source = sources.DefaultSource(db)
155+
# Step 3: Process remaining articles in parallel
156+
# Prepare arguments for _parse_article
157+
parse_args = [(f, html, source_name, pmid, manager, metadata_dir, force_ingest, kwargs)
158+
for f, html, source_name, pmid in files_to_process]
92159

93-
pmid = path.splitext(path.basename(f))[0] if pmid_filenames else None
94-
article = source.parse_article(html, pmid, metadata_dir=metadata_dir, **kwargs)
95-
if article and (config.SAVE_ARTICLES_WITHOUT_ACTIVATIONS or article.tables):
160+
if num_workers is not None and num_workers != 1 and parse_args:
161+
# Parse articles in parallel
162+
with mp.Pool(processes=num_workers) as pool:
163+
parsed_articles = pool.map(_parse_article, parse_args)
164+
else:
165+
# Parse articles sequentially
166+
parsed_articles = []
167+
for args in parse_args:
168+
parsed_articles.append(_parse_article(args))
169+
170+
# Add successfully parsed articles to database
171+
for i, (f, article) in enumerate(parsed_articles):
172+
if article is None:
173+
missing_sources.append(f)
174+
continue
175+
176+
if config.SAVE_ARTICLES_WITHOUT_ACTIVATIONS or article.tables:
177+
# Check again if article exists and handle overwrite
178+
pmid = path.splitext(path.basename(f))[0] if pmid_filenames else None
179+
if pmid is not None and db.article_exists(pmid):
180+
if config.OVERWRITE_EXISTING_ROWS:
181+
db.delete_article(pmid)
182+
else:
183+
continue
184+
96185
db.add(article)
97-
if commit and (i % 100 == 0 or i == len(file_html_pairs) - 1):
186+
if commit and (i % 100 == 0 or i == len(parsed_articles) - 1):
98187
db.save()
188+
99189
db.save()
100190

101191
return missing_sources

ace/sources.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,9 @@ class SourceManager:
3434
associated directory of JSON config files and uses them to determine which parser
3535
to call when a new HTML file is passed. '''
3636

37-
def __init__(self, database, table_dir=None):
37+
def __init__(self, table_dir=None):
3838
''' SourceManager constructor.
3939
Args:
40-
database: A Database instance to use with all Sources.
4140
table_dir: An optional directory name to save any downloaded tables to.
4241
When table_dir is None, nothing will be saved (requiring new scraping
4342
each time the article is processed).
@@ -47,7 +46,7 @@ def __init__(self, database, table_dir=None):
4746
source_dir = os.path.join(os.path.dirname(__file__), 'sources')
4847
for config_file in glob('%s/*json' % source_dir):
4948
class_name = config_file.split('/')[-1].split('.')[0]
50-
cls = getattr(module, class_name + 'Source')(database, config=config_file, table_dir=table_dir)
49+
cls = getattr(module, class_name + 'Source')(config=config_file, table_dir=table_dir)
5150
self.sources[class_name] = cls
5251

5352
def identify_source(self, html):
@@ -161,8 +160,7 @@ def _safe_clean_html(self, html):
161160
text_parts.append(text.strip())
162161
return '\n\n'.join(text_parts) if text_parts else soup.get_text()
163162

164-
def __init__(self, database, config=None, table_dir=None):
165-
self.database = database
163+
def __init__(self, config=None, table_dir=None):
166164
self.table_dir = table_dir
167165
self.entities = {}
168166

@@ -181,16 +179,11 @@ def __init__(self, database, config=None, table_dir=None):
181179
else:
182180
self.entities.update(Source.ENTITIES)
183181

184-
@abc.abstractmethod
185182
def parse_article(self, html, pmid=None, metadata_dir=None):
186-
''' Takes HTML article as input and returns an Article. PMID Can also be
187-
passed, which prevents having to scrape it from the article and/or look it
183+
''' Takes HTML article as input and returns an Article. PMID Can also be
184+
passed, which prevents having to scrape it from the article and/or look it
188185
up in PubMed. '''
189-
190-
# Skip rest of processing if this record already exists
191-
if pmid is not None and self.database.article_exists(pmid) and not config.OVERWRITE_EXISTING_ROWS:
192-
return False
193-
186+
194187
html = self.decode_html_entities(html)
195188
soup = BeautifulSoup(html, "lxml")
196189
if pmid is None:
@@ -208,11 +201,6 @@ def parse_article(self, html, pmid=None, metadata_dir=None):
208201

209202
# Get text using readability
210203
text = self._clean_html_with_readability(str(soup))
211-
if self.database.article_exists(pmid):
212-
if config.OVERWRITE_EXISTING_ROWS:
213-
self.database.delete_article(pmid)
214-
else:
215-
return False
216204

217205
self.article = database.Article(text, pmid=pmid, metadata=metadata)
218206
self.extract_neurovault(soup)
@@ -401,6 +389,9 @@ class DefaultSource(Source):
401389
3. JavaScript expansion detection: Identifies elements that might trigger
402390
table expansion via JavaScript (logging only, not implemented)
403391
"""
392+
def __init__(self, config=None, table_dir=None):
393+
super().__init__(config=config, table_dir=table_dir)
394+
404395
def parse_article(self, html, pmid=None, **kwargs):
405396
soup = super(DefaultSource, self).parse_article(html, pmid, **kwargs)
406397
if not soup:

ace/tests/test_ace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def db():
2828

2929
@pytest.fixture(scope="module")
3030
def source_manager(db):
31-
return sources.SourceManager(db)
31+
return sources.SourceManager()
3232

3333

3434
@pytest.mark.vcr(record_mode="once")

0 commit comments

Comments
 (0)