diff --git a/pyproject.toml b/pyproject.toml
index 1395ed7..d3133f7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -9,6 +9,7 @@ dependencies = [
"loguru>=0.7.3",
"mdanalysis>=2.9.0",
"pydantic>=2.11.3",
+ "pydantic-settings>=2.12.0",
"streamlit>=1.50.0",
]
diff --git a/src/grodecoder/cli/__init__.py b/src/grodecoder/cli/__init__.py
index fab9282..0f07530 100644
--- a/src/grodecoder/cli/__init__.py
+++ b/src/grodecoder/cli/__init__.py
@@ -3,6 +3,7 @@
from ..main import main as grodecoder_main
from .args import Arguments as CliArgs
from .args import CoordinatesFile, StructureFile
+from ..settings import get_settings
from ..logging import setup_logging
@@ -11,7 +12,6 @@
@click.argument("coordinates_file", type=CoordinatesFile, required=False)
@click.option(
"--bond-threshold",
- default=5.0,
type=float,
help="Threshold for interchain bond detection (default: 5 Å)",
)
@@ -29,12 +29,15 @@ def cli(**kwargs):
args = CliArgs(
structure_file=kwargs["structure_file"],
coordinates_file=kwargs["coordinates_file"],
+ bond_threshold=kwargs["bond_threshold"],
no_atom_ids=kwargs["no_atom_ids"],
print_to_stdout=kwargs["stdout"],
)
+ get_settings().debug = kwargs["verbose"]
+
logfile = args.get_log_filename()
- setup_logging(logfile, kwargs["verbose"])
+ setup_logging(logfile)
grodecoder_main(args)
diff --git a/src/grodecoder/cli/args.py b/src/grodecoder/cli/args.py
index 68f7c75..db51a0e 100644
--- a/src/grodecoder/cli/args.py
+++ b/src/grodecoder/cli/args.py
@@ -59,14 +59,14 @@ class Arguments:
Attrs:
structure_file (Path): Path to the structure file.
coordinates_file (Path): Path to the coordinates file.
- bond_threshold (float): Threshold for interchain bond detection.
+ bond_threshold (float | None): Threshold for interchain bond detection.
no_atom_ids (bool): If True, use compact serialization (no atom indices).
print_to_stdout (bool): Whether to output results to stdout.
"""
structure_file: StructureFile
coordinates_file: CoordinatesFile | None = None
- bond_threshold: float = 5.0
+ bond_threshold: float | None = None
no_atom_ids: bool = True
print_to_stdout: bool = False
diff --git a/src/grodecoder/core.py b/src/grodecoder/core.py
index 986e45b..18a27db 100644
--- a/src/grodecoder/core.py
+++ b/src/grodecoder/core.py
@@ -7,6 +7,7 @@
from .io import read_universe
from .models import Decoded
from .toputils import guess_resolution
+from .settings import get_settings
def _now() -> str:
@@ -14,19 +15,35 @@ def _now() -> str:
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
-def decode(universe: UniverseLike, bond_threshold: float = 5.0) -> Decoded:
+def decode(universe: UniverseLike) -> Decoded:
"""Decodes the universe into an inventory of segments."""
+
+ settings = get_settings()
+
+ resolution = guess_resolution(universe, cutoff_distance=settings.resolution_detection.distance_cutoff)
+ logger.info(f"Guessed resolution: {resolution}")
+
+ # Guesses the chain dection distance cutoff if not provided by the user.
+ chain_detection_settings = get_settings().chain_detection
+
+ if chain_detection_settings.distance_cutoff.is_set():
+ value = chain_detection_settings.distance_cutoff.get()
+ logger.debug(f"chain detection: using user-defined value: {value:.2f}")
+ else:
+ logger.debug("chain detection: guessing distance cutoff based on resolution")
+ chain_detection_settings.distance_cutoff.guess(resolution)
+
+ distance_cutoff = chain_detection_settings.distance_cutoff.get()
+
return Decoded(
- inventory=identify(universe, bond_threshold=bond_threshold),
- resolution=guess_resolution(universe),
+ inventory=identify(universe, bond_threshold=distance_cutoff),
+ resolution=resolution,
)
-def decode_structure(
- structure_path: PathLike, coordinates_path: PathLike | None = None, bond_threshold: float = 5.0
-) -> Decoded:
+def decode_structure(structure_path: PathLike, coordinates_path: PathLike | None = None) -> Decoded:
"""Reads a structure file and decodes it into an inventory of segments."""
universe = read_universe(structure_path, coordinates_path)
assert universe.atoms is not None # required by type checker for some reason
logger.debug(f"Universe has {len(universe.atoms):,d} atoms")
- return decode(universe, bond_threshold=bond_threshold)
+ return decode(universe)
diff --git a/src/grodecoder/identifier.py b/src/grodecoder/identifier.py
index f7d4ee5..024a308 100644
--- a/src/grodecoder/identifier.py
+++ b/src/grodecoder/identifier.py
@@ -28,6 +28,9 @@ def identify_small_molecule(
residue = SmallMolecule(
atoms=selection, description=definition.description, molecular_type=molecular_type
)
+ logger.debug(
+ f"identified small molecule {residue.description}: {len(selection.residues)} residues, {len(selection.atoms)} atoms"
+ )
counts.append(residue)
return counts
@@ -66,18 +69,25 @@ def _select_protein(universe: UniverseLike) -> AtomGroup:
selection_str = f"resname {' '.join(protein_residue_names)}"
# Exclude methanol residues from the selection.
+ logger.debug("excluding possible methanol residues (MET) from protein")
methanol = _find_methanol(universe)
if methanol:
selection_str += f" and not index {' '.join(map(str, methanol))}"
- return universe.select_atoms(selection_str)
+ logger.debug("selecting protein")
+ protein = universe.select_atoms(selection_str)
+ logger.debug("selecting protein - done")
+ return protein
def _select_nucleic(universe: UniverseLike) -> AtomGroup:
"""Selects the nucleic acid atoms from the universe."""
nucleic_acid_residue_names = DB.get_nucleotide_names()
selection_str = f"resname {' '.join(nucleic_acid_residue_names)}"
- return universe.select_atoms(selection_str)
+ logger.debug("selecting nucleic")
+ nucleic = universe.select_atoms(selection_str)
+ logger.debug("selecting nucleic - done")
+ return nucleic
def _iter_chains(atoms: AtomGroup, bond_threshold: float = 5.0) -> Iterator[AtomGroup]:
@@ -87,8 +97,14 @@ def _iter_chains(atoms: AtomGroup, bond_threshold: float = 5.0) -> Iterator[Atom
"""
if len(atoms) == 0:
return
- segments = toputils.detect_chains(atoms, cutoff=bond_threshold)
+ logger.debug(f"detecting segments using cutoff distance {bond_threshold:.2f}")
+ segments = toputils.detect_chains(atoms, cutoff_distance=bond_threshold)
+
+ n_seg_str = f"{len(segments)} segment" + "s" if len(segments) > 1 else ""
+ logger.debug(f"detecting segments - done: found {n_seg_str}")
+
for start, end in segments:
+ logger.debug(f"yielding segment containing residues {start} to {end}")
yield atoms.residues[start : end + 1].atoms
@@ -156,6 +172,7 @@ def _log_identified_molecules(molecules: list[SmallMolecule], label: str) -> Non
def _identify(universe: UniverseLike, bond_threshold: float = 5.0) -> Inventory:
"""Identifies the molecules in the universe."""
+ logger.debug("Residu identification: start")
# Ensure the universe is an AtomGroup.
universe = universe.select_atoms("all")
@@ -211,8 +228,12 @@ def _identify(universe: UniverseLike, bond_threshold: float = 5.0) -> Inventory:
)
unknown_molecules.append(molecule)
- return Inventory(
+ logger.debug("Creating inventory")
+ inventory = Inventory(
segments=protein + nucleic,
small_molecules=ions + solvents + lipids + others + unknown_molecules,
total_number_of_atoms=total_number_of_atoms,
)
+
+ logger.debug("Residu identification: end")
+ return inventory
diff --git a/src/grodecoder/logging.py b/src/grodecoder/logging.py
index a32df47..8430d71 100644
--- a/src/grodecoder/logging.py
+++ b/src/grodecoder/logging.py
@@ -6,13 +6,22 @@
from loguru import logger
+from .settings import get_settings
-def setup_logging(logfile: Path, debug: bool = False):
+
+def setup_logging(logfile: Path):
"""Sets up logging configuration."""
+ debug = get_settings().debug
+
fmt = "{time:YYYY-MM-DD HH:mm:ss} {level}: {message}"
level = "DEBUG" if debug else "INFO"
+
logger.remove()
+
+ # Screen logger.
logger.add(sys.stderr, level=level, format=fmt, colorize=True)
+
+ # File logger
logger.add(logfile, level=level, format=fmt, colorize=False, mode="w")
# Sets up loguru to capture warnings (typically MDAnalysis warnings)
@@ -24,6 +33,7 @@ def showwarning(message, *args, **kwargs):
def is_logging_debug() -> bool:
"""Returns True if at least one logging handler is set to level DEBUG."""
+ print("COUCOU", get_logging_level())
return "DEBUG" in get_logging_level()
diff --git a/src/grodecoder/main.py b/src/grodecoder/main.py
index d9e99d2..c5294e4 100644
--- a/src/grodecoder/main.py
+++ b/src/grodecoder/main.py
@@ -10,6 +10,7 @@
from .databases import get_database_version
from .models import GrodecoderRunOutput
from .version import get_version
+from .settings import get_settings
if TYPE_CHECKING:
from .cli.args import Arguments as CliArgs
@@ -27,22 +28,27 @@ def main(args: "CliArgs"):
structure_path = args.structure_file.path
coordinates_path = args.coordinates_file.path if args.coordinates_file else None
+ # Storing cli arguments into settings.
+ settings = get_settings()
+ settings.chain_detection.distance_cutoff = args.bond_threshold
+ settings.output.atom_ids = not args.no_atom_ids
+
logger.info(f"Processing structure file: {structure_path}")
# Decoding.
- decoded = decode_structure(
- structure_path, coordinates_path=coordinates_path, bond_threshold=args.bond_threshold
- )
+ decoded = decode_structure(structure_path, coordinates_path=coordinates_path)
output = GrodecoderRunOutput(
decoded=decoded,
structure_file_checksum=_get_checksum(structure_path),
database_version=get_database_version(),
grodecoder_version=get_version(),
+ input_settings=settings,
)
# Serialization.
- serialization_mode = "compact" if args.no_atom_ids else "full"
+ logger.debug("Creating json output")
+ serialization_mode = "full" if settings.output.atom_ids else "compact"
# Updates run time as late as possible.
output_json = output.model_dump(context={"serialization_mode": serialization_mode})
diff --git a/src/grodecoder/models.py b/src/grodecoder/models.py
index 1803ea1..300762a 100644
--- a/src/grodecoder/models.py
+++ b/src/grodecoder/models.py
@@ -1,5 +1,4 @@
from __future__ import annotations
-from pydantic import model_validator
from enum import StrEnum
from typing import Protocol
@@ -14,9 +13,11 @@
computed_field,
field_serializer,
model_serializer,
+ model_validator,
)
from . import toputils
+from .settings import Settings
class MolecularResolution(StrEnum):
@@ -194,6 +195,7 @@ class GrodecoderRunOutput(BaseModel):
structure_file_checksum: str
database_version: str
grodecoder_version: str
+ input_settings: Settings
# =========================================================================================================
diff --git a/src/grodecoder/settings.py b/src/grodecoder/settings.py
new file mode 100644
index 0000000..2b7915c
--- /dev/null
+++ b/src/grodecoder/settings.py
@@ -0,0 +1,183 @@
+from dataclasses import dataclass
+from functools import lru_cache
+from typing import ClassVar, TYPE_CHECKING
+
+from loguru import logger
+from pydantic import ConfigDict, Field, GetJsonSchemaHandler
+from pydantic.json_schema import JsonSchemaValue
+from pydantic_core import core_schema
+from pydantic_settings import BaseSettings
+from typing_extensions import Annotated
+
+if TYPE_CHECKING:
+ from .models import MolecularResolution
+
+
+@dataclass(init=False)
+class DistanceCutoff:
+ default_distance_cutoff_all_atom: ClassVar[float] = 5.0
+ default_distance_cutoff_coarse_grain: ClassVar[float] = 6.0
+ _user_distance_cutoff: float | None = None
+ _guessed_distance_cutoff: float | None = None
+
+ def __init__(self, user_value: float | None = None):
+ if user_value is not None:
+ self.set(user_value)
+
+ def is_defined(self) -> bool:
+ """Returns True if the distance cutoff has been set or guessed."""
+ return any((self._user_distance_cutoff, self._guessed_distance_cutoff))
+
+ def is_set(self) -> bool:
+ """Returns True if the distance cutoff has been set."""
+ return self._user_distance_cutoff is not None
+
+ def is_guessed(self) -> bool:
+ """Returns True if the distance cutoff has been guessed."""
+ return self._guessed_distance_cutoff is not None
+
+ def get(self) -> float:
+ if not self.is_defined():
+ raise ValueError("`distance_cutoff` must be set or guessed before it is used.")
+ return self._user_distance_cutoff or self._guessed_distance_cutoff # ty: ignore[invalid-return-type]
+
+ def set(self, value: float):
+ if self.is_guessed():
+ self._guessed_distance_cutoff = None
+ self._user_distance_cutoff = value
+
+ def guess(self, resolution: "MolecularResolution"):
+ if resolution == "ALL_ATOM":
+ distance_cutoff = self.default_distance_cutoff_all_atom
+ logger.debug(
+ f"chain detection: using default distance cutoff for all atom structures: {distance_cutoff:.2f}"
+ )
+ else:
+ distance_cutoff = self.default_distance_cutoff_coarse_grain
+ logger.debug(
+ f"chain detection: using default distance cutoff for coarse grain structures: {distance_cutoff:.2f}"
+ )
+ self._guessed_distance_cutoff = distance_cutoff
+
+
+class _DistanceCutoffPydanticAnnotation:
+ """Allows to serialize / validate DistanceCutoff using pydantic.
+
+
+ Examples:
+ >>> from grodecoder.settings import ChainDetectionSettings
+ >>> cds = ChainDetectionSettings()
+ >>> cds.distance_cutoff
+ DistanceCutoff(_user_distance_cutoff=None, _guessed_distance_cutoff=None)
+
+ >>> # Float assignement
+ >>> cds.distance_cutoff = 12
+ >>> cds.distance_cutoff
+ DistanceCutoff(_user_distance_cutoff=12.0, _guessed_distance_cutoff=None)
+
+ >>> # None assignment
+ >>> cds.distance_cutoff = None
+ >>> cds.distance_cutoff
+ DistanceCutoff(_user_distance_cutoff=None, _guessed_distance_cutoff=None)
+
+ >>> # Serialization
+ >>> cds.distance_cutoff = 12
+ >>> cds.model_dump()
+ {'distance_cutoff': 12.0}
+
+ >>> # Validation
+ >>> as_json = cds.model_dump()
+ >>> ChainDetectionSettings.model_validate(as_json)
+ ChainDetectionSettings(distance_cutoff=DistanceCutoff(_user_distance_cutoff=12.0, _guessed_distance_cutoff=None))
+ """
+
+ @classmethod
+ def __get_pydantic_core_schema__(cls, _source_type, _handler) -> core_schema.CoreSchema:
+ """
+ We return a pydantic_core.CoreSchema that behaves in the following ways:
+
+ * floats will be parsed as `DistanceCutoff` instances with the float as the `_user_distance_cutoff` attribute
+ * `DistanceCutoff` instances will be parsed as `DistanceCutoff` instances without any changes
+ * Nothing else will pass validation
+ * Serialization will always return just a float
+ """
+
+ def validate_from_none(value: None) -> DistanceCutoff:
+ return DistanceCutoff()
+
+ def validate_from_float(value: float) -> DistanceCutoff:
+ result = DistanceCutoff()
+ result.set(value)
+ return result
+
+ from_none_schema = core_schema.chain_schema(
+ [
+ core_schema.none_schema(),
+ core_schema.no_info_plain_validator_function(validate_from_none),
+ ]
+ )
+ from_float_schema = core_schema.chain_schema(
+ [
+ core_schema.float_schema(),
+ core_schema.no_info_plain_validator_function(validate_from_float),
+ ]
+ )
+ return core_schema.json_or_python_schema(
+ json_schema=from_float_schema,
+ python_schema=core_schema.union_schema(
+ [
+ # check if it's an instance first before doing any further work
+ core_schema.is_instance_schema(DistanceCutoff),
+ from_none_schema,
+ from_float_schema,
+ ]
+ ),
+ serialization=core_schema.plain_serializer_function_ser_schema(lambda instance: instance.get()),
+ )
+
+ @classmethod
+ def __get_pydantic_json_schema__(
+ cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
+ ) -> JsonSchemaValue:
+ return handler(core_schema.float_schema())
+
+
+# We now create an `Annotated` wrapper that we'll use as the annotation for fields on `BaseModel`s, etc.
+PydanticDistanceCutoff = Annotated[DistanceCutoff, _DistanceCutoffPydanticAnnotation]
+
+
+class ChainDetectionSettings(BaseSettings):
+ model_config = ConfigDict(validate_assignment=True)
+ distance_cutoff: PydanticDistanceCutoff = Field(default_factory=DistanceCutoff)
+
+
+class ResolutionDetectionSettings(BaseSettings):
+ distance_cutoff: float = 1.6
+
+
+class OutputSettings(BaseSettings):
+ # should we output atom ids?
+ atom_ids: bool = True
+
+
+class Settings(BaseSettings):
+ resolution_detection: ResolutionDetectionSettings = ResolutionDetectionSettings()
+ chain_detection: ChainDetectionSettings = ChainDetectionSettings()
+ output: OutputSettings = OutputSettings()
+
+ debug: bool = False
+
+
+_settings: Settings | None = None
+
+
+@lru_cache()
+def get_settings():
+ global _settings
+ if _settings is None:
+ _settings = Settings()
+ return _settings
+
+
+def get_chain_detection_settings():
+ return get_settings().chain_detection
diff --git a/src/grodecoder/toputils.py b/src/grodecoder/toputils.py
index 878c469..49df27d 100644
--- a/src/grodecoder/toputils.py
+++ b/src/grodecoder/toputils.py
@@ -1,11 +1,14 @@
"""Defines utility functions for working with molecular structures."""
import collections
+from itertools import islice
from typing import Iterable
import numpy as np
+from loguru import logger
from ._typing import Residue, UniverseLike
+from .logging import is_logging_debug
from .databases import get_amino_acid_name_map, get_nucleotide_name_map
from .models import MolecularResolution
@@ -42,20 +45,27 @@ def sequence(atoms: UniverseLike) -> str:
return "".join(residue_names.get(residue.resname, "X") for residue in getattr(atoms, "residues", []))
-def has_bonds(residue: Residue, cutoff: float = 2.0):
+def has_bonds(residue: Residue, cutoff_distance: float = 2.0):
"""Returns True if the residue has any bonds."""
+ return bool(np.any(get_bonds(residue, cutoff_distance)))
+
+
+def get_bonds(residue: Residue, cutoff: float = 2.0):
+ """Returns the bonds between the atoms of a residue."""
cutoff_squared = cutoff**2
distances_squared = (
np.linalg.norm(residue.atoms.positions[:, None] - residue.atoms.positions, axis=-1) ** 2
)
- np.fill_diagonal(distances_squared, np.inf) # Ignore self-pairs
- bonded = distances_squared < cutoff_squared
- return bool(np.any(bonded))
+ # ignore self-pairs and permutations
+ distances_squared[np.tril_indices(distances_squared.shape[0])] = np.inf
-def has_bonds_between(residue1: Residue, residue2: Residue, cutoff: float = 5.0):
+ return np.argwhere(distances_squared < cutoff_squared)
+
+
+def has_bonds_between(residue1: Residue, residue2: Residue, cutoff_distance: float = 5.0):
"""Returns True if the two residues are bonded."""
- cutoff_squared = cutoff**2
+ cutoff_squared = cutoff_distance**2
distances_squared = (
np.linalg.norm(residue1.atoms.positions[:, None] - residue2.atoms.positions, axis=-1) ** 2
)
@@ -63,7 +73,7 @@ def has_bonds_between(residue1: Residue, residue2: Residue, cutoff: float = 5.0)
return bool(np.any(bonded))
-def detect_chains(universe: UniverseLike, cutoff: float = 5.0) -> list[tuple[int, int]]:
+def detect_chains(universe: UniverseLike, cutoff_distance: float = 5.0) -> list[tuple[int, int]]:
"""Detects chains in a set of atoms.
Iterates over the residues as detected by MDAnalysis and calculates the bonds
@@ -75,7 +85,7 @@ def detect_chains(universe: UniverseLike, cutoff: float = 5.0) -> list[tuple[int
universe : AtomGroup
The universe to analyze. Typically a protein or a set of residues.
- cutoff : float, optional
+ cutoff_distance : float, optional
The cutoff distance to determine if two residues are bonded. Default is 5.0.
Returns
@@ -99,7 +109,7 @@ def detect_chains(universe: UniverseLike, cutoff: float = 5.0) -> list[tuple[int
def end_of_chain():
"""The end of a chain is defined as the point where two consecutive residues are not bonded."""
- return not has_bonds_between(current_residue, next_residue, cutoff)
+ return not has_bonds_between(current_residue, next_residue, cutoff_distance)
segments = []
@@ -116,19 +126,54 @@ def end_of_chain():
return segments
-def guess_resolution(universe: UniverseLike) -> MolecularResolution:
+def guess_resolution(universe: UniverseLike, cutoff_distance: float) -> MolecularResolution:
"""Guesses the resolution (i.e. all-atom or coarse grain) of the universe.
- The resolution is considered coarse-grained if a residue has at least two atoms within a distance of 2.0 Å.
+ The resolution is considered coarse-grained if a residue has at least two atoms within a distance of
+ `cutoff_distance` Å.
Finds the first five residues with at least two atoms and checks if they have bonds.
If any of them have bonds, the resolution is considered all-atom.
If none of the first five residues have bonds, the resolution is considered coarse-grained.
"""
- # Select the first five residues with at least two atoms.
+
+ def debug(msg):
+ where = f"{__name__}.guess_resolution"
+ logger.debug(f"{where}: {msg}")
+
+ def print_bonds(residue):
+ """Print bonds between atoms inside a residue. Used for debug purposes."""
+
+ def distance(atom1, atom2):
+ return (np.linalg.norm(atom1.position - atom2.position) ** 2.0) ** 0.5
+
+ bonds = get_bonds(residue, cutoff_distance)
+ for bond in bonds:
+ left, right = residue.atoms[bond]
+ bond_str = f"residue {left.resname}:{left.resid}, atoms {left.name}-{right.name}"
+ debug(f"guess_resolution: Found bond: {bond_str} (distance={distance(left, right):.2f})")
+ pair_str = f"pair{'s' if len(bonds) > 1 else ''}"
+ debug(f"guess_resolution: detected {len(bonds)} {pair_str} with distance < {cutoff_distance=:.2f}")
+
+ debug(f"start ; {cutoff_distance=:.2f}")
+
+ # Makes ty happy.
assert (residues := getattr(universe, "residues", [])) and len(residues) > 0
- residues = [residue for residue in residues if len(residue.atoms) >= 2][:5]
+
+ # Selects the first five residues with at least two atoms.
+ residues = list(islice((residue for residue in residues if len(residue.atoms) > 1), 10))
+
for residue in residues:
- if has_bonds(residue, cutoff=2.0):
+ if has_bonds(residue, cutoff_distance):
+ if is_logging_debug():
+ try:
+ print_bonds(residue) # will not work during unit test as we use mocks
+ except Exception:
+ pass
+ debug("end: detected resolution: ALL_ATOM")
return MolecularResolution.ALL_ATOM
+ debug(
+ f"No intra-atomic distance within {cutoff_distance:.2f} Å found in the first {len(residues)} residues"
+ )
+ debug("end: detected resolution: COARSE_GRAINED")
return MolecularResolution.COARSE_GRAINED
diff --git a/tests/test_toputils/test_toputils.py b/tests/test_toputils/test_toputils.py
index 49c14de..be36a9a 100644
--- a/tests/test_toputils/test_toputils.py
+++ b/tests/test_toputils/test_toputils.py
@@ -120,14 +120,14 @@ def test_has_bonds_true(self):
mock_residue = Mock()
mock_residue.atoms.positions = np.array([[0.0, 0.0, 0.0], [1.5, 0.0, 0.0]])
- result = has_bonds(mock_residue, cutoff=2.0)
+ result = has_bonds(mock_residue, cutoff_distance=2.0)
assert result is True
def test_has_bonds_false(self):
mock_residue = Mock()
mock_residue.atoms.positions = np.array([[0.0, 0.0, 0.0], [5.0, 0.0, 0.0]])
- result = has_bonds(mock_residue, cutoff=2.0)
+ result = has_bonds(mock_residue, cutoff_distance=2.0)
assert result is False
def test_has_bonds_single_atom(self):
@@ -135,7 +135,7 @@ def test_has_bonds_single_atom(self):
mock_residue = Mock()
mock_residue.atoms.positions = np.array([[0.0, 0.0, 0.0]])
- result = has_bonds(mock_residue, cutoff=2.0)
+ result = has_bonds(mock_residue, cutoff_distance=2.0)
assert result is False
def test_has_bonds_custom_cutoff(self):
@@ -144,11 +144,11 @@ def test_has_bonds_custom_cutoff(self):
mock_residue.atoms.positions = np.array([[0.0, 0.0, 0.0], [3.0, 0.0, 0.0]])
# Should be False with default cutoff
- result1 = has_bonds(mock_residue, cutoff=2.0)
+ result1 = has_bonds(mock_residue, cutoff_distance=2.0)
assert result1 is False
# Should be True with larger cutoff
- result2 = has_bonds(mock_residue, cutoff=4.0)
+ result2 = has_bonds(mock_residue, cutoff_distance=4.0)
assert result2 is True
@@ -162,7 +162,7 @@ def test_has_bonds_between_true(self):
mock_residue2 = Mock()
mock_residue2.atoms.positions = np.array([[2.0, 0.0, 0.0], [3.0, 0.0, 0.0]])
- result = has_bonds_between(mock_residue1, mock_residue2, cutoff=5.0)
+ result = has_bonds_between(mock_residue1, mock_residue2, cutoff_distance=5.0)
assert result is True
def test_has_bonds_between_false(self):
@@ -172,7 +172,7 @@ def test_has_bonds_between_false(self):
mock_residue2 = Mock()
mock_residue2.atoms.positions = np.array([[10.0, 0.0, 0.0]])
- result = has_bonds_between(mock_residue1, mock_residue2, cutoff=5.0)
+ result = has_bonds_between(mock_residue1, mock_residue2, cutoff_distance=5.0)
assert result is False
def test_has_bonds_between_custom_cutoff(self):
@@ -184,11 +184,11 @@ def test_has_bonds_between_custom_cutoff(self):
mock_residue2.atoms.positions = np.array([[4.0, 0.0, 0.0]])
# Should be False with small cutoff
- result1 = has_bonds_between(mock_residue1, mock_residue2, cutoff=3.0)
+ result1 = has_bonds_between(mock_residue1, mock_residue2, cutoff_distance=3.0)
assert result1 is False
# Should be True with larger cutoff
- result2 = has_bonds_between(mock_residue1, mock_residue2, cutoff=6.0)
+ result2 = has_bonds_between(mock_residue1, mock_residue2, cutoff_distance=6.0)
assert result2 is True
@@ -205,7 +205,7 @@ def test_detect_chains_single_chain(self, monkeypatch):
mock_universe = Mock()
mock_universe.residues = [mock_residue1, mock_residue2, mock_residue3]
- result = detect_chains(mock_universe, cutoff=5.0)
+ result = detect_chains(mock_universe, cutoff_distance=5.0)
assert result == [(0, 2)]
def test_detect_chains_multiple_chains(self, monkeypatch):
@@ -224,7 +224,7 @@ def mock_bonds(res1, res2, cutoff):
mock_universe = Mock()
mock_universe.residues = [mock_residue1, mock_residue2, mock_residue3]
- result = detect_chains(mock_universe, cutoff=5.0)
+ result = detect_chains(mock_universe, cutoff_distance=5.0)
# result: first chain is residue 0 and 1, second chain is residue 2 (starts at 2, ends at 2)
assert result == [(0, 1), (2, 2)]
@@ -246,7 +246,7 @@ def mock_has_bonds(residue, cutoff):
mock_universe = Mock()
mock_universe.residues = [mock_residue]
- result = guess_resolution(mock_universe)
+ result = guess_resolution(mock_universe, cutoff_distance=12)
assert result == MolecularResolution.ALL_ATOM
def test_guess_resolution_coarse_grained(self, monkeypatch):
@@ -261,7 +261,7 @@ def mock_has_bonds(residue, cutoff):
mock_universe = Mock()
mock_universe.residues = [mock_residue] * 5 # Ensure we have enough residues
- result = guess_resolution(mock_universe)
+ result = guess_resolution(mock_universe, cutoff_distance=12)
assert result == MolecularResolution.COARSE_GRAINED
def test_guess_resolution_mixed_first_has_bonds(self, monkeypatch):
@@ -275,5 +275,5 @@ def test_guess_resolution_mixed_first_has_bonds(self, monkeypatch):
mock_universe = Mock()
mock_universe.residues = [mock_residue] * 5
- result = guess_resolution(mock_universe)
+ result = guess_resolution(mock_universe, cutoff_distance=12)
assert result == MolecularResolution.ALL_ATOM
diff --git a/tests/test_toputils/test_toputils_integration.py b/tests/test_toputils/test_toputils_integration.py
index 2c6e1e0..f826081 100644
--- a/tests/test_toputils/test_toputils_integration.py
+++ b/tests/test_toputils/test_toputils_integration.py
@@ -14,6 +14,7 @@
guess_resolution,
MolecularResolution,
)
+from grodecoder.settings import get_settings
TEST_DATA_DIR = Path(__file__).parent.parent / "data"
GRO_SMALL = TEST_DATA_DIR / "barstar_water_ions.gro"
@@ -49,7 +50,9 @@ def test_sequence(self, protein_atoms):
result = sequence(protein_atoms)
# Uniprot p11540 (first residue is missing in the structure)
- expected = "MKKAVINGEQIRSISDLHQTLKKELALPEYYGENLDALWDCLTGWVEYPLVLEWRQFEQSKQLTENGAESVLQVFREAKAEGCDITIILS"[1:]
+ expected = (
+ "MKKAVINGEQIRSISDLHQTLKKELALPEYYGENLDALWDCLTGWVEYPLVLEWRQFEQSKQLTENGAESVLQVFREAKAEGCDITIILS"[1:]
+ )
assert isinstance(result, str)
assert len(result) == len(protein_atoms.residues)
@@ -58,7 +61,7 @@ def test_sequence(self, protein_atoms):
def test_has_bonds(self, small_universe):
"""Test has_bonds function with real residue data."""
first_residue = small_universe.residues[0]
- result = has_bonds(first_residue, cutoff=2.0)
+ result = has_bonds(first_residue, cutoff_distance=2.0)
assert result is True
def test_has_bonds_between_residues(self, small_universe):
@@ -66,12 +69,12 @@ def test_has_bonds_between_residues(self, small_universe):
residue1 = small_universe.residues[0]
residue2 = small_universe.residues[1]
- result = has_bonds_between(residue1, residue2, cutoff=2.0)
+ result = has_bonds_between(residue1, residue2, cutoff_distance=2.0)
assert result is True
def test_detect_chains(self, protein_atoms):
"""Test detect_chains function with real protein data."""
- result = detect_chains(protein_atoms, cutoff=5.0)
+ result = detect_chains(protein_atoms, cutoff_distance=5.0)
# Should return a list of tuples of (start, end) indices.
assert isinstance(result, list)
@@ -82,17 +85,18 @@ def test_detect_chains(self, protein_atoms):
assert len(result) == 1
assert result[0] == (0, len(protein_atoms.residues) - 1)
-
def test_guess_resolution(self, small_universe):
"""Test guess_resolution function with real data."""
- result = guess_resolution(small_universe)
+ result = guess_resolution(
+ small_universe, cutoff_distance=get_settings().resolution_detection.distance_cutoff
+ )
assert result == MolecularResolution.ALL_ATOM
def test_guess_resolution_cg():
"""Test guess_resolution with coarse-grained data."""
universe = mda.Universe(GRO_CG)
- result = guess_resolution(universe)
+ result = guess_resolution(universe, cutoff_distance=get_settings().resolution_detection.distance_cutoff)
assert result == MolecularResolution.COARSE_GRAINED
@@ -100,7 +104,7 @@ def test_detect_chains_big():
"""Test detect_chains with a larger universe."""
universe = mda.Universe(GRO_BIG)
protein_atoms = universe.select_atoms("protein")
-
+
result = detect_chains(protein_atoms)
# GRO_BIG (1BRS) is the complex barstar-barnase, 3 times, i.e. total of 6 chains.
diff --git a/uv.lock b/uv.lock
index d834a11..626bc4a 100644
--- a/uv.lock
+++ b/uv.lock
@@ -369,6 +369,7 @@ dependencies = [
{ name = "loguru" },
{ name = "mdanalysis" },
{ name = "pydantic" },
+ { name = "pydantic-settings" },
{ name = "streamlit" },
]
@@ -393,6 +394,7 @@ requires-dist = [
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "mdanalysis", specifier = ">=2.9.0" },
{ name = "pydantic", specifier = ">=2.11.3" },
+ { name = "pydantic-settings", specifier = ">=2.12.0" },
{ name = "streamlit", specifier = ">=1.50.0" },
]
@@ -1061,6 +1063,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/9f/ed/068e41660b832bb0b1aa5b58011dea2a3fe0ba7861ff38c4d4904c1c1a99/pydantic_core-2.41.5-cp314-cp314t-win_arm64.whl", hash = "sha256:35b44f37a3199f771c3eaa53051bc8a70cd7b54f333531c59e29fd4db5d15008", size = 1974769, upload-time = "2025-11-04T13:42:01.186Z" },
]
+[[package]]
+name = "pydantic-settings"
+version = "2.12.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "pydantic" },
+ { name = "python-dotenv" },
+ { name = "typing-inspection" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/43/4b/ac7e0aae12027748076d72a8764ff1c9d82ca75a7a52622e67ed3f765c54/pydantic_settings-2.12.0.tar.gz", hash = "sha256:005538ef951e3c2a68e1c08b292b5f2e71490def8589d4221b95dab00dafcfd0", size = 194184, upload-time = "2025-11-10T14:25:47.013Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/c1/60/5d4751ba3f4a40a6891f24eec885f51afd78d208498268c734e256fb13c4/pydantic_settings-2.12.0-py3-none-any.whl", hash = "sha256:fddb9fd99a5b18da837b29710391e945b1e30c135477f484084ee513adb93809", size = 51880, upload-time = "2025-11-10T14:25:45.546Z" },
+]
+
[[package]]
name = "pydeck"
version = "0.9.1"
@@ -1132,6 +1148,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" },
]
+[[package]]
+name = "python-dotenv"
+version = "1.2.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/f0/26/19cadc79a718c5edbec86fd4919a6b6d3f681039a2f6d66d14be94e75fb9/python_dotenv-1.2.1.tar.gz", hash = "sha256:42667e897e16ab0d66954af0e60a9caa94f0fd4ecf3aaf6d2d260eec1aa36ad6", size = 44221, upload-time = "2025-10-26T15:12:10.434Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/14/1b/a298b06749107c305e1fe0f814c6c74aea7b2f1e10989cb30f544a1b3253/python_dotenv-1.2.1-py3-none-any.whl", hash = "sha256:b81ee9561e9ca4004139c6cbba3a238c32b03e4894671e181b671e8cb8425d61", size = 21230, upload-time = "2025-10-26T15:12:09.109Z" },
+]
+
[[package]]
name = "pytz"
version = "2025.2"