Skip to content
159 changes: 103 additions & 56 deletions src/murfey/workflows/clem/register_preprocessing_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
import logging
import traceback
from collections.abc import Collection
from functools import cached_property
from importlib.metadata import entry_points
from pathlib import Path
from typing import Literal, Optional

from pydantic import BaseModel
from pydantic import BaseModel, computed_field
from sqlmodel import Session, select

import murfey.util.db as MurfeyDB
Expand Down Expand Up @@ -53,16 +54,42 @@ class CLEMPreprocessingResult(BaseModel):
resolution: float
extent: list[float] # [x0, x1, y0, y1]


def _is_clem_atlas(result: CLEMPreprocessingResult):
# If an image has a width/height of at least 1.5 mm, it should qualify as an atlas
return (
max(
result.pixels_x * result.pixel_size,
result.pixels_y * result.pixel_size,
# Valid Pydantic decorator not supported by MyPy
@computed_field # type: ignore
@cached_property
def is_denoised(self) -> bool:
"""
The "_Lng_LVCC" suffix appended to a CLEM dataset's position name indicates
that it's a denoised image set of the same position. These results should
override or supersede the original ones once they're available.
"""
return "_Lng_LVCC" in self.series_name

# Valid Pydantic decorator not supported by MyPy
@computed_field # type: ignore
@cached_property
def site_name(self) -> str:
"""
Extract just the name of the site by removing the "_Lng_LVCC" suffix from
the series name.
"""
return self.series_name.replace("_Lng_LVCC", "")

# Valid Pydantic decorator not supported by MyPy
@computed_field # type: ignore
@cached_property
def is_atlas(self) -> bool:
"""
Incoming image sets with a width/height greater/equal to the pre-set threshold
should qualify as an atlas.
"""
return (
max(
self.pixels_x * self.pixel_size,
self.pixels_y * self.pixel_size,
)
>= processing_params.atlas_threshold
)
>= processing_params.atlas_threshold
)


COLOR_FLAGS_MURFEY = {
Expand Down Expand Up @@ -91,51 +118,71 @@ def _register_clem_imaging_site(
result: CLEMPreprocessingResult,
murfey_db: Session,
):
def _register(
entry: MurfeyDB.ImagingSite,
result: CLEMPreprocessingResult,
):
"""
Helper function to update the ImagingSite column values with.
"""

# Is this an atlas or grid square
entry.data_type = "atlas" if result.is_atlas else "grid_square"
# Register file paths
output_file = list(result.output_files.values())[0]
entry.image_path = str(output_file.parent / "*.tiff")
# Shape and resolution information
entry.image_pixels_x = result.pixels_x
entry.image_pixels_y = result.pixels_y
entry.image_pixel_size = result.pixel_size
entry.units = result.units
# Extent of imaged area in real space
entry.x0 = result.extent[0]
entry.x1 = result.extent[1]
entry.y0 = result.extent[2]
entry.y1 = result.extent[3]

# Iteratively add colour channel information
entry.number_of_members = result.number_of_members
for col_name, value in _get_color_flags(result.output_files.keys()).items():
setattr(entry, col_name, value)
entry.collection_mode = _determine_collection_mode(result.output_files.keys())

# Register thumbnail information if present
if result.thumbnails and result.thumbnail_size:
thumbnail = list(result.thumbnails.values())[0]
entry.thumbnail_path = str(thumbnail.parent / "*.png")

thumbnail_height, thumbnail_width = result.thumbnail_size
scaling_factor = min(
thumbnail_height / result.pixels_y, thumbnail_width / result.pixels_x
)
entry.thumbnail_pixel_size = result.pixel_size / scaling_factor
entry.thumbnail_pixels_x = int(round(result.pixels_x * scaling_factor)) or 1
entry.thumbnail_pixels_y = int(round(result.pixels_y * scaling_factor)) or 1
return entry

# Create a new entry if one doesn't already exist
if not (
clem_img_site := murfey_db.exec(
select(MurfeyDB.ImagingSite)
.where(MurfeyDB.ImagingSite.session_id == session_id)
.where(MurfeyDB.ImagingSite.site_name == result.series_name)
.where(MurfeyDB.ImagingSite.site_name == result.site_name)
).one_or_none()
):
clem_img_site = MurfeyDB.ImagingSite(
session_id=session_id, site_name=result.series_name
session_id=session_id,
site_name=result.site_name,
)
clem_img_site = _register(clem_img_site, result)

# Prepare to overwrite existing entry if current result is a denoised dataset
if result.is_denoised:
# Proceed with overwrite if current result is different from existing entry
output_file = list(result.output_files.values())[0]
if str(output_file.parent / "*.tiff") != clem_img_site.image_path:
clem_img_site = _register(clem_img_site, result)

# Add metadata for this series
output_file = list(result.output_files.values())[0]
clem_img_site.image_path = str(output_file.parent / "*tiff")
clem_img_site.data_type = "atlas" if _is_clem_atlas(result) else "grid_square"
clem_img_site.number_of_members = result.number_of_members
for col_name, value in _get_color_flags(result.output_files.keys()).items():
setattr(clem_img_site, col_name, value)
clem_img_site.collection_mode = _determine_collection_mode(
result.output_files.keys()
)
clem_img_site.image_pixels_x = result.pixels_x
clem_img_site.image_pixels_y = result.pixels_y
clem_img_site.image_pixel_size = result.pixel_size
clem_img_site.units = result.units
clem_img_site.x0 = result.extent[0]
clem_img_site.x1 = result.extent[1]
clem_img_site.y0 = result.extent[2]
clem_img_site.y1 = result.extent[3]
# Register thumbnails if they are present
if result.thumbnails and result.thumbnail_size:
thumbnail = list(result.thumbnails.values())[0]
clem_img_site.thumbnail_path = str(thumbnail.parent / "*.png")

thumbnail_height, thumbnail_width = result.thumbnail_size
scaling_factor = min(
thumbnail_height / result.pixels_y, thumbnail_width / result.pixels_x
)
clem_img_site.thumbnail_pixel_size = result.pixel_size / scaling_factor
clem_img_site.thumbnail_pixels_x = (
int(round(result.pixels_x * scaling_factor)) or 1
)
clem_img_site.thumbnail_pixels_y = (
int(round(result.pixels_y * scaling_factor)) or 1
)
murfey_db.add(clem_img_site)
murfey_db.commit()
murfey_db.close()
Expand Down Expand Up @@ -183,12 +230,12 @@ def _register_dcg_and_atlas(
visit_number = visit_name.split("-")[-1]

# Generate name/tag for data colleciton group based on series name
dcg_name = result.series_name.split("--")[0]
if result.series_name.split("--")[1].isdigit():
dcg_name += f"--{result.series_name.split('--')[1]}"
dcg_name = result.site_name.split("--")[0]
if result.site_name.split("--")[1].isdigit():
dcg_name += f"--{result.site_name.split('--')[1]}"

# Determine values for atlas
if _is_clem_atlas(result):
if result.is_atlas:
output_file = list(result.output_files.values())[0]
# Register the thumbnail entries if they are provided
if result.thumbnails and result.thumbnail_size is not None:
Expand Down Expand Up @@ -227,7 +274,7 @@ def _register_dcg_and_atlas(
dcg_entry = dcg_search[0]
# Update atlas if registering atlas dataset
# and data collection group already exists
if _is_clem_atlas(result):
if result.is_atlas:
atlas_message = {
"session_id": session_id,
"dcgid": dcg_entry.id,
Expand Down Expand Up @@ -287,11 +334,11 @@ def _register_dcg_and_atlas(
clem_img_site := murfey_db.exec(
select(MurfeyDB.ImagingSite)
.where(MurfeyDB.ImagingSite.session_id == session_id)
.where(MurfeyDB.ImagingSite.site_name == result.series_name)
.where(MurfeyDB.ImagingSite.site_name == result.site_name)
).one_or_none()
):
clem_img_site = MurfeyDB.ImagingSite(
session_id=session_id, site_name=result.series_name
session_id=session_id, site_name=result.site_name
)

clem_img_site.dcg_id = dcg_entry.id
Expand All @@ -311,9 +358,9 @@ def _register_grid_square(
logger.error("Unable to find transport manager")
return
# Load all entries for the current data collection group
dcg_name = result.series_name.split("--")[0]
if result.series_name.split("--")[1].isdigit():
dcg_name += f"--{result.series_name.split('--')[1]}"
dcg_name = result.site_name.split("--")[0]
if result.site_name.split("--")[1].isdigit():
dcg_name += f"--{result.site_name.split('--')[1]}"

# Check if an atlas has been registered
if not (
Expand Down
52 changes: 35 additions & 17 deletions tests/workflows/clem/test_register_preprocessing_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def generate_preprocessing_messages(
# Construct all the datasets to be tested
datasets: list[tuple[Path, bool, bool, tuple[int, int], float, list[float]]] = [
(
grid_dir / "Overview_1" / "Image_1",
grid_dir / "Overview 1" / "Image 1",
False,
True,
(2400, 2400),
Expand All @@ -59,22 +59,38 @@ def generate_preprocessing_messages(
datasets.extend(
[
(
grid_dir / "TileScan_1" / f"Position_{n}",
grid_dir / "TileScan 1" / f"Position {n + 1}",
True,
False,
(2048, 2048),
1.6e-7,
[0.003, 0.00332768, 0.003, 0.00332768],
)
for n in range(5)
for n in range(3)
]
)
datasets.extend(
[
(
grid_dir / "TileScan 1" / f"Position {n + 1}_Lng_LVCC",
True,
False,
(2048, 2048),
1.6e-7,
[0.003, 0.00332768, 0.003, 0.00332768],
)
for n in range(3)
]
)

messages: list[dict[str, Any]] = []
for dataset in datasets:
for series_path, is_stack, is_montage, shape, pixel_size, extent in datasets:
# Unpack items from list of dataset parameters
series_path = dataset[0]
series_name = str(series_path.relative_to(processed_dir)).replace("/", "--")
series_name = (
str(series_path.relative_to(processed_dir))
.replace("/", "--")
.replace(" ", "_")
)
metadata = series_path / "metadata" / f"{series_path.stem}.xml"
metadata.parent.mkdir(parents=True, exist_ok=True)
metadata.touch(exist_ok=True)
Expand All @@ -89,11 +105,6 @@ def generate_preprocessing_messages(
thumbnail.parent.mkdir(parents=True)
thumbnail.touch(exist_ok=True)
thumbnail_size = (512, 512)
is_stack = dataset[1]
is_montage = dataset[2]
shape = dataset[3]
pixel_size = dataset[4]
extent = dataset[5]

message = {
"session_id": session_id,
Expand Down Expand Up @@ -373,21 +384,23 @@ def test_run_with_db(
else:
assert mock_align_and_merge_call.call_count == len(preprocessing_messages) * 3

# Both databases should have entries for data collection group, and grid squares
# ISPyB database should additionally have an atlas entry
# Murfey's DataCollectionGroup should have an entry
murfey_dcg_search = murfey_db_session.exec(
sm_select(MurfeyDB.DataCollectionGroup).where(
MurfeyDB.DataCollectionGroup.session_id == murfey_session.id
)
).all()
assert len(murfey_dcg_search) == 1

# GridSquare entries should be half the initial number of entries due to overwrites
murfey_gs_search = murfey_db_session.exec(
sm_select(MurfeyDB.GridSquare).where(
MurfeyDB.GridSquare.session_id == murfey_session.id
)
).all()
assert len(murfey_gs_search) == len(preprocessing_messages) - 1
assert len(murfey_gs_search) == (len(preprocessing_messages) - 1) // 2

# ISPyB's DataCollectionGroup should have an entry
murfey_dcg = murfey_dcg_search[0]
ispyb_dcg_search = (
ispyb_db_session.execute(
Expand All @@ -400,6 +413,7 @@ def test_run_with_db(
)
assert len(ispyb_dcg_search) == 1

# Atlas should have an entry
ispyb_dcg = ispyb_dcg_search[0]
ispyb_atlas_search = (
ispyb_db_session.execute(
Expand All @@ -419,12 +433,13 @@ def test_run_with_db(
}
collection_mode = _determine_collection_mode(colors)

# Atlas color flags and collection mode should be set correctly
ispyb_atlas = ispyb_atlas_search[0]
# Check that the Atlas color flags and collection mode are set correctly
for flag, value in color_flags.items():
assert getattr(ispyb_atlas, flag) == value
assert ispyb_atlas.mode == collection_mode

# ISPyB's GrridSquare should have half the number of intiail entries
ispyb_gs_search = (
ispyb_db_session.execute(
sa_select(ISPyBDB.GridSquare).where(
Expand All @@ -434,9 +449,12 @@ def test_run_with_db(
.scalars()
.all()
)
assert len(ispyb_gs_search) == len(preprocessing_messages) - 1
assert len(ispyb_gs_search) == (len(preprocessing_messages) - 1) // 2
for gs in ispyb_gs_search:
# Check that the Atlas color flags and collection mode are set correctly
# Check that all entries point to the denoised images ("_Lng_LVCC")
assert gs.gridSquareImage is not None and "_Lng_LVCC" in gs.gridSquareImage

# Check that the GridSquare color flags and collection mode are set correctly
for flag, value in color_flags.items():
assert getattr(gs, flag) == value
assert gs.mode == collection_mode
Expand Down
Loading