diff --git a/pyproject.toml b/pyproject.toml index 1395ed7..d3133f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ dependencies = [ "loguru>=0.7.3", "mdanalysis>=2.9.0", "pydantic>=2.11.3", + "pydantic-settings>=2.12.0", "streamlit>=1.50.0", ] diff --git a/src/grodecoder/cli/__init__.py b/src/grodecoder/cli/__init__.py index fab9282..0f07530 100644 --- a/src/grodecoder/cli/__init__.py +++ b/src/grodecoder/cli/__init__.py @@ -3,6 +3,7 @@ from ..main import main as grodecoder_main from .args import Arguments as CliArgs from .args import CoordinatesFile, StructureFile +from ..settings import get_settings from ..logging import setup_logging @@ -11,7 +12,6 @@ @click.argument("coordinates_file", type=CoordinatesFile, required=False) @click.option( "--bond-threshold", - default=5.0, type=float, help="Threshold for interchain bond detection (default: 5 Å)", ) @@ -29,12 +29,15 @@ def cli(**kwargs): args = CliArgs( structure_file=kwargs["structure_file"], coordinates_file=kwargs["coordinates_file"], + bond_threshold=kwargs["bond_threshold"], no_atom_ids=kwargs["no_atom_ids"], print_to_stdout=kwargs["stdout"], ) + get_settings().debug = kwargs["verbose"] + logfile = args.get_log_filename() - setup_logging(logfile, kwargs["verbose"]) + setup_logging(logfile) grodecoder_main(args) diff --git a/src/grodecoder/cli/args.py b/src/grodecoder/cli/args.py index 68f7c75..db51a0e 100644 --- a/src/grodecoder/cli/args.py +++ b/src/grodecoder/cli/args.py @@ -59,14 +59,14 @@ class Arguments: Attrs: structure_file (Path): Path to the structure file. coordinates_file (Path): Path to the coordinates file. - bond_threshold (float): Threshold for interchain bond detection. + bond_threshold (float | None): Threshold for interchain bond detection. no_atom_ids (bool): If True, use compact serialization (no atom indices). print_to_stdout (bool): Whether to output results to stdout. """ structure_file: StructureFile coordinates_file: CoordinatesFile | None = None - bond_threshold: float = 5.0 + bond_threshold: float | None = None no_atom_ids: bool = True print_to_stdout: bool = False diff --git a/src/grodecoder/core.py b/src/grodecoder/core.py index 986e45b..18a27db 100644 --- a/src/grodecoder/core.py +++ b/src/grodecoder/core.py @@ -7,6 +7,7 @@ from .io import read_universe from .models import Decoded from .toputils import guess_resolution +from .settings import get_settings def _now() -> str: @@ -14,19 +15,35 @@ def _now() -> str: return datetime.now().strftime("%Y-%m-%d %H:%M:%S") -def decode(universe: UniverseLike, bond_threshold: float = 5.0) -> Decoded: +def decode(universe: UniverseLike) -> Decoded: """Decodes the universe into an inventory of segments.""" + + settings = get_settings() + + resolution = guess_resolution(universe, cutoff_distance=settings.resolution_detection.distance_cutoff) + logger.info(f"Guessed resolution: {resolution}") + + # Guesses the chain dection distance cutoff if not provided by the user. + chain_detection_settings = get_settings().chain_detection + + if chain_detection_settings.distance_cutoff.is_set(): + value = chain_detection_settings.distance_cutoff.get() + logger.debug(f"chain detection: using user-defined value: {value:.2f}") + else: + logger.debug("chain detection: guessing distance cutoff based on resolution") + chain_detection_settings.distance_cutoff.guess(resolution) + + distance_cutoff = chain_detection_settings.distance_cutoff.get() + return Decoded( - inventory=identify(universe, bond_threshold=bond_threshold), - resolution=guess_resolution(universe), + inventory=identify(universe, bond_threshold=distance_cutoff), + resolution=resolution, ) -def decode_structure( - structure_path: PathLike, coordinates_path: PathLike | None = None, bond_threshold: float = 5.0 -) -> Decoded: +def decode_structure(structure_path: PathLike, coordinates_path: PathLike | None = None) -> Decoded: """Reads a structure file and decodes it into an inventory of segments.""" universe = read_universe(structure_path, coordinates_path) assert universe.atoms is not None # required by type checker for some reason logger.debug(f"Universe has {len(universe.atoms):,d} atoms") - return decode(universe, bond_threshold=bond_threshold) + return decode(universe) diff --git a/src/grodecoder/identifier.py b/src/grodecoder/identifier.py index f7d4ee5..024a308 100644 --- a/src/grodecoder/identifier.py +++ b/src/grodecoder/identifier.py @@ -28,6 +28,9 @@ def identify_small_molecule( residue = SmallMolecule( atoms=selection, description=definition.description, molecular_type=molecular_type ) + logger.debug( + f"identified small molecule {residue.description}: {len(selection.residues)} residues, {len(selection.atoms)} atoms" + ) counts.append(residue) return counts @@ -66,18 +69,25 @@ def _select_protein(universe: UniverseLike) -> AtomGroup: selection_str = f"resname {' '.join(protein_residue_names)}" # Exclude methanol residues from the selection. + logger.debug("excluding possible methanol residues (MET) from protein") methanol = _find_methanol(universe) if methanol: selection_str += f" and not index {' '.join(map(str, methanol))}" - return universe.select_atoms(selection_str) + logger.debug("selecting protein") + protein = universe.select_atoms(selection_str) + logger.debug("selecting protein - done") + return protein def _select_nucleic(universe: UniverseLike) -> AtomGroup: """Selects the nucleic acid atoms from the universe.""" nucleic_acid_residue_names = DB.get_nucleotide_names() selection_str = f"resname {' '.join(nucleic_acid_residue_names)}" - return universe.select_atoms(selection_str) + logger.debug("selecting nucleic") + nucleic = universe.select_atoms(selection_str) + logger.debug("selecting nucleic - done") + return nucleic def _iter_chains(atoms: AtomGroup, bond_threshold: float = 5.0) -> Iterator[AtomGroup]: @@ -87,8 +97,14 @@ def _iter_chains(atoms: AtomGroup, bond_threshold: float = 5.0) -> Iterator[Atom """ if len(atoms) == 0: return - segments = toputils.detect_chains(atoms, cutoff=bond_threshold) + logger.debug(f"detecting segments using cutoff distance {bond_threshold:.2f}") + segments = toputils.detect_chains(atoms, cutoff_distance=bond_threshold) + + n_seg_str = f"{len(segments)} segment" + "s" if len(segments) > 1 else "" + logger.debug(f"detecting segments - done: found {n_seg_str}") + for start, end in segments: + logger.debug(f"yielding segment containing residues {start} to {end}") yield atoms.residues[start : end + 1].atoms @@ -156,6 +172,7 @@ def _log_identified_molecules(molecules: list[SmallMolecule], label: str) -> Non def _identify(universe: UniverseLike, bond_threshold: float = 5.0) -> Inventory: """Identifies the molecules in the universe.""" + logger.debug("Residu identification: start") # Ensure the universe is an AtomGroup. universe = universe.select_atoms("all") @@ -211,8 +228,12 @@ def _identify(universe: UniverseLike, bond_threshold: float = 5.0) -> Inventory: ) unknown_molecules.append(molecule) - return Inventory( + logger.debug("Creating inventory") + inventory = Inventory( segments=protein + nucleic, small_molecules=ions + solvents + lipids + others + unknown_molecules, total_number_of_atoms=total_number_of_atoms, ) + + logger.debug("Residu identification: end") + return inventory diff --git a/src/grodecoder/logging.py b/src/grodecoder/logging.py index a32df47..8430d71 100644 --- a/src/grodecoder/logging.py +++ b/src/grodecoder/logging.py @@ -6,13 +6,22 @@ from loguru import logger +from .settings import get_settings -def setup_logging(logfile: Path, debug: bool = False): + +def setup_logging(logfile: Path): """Sets up logging configuration.""" + debug = get_settings().debug + fmt = "{time:YYYY-MM-DD HH:mm:ss} {level}: {message}" level = "DEBUG" if debug else "INFO" + logger.remove() + + # Screen logger. logger.add(sys.stderr, level=level, format=fmt, colorize=True) + + # File logger logger.add(logfile, level=level, format=fmt, colorize=False, mode="w") # Sets up loguru to capture warnings (typically MDAnalysis warnings) @@ -24,6 +33,7 @@ def showwarning(message, *args, **kwargs): def is_logging_debug() -> bool: """Returns True if at least one logging handler is set to level DEBUG.""" + print("COUCOU", get_logging_level()) return "DEBUG" in get_logging_level() diff --git a/src/grodecoder/main.py b/src/grodecoder/main.py index d9e99d2..c5294e4 100644 --- a/src/grodecoder/main.py +++ b/src/grodecoder/main.py @@ -10,6 +10,7 @@ from .databases import get_database_version from .models import GrodecoderRunOutput from .version import get_version +from .settings import get_settings if TYPE_CHECKING: from .cli.args import Arguments as CliArgs @@ -27,22 +28,27 @@ def main(args: "CliArgs"): structure_path = args.structure_file.path coordinates_path = args.coordinates_file.path if args.coordinates_file else None + # Storing cli arguments into settings. + settings = get_settings() + settings.chain_detection.distance_cutoff = args.bond_threshold + settings.output.atom_ids = not args.no_atom_ids + logger.info(f"Processing structure file: {structure_path}") # Decoding. - decoded = decode_structure( - structure_path, coordinates_path=coordinates_path, bond_threshold=args.bond_threshold - ) + decoded = decode_structure(structure_path, coordinates_path=coordinates_path) output = GrodecoderRunOutput( decoded=decoded, structure_file_checksum=_get_checksum(structure_path), database_version=get_database_version(), grodecoder_version=get_version(), + input_settings=settings, ) # Serialization. - serialization_mode = "compact" if args.no_atom_ids else "full" + logger.debug("Creating json output") + serialization_mode = "full" if settings.output.atom_ids else "compact" # Updates run time as late as possible. output_json = output.model_dump(context={"serialization_mode": serialization_mode}) diff --git a/src/grodecoder/models.py b/src/grodecoder/models.py index 1803ea1..300762a 100644 --- a/src/grodecoder/models.py +++ b/src/grodecoder/models.py @@ -1,5 +1,4 @@ from __future__ import annotations -from pydantic import model_validator from enum import StrEnum from typing import Protocol @@ -14,9 +13,11 @@ computed_field, field_serializer, model_serializer, + model_validator, ) from . import toputils +from .settings import Settings class MolecularResolution(StrEnum): @@ -194,6 +195,7 @@ class GrodecoderRunOutput(BaseModel): structure_file_checksum: str database_version: str grodecoder_version: str + input_settings: Settings # ========================================================================================================= diff --git a/src/grodecoder/settings.py b/src/grodecoder/settings.py new file mode 100644 index 0000000..2b7915c --- /dev/null +++ b/src/grodecoder/settings.py @@ -0,0 +1,183 @@ +from dataclasses import dataclass +from functools import lru_cache +from typing import ClassVar, TYPE_CHECKING + +from loguru import logger +from pydantic import ConfigDict, Field, GetJsonSchemaHandler +from pydantic.json_schema import JsonSchemaValue +from pydantic_core import core_schema +from pydantic_settings import BaseSettings +from typing_extensions import Annotated + +if TYPE_CHECKING: + from .models import MolecularResolution + + +@dataclass(init=False) +class DistanceCutoff: + default_distance_cutoff_all_atom: ClassVar[float] = 5.0 + default_distance_cutoff_coarse_grain: ClassVar[float] = 6.0 + _user_distance_cutoff: float | None = None + _guessed_distance_cutoff: float | None = None + + def __init__(self, user_value: float | None = None): + if user_value is not None: + self.set(user_value) + + def is_defined(self) -> bool: + """Returns True if the distance cutoff has been set or guessed.""" + return any((self._user_distance_cutoff, self._guessed_distance_cutoff)) + + def is_set(self) -> bool: + """Returns True if the distance cutoff has been set.""" + return self._user_distance_cutoff is not None + + def is_guessed(self) -> bool: + """Returns True if the distance cutoff has been guessed.""" + return self._guessed_distance_cutoff is not None + + def get(self) -> float: + if not self.is_defined(): + raise ValueError("`distance_cutoff` must be set or guessed before it is used.") + return self._user_distance_cutoff or self._guessed_distance_cutoff # ty: ignore[invalid-return-type] + + def set(self, value: float): + if self.is_guessed(): + self._guessed_distance_cutoff = None + self._user_distance_cutoff = value + + def guess(self, resolution: "MolecularResolution"): + if resolution == "ALL_ATOM": + distance_cutoff = self.default_distance_cutoff_all_atom + logger.debug( + f"chain detection: using default distance cutoff for all atom structures: {distance_cutoff:.2f}" + ) + else: + distance_cutoff = self.default_distance_cutoff_coarse_grain + logger.debug( + f"chain detection: using default distance cutoff for coarse grain structures: {distance_cutoff:.2f}" + ) + self._guessed_distance_cutoff = distance_cutoff + + +class _DistanceCutoffPydanticAnnotation: + """Allows to serialize / validate DistanceCutoff using pydantic. + + + Examples: + >>> from grodecoder.settings import ChainDetectionSettings + >>> cds = ChainDetectionSettings() + >>> cds.distance_cutoff + DistanceCutoff(_user_distance_cutoff=None, _guessed_distance_cutoff=None) + + >>> # Float assignement + >>> cds.distance_cutoff = 12 + >>> cds.distance_cutoff + DistanceCutoff(_user_distance_cutoff=12.0, _guessed_distance_cutoff=None) + + >>> # None assignment + >>> cds.distance_cutoff = None + >>> cds.distance_cutoff + DistanceCutoff(_user_distance_cutoff=None, _guessed_distance_cutoff=None) + + >>> # Serialization + >>> cds.distance_cutoff = 12 + >>> cds.model_dump() + {'distance_cutoff': 12.0} + + >>> # Validation + >>> as_json = cds.model_dump() + >>> ChainDetectionSettings.model_validate(as_json) + ChainDetectionSettings(distance_cutoff=DistanceCutoff(_user_distance_cutoff=12.0, _guessed_distance_cutoff=None)) + """ + + @classmethod + def __get_pydantic_core_schema__(cls, _source_type, _handler) -> core_schema.CoreSchema: + """ + We return a pydantic_core.CoreSchema that behaves in the following ways: + + * floats will be parsed as `DistanceCutoff` instances with the float as the `_user_distance_cutoff` attribute + * `DistanceCutoff` instances will be parsed as `DistanceCutoff` instances without any changes + * Nothing else will pass validation + * Serialization will always return just a float + """ + + def validate_from_none(value: None) -> DistanceCutoff: + return DistanceCutoff() + + def validate_from_float(value: float) -> DistanceCutoff: + result = DistanceCutoff() + result.set(value) + return result + + from_none_schema = core_schema.chain_schema( + [ + core_schema.none_schema(), + core_schema.no_info_plain_validator_function(validate_from_none), + ] + ) + from_float_schema = core_schema.chain_schema( + [ + core_schema.float_schema(), + core_schema.no_info_plain_validator_function(validate_from_float), + ] + ) + return core_schema.json_or_python_schema( + json_schema=from_float_schema, + python_schema=core_schema.union_schema( + [ + # check if it's an instance first before doing any further work + core_schema.is_instance_schema(DistanceCutoff), + from_none_schema, + from_float_schema, + ] + ), + serialization=core_schema.plain_serializer_function_ser_schema(lambda instance: instance.get()), + ) + + @classmethod + def __get_pydantic_json_schema__( + cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: + return handler(core_schema.float_schema()) + + +# We now create an `Annotated` wrapper that we'll use as the annotation for fields on `BaseModel`s, etc. +PydanticDistanceCutoff = Annotated[DistanceCutoff, _DistanceCutoffPydanticAnnotation] + + +class ChainDetectionSettings(BaseSettings): + model_config = ConfigDict(validate_assignment=True) + distance_cutoff: PydanticDistanceCutoff = Field(default_factory=DistanceCutoff) + + +class ResolutionDetectionSettings(BaseSettings): + distance_cutoff: float = 1.6 + + +class OutputSettings(BaseSettings): + # should we output atom ids? + atom_ids: bool = True + + +class Settings(BaseSettings): + resolution_detection: ResolutionDetectionSettings = ResolutionDetectionSettings() + chain_detection: ChainDetectionSettings = ChainDetectionSettings() + output: OutputSettings = OutputSettings() + + debug: bool = False + + +_settings: Settings | None = None + + +@lru_cache() +def get_settings(): + global _settings + if _settings is None: + _settings = Settings() + return _settings + + +def get_chain_detection_settings(): + return get_settings().chain_detection diff --git a/src/grodecoder/toputils.py b/src/grodecoder/toputils.py index 878c469..49df27d 100644 --- a/src/grodecoder/toputils.py +++ b/src/grodecoder/toputils.py @@ -1,11 +1,14 @@ """Defines utility functions for working with molecular structures.""" import collections +from itertools import islice from typing import Iterable import numpy as np +from loguru import logger from ._typing import Residue, UniverseLike +from .logging import is_logging_debug from .databases import get_amino_acid_name_map, get_nucleotide_name_map from .models import MolecularResolution @@ -42,20 +45,27 @@ def sequence(atoms: UniverseLike) -> str: return "".join(residue_names.get(residue.resname, "X") for residue in getattr(atoms, "residues", [])) -def has_bonds(residue: Residue, cutoff: float = 2.0): +def has_bonds(residue: Residue, cutoff_distance: float = 2.0): """Returns True if the residue has any bonds.""" + return bool(np.any(get_bonds(residue, cutoff_distance))) + + +def get_bonds(residue: Residue, cutoff: float = 2.0): + """Returns the bonds between the atoms of a residue.""" cutoff_squared = cutoff**2 distances_squared = ( np.linalg.norm(residue.atoms.positions[:, None] - residue.atoms.positions, axis=-1) ** 2 ) - np.fill_diagonal(distances_squared, np.inf) # Ignore self-pairs - bonded = distances_squared < cutoff_squared - return bool(np.any(bonded)) + # ignore self-pairs and permutations + distances_squared[np.tril_indices(distances_squared.shape[0])] = np.inf -def has_bonds_between(residue1: Residue, residue2: Residue, cutoff: float = 5.0): + return np.argwhere(distances_squared < cutoff_squared) + + +def has_bonds_between(residue1: Residue, residue2: Residue, cutoff_distance: float = 5.0): """Returns True if the two residues are bonded.""" - cutoff_squared = cutoff**2 + cutoff_squared = cutoff_distance**2 distances_squared = ( np.linalg.norm(residue1.atoms.positions[:, None] - residue2.atoms.positions, axis=-1) ** 2 ) @@ -63,7 +73,7 @@ def has_bonds_between(residue1: Residue, residue2: Residue, cutoff: float = 5.0) return bool(np.any(bonded)) -def detect_chains(universe: UniverseLike, cutoff: float = 5.0) -> list[tuple[int, int]]: +def detect_chains(universe: UniverseLike, cutoff_distance: float = 5.0) -> list[tuple[int, int]]: """Detects chains in a set of atoms. Iterates over the residues as detected by MDAnalysis and calculates the bonds @@ -75,7 +85,7 @@ def detect_chains(universe: UniverseLike, cutoff: float = 5.0) -> list[tuple[int universe : AtomGroup The universe to analyze. Typically a protein or a set of residues. - cutoff : float, optional + cutoff_distance : float, optional The cutoff distance to determine if two residues are bonded. Default is 5.0. Returns @@ -99,7 +109,7 @@ def detect_chains(universe: UniverseLike, cutoff: float = 5.0) -> list[tuple[int def end_of_chain(): """The end of a chain is defined as the point where two consecutive residues are not bonded.""" - return not has_bonds_between(current_residue, next_residue, cutoff) + return not has_bonds_between(current_residue, next_residue, cutoff_distance) segments = [] @@ -116,19 +126,54 @@ def end_of_chain(): return segments -def guess_resolution(universe: UniverseLike) -> MolecularResolution: +def guess_resolution(universe: UniverseLike, cutoff_distance: float) -> MolecularResolution: """Guesses the resolution (i.e. all-atom or coarse grain) of the universe. - The resolution is considered coarse-grained if a residue has at least two atoms within a distance of 2.0 Å. + The resolution is considered coarse-grained if a residue has at least two atoms within a distance of + `cutoff_distance` Å. Finds the first five residues with at least two atoms and checks if they have bonds. If any of them have bonds, the resolution is considered all-atom. If none of the first five residues have bonds, the resolution is considered coarse-grained. """ - # Select the first five residues with at least two atoms. + + def debug(msg): + where = f"{__name__}.guess_resolution" + logger.debug(f"{where}: {msg}") + + def print_bonds(residue): + """Print bonds between atoms inside a residue. Used for debug purposes.""" + + def distance(atom1, atom2): + return (np.linalg.norm(atom1.position - atom2.position) ** 2.0) ** 0.5 + + bonds = get_bonds(residue, cutoff_distance) + for bond in bonds: + left, right = residue.atoms[bond] + bond_str = f"residue {left.resname}:{left.resid}, atoms {left.name}-{right.name}" + debug(f"guess_resolution: Found bond: {bond_str} (distance={distance(left, right):.2f})") + pair_str = f"pair{'s' if len(bonds) > 1 else ''}" + debug(f"guess_resolution: detected {len(bonds)} {pair_str} with distance < {cutoff_distance=:.2f}") + + debug(f"start ; {cutoff_distance=:.2f}") + + # Makes ty happy. assert (residues := getattr(universe, "residues", [])) and len(residues) > 0 - residues = [residue for residue in residues if len(residue.atoms) >= 2][:5] + + # Selects the first five residues with at least two atoms. + residues = list(islice((residue for residue in residues if len(residue.atoms) > 1), 10)) + for residue in residues: - if has_bonds(residue, cutoff=2.0): + if has_bonds(residue, cutoff_distance): + if is_logging_debug(): + try: + print_bonds(residue) # will not work during unit test as we use mocks + except Exception: + pass + debug("end: detected resolution: ALL_ATOM") return MolecularResolution.ALL_ATOM + debug( + f"No intra-atomic distance within {cutoff_distance:.2f} Å found in the first {len(residues)} residues" + ) + debug("end: detected resolution: COARSE_GRAINED") return MolecularResolution.COARSE_GRAINED diff --git a/tests/test_toputils/test_toputils.py b/tests/test_toputils/test_toputils.py index 49c14de..be36a9a 100644 --- a/tests/test_toputils/test_toputils.py +++ b/tests/test_toputils/test_toputils.py @@ -120,14 +120,14 @@ def test_has_bonds_true(self): mock_residue = Mock() mock_residue.atoms.positions = np.array([[0.0, 0.0, 0.0], [1.5, 0.0, 0.0]]) - result = has_bonds(mock_residue, cutoff=2.0) + result = has_bonds(mock_residue, cutoff_distance=2.0) assert result is True def test_has_bonds_false(self): mock_residue = Mock() mock_residue.atoms.positions = np.array([[0.0, 0.0, 0.0], [5.0, 0.0, 0.0]]) - result = has_bonds(mock_residue, cutoff=2.0) + result = has_bonds(mock_residue, cutoff_distance=2.0) assert result is False def test_has_bonds_single_atom(self): @@ -135,7 +135,7 @@ def test_has_bonds_single_atom(self): mock_residue = Mock() mock_residue.atoms.positions = np.array([[0.0, 0.0, 0.0]]) - result = has_bonds(mock_residue, cutoff=2.0) + result = has_bonds(mock_residue, cutoff_distance=2.0) assert result is False def test_has_bonds_custom_cutoff(self): @@ -144,11 +144,11 @@ def test_has_bonds_custom_cutoff(self): mock_residue.atoms.positions = np.array([[0.0, 0.0, 0.0], [3.0, 0.0, 0.0]]) # Should be False with default cutoff - result1 = has_bonds(mock_residue, cutoff=2.0) + result1 = has_bonds(mock_residue, cutoff_distance=2.0) assert result1 is False # Should be True with larger cutoff - result2 = has_bonds(mock_residue, cutoff=4.0) + result2 = has_bonds(mock_residue, cutoff_distance=4.0) assert result2 is True @@ -162,7 +162,7 @@ def test_has_bonds_between_true(self): mock_residue2 = Mock() mock_residue2.atoms.positions = np.array([[2.0, 0.0, 0.0], [3.0, 0.0, 0.0]]) - result = has_bonds_between(mock_residue1, mock_residue2, cutoff=5.0) + result = has_bonds_between(mock_residue1, mock_residue2, cutoff_distance=5.0) assert result is True def test_has_bonds_between_false(self): @@ -172,7 +172,7 @@ def test_has_bonds_between_false(self): mock_residue2 = Mock() mock_residue2.atoms.positions = np.array([[10.0, 0.0, 0.0]]) - result = has_bonds_between(mock_residue1, mock_residue2, cutoff=5.0) + result = has_bonds_between(mock_residue1, mock_residue2, cutoff_distance=5.0) assert result is False def test_has_bonds_between_custom_cutoff(self): @@ -184,11 +184,11 @@ def test_has_bonds_between_custom_cutoff(self): mock_residue2.atoms.positions = np.array([[4.0, 0.0, 0.0]]) # Should be False with small cutoff - result1 = has_bonds_between(mock_residue1, mock_residue2, cutoff=3.0) + result1 = has_bonds_between(mock_residue1, mock_residue2, cutoff_distance=3.0) assert result1 is False # Should be True with larger cutoff - result2 = has_bonds_between(mock_residue1, mock_residue2, cutoff=6.0) + result2 = has_bonds_between(mock_residue1, mock_residue2, cutoff_distance=6.0) assert result2 is True @@ -205,7 +205,7 @@ def test_detect_chains_single_chain(self, monkeypatch): mock_universe = Mock() mock_universe.residues = [mock_residue1, mock_residue2, mock_residue3] - result = detect_chains(mock_universe, cutoff=5.0) + result = detect_chains(mock_universe, cutoff_distance=5.0) assert result == [(0, 2)] def test_detect_chains_multiple_chains(self, monkeypatch): @@ -224,7 +224,7 @@ def mock_bonds(res1, res2, cutoff): mock_universe = Mock() mock_universe.residues = [mock_residue1, mock_residue2, mock_residue3] - result = detect_chains(mock_universe, cutoff=5.0) + result = detect_chains(mock_universe, cutoff_distance=5.0) # result: first chain is residue 0 and 1, second chain is residue 2 (starts at 2, ends at 2) assert result == [(0, 1), (2, 2)] @@ -246,7 +246,7 @@ def mock_has_bonds(residue, cutoff): mock_universe = Mock() mock_universe.residues = [mock_residue] - result = guess_resolution(mock_universe) + result = guess_resolution(mock_universe, cutoff_distance=12) assert result == MolecularResolution.ALL_ATOM def test_guess_resolution_coarse_grained(self, monkeypatch): @@ -261,7 +261,7 @@ def mock_has_bonds(residue, cutoff): mock_universe = Mock() mock_universe.residues = [mock_residue] * 5 # Ensure we have enough residues - result = guess_resolution(mock_universe) + result = guess_resolution(mock_universe, cutoff_distance=12) assert result == MolecularResolution.COARSE_GRAINED def test_guess_resolution_mixed_first_has_bonds(self, monkeypatch): @@ -275,5 +275,5 @@ def test_guess_resolution_mixed_first_has_bonds(self, monkeypatch): mock_universe = Mock() mock_universe.residues = [mock_residue] * 5 - result = guess_resolution(mock_universe) + result = guess_resolution(mock_universe, cutoff_distance=12) assert result == MolecularResolution.ALL_ATOM diff --git a/tests/test_toputils/test_toputils_integration.py b/tests/test_toputils/test_toputils_integration.py index 2c6e1e0..f826081 100644 --- a/tests/test_toputils/test_toputils_integration.py +++ b/tests/test_toputils/test_toputils_integration.py @@ -14,6 +14,7 @@ guess_resolution, MolecularResolution, ) +from grodecoder.settings import get_settings TEST_DATA_DIR = Path(__file__).parent.parent / "data" GRO_SMALL = TEST_DATA_DIR / "barstar_water_ions.gro" @@ -49,7 +50,9 @@ def test_sequence(self, protein_atoms): result = sequence(protein_atoms) # Uniprot p11540 (first residue is missing in the structure) - expected = "MKKAVINGEQIRSISDLHQTLKKELALPEYYGENLDALWDCLTGWVEYPLVLEWRQFEQSKQLTENGAESVLQVFREAKAEGCDITIILS"[1:] + expected = ( + "MKKAVINGEQIRSISDLHQTLKKELALPEYYGENLDALWDCLTGWVEYPLVLEWRQFEQSKQLTENGAESVLQVFREAKAEGCDITIILS"[1:] + ) assert isinstance(result, str) assert len(result) == len(protein_atoms.residues) @@ -58,7 +61,7 @@ def test_sequence(self, protein_atoms): def test_has_bonds(self, small_universe): """Test has_bonds function with real residue data.""" first_residue = small_universe.residues[0] - result = has_bonds(first_residue, cutoff=2.0) + result = has_bonds(first_residue, cutoff_distance=2.0) assert result is True def test_has_bonds_between_residues(self, small_universe): @@ -66,12 +69,12 @@ def test_has_bonds_between_residues(self, small_universe): residue1 = small_universe.residues[0] residue2 = small_universe.residues[1] - result = has_bonds_between(residue1, residue2, cutoff=2.0) + result = has_bonds_between(residue1, residue2, cutoff_distance=2.0) assert result is True def test_detect_chains(self, protein_atoms): """Test detect_chains function with real protein data.""" - result = detect_chains(protein_atoms, cutoff=5.0) + result = detect_chains(protein_atoms, cutoff_distance=5.0) # Should return a list of tuples of (start, end) indices. assert isinstance(result, list) @@ -82,17 +85,18 @@ def test_detect_chains(self, protein_atoms): assert len(result) == 1 assert result[0] == (0, len(protein_atoms.residues) - 1) - def test_guess_resolution(self, small_universe): """Test guess_resolution function with real data.""" - result = guess_resolution(small_universe) + result = guess_resolution( + small_universe, cutoff_distance=get_settings().resolution_detection.distance_cutoff + ) assert result == MolecularResolution.ALL_ATOM def test_guess_resolution_cg(): """Test guess_resolution with coarse-grained data.""" universe = mda.Universe(GRO_CG) - result = guess_resolution(universe) + result = guess_resolution(universe, cutoff_distance=get_settings().resolution_detection.distance_cutoff) assert result == MolecularResolution.COARSE_GRAINED @@ -100,7 +104,7 @@ def test_detect_chains_big(): """Test detect_chains with a larger universe.""" universe = mda.Universe(GRO_BIG) protein_atoms = universe.select_atoms("protein") - + result = detect_chains(protein_atoms) # GRO_BIG (1BRS) is the complex barstar-barnase, 3 times, i.e. total of 6 chains. diff --git a/uv.lock b/uv.lock index d834a11..626bc4a 100644 --- a/uv.lock +++ b/uv.lock @@ -369,6 +369,7 @@ dependencies = [ { name = "loguru" }, { name = "mdanalysis" }, { name = "pydantic" }, + { name = "pydantic-settings" }, { name = "streamlit" }, ] @@ -393,6 +394,7 @@ requires-dist = [ { name = "loguru", specifier = ">=0.7.3" }, { name = "mdanalysis", specifier = ">=2.9.0" }, { name = "pydantic", specifier = ">=2.11.3" }, + { name = "pydantic-settings", specifier = ">=2.12.0" }, { name = "streamlit", specifier = ">=1.50.0" }, ] @@ -1061,6 +1063,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9f/ed/068e41660b832bb0b1aa5b58011dea2a3fe0ba7861ff38c4d4904c1c1a99/pydantic_core-2.41.5-cp314-cp314t-win_arm64.whl", hash = "sha256:35b44f37a3199f771c3eaa53051bc8a70cd7b54f333531c59e29fd4db5d15008", size = 1974769, upload-time = "2025-11-04T13:42:01.186Z" }, ] +[[package]] +name = "pydantic-settings" +version = "2.12.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "python-dotenv" }, + { name = "typing-inspection" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/43/4b/ac7e0aae12027748076d72a8764ff1c9d82ca75a7a52622e67ed3f765c54/pydantic_settings-2.12.0.tar.gz", hash = "sha256:005538ef951e3c2a68e1c08b292b5f2e71490def8589d4221b95dab00dafcfd0", size = 194184, upload-time = "2025-11-10T14:25:47.013Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c1/60/5d4751ba3f4a40a6891f24eec885f51afd78d208498268c734e256fb13c4/pydantic_settings-2.12.0-py3-none-any.whl", hash = "sha256:fddb9fd99a5b18da837b29710391e945b1e30c135477f484084ee513adb93809", size = 51880, upload-time = "2025-11-10T14:25:45.546Z" }, +] + [[package]] name = "pydeck" version = "0.9.1" @@ -1132,6 +1148,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" }, ] +[[package]] +name = "python-dotenv" +version = "1.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f0/26/19cadc79a718c5edbec86fd4919a6b6d3f681039a2f6d66d14be94e75fb9/python_dotenv-1.2.1.tar.gz", hash = "sha256:42667e897e16ab0d66954af0e60a9caa94f0fd4ecf3aaf6d2d260eec1aa36ad6", size = 44221, upload-time = "2025-10-26T15:12:10.434Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/1b/a298b06749107c305e1fe0f814c6c74aea7b2f1e10989cb30f544a1b3253/python_dotenv-1.2.1-py3-none-any.whl", hash = "sha256:b81ee9561e9ca4004139c6cbba3a238c32b03e4894671e181b671e8cb8425d61", size = 21230, upload-time = "2025-10-26T15:12:09.109Z" }, +] + [[package]] name = "pytz" version = "2025.2"