Skip to content

Commit 3e070c8

Browse files
database.api: cleaning (now uses lru_cache)
1 parent 3c0a4a3 commit 3e070c8

File tree

1 file changed

+62
-101
lines changed
  • src/grodecoder/databases

1 file changed

+62
-101
lines changed

src/grodecoder/databases/api.py

Lines changed: 62 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
import itertools
21
import json
32
from collections import Counter
4-
from dataclasses import dataclass
3+
from functools import lru_cache
54
from pathlib import Path
6-
from typing import Iterable, TypeVar
5+
from typing import TypeVar
76

87
from loguru import logger
98

@@ -96,18 +95,32 @@ def read_csml_database() -> list[csml.Residue]:
9695
return _read_database(CSML_DB_PATH, csml.Residue)
9796

9897

99-
ION_DB: list[Ion] = read_ion_database()
100-
SOLVENT_DB: list[Solvent] = read_solvent_database()
101-
AMINO_ACIDS_DB: list[AminoAcid] = read_amino_acid_database()
102-
NUCLEOTIDES_DB: list[Nucleotide] = read_nucleotide_database()
98+
# ION_DB: list[Ion] = read_ion_database()
99+
# SOLVENT_DB: list[Solvent] = read_solvent_database()
100+
# AMINO_ACIDS_DB: list[AminoAcid] = read_amino_acid_database()
101+
# NUCLEOTIDES_DB: list[Nucleotide] = read_nucleotide_database()
102+
# MAD_DB: list[mad.Residue] = read_mad_database()
103+
# CSML_DB: list[csml.Residue] = read_csml_database()
104+
# LIPID_DB: list[Lipid] = _build_lipid_db()
105+
# OTHER_DB: list[Residue] = _build_other_db()
103106

104-
MAD_DB: list[mad.Residue] = read_mad_database()
105-
CSML_DB: list[csml.Residue] = read_csml_database()
107+
108+
@lru_cache(maxsize=1)
109+
def _get_csml_databse():
110+
return read_csml_database()
111+
112+
113+
@lru_cache(maxsize=1)
114+
def _get_mad_databse():
115+
return read_mad_database()
106116

107117

108118
def _build_lipid_db() -> list[Lipid]:
109-
_mad_lipid_resnames = {item.alias for item in MAD_DB if item.family == mad.ResidueFamily.LIPID}
110-
_csml_lipid_resnames = {residue.name for residue in CSML_DB if residue.family == csml.ResidueFamily.LIPID}
119+
mad_db = _get_mad_databse()
120+
csml_db = _get_csml_databse()
121+
122+
_mad_lipid_resnames = {item.alias for item in mad_db if item.family == mad.ResidueFamily.LIPID}
123+
_csml_lipid_resnames = {residue.name for residue in csml_db if residue.family == csml.ResidueFamily.LIPID}
111124

112125
if False:
113126
# IMPORTANT:
@@ -121,27 +134,27 @@ def _build_lipid_db() -> list[Lipid]:
121134

122135
db = {
123136
item.alias: Lipid(description=item.name, residue_name=item.alias)
124-
for item in MAD_DB
137+
for item in mad_db
125138
if item.family == mad.ResidueFamily.LIPID
126139
}
127140
db.update(
128141
{
129142
residue.name: Lipid(description=residue.description, residue_name=residue.name)
130-
for residue in CSML_DB
143+
for residue in csml_db
131144
if residue.family == csml.ResidueFamily.LIPID
132145
}
133146
)
134147
return list(db.values())
135148

136149

137-
LIPID_DB: list[Lipid] = _build_lipid_db()
138-
139-
140150
def _build_other_db() -> list[Residue]:
141151
"""Builds a database of other residues that are not ions, solvents, amino acids, or nucleotides."""
152+
mad_db = _get_mad_databse()
153+
csml_db = _get_csml_databse()
154+
142155
csml_other = {
143156
residue.name: Residue(residue_name=residue.name, description=residue.description)
144-
for residue in CSML_DB
157+
for residue in csml_db
145158
if residue.family
146159
not in {
147160
csml.ResidueFamily.PROTEIN,
@@ -154,7 +167,7 @@ def _build_other_db() -> list[Residue]:
154167

155168
mad_other = {
156169
residue.alias: Residue(residue_name=residue.alias, description=residue.name)
157-
for residue in MAD_DB
170+
for residue in mad_db
158171
if residue.family
159172
not in {
160173
mad.ResidueFamily.PROTEIN,
@@ -168,124 +181,72 @@ def _build_other_db() -> list[Residue]:
168181
return list(by_name.values())
169182

170183

171-
OTHER_DB: list[Residue] = _build_other_db()
172-
173-
174-
def get_other_definitions() -> list[Residue]:
175-
"""Returns the definitions of other residues in the database."""
176-
return OTHER_DB
177-
178-
184+
@lru_cache(maxsize=1)
179185
def get_ion_definitions() -> list[Ion]:
180186
"""Returns the definitions of the ions in the database."""
181-
return ION_DB
187+
return read_ion_database()
182188

183189

190+
@lru_cache(maxsize=1)
184191
def get_solvent_definitions() -> list[Solvent]:
185192
"""Returns the definitions of the solvents in the database."""
186-
return SOLVENT_DB
193+
return read_solvent_database()
187194

188195

196+
@lru_cache(maxsize=1)
189197
def get_amino_acid_definitions() -> list[AminoAcid]:
190198
"""Returns the definitions of the amino acids in the database."""
191-
return AMINO_ACIDS_DB
192-
193-
194-
def get_amino_acid_name_map() -> dict[str, str]:
195-
"""Returns a mapping of amino acid 3-letter names to 1-letter names."""
196-
return {aa.long_name: aa.short_name for aa in AMINO_ACIDS_DB}
197-
198-
199-
def get_nucleotide_name_map() -> dict[str, str]:
200-
"""Returns a mapping of nucleotide 3-letter names to 1-letter names."""
201-
return {nucleotide.residue_name: nucleotide.short_name for nucleotide in NUCLEOTIDES_DB}
199+
return read_amino_acid_database()
202200

203201

202+
@lru_cache(maxsize=1)
204203
def get_nucleotide_definitions() -> list[Nucleotide]:
205204
"""Returns the definitions of the nucleotides in the database."""
206-
return NUCLEOTIDES_DB
205+
return read_nucleotide_database()
207206

208207

208+
@lru_cache(maxsize=1)
209209
def get_lipid_definitions() -> list[Lipid]:
210210
"""Returns the definitions of the lipids in the database."""
211-
return LIPID_DB
211+
return _build_lipid_db()
212+
213+
214+
@lru_cache(maxsize=1)
215+
def get_other_definitions() -> list[Residue]:
216+
"""Returns the definitions of other residues in the database."""
217+
return _build_other_db()
218+
219+
220+
def get_amino_acid_name_map() -> dict[str, str]:
221+
"""Returns a mapping of amino acid 3-letter names to 1-letter names."""
222+
return {aa.long_name: aa.short_name for aa in get_amino_acid_definitions()}
223+
224+
225+
def get_nucleotide_name_map() -> dict[str, str]:
226+
"""Returns a mapping of nucleotide 3-letter names to 1-letter names."""
227+
return {nucleotide.residue_name: nucleotide.short_name for nucleotide in get_nucleotide_definitions()}
212228

213229

214230
def get_ion_names() -> set[str]:
215231
"""Returns the names of the ions in the database."""
216-
return set(ion.residue_name for ion in ION_DB)
232+
return set(ion.residue_name for ion in get_ion_definitions())
217233

218234

219235
def get_solvent_names() -> set[str]:
220236
"""Returns the names of the solvents in the database."""
221-
return set(solvent.residue_name for solvent in SOLVENT_DB)
237+
return set(solvent.residue_name for solvent in get_solvent_definitions())
222238

223239

224240
def get_amino_acid_names() -> set[str]:
225241
"""Returns the names of the amino acids in the database."""
226-
return set(aa.long_name for aa in AMINO_ACIDS_DB)
242+
return set(aa.long_name for aa in get_amino_acid_definitions())
227243

228244

229245
def get_nucleotide_names() -> set[str]:
230246
"""Returns the names of the nucleotides in the database."""
231-
return set(nucleotide.residue_name for nucleotide in NUCLEOTIDES_DB)
247+
return set(nucleotide.residue_name for nucleotide in get_nucleotide_definitions())
232248

233249

234250
def get_lipid_names() -> set[str]:
235251
"""Returns the names of the lipids in the database."""
236-
return set(lipid.residue_name for lipid in LIPID_DB)
237-
238-
239-
@dataclass(frozen=True)
240-
class ResidueDatabase:
241-
"""Database of residues."""
242-
243-
ions: list[Ion]
244-
solvents: list[Solvent]
245-
amino_acids: list[AminoAcid]
246-
nucleotides: list[Nucleotide]
247-
248-
def __post_init__(self):
249-
names = {
250-
"ions": {ion.residue_name for ion in self.ions},
251-
"solvents": {solvent.residue_name for solvent in self.solvents},
252-
"amino_acids": {aa.long_name for aa in self.amino_acids},
253-
"nucleotides": {nucleotide.residue_name for nucleotide in self.nucleotides},
254-
}
255-
256-
combinations = itertools.combinations(names.keys(), 2)
257-
for lhs, rhs in combinations:
258-
duplicates = names[lhs].intersection(names[rhs])
259-
if duplicates:
260-
logger.warning(
261-
f"Residue names {duplicates} are defined in multiple families: {lhs} and {rhs}"
262-
)
263-
264-
265-
class ResidueNotFound(Exception):
266-
"""Raised when a residue with a given name and atom names is not found in the database."""
267-
268-
269-
class DuplicateResidue(Exception):
270-
"""Raised when a residue with a given name and atom names is defined multiple times in the database."""
271-
272-
273-
def _find_using_atom_names(
274-
residue_name: str, atom_names: Iterable[str], database: list[Ion | Solvent]
275-
) -> Ion | Solvent | None:
276-
candidate_residues = [ion for ion in database if ion.residue_name == residue_name]
277-
if not candidate_residues:
278-
return None
279-
280-
actual_atom_names = set(atom_names)
281-
matching_residues = [ion for ion in candidate_residues if set(ion.atom_names) == actual_atom_names]
282-
283-
if len(matching_residues) == 0:
284-
raise ResidueNotFound(f"No residue '{residue_name}' found with atom names {actual_atom_names}")
285-
286-
elif len(matching_residues) > 1:
287-
raise DuplicateResidue(
288-
f"Multiple residues '{residue_name}' found with atom names {actual_atom_names}"
289-
)
290-
291-
return matching_residues[0]
252+
return set(lipid.residue_name for lipid in get_lipid_definitions())

0 commit comments

Comments
 (0)