From 90c370f256be2566129838f55545824105137ae8 Mon Sep 17 00:00:00 2001 From: benoistlaurent Date: Thu, 20 Nov 2025 13:43:03 +0100 Subject: [PATCH 01/10] guess_resolution: add debug info --- src/grodecoder/logging.py | 4 ++++ src/grodecoder/toputils.py | 47 +++++++++++++++++++++++++++++++++----- 2 files changed, 45 insertions(+), 6 deletions(-) diff --git a/src/grodecoder/logging.py b/src/grodecoder/logging.py index 3fb69bd..a4ed78f 100644 --- a/src/grodecoder/logging.py +++ b/src/grodecoder/logging.py @@ -12,7 +12,11 @@ def setup_logging(logfile: Path, debug: bool = False): fmt = "{time:YYYY-MM-DD HH:mm:ss} {level}: {message}" level = "DEBUG" if debug else "INFO" logger.remove() + + # Screen logger. logger.add(sys.stderr, level=level, format=fmt, colorize=True) + + # File logger logger.add(logfile, level=level, format=fmt, colorize=False, mode="w") # Sets up loguru to capture warnings (typically MDAnalysis warnings) diff --git a/src/grodecoder/toputils.py b/src/grodecoder/toputils.py index 878c469..1bd18eb 100644 --- a/src/grodecoder/toputils.py +++ b/src/grodecoder/toputils.py @@ -4,8 +4,10 @@ from typing import Iterable import numpy as np +from loguru import logger from ._typing import Residue, UniverseLike +from .logging import is_logging_debug from .databases import get_amino_acid_name_map, get_nucleotide_name_map from .models import MolecularResolution @@ -42,15 +44,22 @@ def sequence(atoms: UniverseLike) -> str: return "".join(residue_names.get(residue.resname, "X") for residue in getattr(atoms, "residues", [])) -def has_bonds(residue: Residue, cutoff: float = 2.0): +def has_bonds(residue: Residue, distance_cutoff: float = 2.0): """Returns True if the residue has any bonds.""" + return bool(np.any(get_bonds(residue, distance_cutoff))) + + +def get_bonds(residue: Residue, cutoff: float = 2.0): + """Returns the bonds between the atoms of a residue.""" cutoff_squared = cutoff**2 distances_squared = ( np.linalg.norm(residue.atoms.positions[:, None] - residue.atoms.positions, axis=-1) ** 2 ) - np.fill_diagonal(distances_squared, np.inf) # Ignore self-pairs - bonded = distances_squared < cutoff_squared - return bool(np.any(bonded)) + + # ignore self-pairs and permutations + distances_squared[np.tril_indices(distances_squared.shape[0])] = np.inf + + return np.argwhere(distances_squared < cutoff_squared) def has_bonds_between(residue1: Residue, residue2: Residue, cutoff: float = 5.0): @@ -116,7 +125,7 @@ def end_of_chain(): return segments -def guess_resolution(universe: UniverseLike) -> MolecularResolution: +def guess_resolution(universe: UniverseLike, cutoff_distance: float = 2.0) -> MolecularResolution: """Guesses the resolution (i.e. all-atom or coarse grain) of the universe. The resolution is considered coarse-grained if a residue has at least two atoms within a distance of 2.0 Å. @@ -125,10 +134,36 @@ def guess_resolution(universe: UniverseLike) -> MolecularResolution: If any of them have bonds, the resolution is considered all-atom. If none of the first five residues have bonds, the resolution is considered coarse-grained. """ + + def debug(msg): + where = f"{__name__}.guess_resolution" + logger.debug(f"{where}: {msg}") + + def print_bonds(residue): + """Print bonds between atoms inside a residue. Used for debug purposes.""" + + def distance(atom1, atom2): + return (np.linalg.norm(atom1.position - atom2.position) ** 2.0) ** 0.5 + + bonds = get_bonds(residue, cutoff_distance) + for bond in bonds: + left, right = residue.atoms[bond] + bond_str = f"residue {left.resname}:{left.resid}, atoms {left.name}-{right.name}" + debug(f"guess_resolution: Found bond: {bond_str} (distance={distance(left, right):.2f})") + pair_str = f"pair{'s' if len(bonds) > 1 else ''}" + debug(f"guess_resolution: detected {len(bonds)} {pair_str} with distance < {cutoff_distance=:.2f}") + # Select the first five residues with at least two atoms. assert (residues := getattr(universe, "residues", [])) and len(residues) > 0 residues = [residue for residue in residues if len(residue.atoms) >= 2][:5] for residue in residues: - if has_bonds(residue, cutoff=2.0): + if has_bonds(residue, cutoff_distance): + if is_logging_debug(): + print_bonds(residue) + logger.debug("Detected resolution ALL_ATOM") return MolecularResolution.ALL_ATOM + debug( + f"No intra-atomic distance within {cutoff_distance:.2f} Å found in the first {len(residues)} residues" + ) + debug("Detected resolution COARSE_GRAINED") return MolecularResolution.COARSE_GRAINED From 8f8f307cdb1ba560753866f3c84a1673bb52c134 Mon Sep 17 00:00:00 2001 From: benoistlaurent Date: Thu, 20 Nov 2025 16:50:29 +0100 Subject: [PATCH 02/10] introducing the Settings class --- pyproject.toml | 1 + src/grodecoder/core.py | 9 +++++++-- src/grodecoder/main.py | 5 ++++- src/grodecoder/settings.py | 20 ++++++++++++++++++++ uv.lock | 25 +++++++++++++++++++++++++ 5 files changed, 57 insertions(+), 3 deletions(-) create mode 100644 src/grodecoder/settings.py diff --git a/pyproject.toml b/pyproject.toml index 1395ed7..d3133f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ dependencies = [ "loguru>=0.7.3", "mdanalysis>=2.9.0", "pydantic>=2.11.3", + "pydantic-settings>=2.12.0", "streamlit>=1.50.0", ] diff --git a/src/grodecoder/core.py b/src/grodecoder/core.py index 986e45b..0f19747 100644 --- a/src/grodecoder/core.py +++ b/src/grodecoder/core.py @@ -7,6 +7,7 @@ from .io import read_universe from .models import Decoded from .toputils import guess_resolution +from .settings import Settings def _now() -> str: @@ -23,10 +24,14 @@ def decode(universe: UniverseLike, bond_threshold: float = 5.0) -> Decoded: def decode_structure( - structure_path: PathLike, coordinates_path: PathLike | None = None, bond_threshold: float = 5.0 + structure_path: PathLike, + settings: Settings, + coordinates_path: PathLike | None = None, ) -> Decoded: """Reads a structure file and decodes it into an inventory of segments.""" universe = read_universe(structure_path, coordinates_path) assert universe.atoms is not None # required by type checker for some reason logger.debug(f"Universe has {len(universe.atoms):,d} atoms") - return decode(universe, bond_threshold=bond_threshold) + + cutoff = settings.chain_detection.distance_cutoff or settings.chain_detection.default_distance_cutoff + return decode(universe, bond_threshold=cutoff) diff --git a/src/grodecoder/main.py b/src/grodecoder/main.py index d9e99d2..37a8aef 100644 --- a/src/grodecoder/main.py +++ b/src/grodecoder/main.py @@ -10,6 +10,7 @@ from .databases import get_database_version from .models import GrodecoderRunOutput from .version import get_version +from .settings import Settings if TYPE_CHECKING: from .cli.args import Arguments as CliArgs @@ -27,11 +28,13 @@ def main(args: "CliArgs"): structure_path = args.structure_file.path coordinates_path = args.coordinates_file.path if args.coordinates_file else None + settings = Settings() + logger.info(f"Processing structure file: {structure_path}") # Decoding. decoded = decode_structure( - structure_path, coordinates_path=coordinates_path, bond_threshold=args.bond_threshold + structure_path, coordinates_path=coordinates_path, settings = settings ) output = GrodecoderRunOutput( diff --git a/src/grodecoder/settings.py b/src/grodecoder/settings.py new file mode 100644 index 0000000..a70e8a9 --- /dev/null +++ b/src/grodecoder/settings.py @@ -0,0 +1,20 @@ +from typing import ClassVar + +from pydantic_settings import BaseSettings + + +class ResolutionDetectionSettings(BaseSettings): + default_distance_cutoff: ClassVar[float] = 2.0 + distance_cutoff: float | None = None + + +class ChainDetectionSettings(BaseSettings): + default_distance_cutoff: ClassVar[float] = 5.0 + distance_cutoff: float | None = None + + +class Settings(BaseSettings): + resolution_detection: ResolutionDetectionSettings = ResolutionDetectionSettings() + chain_detection: ChainDetectionSettings = ChainDetectionSettings() + + debug: bool = False diff --git a/uv.lock b/uv.lock index d834a11..626bc4a 100644 --- a/uv.lock +++ b/uv.lock @@ -369,6 +369,7 @@ dependencies = [ { name = "loguru" }, { name = "mdanalysis" }, { name = "pydantic" }, + { name = "pydantic-settings" }, { name = "streamlit" }, ] @@ -393,6 +394,7 @@ requires-dist = [ { name = "loguru", specifier = ">=0.7.3" }, { name = "mdanalysis", specifier = ">=2.9.0" }, { name = "pydantic", specifier = ">=2.11.3" }, + { name = "pydantic-settings", specifier = ">=2.12.0" }, { name = "streamlit", specifier = ">=1.50.0" }, ] @@ -1061,6 +1063,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9f/ed/068e41660b832bb0b1aa5b58011dea2a3fe0ba7861ff38c4d4904c1c1a99/pydantic_core-2.41.5-cp314-cp314t-win_arm64.whl", hash = "sha256:35b44f37a3199f771c3eaa53051bc8a70cd7b54f333531c59e29fd4db5d15008", size = 1974769, upload-time = "2025-11-04T13:42:01.186Z" }, ] +[[package]] +name = "pydantic-settings" +version = "2.12.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "python-dotenv" }, + { name = "typing-inspection" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/43/4b/ac7e0aae12027748076d72a8764ff1c9d82ca75a7a52622e67ed3f765c54/pydantic_settings-2.12.0.tar.gz", hash = "sha256:005538ef951e3c2a68e1c08b292b5f2e71490def8589d4221b95dab00dafcfd0", size = 194184, upload-time = "2025-11-10T14:25:47.013Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c1/60/5d4751ba3f4a40a6891f24eec885f51afd78d208498268c734e256fb13c4/pydantic_settings-2.12.0-py3-none-any.whl", hash = "sha256:fddb9fd99a5b18da837b29710391e945b1e30c135477f484084ee513adb93809", size = 51880, upload-time = "2025-11-10T14:25:45.546Z" }, +] + [[package]] name = "pydeck" version = "0.9.1" @@ -1132,6 +1148,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" }, ] +[[package]] +name = "python-dotenv" +version = "1.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f0/26/19cadc79a718c5edbec86fd4919a6b6d3f681039a2f6d66d14be94e75fb9/python_dotenv-1.2.1.tar.gz", hash = "sha256:42667e897e16ab0d66954af0e60a9caa94f0fd4ecf3aaf6d2d260eec1aa36ad6", size = 44221, upload-time = "2025-10-26T15:12:10.434Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/1b/a298b06749107c305e1fe0f814c6c74aea7b2f1e10989cb30f544a1b3253/python_dotenv-1.2.1-py3-none-any.whl", hash = "sha256:b81ee9561e9ca4004139c6cbba3a238c32b03e4894671e181b671e8cb8425d61", size = 21230, upload-time = "2025-10-26T15:12:09.109Z" }, +] + [[package]] name = "pytz" version = "2025.2" From ee11931578cb717caff713a53524876499f28370 Mon Sep 17 00:00:00 2001 From: benoistlaurent Date: Fri, 21 Nov 2025 14:34:23 +0100 Subject: [PATCH 03/10] guess resolution: more efficient on large systems --- src/grodecoder/core.py | 4 +++- src/grodecoder/identifier.py | 7 ++++++- src/grodecoder/toputils.py | 14 ++++++++++---- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/src/grodecoder/core.py b/src/grodecoder/core.py index 0f19747..0a7bbdb 100644 --- a/src/grodecoder/core.py +++ b/src/grodecoder/core.py @@ -17,9 +17,11 @@ def _now() -> str: def decode(universe: UniverseLike, bond_threshold: float = 5.0) -> Decoded: """Decodes the universe into an inventory of segments.""" + resolution = guess_resolution(universe, cutoff_distance=1.60) + logger.info(f"Guessed resolution: {resolution}") return Decoded( inventory=identify(universe, bond_threshold=bond_threshold), - resolution=guess_resolution(universe), + resolution=resolution, ) diff --git a/src/grodecoder/identifier.py b/src/grodecoder/identifier.py index f7d4ee5..6a53ae9 100644 --- a/src/grodecoder/identifier.py +++ b/src/grodecoder/identifier.py @@ -156,6 +156,7 @@ def _log_identified_molecules(molecules: list[SmallMolecule], label: str) -> Non def _identify(universe: UniverseLike, bond_threshold: float = 5.0) -> Inventory: """Identifies the molecules in the universe.""" + logger.debug("Residu identification: start") # Ensure the universe is an AtomGroup. universe = universe.select_atoms("all") @@ -211,8 +212,12 @@ def _identify(universe: UniverseLike, bond_threshold: float = 5.0) -> Inventory: ) unknown_molecules.append(molecule) - return Inventory( + logger.debug("Creating inventory") + inventory = Inventory( segments=protein + nucleic, small_molecules=ions + solvents + lipids + others + unknown_molecules, total_number_of_atoms=total_number_of_atoms, ) + + logger.debug("Residu identification: end") + return inventory diff --git a/src/grodecoder/toputils.py b/src/grodecoder/toputils.py index 1bd18eb..390da2f 100644 --- a/src/grodecoder/toputils.py +++ b/src/grodecoder/toputils.py @@ -1,6 +1,7 @@ """Defines utility functions for working with molecular structures.""" import collections +from itertools import islice from typing import Iterable import numpy as np @@ -153,17 +154,22 @@ def distance(atom1, atom2): pair_str = f"pair{'s' if len(bonds) > 1 else ''}" debug(f"guess_resolution: detected {len(bonds)} {pair_str} with distance < {cutoff_distance=:.2f}") - # Select the first five residues with at least two atoms. + debug("start") + + # Makes ty happy. assert (residues := getattr(universe, "residues", [])) and len(residues) > 0 - residues = [residue for residue in residues if len(residue.atoms) >= 2][:5] + + # Selects the first five residues with at least two atoms. + residues = list(islice((residue for residue in residues if len(residue.atoms) > 1), 5)) + for residue in residues: if has_bonds(residue, cutoff_distance): if is_logging_debug(): print_bonds(residue) - logger.debug("Detected resolution ALL_ATOM") + debug("end: detected resolution: ALL_ATOM") return MolecularResolution.ALL_ATOM debug( f"No intra-atomic distance within {cutoff_distance:.2f} Å found in the first {len(residues)} residues" ) - debug("Detected resolution COARSE_GRAINED") + debug("end: detected resolution: COARSE_GRAINED") return MolecularResolution.COARSE_GRAINED From 55282562fd93509e6ff69aa557c5bb15eeff0561 Mon Sep 17 00:00:00 2001 From: benoistlaurent Date: Mon, 24 Nov 2025 16:49:05 +0100 Subject: [PATCH 04/10] proper implementation and use of settings --- src/grodecoder/cli/__init__.py | 2 +- src/grodecoder/cli/args.py | 4 +- src/grodecoder/core.py | 28 ++++++++----- src/grodecoder/identifier.py | 20 ++++++++- src/grodecoder/main.py | 10 ++--- src/grodecoder/settings.py | 75 +++++++++++++++++++++++++++++++++- src/grodecoder/toputils.py | 2 +- 7 files changed, 117 insertions(+), 24 deletions(-) diff --git a/src/grodecoder/cli/__init__.py b/src/grodecoder/cli/__init__.py index fab9282..efa4044 100644 --- a/src/grodecoder/cli/__init__.py +++ b/src/grodecoder/cli/__init__.py @@ -11,7 +11,6 @@ @click.argument("coordinates_file", type=CoordinatesFile, required=False) @click.option( "--bond-threshold", - default=5.0, type=float, help="Threshold for interchain bond detection (default: 5 Å)", ) @@ -29,6 +28,7 @@ def cli(**kwargs): args = CliArgs( structure_file=kwargs["structure_file"], coordinates_file=kwargs["coordinates_file"], + bond_threshold=kwargs["bond_threshold"], no_atom_ids=kwargs["no_atom_ids"], print_to_stdout=kwargs["stdout"], ) diff --git a/src/grodecoder/cli/args.py b/src/grodecoder/cli/args.py index 68f7c75..db51a0e 100644 --- a/src/grodecoder/cli/args.py +++ b/src/grodecoder/cli/args.py @@ -59,14 +59,14 @@ class Arguments: Attrs: structure_file (Path): Path to the structure file. coordinates_file (Path): Path to the coordinates file. - bond_threshold (float): Threshold for interchain bond detection. + bond_threshold (float | None): Threshold for interchain bond detection. no_atom_ids (bool): If True, use compact serialization (no atom indices). print_to_stdout (bool): Whether to output results to stdout. """ structure_file: StructureFile coordinates_file: CoordinatesFile | None = None - bond_threshold: float = 5.0 + bond_threshold: float | None = None no_atom_ids: bool = True print_to_stdout: bool = False diff --git a/src/grodecoder/core.py b/src/grodecoder/core.py index 0a7bbdb..f6fb269 100644 --- a/src/grodecoder/core.py +++ b/src/grodecoder/core.py @@ -7,7 +7,7 @@ from .io import read_universe from .models import Decoded from .toputils import guess_resolution -from .settings import Settings +from .settings import get_settings def _now() -> str: @@ -15,25 +15,31 @@ def _now() -> str: return datetime.now().strftime("%Y-%m-%d %H:%M:%S") -def decode(universe: UniverseLike, bond_threshold: float = 5.0) -> Decoded: +def decode(universe: UniverseLike) -> Decoded: """Decodes the universe into an inventory of segments.""" resolution = guess_resolution(universe, cutoff_distance=1.60) logger.info(f"Guessed resolution: {resolution}") + + chain_detection_settings = get_settings().chain_detection + + if chain_detection_settings.distance_cutoff.is_set(): + value = chain_detection_settings.distance_cutoff.get() + logger.debug(f"chain detection: using user-defined value: {value:.2f}") + else: + logger.debug("chain detection: guessing distance cutoff based on resolution") + chain_detection_settings.distance_cutoff.guess(resolution) + + distance_cutoff = chain_detection_settings.distance_cutoff.get() + return Decoded( - inventory=identify(universe, bond_threshold=bond_threshold), + inventory=identify(universe, bond_threshold=distance_cutoff), resolution=resolution, ) -def decode_structure( - structure_path: PathLike, - settings: Settings, - coordinates_path: PathLike | None = None, -) -> Decoded: +def decode_structure(structure_path: PathLike, coordinates_path: PathLike | None = None) -> Decoded: """Reads a structure file and decodes it into an inventory of segments.""" universe = read_universe(structure_path, coordinates_path) assert universe.atoms is not None # required by type checker for some reason logger.debug(f"Universe has {len(universe.atoms):,d} atoms") - - cutoff = settings.chain_detection.distance_cutoff or settings.chain_detection.default_distance_cutoff - return decode(universe, bond_threshold=cutoff) + return decode(universe) diff --git a/src/grodecoder/identifier.py b/src/grodecoder/identifier.py index 6a53ae9..c62e5e4 100644 --- a/src/grodecoder/identifier.py +++ b/src/grodecoder/identifier.py @@ -28,6 +28,9 @@ def identify_small_molecule( residue = SmallMolecule( atoms=selection, description=definition.description, molecular_type=molecular_type ) + logger.debug( + f"identified small molecule {residue.description}: {len(selection.residues)} residues, {len(selection.atoms)} atoms" + ) counts.append(residue) return counts @@ -66,18 +69,25 @@ def _select_protein(universe: UniverseLike) -> AtomGroup: selection_str = f"resname {' '.join(protein_residue_names)}" # Exclude methanol residues from the selection. + logger.debug("excluding possible methanol residues (MET) from protein") methanol = _find_methanol(universe) if methanol: selection_str += f" and not index {' '.join(map(str, methanol))}" - return universe.select_atoms(selection_str) + logger.debug("selecting protein") + protein = universe.select_atoms(selection_str) + logger.debug("selecting protein - done") + return protein def _select_nucleic(universe: UniverseLike) -> AtomGroup: """Selects the nucleic acid atoms from the universe.""" nucleic_acid_residue_names = DB.get_nucleotide_names() selection_str = f"resname {' '.join(nucleic_acid_residue_names)}" - return universe.select_atoms(selection_str) + logger.debug("selecting nucleic") + nucleic = universe.select_atoms(selection_str) + logger.debug("selecting nucleic - done") + return nucleic def _iter_chains(atoms: AtomGroup, bond_threshold: float = 5.0) -> Iterator[AtomGroup]: @@ -87,8 +97,14 @@ def _iter_chains(atoms: AtomGroup, bond_threshold: float = 5.0) -> Iterator[Atom """ if len(atoms) == 0: return + logger.debug(f"detecting segments using distance cutoff {bond_threshold:.2f}") segments = toputils.detect_chains(atoms, cutoff=bond_threshold) + + n_seg_str = f"{len(segments)} segment" + "s" if len(segments) > 1 else "" + logger.debug(f"detecting segments - done: found {n_seg_str}") + for start, end in segments: + logger.debug(f"yielding segment containing residues {start} to {end}") yield atoms.residues[start : end + 1].atoms diff --git a/src/grodecoder/main.py b/src/grodecoder/main.py index 37a8aef..ca5cd3d 100644 --- a/src/grodecoder/main.py +++ b/src/grodecoder/main.py @@ -10,7 +10,7 @@ from .databases import get_database_version from .models import GrodecoderRunOutput from .version import get_version -from .settings import Settings +from .settings import get_settings if TYPE_CHECKING: from .cli.args import Arguments as CliArgs @@ -28,14 +28,14 @@ def main(args: "CliArgs"): structure_path = args.structure_file.path coordinates_path = args.coordinates_file.path if args.coordinates_file else None - settings = Settings() + # Storing cli arguments into settings. + s = get_settings() + s.chain_detection.distance_cutoff = args.bond_threshold logger.info(f"Processing structure file: {structure_path}") # Decoding. - decoded = decode_structure( - structure_path, coordinates_path=coordinates_path, settings = settings - ) + decoded = decode_structure(structure_path, coordinates_path=coordinates_path) output = GrodecoderRunOutput( decoded=decoded, diff --git a/src/grodecoder/settings.py b/src/grodecoder/settings.py index a70e8a9..e84768f 100644 --- a/src/grodecoder/settings.py +++ b/src/grodecoder/settings.py @@ -1,16 +1,72 @@ +from functools import lru_cache from typing import ClassVar +from loguru import logger from pydantic_settings import BaseSettings +from .models import MolecularResolution + class ResolutionDetectionSettings(BaseSettings): default_distance_cutoff: ClassVar[float] = 2.0 distance_cutoff: float | None = None +from dataclasses import dataclass + +@dataclass +class DistanceCutoffSetting: + default_distance_cutoff_all_atom: ClassVar[float] = 5.0 + default_distance_cutoff_coarse_grain: ClassVar[float] = 7.0 + _user_distance_cutoff: float | None = None + _guessed_distance_cutoff: float | None = None + + def is_defined(self) -> bool: + """Returns True if the distance cutoff has been set or guessed.""" + return self._user_distance_cutoff or self._guessed_distance_cutoff + + def is_set(self) -> bool: + """Returns True if the distance cutoff has been set.""" + return self._user_distance_cutoff is not None + + def is_guessed(self) -> bool: + """Returns True if the distance cutoff has been guessed.""" + return self._guessed_distance_cutoff is not None + + def get(self) -> float: + if not self.is_defined(): + raise ValueError("`distance_cutoff` must be set or guessed before it is used.") + return self._user_distance_cutoff or self._guessed_distance_cutoff + + def set(self, value: float): + if self.is_guessed(): + self._guessed_distance_cutoff = None + self._user_distance_cutoff = value + + def guess(self, resolution: MolecularResolution): + if resolution == MolecularResolution.ALL_ATOM: + distance_cutoff = self.default_distance_cutoff_all_atom + logger.debug( + f"chain detection: using default distance cutoff for all atom structures: {distance_cutoff:.2f}" + ) + else: + distance_cutoff = self.default_distance_cutoff_coarse_grain + logger.debug( + f"chain detection: using default distance cutoff for coarse grain structures: {distance_cutoff:.2f}" + ) + self._guessed_distance_cutoff = distance_cutoff + + class ChainDetectionSettings(BaseSettings): - default_distance_cutoff: ClassVar[float] = 5.0 - distance_cutoff: float | None = None + _distance_cutoff: DistanceCutoffSetting = DistanceCutoffSetting() + + @property + def distance_cutoff(self) -> DistanceCutoffSetting: + return self._distance_cutoff + + @distance_cutoff.setter + def distance_cutoff(self, value: float): + self._distance_cutoff.set(value) class Settings(BaseSettings): @@ -18,3 +74,18 @@ class Settings(BaseSettings): chain_detection: ChainDetectionSettings = ChainDetectionSettings() debug: bool = False + + +_settings: Settings | None = None + + +@lru_cache() +def get_settings(): + global _settings + if _settings is None: + _settings = Settings() + return _settings + + +def get_chain_detection_settings(): + return get_settings().chain_detection diff --git a/src/grodecoder/toputils.py b/src/grodecoder/toputils.py index 390da2f..5705897 100644 --- a/src/grodecoder/toputils.py +++ b/src/grodecoder/toputils.py @@ -160,7 +160,7 @@ def distance(atom1, atom2): assert (residues := getattr(universe, "residues", [])) and len(residues) > 0 # Selects the first five residues with at least two atoms. - residues = list(islice((residue for residue in residues if len(residue.atoms) > 1), 5)) + residues = list(islice((residue for residue in residues if len(residue.atoms) > 1), 10)) for residue in residues: if has_bonds(residue, cutoff_distance): From f4553fbbb22fa9df773a611a8d780b09ff7f5f6d Mon Sep 17 00:00:00 2001 From: benoistlaurent Date: Tue, 25 Nov 2025 10:57:39 +0100 Subject: [PATCH 05/10] proper management of chain detection distance cutoff setting --- src/grodecoder/core.py | 1 + src/grodecoder/settings.py | 110 +++++++++++++++++++++++++++++++++---- 2 files changed, 99 insertions(+), 12 deletions(-) diff --git a/src/grodecoder/core.py b/src/grodecoder/core.py index f6fb269..1d51e21 100644 --- a/src/grodecoder/core.py +++ b/src/grodecoder/core.py @@ -20,6 +20,7 @@ def decode(universe: UniverseLike) -> Decoded: resolution = guess_resolution(universe, cutoff_distance=1.60) logger.info(f"Guessed resolution: {resolution}") + # Guesses the chain dection distance cutoff if not provided by the user. chain_detection_settings = get_settings().chain_detection if chain_detection_settings.distance_cutoff.is_set(): diff --git a/src/grodecoder/settings.py b/src/grodecoder/settings.py index e84768f..d23f5b6 100644 --- a/src/grodecoder/settings.py +++ b/src/grodecoder/settings.py @@ -1,8 +1,13 @@ +from dataclasses import dataclass from functools import lru_cache from typing import ClassVar from loguru import logger +from pydantic import ConfigDict, Field, GetJsonSchemaHandler +from pydantic.json_schema import JsonSchemaValue +from pydantic_core import core_schema from pydantic_settings import BaseSettings +from typing_extensions import Annotated from .models import MolecularResolution @@ -12,15 +17,17 @@ class ResolutionDetectionSettings(BaseSettings): distance_cutoff: float | None = None -from dataclasses import dataclass - -@dataclass -class DistanceCutoffSetting: +@dataclass(init=False) +class DistanceCutoff: default_distance_cutoff_all_atom: ClassVar[float] = 5.0 default_distance_cutoff_coarse_grain: ClassVar[float] = 7.0 _user_distance_cutoff: float | None = None _guessed_distance_cutoff: float | None = None + def __init__(self, user_value: float | None = None): + if user_value is not None: + self.set(user_value) + def is_defined(self) -> bool: """Returns True if the distance cutoff has been set or guessed.""" return self._user_distance_cutoff or self._guessed_distance_cutoff @@ -57,16 +64,95 @@ def guess(self, resolution: MolecularResolution): self._guessed_distance_cutoff = distance_cutoff -class ChainDetectionSettings(BaseSettings): - _distance_cutoff: DistanceCutoffSetting = DistanceCutoffSetting() +class _DistanceCutoffPydanticAnnotation: + """Allows to serialize / validate DistanceCutoff using pydantic. + + + Examples: + >>> from grodecoder.settings import ChainDetectionSettings + >>> cds = ChainDetectionSettings() + >>> cds.distance_cutoff + DistanceCutoff(_user_distance_cutoff=None, _guessed_distance_cutoff=None) + + >>> # Float assignement + >>> cds.distance_cutoff = 12 + >>> cds.distance_cutoff + DistanceCutoff(_user_distance_cutoff=12.0, _guessed_distance_cutoff=None) + + >>> # None assignment + >>> cds.distance_cutoff = None + >>> cds.distance_cutoff + DistanceCutoff(_user_distance_cutoff=None, _guessed_distance_cutoff=None) + + >>> # Serialization + >>> cds.distance_cutoff = 12 + >>> cds.model_dump() + {'distance_cutoff': 12.0} + + >>> # Validation + >>> as_json = cds.model_dump() + >>> ChainDetectionSettings.model_validate(as_json) + ChainDetectionSettings(distance_cutoff=DistanceCutoff(_user_distance_cutoff=12.0, _guessed_distance_cutoff=None)) + """ + + @classmethod + def __get_pydantic_core_schema__(cls, _source_type, _handler) -> core_schema.CoreSchema: + """ + We return a pydantic_core.CoreSchema that behaves in the following ways: + + * floats will be parsed as `DistanceCutoff` instances with the float as the `_user_distance_cutoff` attribute + * `DistanceCutoff` instances will be parsed as `DistanceCutoff` instances without any changes + * Nothing else will pass validation + * Serialization will always return just a float + """ + + def validate_from_none(value: None) -> DistanceCutoff: + return DistanceCutoff() + + def validate_from_float(value: float) -> DistanceCutoff: + result = DistanceCutoff() + result.set(value) + return result + + from_none_schema = core_schema.chain_schema( + [ + core_schema.none_schema(), + core_schema.no_info_plain_validator_function(validate_from_none), + ] + ) + from_float_schema = core_schema.chain_schema( + [ + core_schema.float_schema(), + core_schema.no_info_plain_validator_function(validate_from_float), + ] + ) + return core_schema.json_or_python_schema( + json_schema=from_float_schema, + python_schema=core_schema.union_schema( + [ + # check if it's an instance first before doing any further work + core_schema.is_instance_schema(DistanceCutoff), + from_none_schema, + from_float_schema, + ] + ), + serialization=core_schema.plain_serializer_function_ser_schema(lambda instance: instance.get()), + ) + + @classmethod + def __get_pydantic_json_schema__( + cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: + return handler(core_schema.float_schema()) + + +# We now create an `Annotated` wrapper that we'll use as the annotation for fields on `BaseModel`s, etc. +PydanticDistanceCutoff = Annotated[DistanceCutoff, _DistanceCutoffPydanticAnnotation] - @property - def distance_cutoff(self) -> DistanceCutoffSetting: - return self._distance_cutoff - @distance_cutoff.setter - def distance_cutoff(self, value: float): - self._distance_cutoff.set(value) +class ChainDetectionSettings(BaseSettings): + model_config = ConfigDict(validate_assignment=True) + distance_cutoff: PydanticDistanceCutoff = Field(default_factory=DistanceCutoff) class Settings(BaseSettings): From dd2807c72e765539e0e8407e0804ded4d8e517a7 Mon Sep 17 00:00:00 2001 From: benoistlaurent Date: Tue, 25 Nov 2025 11:29:27 +0100 Subject: [PATCH 06/10] output input settings to json --- src/grodecoder/cli/__init__.py | 5 ++++- src/grodecoder/core.py | 5 ++++- src/grodecoder/logging.py | 6 +++++- src/grodecoder/main.py | 1 + src/grodecoder/models.py | 6 +++++- src/grodecoder/settings.py | 18 +++++++++--------- src/grodecoder/toputils.py | 7 ++++--- 7 files changed, 32 insertions(+), 16 deletions(-) diff --git a/src/grodecoder/cli/__init__.py b/src/grodecoder/cli/__init__.py index efa4044..0f07530 100644 --- a/src/grodecoder/cli/__init__.py +++ b/src/grodecoder/cli/__init__.py @@ -3,6 +3,7 @@ from ..main import main as grodecoder_main from .args import Arguments as CliArgs from .args import CoordinatesFile, StructureFile +from ..settings import get_settings from ..logging import setup_logging @@ -33,8 +34,10 @@ def cli(**kwargs): print_to_stdout=kwargs["stdout"], ) + get_settings().debug = kwargs["verbose"] + logfile = args.get_log_filename() - setup_logging(logfile, kwargs["verbose"]) + setup_logging(logfile) grodecoder_main(args) diff --git a/src/grodecoder/core.py b/src/grodecoder/core.py index 1d51e21..18a27db 100644 --- a/src/grodecoder/core.py +++ b/src/grodecoder/core.py @@ -17,7 +17,10 @@ def _now() -> str: def decode(universe: UniverseLike) -> Decoded: """Decodes the universe into an inventory of segments.""" - resolution = guess_resolution(universe, cutoff_distance=1.60) + + settings = get_settings() + + resolution = guess_resolution(universe, cutoff_distance=settings.resolution_detection.distance_cutoff) logger.info(f"Guessed resolution: {resolution}") # Guesses the chain dection distance cutoff if not provided by the user. diff --git a/src/grodecoder/logging.py b/src/grodecoder/logging.py index d576868..89f0107 100644 --- a/src/grodecoder/logging.py +++ b/src/grodecoder/logging.py @@ -6,9 +6,13 @@ from loguru import logger +from .settings import get_settings -def setup_logging(logfile: Path, debug: bool = False): + +def setup_logging(logfile: Path): """Sets up logging configuration.""" + debug = get_settings().debug + fmt = "{time:YYYY-MM-DD HH:mm:ss} {level}: {message}" level = "DEBUG" if debug else "INFO" logger.remove() diff --git a/src/grodecoder/main.py b/src/grodecoder/main.py index ca5cd3d..254f6b9 100644 --- a/src/grodecoder/main.py +++ b/src/grodecoder/main.py @@ -42,6 +42,7 @@ def main(args: "CliArgs"): structure_file_checksum=_get_checksum(structure_path), database_version=get_database_version(), grodecoder_version=get_version(), + input_settings = s, ) # Serialization. diff --git a/src/grodecoder/models.py b/src/grodecoder/models.py index 1803ea1..d26c70d 100644 --- a/src/grodecoder/models.py +++ b/src/grodecoder/models.py @@ -1,5 +1,4 @@ from __future__ import annotations -from pydantic import model_validator from enum import StrEnum from typing import Protocol @@ -14,6 +13,7 @@ computed_field, field_serializer, model_serializer, + model_validator, ) from . import toputils @@ -176,6 +176,9 @@ class Decoded(FrozenModel): resolution: MolecularResolution +from .settings import Settings + + class GrodecoderRunOutput(BaseModel): """Output model for grodecoder results. @@ -194,6 +197,7 @@ class GrodecoderRunOutput(BaseModel): structure_file_checksum: str database_version: str grodecoder_version: str + input_settings: Settings # ========================================================================================================= diff --git a/src/grodecoder/settings.py b/src/grodecoder/settings.py index d23f5b6..2aa3ed6 100644 --- a/src/grodecoder/settings.py +++ b/src/grodecoder/settings.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from functools import lru_cache -from typing import ClassVar +from typing import ClassVar, TYPE_CHECKING from loguru import logger from pydantic import ConfigDict, Field, GetJsonSchemaHandler @@ -9,12 +9,8 @@ from pydantic_settings import BaseSettings from typing_extensions import Annotated -from .models import MolecularResolution - - -class ResolutionDetectionSettings(BaseSettings): - default_distance_cutoff: ClassVar[float] = 2.0 - distance_cutoff: float | None = None +if TYPE_CHECKING: + from .models import MolecularResolution @dataclass(init=False) @@ -50,8 +46,8 @@ def set(self, value: float): self._guessed_distance_cutoff = None self._user_distance_cutoff = value - def guess(self, resolution: MolecularResolution): - if resolution == MolecularResolution.ALL_ATOM: + def guess(self, resolution: "MolecularResolution"): + if resolution == "ALL_ATOM": distance_cutoff = self.default_distance_cutoff_all_atom logger.debug( f"chain detection: using default distance cutoff for all atom structures: {distance_cutoff:.2f}" @@ -155,6 +151,10 @@ class ChainDetectionSettings(BaseSettings): distance_cutoff: PydanticDistanceCutoff = Field(default_factory=DistanceCutoff) +class ResolutionDetectionSettings(BaseSettings): + distance_cutoff: float = 1.6 + + class Settings(BaseSettings): resolution_detection: ResolutionDetectionSettings = ResolutionDetectionSettings() chain_detection: ChainDetectionSettings = ChainDetectionSettings() diff --git a/src/grodecoder/toputils.py b/src/grodecoder/toputils.py index 5705897..71e778a 100644 --- a/src/grodecoder/toputils.py +++ b/src/grodecoder/toputils.py @@ -126,10 +126,11 @@ def end_of_chain(): return segments -def guess_resolution(universe: UniverseLike, cutoff_distance: float = 2.0) -> MolecularResolution: +def guess_resolution(universe: UniverseLike, cutoff_distance: float) -> MolecularResolution: """Guesses the resolution (i.e. all-atom or coarse grain) of the universe. - The resolution is considered coarse-grained if a residue has at least two atoms within a distance of 2.0 Å. + The resolution is considered coarse-grained if a residue has at least two atoms within a distance of + `cutoff_distance` Å. Finds the first five residues with at least two atoms and checks if they have bonds. If any of them have bonds, the resolution is considered all-atom. @@ -154,7 +155,7 @@ def distance(atom1, atom2): pair_str = f"pair{'s' if len(bonds) > 1 else ''}" debug(f"guess_resolution: detected {len(bonds)} {pair_str} with distance < {cutoff_distance=:.2f}") - debug("start") + debug(f"start ; {cutoff_distance=:.2f}") # Makes ty happy. assert (residues := getattr(universe, "residues", [])) and len(residues) > 0 From 94019c6e26710a32d838f60900bc13916e5c5a9f Mon Sep 17 00:00:00 2001 From: benoistlaurent Date: Tue, 25 Nov 2025 11:43:26 +0100 Subject: [PATCH 07/10] fixes typing issues --- src/grodecoder/models.py | 4 +--- src/grodecoder/settings.py | 4 ++-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/grodecoder/models.py b/src/grodecoder/models.py index d26c70d..300762a 100644 --- a/src/grodecoder/models.py +++ b/src/grodecoder/models.py @@ -17,6 +17,7 @@ ) from . import toputils +from .settings import Settings class MolecularResolution(StrEnum): @@ -176,9 +177,6 @@ class Decoded(FrozenModel): resolution: MolecularResolution -from .settings import Settings - - class GrodecoderRunOutput(BaseModel): """Output model for grodecoder results. diff --git a/src/grodecoder/settings.py b/src/grodecoder/settings.py index 2aa3ed6..bb4a584 100644 --- a/src/grodecoder/settings.py +++ b/src/grodecoder/settings.py @@ -26,7 +26,7 @@ def __init__(self, user_value: float | None = None): def is_defined(self) -> bool: """Returns True if the distance cutoff has been set or guessed.""" - return self._user_distance_cutoff or self._guessed_distance_cutoff + return any((self._user_distance_cutoff, self._guessed_distance_cutoff)) def is_set(self) -> bool: """Returns True if the distance cutoff has been set.""" @@ -39,7 +39,7 @@ def is_guessed(self) -> bool: def get(self) -> float: if not self.is_defined(): raise ValueError("`distance_cutoff` must be set or guessed before it is used.") - return self._user_distance_cutoff or self._guessed_distance_cutoff + return self._user_distance_cutoff or self._guessed_distance_cutoff # ty: ignore[invalid-return-type] def set(self, value: float): if self.is_guessed(): From 5d1aba3360cf2800c38638373446e362359e14aa Mon Sep 17 00:00:00 2001 From: benoistlaurent Date: Thu, 27 Nov 2025 09:24:47 +0100 Subject: [PATCH 08/10] --no-atom-ids stored into the Settings class --- src/grodecoder/main.py | 10 ++++++---- src/grodecoder/settings.py | 6 ++++++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/grodecoder/main.py b/src/grodecoder/main.py index 254f6b9..c5294e4 100644 --- a/src/grodecoder/main.py +++ b/src/grodecoder/main.py @@ -29,8 +29,9 @@ def main(args: "CliArgs"): coordinates_path = args.coordinates_file.path if args.coordinates_file else None # Storing cli arguments into settings. - s = get_settings() - s.chain_detection.distance_cutoff = args.bond_threshold + settings = get_settings() + settings.chain_detection.distance_cutoff = args.bond_threshold + settings.output.atom_ids = not args.no_atom_ids logger.info(f"Processing structure file: {structure_path}") @@ -42,11 +43,12 @@ def main(args: "CliArgs"): structure_file_checksum=_get_checksum(structure_path), database_version=get_database_version(), grodecoder_version=get_version(), - input_settings = s, + input_settings=settings, ) # Serialization. - serialization_mode = "compact" if args.no_atom_ids else "full" + logger.debug("Creating json output") + serialization_mode = "full" if settings.output.atom_ids else "compact" # Updates run time as late as possible. output_json = output.model_dump(context={"serialization_mode": serialization_mode}) diff --git a/src/grodecoder/settings.py b/src/grodecoder/settings.py index bb4a584..c34b92a 100644 --- a/src/grodecoder/settings.py +++ b/src/grodecoder/settings.py @@ -155,9 +155,15 @@ class ResolutionDetectionSettings(BaseSettings): distance_cutoff: float = 1.6 +class OutputSettings(BaseSettings): + # should we output atom ids? + atom_ids: bool = True + + class Settings(BaseSettings): resolution_detection: ResolutionDetectionSettings = ResolutionDetectionSettings() chain_detection: ChainDetectionSettings = ChainDetectionSettings() + output: OutputSettings = OutputSettings() debug: bool = False From 11531624c195e5d4e0bc3e2e036cc5fc056afc37 Mon Sep 17 00:00:00 2001 From: benoistlaurent Date: Tue, 2 Dec 2025 13:09:07 +0100 Subject: [PATCH 09/10] bugfix: fixes changes in API that were unaccounted for in tests --- src/grodecoder/identifier.py | 4 +-- src/grodecoder/logging.py | 2 ++ src/grodecoder/toputils.py | 19 +++++++------ tests/test_toputils/test_toputils.py | 28 +++++++++---------- .../test_toputils_integration.py | 20 +++++++------ 5 files changed, 41 insertions(+), 32 deletions(-) diff --git a/src/grodecoder/identifier.py b/src/grodecoder/identifier.py index c62e5e4..024a308 100644 --- a/src/grodecoder/identifier.py +++ b/src/grodecoder/identifier.py @@ -97,8 +97,8 @@ def _iter_chains(atoms: AtomGroup, bond_threshold: float = 5.0) -> Iterator[Atom """ if len(atoms) == 0: return - logger.debug(f"detecting segments using distance cutoff {bond_threshold:.2f}") - segments = toputils.detect_chains(atoms, cutoff=bond_threshold) + logger.debug(f"detecting segments using cutoff distance {bond_threshold:.2f}") + segments = toputils.detect_chains(atoms, cutoff_distance=bond_threshold) n_seg_str = f"{len(segments)} segment" + "s" if len(segments) > 1 else "" logger.debug(f"detecting segments - done: found {n_seg_str}") diff --git a/src/grodecoder/logging.py b/src/grodecoder/logging.py index 89f0107..8430d71 100644 --- a/src/grodecoder/logging.py +++ b/src/grodecoder/logging.py @@ -15,6 +15,7 @@ def setup_logging(logfile: Path): fmt = "{time:YYYY-MM-DD HH:mm:ss} {level}: {message}" level = "DEBUG" if debug else "INFO" + logger.remove() # Screen logger. @@ -32,6 +33,7 @@ def showwarning(message, *args, **kwargs): def is_logging_debug() -> bool: """Returns True if at least one logging handler is set to level DEBUG.""" + print("COUCOU", get_logging_level()) return "DEBUG" in get_logging_level() diff --git a/src/grodecoder/toputils.py b/src/grodecoder/toputils.py index 71e778a..49df27d 100644 --- a/src/grodecoder/toputils.py +++ b/src/grodecoder/toputils.py @@ -45,9 +45,9 @@ def sequence(atoms: UniverseLike) -> str: return "".join(residue_names.get(residue.resname, "X") for residue in getattr(atoms, "residues", [])) -def has_bonds(residue: Residue, distance_cutoff: float = 2.0): +def has_bonds(residue: Residue, cutoff_distance: float = 2.0): """Returns True if the residue has any bonds.""" - return bool(np.any(get_bonds(residue, distance_cutoff))) + return bool(np.any(get_bonds(residue, cutoff_distance))) def get_bonds(residue: Residue, cutoff: float = 2.0): @@ -63,9 +63,9 @@ def get_bonds(residue: Residue, cutoff: float = 2.0): return np.argwhere(distances_squared < cutoff_squared) -def has_bonds_between(residue1: Residue, residue2: Residue, cutoff: float = 5.0): +def has_bonds_between(residue1: Residue, residue2: Residue, cutoff_distance: float = 5.0): """Returns True if the two residues are bonded.""" - cutoff_squared = cutoff**2 + cutoff_squared = cutoff_distance**2 distances_squared = ( np.linalg.norm(residue1.atoms.positions[:, None] - residue2.atoms.positions, axis=-1) ** 2 ) @@ -73,7 +73,7 @@ def has_bonds_between(residue1: Residue, residue2: Residue, cutoff: float = 5.0) return bool(np.any(bonded)) -def detect_chains(universe: UniverseLike, cutoff: float = 5.0) -> list[tuple[int, int]]: +def detect_chains(universe: UniverseLike, cutoff_distance: float = 5.0) -> list[tuple[int, int]]: """Detects chains in a set of atoms. Iterates over the residues as detected by MDAnalysis and calculates the bonds @@ -85,7 +85,7 @@ def detect_chains(universe: UniverseLike, cutoff: float = 5.0) -> list[tuple[int universe : AtomGroup The universe to analyze. Typically a protein or a set of residues. - cutoff : float, optional + cutoff_distance : float, optional The cutoff distance to determine if two residues are bonded. Default is 5.0. Returns @@ -109,7 +109,7 @@ def detect_chains(universe: UniverseLike, cutoff: float = 5.0) -> list[tuple[int def end_of_chain(): """The end of a chain is defined as the point where two consecutive residues are not bonded.""" - return not has_bonds_between(current_residue, next_residue, cutoff) + return not has_bonds_between(current_residue, next_residue, cutoff_distance) segments = [] @@ -166,7 +166,10 @@ def distance(atom1, atom2): for residue in residues: if has_bonds(residue, cutoff_distance): if is_logging_debug(): - print_bonds(residue) + try: + print_bonds(residue) # will not work during unit test as we use mocks + except Exception: + pass debug("end: detected resolution: ALL_ATOM") return MolecularResolution.ALL_ATOM debug( diff --git a/tests/test_toputils/test_toputils.py b/tests/test_toputils/test_toputils.py index 49c14de..be36a9a 100644 --- a/tests/test_toputils/test_toputils.py +++ b/tests/test_toputils/test_toputils.py @@ -120,14 +120,14 @@ def test_has_bonds_true(self): mock_residue = Mock() mock_residue.atoms.positions = np.array([[0.0, 0.0, 0.0], [1.5, 0.0, 0.0]]) - result = has_bonds(mock_residue, cutoff=2.0) + result = has_bonds(mock_residue, cutoff_distance=2.0) assert result is True def test_has_bonds_false(self): mock_residue = Mock() mock_residue.atoms.positions = np.array([[0.0, 0.0, 0.0], [5.0, 0.0, 0.0]]) - result = has_bonds(mock_residue, cutoff=2.0) + result = has_bonds(mock_residue, cutoff_distance=2.0) assert result is False def test_has_bonds_single_atom(self): @@ -135,7 +135,7 @@ def test_has_bonds_single_atom(self): mock_residue = Mock() mock_residue.atoms.positions = np.array([[0.0, 0.0, 0.0]]) - result = has_bonds(mock_residue, cutoff=2.0) + result = has_bonds(mock_residue, cutoff_distance=2.0) assert result is False def test_has_bonds_custom_cutoff(self): @@ -144,11 +144,11 @@ def test_has_bonds_custom_cutoff(self): mock_residue.atoms.positions = np.array([[0.0, 0.0, 0.0], [3.0, 0.0, 0.0]]) # Should be False with default cutoff - result1 = has_bonds(mock_residue, cutoff=2.0) + result1 = has_bonds(mock_residue, cutoff_distance=2.0) assert result1 is False # Should be True with larger cutoff - result2 = has_bonds(mock_residue, cutoff=4.0) + result2 = has_bonds(mock_residue, cutoff_distance=4.0) assert result2 is True @@ -162,7 +162,7 @@ def test_has_bonds_between_true(self): mock_residue2 = Mock() mock_residue2.atoms.positions = np.array([[2.0, 0.0, 0.0], [3.0, 0.0, 0.0]]) - result = has_bonds_between(mock_residue1, mock_residue2, cutoff=5.0) + result = has_bonds_between(mock_residue1, mock_residue2, cutoff_distance=5.0) assert result is True def test_has_bonds_between_false(self): @@ -172,7 +172,7 @@ def test_has_bonds_between_false(self): mock_residue2 = Mock() mock_residue2.atoms.positions = np.array([[10.0, 0.0, 0.0]]) - result = has_bonds_between(mock_residue1, mock_residue2, cutoff=5.0) + result = has_bonds_between(mock_residue1, mock_residue2, cutoff_distance=5.0) assert result is False def test_has_bonds_between_custom_cutoff(self): @@ -184,11 +184,11 @@ def test_has_bonds_between_custom_cutoff(self): mock_residue2.atoms.positions = np.array([[4.0, 0.0, 0.0]]) # Should be False with small cutoff - result1 = has_bonds_between(mock_residue1, mock_residue2, cutoff=3.0) + result1 = has_bonds_between(mock_residue1, mock_residue2, cutoff_distance=3.0) assert result1 is False # Should be True with larger cutoff - result2 = has_bonds_between(mock_residue1, mock_residue2, cutoff=6.0) + result2 = has_bonds_between(mock_residue1, mock_residue2, cutoff_distance=6.0) assert result2 is True @@ -205,7 +205,7 @@ def test_detect_chains_single_chain(self, monkeypatch): mock_universe = Mock() mock_universe.residues = [mock_residue1, mock_residue2, mock_residue3] - result = detect_chains(mock_universe, cutoff=5.0) + result = detect_chains(mock_universe, cutoff_distance=5.0) assert result == [(0, 2)] def test_detect_chains_multiple_chains(self, monkeypatch): @@ -224,7 +224,7 @@ def mock_bonds(res1, res2, cutoff): mock_universe = Mock() mock_universe.residues = [mock_residue1, mock_residue2, mock_residue3] - result = detect_chains(mock_universe, cutoff=5.0) + result = detect_chains(mock_universe, cutoff_distance=5.0) # result: first chain is residue 0 and 1, second chain is residue 2 (starts at 2, ends at 2) assert result == [(0, 1), (2, 2)] @@ -246,7 +246,7 @@ def mock_has_bonds(residue, cutoff): mock_universe = Mock() mock_universe.residues = [mock_residue] - result = guess_resolution(mock_universe) + result = guess_resolution(mock_universe, cutoff_distance=12) assert result == MolecularResolution.ALL_ATOM def test_guess_resolution_coarse_grained(self, monkeypatch): @@ -261,7 +261,7 @@ def mock_has_bonds(residue, cutoff): mock_universe = Mock() mock_universe.residues = [mock_residue] * 5 # Ensure we have enough residues - result = guess_resolution(mock_universe) + result = guess_resolution(mock_universe, cutoff_distance=12) assert result == MolecularResolution.COARSE_GRAINED def test_guess_resolution_mixed_first_has_bonds(self, monkeypatch): @@ -275,5 +275,5 @@ def test_guess_resolution_mixed_first_has_bonds(self, monkeypatch): mock_universe = Mock() mock_universe.residues = [mock_residue] * 5 - result = guess_resolution(mock_universe) + result = guess_resolution(mock_universe, cutoff_distance=12) assert result == MolecularResolution.ALL_ATOM diff --git a/tests/test_toputils/test_toputils_integration.py b/tests/test_toputils/test_toputils_integration.py index 2c6e1e0..f826081 100644 --- a/tests/test_toputils/test_toputils_integration.py +++ b/tests/test_toputils/test_toputils_integration.py @@ -14,6 +14,7 @@ guess_resolution, MolecularResolution, ) +from grodecoder.settings import get_settings TEST_DATA_DIR = Path(__file__).parent.parent / "data" GRO_SMALL = TEST_DATA_DIR / "barstar_water_ions.gro" @@ -49,7 +50,9 @@ def test_sequence(self, protein_atoms): result = sequence(protein_atoms) # Uniprot p11540 (first residue is missing in the structure) - expected = "MKKAVINGEQIRSISDLHQTLKKELALPEYYGENLDALWDCLTGWVEYPLVLEWRQFEQSKQLTENGAESVLQVFREAKAEGCDITIILS"[1:] + expected = ( + "MKKAVINGEQIRSISDLHQTLKKELALPEYYGENLDALWDCLTGWVEYPLVLEWRQFEQSKQLTENGAESVLQVFREAKAEGCDITIILS"[1:] + ) assert isinstance(result, str) assert len(result) == len(protein_atoms.residues) @@ -58,7 +61,7 @@ def test_sequence(self, protein_atoms): def test_has_bonds(self, small_universe): """Test has_bonds function with real residue data.""" first_residue = small_universe.residues[0] - result = has_bonds(first_residue, cutoff=2.0) + result = has_bonds(first_residue, cutoff_distance=2.0) assert result is True def test_has_bonds_between_residues(self, small_universe): @@ -66,12 +69,12 @@ def test_has_bonds_between_residues(self, small_universe): residue1 = small_universe.residues[0] residue2 = small_universe.residues[1] - result = has_bonds_between(residue1, residue2, cutoff=2.0) + result = has_bonds_between(residue1, residue2, cutoff_distance=2.0) assert result is True def test_detect_chains(self, protein_atoms): """Test detect_chains function with real protein data.""" - result = detect_chains(protein_atoms, cutoff=5.0) + result = detect_chains(protein_atoms, cutoff_distance=5.0) # Should return a list of tuples of (start, end) indices. assert isinstance(result, list) @@ -82,17 +85,18 @@ def test_detect_chains(self, protein_atoms): assert len(result) == 1 assert result[0] == (0, len(protein_atoms.residues) - 1) - def test_guess_resolution(self, small_universe): """Test guess_resolution function with real data.""" - result = guess_resolution(small_universe) + result = guess_resolution( + small_universe, cutoff_distance=get_settings().resolution_detection.distance_cutoff + ) assert result == MolecularResolution.ALL_ATOM def test_guess_resolution_cg(): """Test guess_resolution with coarse-grained data.""" universe = mda.Universe(GRO_CG) - result = guess_resolution(universe) + result = guess_resolution(universe, cutoff_distance=get_settings().resolution_detection.distance_cutoff) assert result == MolecularResolution.COARSE_GRAINED @@ -100,7 +104,7 @@ def test_detect_chains_big(): """Test detect_chains with a larger universe.""" universe = mda.Universe(GRO_BIG) protein_atoms = universe.select_atoms("protein") - + result = detect_chains(protein_atoms) # GRO_BIG (1BRS) is the complex barstar-barnase, 3 times, i.e. total of 6 chains. From 5c050a3dec231539fa8a0ef000d0ce50d05bd253 Mon Sep 17 00:00:00 2001 From: benoistlaurent Date: Tue, 2 Dec 2025 16:07:45 +0100 Subject: [PATCH 10/10] bugfix: fixes default cutoff distance for coarse-grain detection to preserve regression test results --- src/grodecoder/settings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/grodecoder/settings.py b/src/grodecoder/settings.py index c34b92a..2b7915c 100644 --- a/src/grodecoder/settings.py +++ b/src/grodecoder/settings.py @@ -16,7 +16,7 @@ @dataclass(init=False) class DistanceCutoff: default_distance_cutoff_all_atom: ClassVar[float] = 5.0 - default_distance_cutoff_coarse_grain: ClassVar[float] = 7.0 + default_distance_cutoff_coarse_grain: ClassVar[float] = 6.0 _user_distance_cutoff: float | None = None _guessed_distance_cutoff: float | None = None