diff --git a/.gitignore b/.gitignore index 116e7a3..c709e6e 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,9 @@ build/ .eggs/ *.egg +# Version file +_version.py + # Installer logs pip-log.txt pip-delete-this-directory.txt diff --git a/notebooks/Test pnanolocz-lib.ipynb b/notebooks/Test pnanolocz-lib.ipynb index f3af24b..b440811 100644 --- a/notebooks/Test pnanolocz-lib.ipynb +++ b/notebooks/Test pnanolocz-lib.ipynb @@ -17,7 +17,6 @@ "source": [ "data_path = r\"C:/Users/ggjh246/OneDrive - University of Leeds/Code/playNano_testdata/save-2025.05.20-12.57.06.187.h5-jpk\"\n", "from playnano.io.loader import load_afm_stack\n", - "\n", "afm_stack = load_afm_stack(data_path, channel = \"height_trace\")" ] }, @@ -41,7 +40,7 @@ "metadata": {}, "outputs": [], "source": [ - "from pnanolocz_lib import level, level_auto\n", + "from pnanolocz_lib import level_auto, level, level_weighted\n", "\n", "frames = afm_stack.n_frames\n", "frame_ind = range(0, frames)\n", @@ -68,7 +67,7 @@ "metadata": {}, "outputs": [], "source": [ - "auto_leveled = level_auto.apply_level_auto(afm_stack.data, \"iterative 1nm high\")" + "auto_leveled = level_auto.apply_level_auto(afm_stack.data, \"multi-plane-otsu\")" ] }, { @@ -89,7 +88,6 @@ "outputs": [], "source": [ "from pnanolocz_lib.thresholder import thresholder\n", - "\n", "mask_hist = thresholder(plane_levelled, 'histogram', limits = (0.2, 100),invert = False)" ] }, @@ -109,6 +107,18 @@ "id": "9", "metadata": {}, "outputs": [], + "source": [ + "from pnanolocz_lib.level_weighted import apply_level_weighted\n", + "lev_weight = apply_level_weighted(plane_levelled, 1, 1, \"smed_line\", mask=mask_hist)\n", + "plt.imshow(lev_weight[5], cmap=\"afmhot\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], "source": [ "masked_line_levelled = level.apply_level(plane_levelled, 1, 0, \"line\", mask_hist)\n", "plt.imshow(masked_line_levelled[5], cmap=\"afmhot\")" @@ -117,7 +127,15 @@ { "cell_type": "code", "execution_count": null, - "id": "10", + "id": "11", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", "metadata": {}, "outputs": [], "source": [ @@ -128,7 +146,7 @@ { "cell_type": "code", "execution_count": null, - "id": "11", + "id": "13", "metadata": {}, "outputs": [], "source": [ @@ -139,17 +157,17 @@ { "cell_type": "code", "execution_count": null, - "id": "12", + "id": "14", "metadata": {}, "outputs": [], "source": [ - "mask_hist2 =thresholder(maksed_line_levelled, 'histogram', limits= (0.5,100), invert=False)" + "mask_hist2 =thresholder(maksed_linemed_levelled, 'histogram', limits= (0.5,100), invert=False)" ] }, { "cell_type": "code", "execution_count": null, - "id": "13", + "id": "15", "metadata": {}, "outputs": [], "source": [ @@ -160,7 +178,7 @@ { "cell_type": "code", "execution_count": null, - "id": "14", + "id": "16", "metadata": {}, "outputs": [], "source": [ @@ -171,7 +189,7 @@ { "cell_type": "code", "execution_count": null, - "id": "15", + "id": "17", "metadata": {}, "outputs": [], "source": [] @@ -179,7 +197,7 @@ { "cell_type": "code", "execution_count": null, - "id": "16", + "id": "18", "metadata": {}, "outputs": [], "source": [ @@ -189,7 +207,7 @@ { "cell_type": "code", "execution_count": null, - "id": "17", + "id": "19", "metadata": {}, "outputs": [], "source": [ @@ -199,7 +217,7 @@ { "cell_type": "code", "execution_count": null, - "id": "18", + "id": "20", "metadata": {}, "outputs": [], "source": [ @@ -209,7 +227,7 @@ { "cell_type": "code", "execution_count": null, - "id": "19", + "id": "21", "metadata": {}, "outputs": [], "source": [] diff --git a/pyproject.toml b/pyproject.toml index af420c5..832d47b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,7 +71,6 @@ classifiers = [ "Topic :: Scientific/Engineering :: Image Processing", ] - [project.optional-dependencies] dev = [ "pytest>=7.0", @@ -82,6 +81,7 @@ dev = [ "black", "codespell>=2.3.0", "mypy==1.11.1", + "nbstripout", ] [project.urls] @@ -93,6 +93,9 @@ package-dir = { "" = "src" } [tool.setuptools.packages.find] where = ["src"] +[tool.setuptools_scm] +write_to = "src/pnanolocz_lib/_version.py" + [tool.black] line-length = 88 target-version = ["py310"] @@ -138,7 +141,7 @@ ignore = ["D100", "D104", "D203", "D212", "D401"] convention = "numpy" [tool.codespell] -skip = "*.egg-info,*.png,*.jpg,*.jpeg,*.gif,*.bmp,*.svg,*.ico,*.pdf,*.zip,*.gz,*.bz2,*.xz,*.7z,*.mp3,*.wav,*.mp4,*.mov,*.tif,*.tiff,*.ipynb,docs/_build/*,build/*,dist/*" +skip = "*.egg-info,*.png,*.jpg,*.jpeg,*.gif,*.bmp,*.svg,*.ico,*.pdf,*.zip,*.gz,*.bz2,*.xz,*.7z,*.mp3,*.wav,*.mp4,*.mov,*.tif,*.tiff,*.ipynb,docs/_build/*,build/*,dist/*,*.html,*.css,*.js" ignore-words-list = "tabel,crate" check-filenames = false check-hidden = false @@ -156,12 +159,18 @@ namespace_packages = false # Your existing strictness strict = true -warn_unused_ignores = true +warn_unused_ignores = false disallow_untyped_defs = true exclude = '^tests/' [[tool.mypy.overrides]] -module = ["scipy.*", "skimage.*", "ruptures", "sknw"] +module = [ + "scipy.*", + "skimage.*", + "ruptures", + "sknw", + "numpy.*" +] ignore_missing_imports = true [tool.pytest.ini_options] diff --git a/src/pnanolocz_lib/__init__.py b/src/pnanolocz_lib/__init__.py index 2bcb06c..8ab1595 100644 --- a/src/pnanolocz_lib/__init__.py +++ b/src/pnanolocz_lib/__init__.py @@ -3,6 +3,9 @@ from importlib.metadata import PackageNotFoundError, version try: - __version__ = version("pnanolocz_lib") -except PackageNotFoundError: - __version__ = "0.0.0" + from ._version import version as __version__ # type: ignore[import-not-found] +except Exception: + try: + __version__ = version("pnanolocz_lib") + except PackageNotFoundError: + __version__ = "0.0.0" diff --git a/src/pnanolocz_lib/level.py b/src/pnanolocz_lib/level.py index 8570c85..3d8e42c 100644 --- a/src/pnanolocz_lib/level.py +++ b/src/pnanolocz_lib/level.py @@ -9,7 +9,7 @@ The functions here were ported from the original MATLAB NanoLocz Library, and maintain compatibility with high-speed AFM, localization AFM, and static -imaging data. +imaging data. () Supported Leveling Methods -------------------------- @@ -45,10 +45,10 @@ """ import warnings -from typing import Literal, Optional +from typing import Any, Literal, Optional import numpy as np -from numpy.polynomial.polyutils import RankWarning # type: ignore +from numpy.polynomial.polyutils import RankWarning # type: ignore[attr-defined] from scipy.optimize import curve_fit # Constants @@ -57,8 +57,11 @@ def level_plane( - img: np.ndarray, mask: Optional[np.ndarray], polyx: int, polyy: int -) -> np.ndarray: + img: np.ndarray[Any, np.dtype[np.float64]], + mask: Optional[np.ndarray[Any, np.dtype[np.bool_]]], + polyx: int, + polyy: int, +) -> np.ndarray[Any, np.dtype[np.float64]]: """ Plane leveling fitting by subtracting polynomial curves in X and Y. @@ -136,7 +139,7 @@ def level_plane( row_indices = np.flatnonzero(valid_rows) if row_indices.size <= polyy: # Not enough points to fit Y polynomial - return leveled_img + return np.asarray(leveled_img) # Center & scale row indices # replicate MATLAB centering @@ -159,12 +162,15 @@ def level_plane( y_plane = np.polyval(y_coeffs, standardized_all_rows)[:, None] # Subtract Y-plane - return leveled_img - y_plane + return np.asarray(leveled_img - y_plane) def level_line( - img: np.ndarray, mask: Optional[np.ndarray], polyx: int, polyy: int -) -> np.ndarray: + img: np.ndarray[Any, np.dtype[np.float64]], + mask: Optional[np.ndarray[Any, np.dtype[np.bool_]]], + polyx: int, + polyy: int, +) -> np.ndarray[Any, np.dtype[np.float64]]: """ Polynomial line leveling, correcting each row and column separately. @@ -256,11 +262,11 @@ def level_line( def level_med_line( - img: np.ndarray, - mask: Optional[np.ndarray], + img: np.ndarray[Any, np.dtype[np.float64]], + mask: Optional[np.ndarray[Any, np.dtype[np.bool_]]], polyx: int, polyy: int, # unused (MATLAB semantics) -) -> np.ndarray: +) -> np.ndarray[Any, np.dtype[np.float64]]: """ Row-wise median line leveling for AFM images. @@ -307,8 +313,11 @@ def level_med_line( def level_med_line_y( - img: np.ndarray, mask: Optional[np.ndarray], polyx: int, polyy: int -) -> np.ndarray: + img: np.ndarray[Any, np.dtype[np.float64]], + mask: Optional[np.ndarray[Any, np.dtype[np.bool_]]], + polyx: int, + polyy: int, +) -> np.ndarray[Any, np.dtype[np.float64]]: """ Column-wise median line leveling. @@ -344,8 +353,11 @@ def level_med_line_y( def level_smed_line( - img: np.ndarray, mask: Optional[np.ndarray], polyx: int, polyy: int -) -> np.ndarray: + img: np.ndarray[Any, np.dtype[np.float64]], + mask: Optional[np.ndarray[Any, np.dtype[np.bool_]]], + polyx: int, + polyy: int, +) -> np.ndarray[Any, np.dtype[np.float64]]: """ Smoothed median line subtraction. @@ -389,8 +401,11 @@ def level_smed_line( def level_mean_plane( - img: np.ndarray, mask: Optional[np.ndarray], polyx: int, polyy: int -) -> np.ndarray: + img: np.ndarray[Any, np.dtype[np.float64]], + mask: Optional[np.ndarray[Any, np.dtype[np.bool_]]], + polyx: int, + polyy: int, +) -> np.ndarray[Any, np.dtype[np.float64]]: """ Mean plane subtraction. @@ -418,13 +433,13 @@ def level_mean_plane( def level_log_y( - img: np.ndarray, - mask: Optional[np.ndarray], + img: np.ndarray[Any, np.dtype[np.float64]], + mask: Optional[np.ndarray[Any, np.dtype[np.bool_]]], polyx: int, polyy: int, *, orientation: str = "auto", # "auto" | "normal" | "reverse" -) -> np.ndarray: +) -> np.ndarray[Any, np.dtype[np.float64]]: """ Logarithmic curve subtraction along the Y-axis. @@ -452,7 +467,11 @@ def level_log_y( y = np.mean(img, axis=1) correction = _log_y_correction(y, polyy) - def _apply(img_: np.ndarray, corr: np.ndarray, rev: bool) -> np.ndarray: + def _apply( + img_: np.ndarray[Any, np.dtype[np.float64]], + corr: np.ndarray[Any, np.dtype[np.float64]], + rev: bool, + ) -> np.ndarray[Any, np.dtype[np.float64]]: if rev: corr = corr[::-1] return np.asarray(img_ - corr[:, None]) @@ -470,7 +489,9 @@ def _apply(img_: np.ndarray, corr: np.ndarray, rev: bool) -> np.ndarray: return np.asarray(cand1 if rng1 <= rng2 else cand2) -def _log_y_correction(y: np.ndarray, scale: float) -> np.ndarray: +def _log_y_correction( + y: np.ndarray[Any, np.dtype[Any]], scale: float +) -> np.ndarray[Any, np.dtype[Any]]: """ Fit and return a logarithmic correction curve. @@ -492,8 +513,10 @@ def _log_y_correction(y: np.ndarray, scale: float) -> np.ndarray: x_fit = x[pos] y_fit = y[pos] - def _log_model(x: np.ndarray, a: float, b: float, c: float) -> np.ndarray: - return a * np.log(c * x + b) + def _log_model( + x: np.ndarray[Any, np.dtype[Any]], a: float, b: float, c: float + ) -> np.ndarray[Any, np.dtype[Any]]: + return np.asarray(a * np.log(c * x + b)) try: popt, _ = curve_fit( @@ -505,7 +528,7 @@ def _log_model(x: np.ndarray, a: float, b: float, c: float) -> np.ndarray: def apply_level( - img: np.ndarray, + img: np.ndarray[Any, np.dtype[np.float64]], polyx: int, polyy: int, method: Literal[ @@ -517,8 +540,8 @@ def apply_level( "mean_plane", "log_y", ], - mask: Optional[np.ndarray] = None, -) -> np.ndarray: + mask: Optional[np.ndarray[Any, np.dtype[np.bool_]]] = None, +) -> np.ndarray[Any, np.dtype[np.float64]]: """ Apply a function to level or flatten AFM images or stacks. @@ -592,7 +615,7 @@ def apply_level( def get_background( - img: np.ndarray, + img: np.ndarray[Any, np.dtype[np.float64]], polyx: int, polyy: int, method: Literal[ @@ -604,8 +627,8 @@ def get_background( "mean_plane", "log_y", ], - mask: Optional[np.ndarray] = None, -) -> np.ndarray: + mask: Optional[np.ndarray[Any, np.dtype[np.bool_]]] = None, +) -> np.ndarray[Any, np.dtype[np.float64]]: """ Compute a background surface/lines that would be subtracted by `apply_level(...)`. diff --git a/src/pnanolocz_lib/level_auto.py b/src/pnanolocz_lib/level_auto.py index 78d5e6f..e14623c 100644 --- a/src/pnanolocz_lib/level_auto.py +++ b/src/pnanolocz_lib/level_auto.py @@ -57,6 +57,7 @@ from scipy import stats from pnanolocz_lib.level import apply_level +from pnanolocz_lib.level_weighted import apply_level_weighted from pnanolocz_lib.thresholder import thresholder # Data‑driven routine definitions @@ -403,8 +404,6 @@ "method": "line", }, ], -} -""" # Multi plane edges level 0uses level_weighted "multi-plane-edges": [ { @@ -420,7 +419,7 @@ "invert": False, }, { - "func": level_weighted, + "func": apply_level_weighted, "polyx": 2, "polyy": 2, "method": "plane", @@ -432,13 +431,13 @@ "invert": False, }, { - "func": level_weighted, + "func": apply_level_weighted, "polyx": 2, "polyy": 2, "method": "plane", }, { - "func": level_weighted, + "func": apply_level_weighted, "polyx": 0, "polyy": 0, "method": "med_line", @@ -471,7 +470,7 @@ "invert": False, }, { - "func": level_weighted, + "func": apply_level_weighted, "polyx": 2, "polyy": 2, "method": "plane", @@ -483,7 +482,7 @@ "invert": False, }, { - "func": level_weighted, + "func": apply_level_weighted, "polyx": 2, "polyy": 2, "method": "plane", @@ -495,13 +494,13 @@ "invert": False, }, { - "func": level_weighted, + "func": apply_level_weighted, "polyx": 2, "polyy": 2, "method": "plane", }, { - "func": level_weighted, + "func": apply_level_weighted, "polyx": 0, "polyy": 0, "method": "med_line", @@ -519,12 +518,12 @@ "method": "mean_plane", }, ], - } -""" -def _compute_gauss_limits(image: np.ndarray, kind: str) -> tuple[float, float]: +def _compute_gauss_limits( + image: np.ndarray[Any, np.dtype[np.float64]], kind: str +) -> tuple[float, float]: """ Compute intensity threshold limits from a Gaussian fit to the image data. @@ -583,9 +582,9 @@ def _compute_gauss_limits(image: np.ndarray, kind: str) -> tuple[float, float]: def apply_level_auto( - img_stack: np.ndarray, + img_stack: np.ndarray[Any, np.dtype[np.float64]], routine: str, -) -> np.ndarray: +) -> np.ndarray[Any, np.dtype[np.float64]]: """ Apply leveling "routines" across specified frames of an AFM image stack. @@ -635,7 +634,6 @@ def apply_level_auto( for step in steps: func = step["func"] params = {k: v for k, v in step.items() if k != "func"} - if func is thresholder: method = params["method"] args = params.get("args", None) diff --git a/src/pnanolocz_lib/level_weighted.py b/src/pnanolocz_lib/level_weighted.py new file mode 100644 index 0000000..44a7829 --- /dev/null +++ b/src/pnanolocz_lib/level_weighted.py @@ -0,0 +1,864 @@ +""" +Weighted-region AFM image flattening and background leveling tools. + +This module provides a Python port of the MATLAB ``level_weighted.m`` +function used in the NanoLocz workflow. It implements region-wise weighted +polynomial and median-based background estimation for Atomic Force Microscopy +(AFM) images, enabling correction of multi-region drift, structured background, +and non-uniform masking effects. + +The original Nanolocz-lib script was adapted from FindSteps.m and PolyfitLineMasked.m +scripts from the SPIW project () and combined +with NanoLocz leveling methods. + +Supported Leveling Methods +-------------------------- +- 'plane' : Region-weighted polynomial plane subtraction in X and Y. +- 'line' : Region-weighted row/column polynomial leveling. +- 'med_line' : Region-weighted row-wise median line flattening. +- 'med_line_y' : Region-weighted column-wise median line flattening. +- 'smed_line' : Region-weighted smoothed median line subtraction. + +Typical usage involves calling the :func:`apply_level_weighted` dispatcher with +an AFM image (2D) or a stack (3D) and choosing one of the methods above. + +Examples +-------- +>>> from pnanolocz_lib.filters.level_weighted import apply_level_weighted +>>> leveled = apply_level_weighted(img, polyx=2, polyy=1, method='plane', mask=mask) + +Authors +------- +George Heath, University of Leeds (2025) +Daniel E. Rollins, University of Leeds (2025) +""" + +from __future__ import annotations + +import warnings +from typing import Any, List, Optional, Tuple + +import numpy as np +from numpy.polynomial.polyutils import RankWarning # type: ignore[attr-defined] +from scipy import ndimage + +# --------------------- +# Low-level helpers +# --------------------- + + +def _center_scale_indices( + indices: np.ndarray[Any, np.dtype[np.float64]], +) -> tuple[np.ndarray[Any, np.dtype[np.float64]], float, float]: + """Center and scale a 1-D index array. + + Parameters + ---------- + indices + 1-D integer index positions (e.g. column or row indices). + + Returns + ------- + std_indices : np.ndarray + Centered and scaled indices (float). + centroid : float + Mean of the original indices. + scale : float + Sample standard deviation (ddof=1) of the original indices; guaranteed + non-zero (defaults to 1.0 when degenerate). + """ + if indices.size == 0: + return indices.astype(float), 0.0, 1.0 + + centroid = float(indices.mean()) + scale = float(indices.std(ddof=1)) if indices.size > 1 else 1.0 + if scale == 0: + scale = 1.0 + std_indices = (indices - centroid) / scale + return std_indices, centroid, scale + + +def _polyfit_centered( + x: np.ndarray[Any, np.dtype[np.float64]], + y: np.ndarray[Any, np.dtype[np.float64]], + order: int, +) -> tuple[np.ndarray[Any, np.dtype[np.float64]], Tuple[float, float]]: + """Fit polynomial to ``y`` vs ``x`` after centering and scaling ``x``. + + This tries to replicate the MATLAB polyfit function. + + Parameters + ---------- + x : np.ndarray + 1-D positions used as the independent variable. + y : np.ndarray + 1-D values (dependent variable). + order : int + Polynomial order. + + Returns + ------- + coeffs : np.ndarray + Coefficients in decreasing power order compatible with ``np.polyval``. + (centroid, scale) : tuple + The centering and scaling applied to ``x`` (so evaluation may use the + same parameters). + """ + if x.size == 0 or y.size == 0 or x.size <= order: + return np.zeros(order + 1, dtype=float), (0.0, 1.0) + + std_x, centroid, scale = _center_scale_indices(x) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RankWarning) + coeffs = np.polyfit(std_x, y, order) + return coeffs, (centroid, scale) + + +def _polyval_centered( + coeffs: np.ndarray[Any, np.dtype[np.float64]], + centering: Tuple[float, float], + points: np.ndarray[Any, np.dtype[np.int64]], +) -> np.ndarray[Any, np.dtype[np.float64]]: + """ + Evaluate a polynomial fitted on centered-and-scaled x. + + Equivalent to MATLAB function: polyval. + + Parameters + ---------- + coeffs : np.ndarray + Polynomial coefficients from :func:`_polyfit_centered`. + centering : tuple + ``(centroid, scale)`` used to standardise x during fitting. + points : np.ndarray + Points (original coordinate space) at which to evaluate. + """ + centroid, scale = centering + if scale == 0: + scale = 1.0 + std_points = (points - centroid) / scale + return np.polyval(coeffs, std_points) + + +def _find_regions( + mask: np.ndarray[Any, np.dtype[np.bool_]], min_area: int +) -> List[np.ndarray[Any, np.dtype[np.int64]]]: + """Find connected foreground regions and return their flat indices. + + Parameters + ---------- + mask : np.ndarray + Boolean mask where True indicates foreground. + min_area : int + Minimum number of pixels for a region to be kept. + + Returns + ------- + regions : list of np.ndarray + Each element is a 1-D array of flat indices for that region. + """ + structure = np.ones((3, 3), dtype=int) + labeled, num_features = ndimage.label(mask, structure=structure) + regions: List[np.ndarray[Any, np.dtype[np.int64]]] = [] + for lab in range(1, num_features + 1): + flat_idx = np.flatnonzero(labeled.ravel() == lab) + if flat_idx.size >= min_area: + regions.append(flat_idx) + return regions + + +# --------------------- +# Per-method implementations +# --------------------- + + +def level_weighted_plane( + img: np.ndarray[Any, np.dtype[np.float64]], + regions: List[np.ndarray[Any, np.dtype[np.int64]]], + polyx: int, + polyy: int, +) -> np.ndarray[Any, np.dtype[np.float64]]: + """ + Region-weighted polynomial plane subtraction along X and Y. + + The function computes per-region polynomial fits of the mean profile in the + X- and Y-directions and forms a weighted average of those per-region fits + (weights proportional to region pixel counts). The combined plane is + subtracted from ``img``. + + Parameters + ---------- + img : np.ndarray + 2-D AFM image. + regions : list of np.ndarray + List of flat index arrays describing foreground regions. + polyx, polyy : int + Polynomial orders for the X (columns) and Y (rows) directions. + + Returns + ------- + np.ndarray + The leveled image (float64). + """ + rows, cols = img.shape + img_f = np.asarray(img, dtype=float) + + n_regions = len(regions) + region_pixel_counts = np.zeros(n_regions, dtype=float) + + x_poly_list: List[np.ndarray[Any, np.dtype[np.float64]]] = [] + x_centroid_list: List[float] = [] + x_scale_list: List[float] = [] + + y_poly_list: List[np.ndarray[Any, np.dtype[np.float64]]] = [] + y_centroid_list: List[float] = [] + y_scale_list: List[float] = [] + + for i, region_indices in enumerate(regions): + # Nanolocz- build regionMatrix (here region_masked) + region_masked = np.full(img_f.shape, np.nan, dtype=float) + region_masked.flat[region_indices] = img_f.flat[region_indices] + region_pixel_counts[i] = region_indices.size # w(i) in Nanolocz + + # X-direction: mean of each column within region + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=RuntimeWarning) + mean_by_col = np.nanmean( + region_masked, axis=0 + ) # Nanolocz-mean_by_col is xp + valid_cols = ~np.isnan(mean_by_col) # Nanolocz xf + col_values = mean_by_col[valid_cols] + col_positions = np.flatnonzero(valid_cols) + + if col_positions.size > polyx: + coeffs_x, (cent_x, scale_x) = _polyfit_centered( + col_positions.astype(float), col_values.astype(float), polyx + ) + else: + coeffs_x = np.zeros(polyx + 1, dtype=float) + cent_x, scale_x = 0.0, 1.0 + + x_poly_list.append(coeffs_x) + x_centroid_list.append(cent_x) + x_scale_list.append(scale_x) + + # Y-direction: mean of each row within region + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=RuntimeWarning) + mean_by_row = np.nanmean(region_masked, axis=1) + valid_rows = ~np.isnan(mean_by_row) + row_values = mean_by_row[valid_rows] + row_positions = np.flatnonzero(valid_rows) + + if row_positions.size > polyy: + coeffs_y, (cent_y, scale_y) = _polyfit_centered( + row_positions.astype(float), row_values.astype(float), polyy + ) + else: + coeffs_y = np.zeros(polyy + 1, dtype=float) + cent_y, scale_y = 0.0, 1.0 + + y_poly_list.append(coeffs_y) + y_centroid_list.append(cent_y) + y_scale_list.append(scale_y) + + # region weights (normalized; exclude tiny regions by thresholding) + # Weights are W in Nanlocz-lib + weights = region_pixel_counts / ( + region_pixel_counts.sum() if region_pixel_counts.sum() > 0 else 1.0 + ) + # Exclude regions with less that 2% area + weights = np.where(weights > 0.02, weights, 0.0) + + if weights.sum() == 0: + weights = region_pixel_counts / ( + region_pixel_counts.sum() if region_pixel_counts.sum() > 0 else 1.0 + ) + + # Pad coefficient arrays to the same length then take weighted sum + max_len_x = max((p.size for p in x_poly_list), default=0) + max_len_y = max((p.size for p in y_poly_list), default=0) + x_poly_arr = np.stack( + [np.pad(p, (0, max_len_x - p.size), mode="constant") for p in x_poly_list], + axis=1, + ) + y_poly_arr = np.stack( + [np.pad(p, (0, max_len_y - p.size), mode="constant") for p in y_poly_list], + axis=1, + ) + + weighted_x_coeffs = (x_poly_arr * weights[None, :]).sum(axis=1) + weighted_y_coeffs = (y_poly_arr * weights[None, :]).sum(axis=1) + weighted_x_centroid = (np.array(x_centroid_list) * weights).sum() + weighted_x_scale = (np.array(x_scale_list) * weights).sum() + weighted_y_centroid = (np.array(y_centroid_list) * weights).sum() + weighted_y_scale = (np.array(y_scale_list) * weights).sum() + + all_cols = np.arange(cols) + all_rows = np.arange(rows) + background_x = _polyval_centered( + weighted_x_coeffs, (weighted_x_centroid, weighted_x_scale), all_cols + )[None, :] + background_y = _polyval_centered( + weighted_y_coeffs, (weighted_y_centroid, weighted_y_scale), all_rows + )[:, None] + + background_plane = background_x + background_y + return img_f - background_plane + + +def level_weighted_line( + img: np.ndarray[Any, np.dtype[np.float64]], + regions: List[np.ndarray[Any, np.dtype[np.int64]]], + polyx: int, + polyy: int, +) -> np.ndarray[Any, np.dtype[np.float64]]: + """ + Region-weighted per-row and per-column polynomial leveling. + + This function removes large-scale background trends from an image by fitting + per-row and/or per-column polynomials to user-provided regions, then + subtracting the region-weighted background from the image. + + For each region, a polynomial is fit independently to each row (if + ``polyx > 0``) and/or each column (if ``polyy > 0``) using only pixels + inside that region. The polynomial coefficients and the centering parameters + (mean and scale) for each row/column are aggregated across regions using + weights proportional to the per-row/column counts of valid pixels in each + region. Extremely small weights (< 0.02) are nulled to reduce noise; rows or + columns whose weights zero out are reweighted by raw pixel counts. The + resulting weighted polynomial background is evaluated and subtracted from + the image (rows first, then columns if both are enabled). + + Parameters + ---------- + img : ndarray of float64, shape (H, W) + Input image to level. + regions : list of 1D ndarray of int64 + A list of flat (ravelled) indices specifying disjoint or overlapping + regions within ``img``. Each array contains indices into + ``np.ravel(img)`` (i.e., C-order flattening). Only pixels belonging to a + given region are used to fit that region's per-row/column polynomials. + polyx : int + Polynomial degree for row-wise fitting. If ``polyx <= 0``, no row-wise + leveling is performed. + polyy : int + Polynomial degree for column-wise fitting. If ``polyy <= 0``, no + column-wise leveling is performed. + + Returns + ------- + ndarray of float64, shape (H, W) + The leveled image. If both ``polyx > 0`` and ``polyy > 0``, the result is + the input with the row-wise background subtracted first, followed by the + column-wise background subtraction. + + Notes + ----- + - For each row/column and region, polynomial fitting is performed on the + coordinate positions that have valid pixels within the region. A fit + requires at least ``degree + 2`` valid points; otherwise, a neutral + centering ``(0.0, 1.0)`` is recorded and coefficients are left at zero + for that row/column in that region. + - Region aggregation uses normalized pixel-count weights per row/column. + Weights below 0.02 are set to 0 to suppress weak regions; if all weights + become zero for a given row/column, raw pixel-count normalization is used + as a fallback. + - Polynomial evaluation is done with centered/scaled coordinates obtained + from each fit (mean, scale) aggregated across regions using the same + weights as for the coefficients. + - This function relies on helper routines + ``_polyfit_centered(x, y, degree) -> (coeffs, (mean, scale))`` and + ``_polyval_centered(coeffs, (mean, scale), x) -> y``. + + Examples + -------- + >>> H, W = 128, 256 + >>> img = np.random.randn(H, W).astype(float) + >>> # Define two rectangular regions via flat indices + >>> r1 = np.ravel_multi_index( + ... np.mgrid[10:60, 20:120].reshape(2, -1), dims=img.shape, order='C' + ... ) + >>> r2 = np.ravel_multi_index( + ... np.mgrid[70:120, 100:220].reshape(2, -1), dims=img.shape, order='C' + ... ) + >>> leveled = level_weighted_line(img, [r1, r2], polyx=2, polyy=1) + """ + rows, cols = img.shape + img_f = np.asarray(img, dtype=float) + + leveled_image = img_f.copy() + + # Row-wise polynomial fitting per region + if polyx > 0 and len(regions) > 0: + # MATLAB: px{k}(ii, i) -> per-row polynomial coefficients per region + # Python: we store a (rows, polyx+1) array per region, then stack to (rows, + # polyx+1, n_regions) + row_coeffs_regions: List[np.ndarray[Any, np.dtype[np.float64]]] = [] + # MATLAB Nanolocz-lib: mux{1}(ii, i) and mux{2}(ii, i) -> centering (mean, + # scale) per row and region + # Python: store as (rows, 2) per region, then combine with weights + row_centering_regions: List[np.ndarray[Any, np.dtype[np.float64]]] = [] + # MATLAB: w(ii, i) -> per-row valid-pixel counts per region + # (row_pixel_counts_regions) + # Python: store as (rows,) per region; later stack to (rows, n_regions) + row_pixel_counts_regions: List[np.ndarray[Any, np.dtype[np.float64]]] = [] + + for region_indices in regions: + region_masked = np.full(img_f.shape, np.nan, dtype=float) + region_masked.flat[region_indices] = img_f.flat[region_indices] + + coeffs_for_rows = np.zeros( + (rows, polyx + 1), dtype=float + ) # coeffs_for_rows ~ MATLAB Nanolocz-lib px{k} matrices collected by k + centering_for_rows = np.zeros((rows, 2), dtype=float) # + pixel_counts_per_row = np.zeros(rows, dtype=float) + + for row_idx in range(rows): + valid_columns = ~np.isnan(region_masked[row_idx, :]) + pixel_counts_per_row[row_idx] = valid_columns.sum() + + if valid_columns.sum() > polyx + 1: + col_positions = np.flatnonzero(valid_columns).astype(float) + values = img_f[row_idx, valid_columns] + coeffs, centering = _polyfit_centered(col_positions, values, polyx) + coeffs_for_rows[row_idx, : coeffs.size] = coeffs + centering_for_rows[row_idx, :] = centering + else: + centering_for_rows[row_idx, :] = (0.0, 1.0) + + row_coeffs_regions.append(coeffs_for_rows) + row_centering_regions.append(centering_for_rows) + row_pixel_counts_regions.append(pixel_counts_per_row) + + row_coeffs_stack = np.stack( + row_coeffs_regions, axis=2 + ) # (rows, poly+1, n_regions) + row_pixel_counts_array = np.stack( + row_pixel_counts_regions, axis=1 + ) # (rows, n_regions) + + total_counts_per_row = row_pixel_counts_array.sum(axis=1, keepdims=True) + row_weights = row_pixel_counts_array / np.where( + total_counts_per_row == 0, 1.0, total_counts_per_row + ) + row_weights = np.where(row_weights > 0.02, row_weights, 0.0) + + zero_weight_rows = row_weights.sum(axis=1) == 0 + if zero_weight_rows.any(): + row_weights[zero_weight_rows, :] = row_pixel_counts_array[ + zero_weight_rows, : + ] / np.maximum(total_counts_per_row[zero_weight_rows], 1.0) + + row_weights_expanded = row_weights[:, None, :] + weighted_row_coeffs = (row_coeffs_stack * row_weights_expanded).sum(axis=2) + + row_cent0 = np.stack([c[:, 0] for c in row_centering_regions], axis=1) + row_cent1 = np.stack([c[:, 1] for c in row_centering_regions], axis=1) + weighted_row_centroid = (row_cent0 * row_weights).sum(axis=1) + weighted_row_scale = (row_cent1 * row_weights).sum(axis=1) + + # Evaluate row background and subtract + row_background = np.zeros_like(img_f) + col_positions_all = np.arange(cols) + for r_idx in range(rows): + row_background[r_idx, :] = _polyval_centered( + weighted_row_coeffs[r_idx], + ( + weighted_row_centroid[r_idx], + ( + weighted_row_scale[r_idx] + if weighted_row_scale[r_idx] != 0 + else 1.0 + ), + ), + col_positions_all, + ) + + leveled_image = img_f - row_background + + # Column-wise polynomial fitting per region + if polyy > 0 and len(regions) > 0: + # IN MATLAB Nanolocz-lib col_coeffs_regions is py + col_coeffs_regions: List[np.ndarray[Any, np.dtype[np.float64]]] = [] + # IN MATLAB Nanolocz-lib col_centering_regions is muy + col_centering_regions: List[np.ndarray[Any, np.dtype[np.float64]]] = [] + # IN MATLAB Nanolocz-lib col_pixel_counts_regions is w + col_pixel_counts_regions: List[np.ndarray[Any, np.dtype[np.float64]]] = [] + + for region_indices in regions: + region_masked = np.full(img_f.shape, np.nan, dtype=float) + region_masked.flat[region_indices] = img_f.flat[region_indices] + + coeffs_for_cols = np.zeros((cols, polyy + 1), dtype=float) + centering_for_cols = np.zeros((cols, 2), dtype=float) + pixel_counts_per_col = np.zeros(cols, dtype=float) + + for col_idx in range(cols): + valid_rows = ~np.isnan(region_masked[:, col_idx]) + pixel_counts_per_col[col_idx] = valid_rows.sum() + + if valid_rows.sum() > polyy + 1: + row_positions = np.flatnonzero(valid_rows).astype(float) + values = img_f[valid_rows, col_idx] + coeffs, centering = _polyfit_centered(row_positions, values, polyy) + coeffs_for_cols[col_idx, : coeffs.size] = coeffs + centering_for_cols[col_idx, :] = centering + else: + centering_for_cols[col_idx, :] = (0.0, 1.0) + + col_coeffs_regions.append(coeffs_for_cols) + col_centering_regions.append(centering_for_cols) + col_pixel_counts_regions.append(pixel_counts_per_col) + + col_coeffs_stack = np.stack( + col_coeffs_regions, axis=2 + ) # (cols, poly+1, n_regions) + col_pixel_counts_array = np.stack( + col_pixel_counts_regions, axis=1 + ) # (cols, n_regions) + + total_counts_per_col = col_pixel_counts_array.sum(axis=1, keepdims=True) + col_weights = col_pixel_counts_array / np.where( + total_counts_per_col == 0, 1.0, total_counts_per_col + ) + col_weights = np.where(col_weights > 0.02, col_weights, 0.0) + + zero_weight_cols = col_weights.sum(axis=1) == 0 + if zero_weight_cols.any(): + col_weights[zero_weight_cols, :] = col_pixel_counts_array[ + zero_weight_cols, : + ] / np.maximum(total_counts_per_col[zero_weight_cols], 1.0) + + col_weights_expanded = col_weights[:, None, :] + weighted_col_coeffs = (col_coeffs_stack * col_weights_expanded).sum(axis=2) + + # After computing weighted_row_coeffs and weighted_col_coeffs + # Force constant term to zero like MATLAB Nanolocz: + weighted_row_coeffs[:, -1] = 0.0 + weighted_col_coeffs[:, -1] = 0.0 + + col_cent0 = np.stack([c[:, 0] for c in col_centering_regions], axis=1) + col_cent1 = np.stack([c[:, 1] for c in col_centering_regions], axis=1) + weighted_col_centroid = (col_cent0 * col_weights).sum(axis=1) + weighted_col_scale = (col_cent1 * col_weights).sum(axis=1) + + # Evaluate column background and subtract + col_background = np.zeros_like(img_f) + row_positions_all = np.arange(rows) + for c_idx in range(cols): + col_background[:, c_idx] = _polyval_centered( + weighted_col_coeffs[c_idx], + ( + weighted_col_centroid[c_idx], + ( + weighted_col_scale[c_idx] + if weighted_col_scale[c_idx] != 0 + else 1.0 + ), + ), + row_positions_all, + ) + + leveled_image = leveled_image - col_background + + return np.asarray(leveled_image) + + +def level_weighted_med_line( + image: np.ndarray[Any, np.dtype[np.float64]], + regions: List[np.ndarray[Any, np.dtype[np.int64]]], +) -> np.ndarray[Any, np.dtype[np.float64]]: + """Region-weighted median line subtraction along image rows. + + Computes a region-weighted median background per row and subtracts it + from the image. Behaviour mirrors the MATLAB ``med_line`` case but uses + descriptive variable names and NumPy-style docstrings. + + Parameters + ---------- + image : np.ndarray + 2-D AFM image (rows * columns). + regions : list of np.ndarray + List of flat-index arrays describing connected foreground regions. + + Returns + ------- + np.ndarray + Row-leveled image (float64). + """ + image_float = np.asarray(image, dtype=float) + n_rows, n_cols = image_float.shape + n_regions = len(regions) + + # Per-row counts and per-row median offsets for each region + per_row_counts = np.zeros((n_rows, n_regions), dtype=float) + per_row_offsets = np.zeros((n_rows, n_regions), dtype=float) + region_baselines = np.zeros(n_regions, dtype=float) + + for r_idx, region_indices in enumerate(regions): + region_masked = np.full(image_float.shape, np.nan, dtype=float) + region_masked.flat[region_indices] = image_float.flat[region_indices] + + region_baselines[r_idx] = np.nanmedian(region_masked) + + for row_idx in range(n_rows): + valid = ~np.isnan(region_masked[row_idx, :]) + per_row_counts[row_idx, r_idx] = valid.sum() + if valid.sum() > 2: + per_row_offsets[row_idx, r_idx] = ( + np.nanmedian(image_float[row_idx, valid]) - region_baselines[r_idx] + ) + else: + per_row_offsets[row_idx, r_idx] = -region_baselines[r_idx] + + # Compute normalized weights per row + totals = per_row_counts.sum(axis=1, keepdims=True) + denom = np.where(totals == 0, 1.0, totals) + weights = per_row_counts / denom + weights = np.where(weights > 0.02, weights, 0.0) + + zero_weight_rows = weights.sum(axis=1) == 0 + if zero_weight_rows.any(): + weights[zero_weight_rows, :] = per_row_counts[zero_weight_rows, :] / np.maximum( + denom[zero_weight_rows], 1.0 + ) + + weighted_row_background = (weights * per_row_offsets).sum(axis=1) + has_data = per_row_counts.sum(axis=1) > 0 + + leveled = image_float.copy() + leveled[has_data, :] = ( + image_float[has_data, :] - weighted_row_background[has_data, None] + ) + return np.asarray(leveled) + + +def level_weighted_med_line_y( + image: np.ndarray[Any, np.dtype[np.float64]], + regions: List[np.ndarray[Any, np.dtype[np.int64]]], +) -> np.ndarray[Any, np.dtype[np.float64]]: + """Region-weighted median line subtraction along image columns. + + Parameters + ---------- + image : np.ndarray + 2-D AFM image (rows * columns). + regions : list of np.ndarray + List of flat-index arrays describing connected foreground regions. + + Returns + ------- + np.ndarray + Column-leveled image (float64). + """ + image_float = np.asarray(image, dtype=float) + n_rows, n_cols = image_float.shape + n_regions = len(regions) + + per_col_counts = np.zeros((n_cols, n_regions), dtype=float) + per_col_offsets = np.zeros((n_cols, n_regions), dtype=float) + region_baselines = np.zeros(n_regions, dtype=float) + + for r_idx, region_indices in enumerate(regions): + region_masked = np.full(image_float.shape, np.nan, dtype=float) + region_masked.flat[region_indices] = image_float.flat[region_indices] + + region_baselines[r_idx] = np.nanmedian(region_masked) + + for col_idx in range(n_cols): + valid = ~np.isnan(region_masked[:, col_idx]) + per_col_counts[col_idx, r_idx] = valid.sum() + if valid.sum() > 2: + per_col_offsets[col_idx, r_idx] = ( + np.nanmedian(image_float[valid, col_idx]) - region_baselines[r_idx] + ) + else: + per_col_offsets[col_idx, r_idx] = -region_baselines[r_idx] + + totals = per_col_counts.sum(axis=1, keepdims=True) + denom = np.where(totals == 0, 1.0, totals) + weights = per_col_counts / denom + weights = np.where(weights > 0.02, weights, 0.0) + + zero_weight_cols = weights.sum(axis=1) == 0 + if zero_weight_cols.any(): + weights[zero_weight_cols, :] = per_col_counts[zero_weight_cols, :] / np.maximum( + denom[zero_weight_cols], 1.0 + ) + + weighted_col_background = (weights * per_col_offsets).sum(axis=1) + has_data = per_col_counts.sum(axis=1) > 0 + + leveled = image_float.copy() + cols_with_data = has_data + leveled[:, cols_with_data] = ( + image_float[:, cols_with_data] + - weighted_col_background[cols_with_data][None, :] + ) + return np.asarray(leveled) + + +def level_weighted_smed_line( + image: np.ndarray[Any, np.dtype[np.float64]], + regions: List[np.ndarray[Any, np.dtype[np.int64]]], + smoothing_window: int = 10, +) -> np.ndarray[Any, np.dtype[np.float64]]: + """Region-weighted smoothed median line subtraction along rows. + + Computes a weighted median profile per row and then subtracts the difference + between that profile and a moving-median-smoothed version of it (MATLAB + ``smed_line`` behaviour). + + Parameters + ---------- + image : np.ndarray + 2-D AFM image (rows * columns). + regions : list of np.ndarray + List of flat-index arrays describing connected foreground regions. + smoothing_window : int, optional + Window length for moving-median smoothing (default 10). + + Returns + ------- + np.ndarray + Smoothed-median-leveled image. + """ + image_float = np.asarray(image, dtype=float) + n_rows, n_cols = image_float.shape + n_regions = len(regions) + + median_per_row = np.zeros(n_rows, dtype=float) + per_row_counts = np.zeros((n_rows, n_regions), dtype=float) + region_baselines = np.zeros(n_regions, dtype=float) + + for r_idx, region_indices in enumerate(regions): + region_masked = np.full(image_float.shape, np.nan, dtype=float) + region_masked.flat[region_indices] = image_float.flat[region_indices] + + region_baselines[r_idx] = np.nanmedian(region_masked) + + for row_idx in range(n_rows): + valid = ~np.isnan(region_masked[row_idx, :]) + per_row_counts[row_idx, r_idx] = valid.sum() + if valid.sum() > 2: + median_per_row[row_idx] = np.nanmedian(image_float[row_idx, valid]) + else: + median_per_row[row_idx] = -region_baselines[r_idx] + + totals = per_row_counts.sum(axis=1, keepdims=True) + denom = np.where(totals == 0, 1.0, totals) + weights = per_row_counts / denom + weights = np.where(weights > 0.02, weights, 0.0) + + zero_weight_rows = weights.sum(axis=1) == 0 + if zero_weight_rows.any(): + weights[zero_weight_rows, :] = per_row_counts[zero_weight_rows, :] / np.maximum( + denom[zero_weight_rows], 1.0 + ) + + weighted_background = (weights * median_per_row[:, None]).sum(axis=1) + + # moving median smoothing + k = int(smoothing_window) + if k <= 1: + smoothed = weighted_background.copy() + else: + pad = k // 2 + padded = np.pad(weighted_background, pad, mode="edge") + smoothed = np.empty_like(weighted_background) + for i in range(n_rows): + smoothed[i] = np.median(padded[i : i + k]) + + return np.asarray(image_float - (weighted_background[:, None] - smoothed[:, None])) + + +def apply_level_weighted( + img: np.ndarray[Any, np.dtype[np.float64]], + polyx: int, + polyy: int, + method: str, + mask: Optional[np.ndarray[Any, np.dtype[np.bool_]]] = None, + smoothing_window: int = 10, +) -> np.ndarray[Any, np.dtype[np.float64]]: + """ + Apply a weighted-region leveling method to a 2D AFM image or stack. + + Dispatcher for the level_weighted functions. + + Parameters + ---------- + img : np.ndarray + 2-D image (H * W) or 3-D stack (N * H * W). + polyx, polyy : int + Polynomial orders for X (columns) and Y (rows) fits when relevant. + method : str + One of ``'plane'``, ``'line'``, ``'med_line'``, ``'med_line_y'``, + or ``'smed_line'``. + mask : Optional[np.ndarray] + Mask with same shape as ``img`` (or H * W for single image). Non-zero + values are treated as foreground. If ``None``, the entire image is used. + smoothing_window : int + Window for ``smed_line`` smoothing. + + Returns + ------- + np.ndarray + Leveled image with same shape as ``img`` (or stack). + """ + arr = np.asarray(img) + is_stack = arr.ndim == 3 + + frames = arr if is_stack else arr[np.newaxis, ...] + + if mask is not None: + mask_arr = np.asarray(mask) + if mask_arr.ndim == 2: + mask_arr = mask_arr[np.newaxis, ...] + if mask_arr.shape != frames.shape: + raise ValueError("mask must have the same shape as img or stack") + else: + mask_arr = None + + leveled_frames: List[np.ndarray[Any, np.dtype[np.float64]]] = [] + for frame_idx in range(frames.shape[0]): + frame = frames[frame_idx] + frame_mask = ( + mask_arr[frame_idx] + if mask_arr is not None + else np.ones_like(frame, dtype=bool) + ) + mask_bool = frame_mask.astype(bool) + + n_rows, n_cols = frame.shape + min_area = max(1, int(0.01 * n_rows * n_cols)) + regions = _find_regions(mask_bool, min_area) + + method = method.lower() + if method == "plane": + leveled = level_weighted_plane(frame, regions, polyx, polyy) + elif method == "line": + leveled = level_weighted_line(frame, regions, polyx, polyy) + elif method == "med_line": + leveled = level_weighted_med_line(frame, regions) + elif method == "med_line_y": + leveled = level_weighted_med_line_y(frame, regions) + elif method == "smed_line": + leveled = level_weighted_smed_line(frame, regions, smoothing_window) + else: + raise ValueError(f"Unknown leveling method: {method}") + + leveled_frames.append(leveled) + + stacked = np.stack(leveled_frames, axis=0) + return np.asarray(stacked if is_stack else stacked[0]) + + +__all__ = [ + "apply_level_weighted", + "level_weighted_plane", + "level_weighted_line", + "level_weighted_med_line", + "level_weighted_med_line_y", + "level_weighted_smed_line", +] diff --git a/src/pnanolocz_lib/thresholder.py b/src/pnanolocz_lib/thresholder.py index d2f9115..cf556b3 100644 --- a/src/pnanolocz_lib/thresholder.py +++ b/src/pnanolocz_lib/thresholder.py @@ -52,7 +52,7 @@ This module is part of the pNanoLocz-Lib Python library for AFM analysis. """ -from typing import Callable, TypeVar +from typing import Any, Callable, TypeVar import numpy as np import ruptures as rpt @@ -72,7 +72,7 @@ # Map method names to handler functions _METHOD_MAP = {} -F = TypeVar("F", bound=Callable[..., np.ndarray]) +F = TypeVar("F", bound=Callable[..., np.ndarray[Any, np.dtype[Any]]]) def _register(name: str) -> Callable[[F], F]: @@ -85,7 +85,9 @@ def decorator(func: F) -> F: return decorator -def to_nan_mask(binary_mask: np.ndarray) -> np.ndarray: +def to_nan_mask( + binary_mask: np.ndarray[Any, np.dtype[Any]], +) -> np.ndarray[Any, np.dtype[Any]]: """ Convert a boolean mask to a float mask with NaNs in False positions. @@ -96,7 +98,7 @@ def to_nan_mask(binary_mask: np.ndarray) -> np.ndarray: Returns ------- - np.ndarray + mask : np.ndarray Float array where True becomes 1.0 and False becomes NaN. """ mask = binary_mask.astype(float) @@ -105,21 +107,21 @@ def to_nan_mask(binary_mask: np.ndarray) -> np.ndarray: def prune_skeleton_min_branch_length( - skel: np.ndarray, min_branch_length: int -) -> np.ndarray: + skel: np.ndarray[Any, np.dtype[Any]], min_branch_length: int +) -> np.ndarray[Any, np.dtype[Any]]: """ - Prune branches shorter than min_branch_length from a skeleton image. + Prune branches shorter than `min_branch_length` from a skeleton image. Parameters ---------- - skel : ndarray + skel : np.ndarray Binary skeleton image (bool or 0/1). min_branch_length : int Minimum branch length to keep. Returns ------- - pruned_skel : ndarray + pruned_skel : np.ndarray Binary skeleton with short branches removed. """ # Build graph from skeleton @@ -148,8 +150,9 @@ def prune_skeleton_min_branch_length( @_register("selection") def selection( - img: np.ndarray, limits: tuple[float, float] | list[float] | str | None = None -) -> np.ndarray: + img: np.ndarray[Any, np.dtype[np.float64]], + limits: tuple[float, float] | list[float] | str | None = None, +) -> np.ndarray[Any, np.dtype[np.float64]]: """ Pass-through user-provided mask (interpreted as boolean). @@ -171,8 +174,9 @@ def selection( @_register("histogram") def histogram( - img: np.ndarray, limits: tuple[float, float] | list[float] | str | None = None -) -> np.ndarray: + img: np.ndarray[Any, np.dtype[np.float64]], + limits: tuple[float, float] | list[float] | str | None = None, +) -> np.ndarray[Any, np.dtype[np.float64]]: """ Threshold image based on intensity limits. @@ -206,8 +210,9 @@ def histogram( @_register("otsu") def otsu( - img: np.ndarray, limits: tuple[float, float] | list[float] | str | None = None -) -> np.ndarray: + img: np.ndarray[Any, np.dtype[np.float64]], + limits: tuple[float, float] | list[float] | str | None = None, +) -> np.ndarray[Any, np.dtype[np.float64]]: """ Apply single-level Otsu thresholding. @@ -230,8 +235,9 @@ def otsu( @_register("auto edges") def auto_edges( - img: np.ndarray, limits: tuple[float, float] | list[float] | str | None = None -) -> np.ndarray: + img: np.ndarray[Any, np.dtype[np.float64]], + limits: tuple[float, float] | list[float] | str | None = None, +) -> np.ndarray[Any, np.dtype[np.float64]]: """ Detect edges using Sobel gradient and morphological filtering. @@ -263,9 +269,9 @@ def auto_edges( @_register("hist edges") def hist_edges( - img: np.ndarray, + img: np.ndarray[Any, np.dtype[np.float64]], limits: tuple[float, float] | list[float] | str | None = None, -) -> np.ndarray: +) -> np.ndarray[Any, np.dtype[np.float64]]: """ Detect edges by thresholding with histogram limits and morphological operations. @@ -299,9 +305,9 @@ def hist_edges( @_register("otsu edges") def otsu_edges( - img: np.ndarray, + img: np.ndarray[Any, np.dtype[np.float64]], limits: tuple[float, float] | list[float] | str | None = None, -) -> np.ndarray: +) -> np.ndarray[Any, np.dtype[np.float64]]: """ Detect edges after Otsu thresholding using morphological operations. @@ -321,7 +327,9 @@ def otsu_edges( thresh = threshold_otsu(sm) binary = sm <= thresh - def process_slice(slice_: np.ndarray) -> np.ndarray: + def process_slice( + slice_: np.ndarray[tuple[int, int], np.dtype[np.bool_]], + ) -> np.ndarray[tuple[int, int], np.dtype[np.float64]]: e = binary_erosion(~slice_) ^ ~slice_ e = remove_small_objects(e, 100) e = ~remove_small_objects(~e, 50) @@ -329,13 +337,14 @@ def process_slice(slice_: np.ndarray) -> np.ndarray: edges = process_slice(binary) - return to_nan_mask(~edges) + return to_nan_mask(np.logical_not(edges)) @_register("otsu skel") def otsu_skel( - img: np.ndarray, limits: tuple[float, float] | list[float] | str | None = None -) -> np.ndarray: + img: np.ndarray[Any, np.dtype[np.float64]], + limits: tuple[float, float] | list[float] | str | None = None, +) -> np.ndarray[Any, np.dtype[np.float64]]: """ Skeletonize regions selected by Otsu thresholding. @@ -358,7 +367,9 @@ def otsu_skel( binary = ~(sm <= thresh) mbl = 10 # Minimum branch length - def _process_slice(slice_: np.ndarray) -> np.ndarray: + def _process_slice( + slice_: np.ndarray[tuple[int, int], np.dtype[np.bool_]], + ) -> np.ndarray[tuple[int, int], np.dtype[np.float64]]: labeled = label(slice_) thin_mask = thin(labeled) skel = skeletonize(thin_mask) @@ -371,13 +382,14 @@ def _process_slice(slice_: np.ndarray) -> np.ndarray: skeleton = _process_slice(binary) - return to_nan_mask(~skeleton) + return to_nan_mask(np.logical_not(skeleton)) @_register("hist skel") def hist_skel( - img: np.ndarray, limits: tuple[float, float] | list[float] | str | None = None -) -> np.ndarray: + img: np.ndarray[Any, np.dtype[np.float64]], + limits: tuple[float, float] | list[float] | str | None = None, +) -> np.ndarray[Any, np.dtype[np.float64]]: """ Skeletonize regions selected by histogram thresholding. @@ -404,7 +416,9 @@ def hist_skel( binary = ~((sm >= low) & (sm <= high)) mbl = 10 # Minimum branch length - def _process_slice(slice_: np.ndarray) -> np.ndarray: + def _process_slice( + slice_: np.ndarray[tuple[int, int], np.dtype[np.bool_]], + ) -> np.ndarray[tuple[int, int], np.dtype[np.float64]]: labeled = label(slice_) thin_mask = thin(labeled) skel = skeletonize(thin_mask) @@ -418,13 +432,14 @@ def _process_slice(slice_: np.ndarray) -> np.ndarray: skeleton = _process_slice(binary) - return to_nan_mask(~skeleton) + return to_nan_mask(np.logical_not(skeleton)) @_register("line_step") def line_step( - img: np.ndarray, limits: tuple[float, float] | list[float] | str | None = None -) -> np.ndarray: + img: np.ndarray[Any, np.dtype[np.float64]], + limits: tuple[float, float] | list[float] | str | None = None, +) -> np.ndarray[Any, np.dtype[np.float64]]: """ Detect step changes along each row using PELT change point detection. @@ -477,11 +492,11 @@ def line_step( def thresholder( - img: np.ndarray, + img: np.ndarray[Any, np.dtype[np.float64]], method: str, - limits: tuple[float, float] | list[float] | str | list[float] | str | None = None, + limits: tuple[float, float] | list[float] | str | None = None, invert: bool = False, -) -> np.ndarray: +) -> np.ndarray[Any, np.dtype[np.float64]]: """ Apply a thresholding or edge detection method to an image or stack. @@ -507,10 +522,12 @@ def thresholder( func = _METHOD_MAP[method] - result: np.ndarray + result: np.ndarray[Any, np.dtype[np.float64]] # Handle 3D stacks frame-by-frame if img.ndim == 3: - masks: list[np.ndarray] = [func(frame, limits) for frame in img] + masks: list[np.ndarray[Any, np.dtype[np.float64]]] = [ + func(frame, limits) for frame in img + ] result = np.stack(masks) else: result = func(img, limits) diff --git a/tests/test_level_weighted.py b/tests/test_level_weighted.py new file mode 100644 index 0000000..2488cad --- /dev/null +++ b/tests/test_level_weighted.py @@ -0,0 +1,188 @@ +"""Tests for the level_weighted module.""" + +import numpy as np +import pytest + +from pnanolocz_lib.level_weighted import ( + _center_scale_indices, + _find_regions, + _polyfit_centered, + _polyval_centered, + apply_level_weighted, + level_weighted_line, + level_weighted_med_line, + level_weighted_med_line_y, + level_weighted_plane, + level_weighted_smed_line, +) + + +def test_center_scale_indices_basic(): + """Test centering and scaling of a simple index array.""" + arr = np.array([0, 1, 2, 3, 4], dtype=float) + std, c, s = _center_scale_indices(arr) + np.testing.assert_allclose(std.mean(), 0, atol=1e-12) + assert s > 0 + + +def test_center_scale_indices_empty(): + """Test that empty index array returns defaults.""" + std, c, s = _center_scale_indices(np.array([], dtype=float)) + assert std.size == 0 + assert c == 0 + assert s == 1 + + +def test_polyfit_and_polyval_centered(): + """Test polynomial fitting and evaluation with centering/scaling.""" + x = np.array([0, 1, 2, 3, 4], dtype=float) + y = x**2 + coeffs, cent_scale = _polyfit_centered(x, y, 2) + y_fit = _polyval_centered(coeffs, cent_scale, x) + np.testing.assert_allclose(y_fit, y, atol=1e-12) + + +def test_find_regions_simple(): + """Test that connected foreground regions are correctly identified.""" + mask = np.zeros((5, 5), dtype=bool) + mask[1:3, 1:3] = True + mask[4, 4] = True + regions = _find_regions(mask, min_area=1) + assert len(regions) == 2 + assert all(isinstance(r, np.ndarray) for r in regions) + + +def test_level_weighted_plane_basic(): + """Test plane leveling on a small synthetic image.""" + img = np.arange(16).reshape(4, 4).astype(float) + regions = [np.arange(16)] + leveled = level_weighted_plane(img, regions, 1, 1) + assert leveled.shape == img.shape + + +def test_level_weighted_line_basic(): + """Test line leveling on a small synthetic image.""" + img = np.arange(16).reshape(4, 4).astype(float) + regions = [np.arange(16)] + leveled = level_weighted_line(img, regions, 1, 1) + assert leveled.shape == img.shape + + +def test_level_weighted_med_line_basic(): + """Test median line leveling along rows.""" + img = np.arange(16).reshape(4, 4).astype(float) + regions = [np.arange(16)] + leveled = level_weighted_med_line(img, regions) + assert leveled.shape == img.shape + + +def test_level_weighted_med_line_y_basic(): + """Test median line leveling along columns.""" + img = np.arange(16).reshape(4, 4).astype(float) + regions = [np.arange(16)] + leveled = level_weighted_med_line_y(img, regions) + assert leveled.shape == img.shape + + +def test_level_weighted_smed_line_basic(): + """Test smoothed median line leveling along rows.""" + img = np.arange(16).reshape(4, 4).astype(float) + regions = [np.arange(16)] + leveled = level_weighted_smed_line(img, regions, smoothing_window=2) + assert leveled.shape == img.shape + + +def test_apply_level_weighted_dispatch_plane(): + """Test dispatcher applies plane leveling to an image.""" + img = np.arange(16).reshape(4, 4).astype(float) + leveled = apply_level_weighted(img, 1, 1, method="plane") + assert leveled.shape == img.shape + + +def test_apply_level_weighted_dispatch_line(): + """Test dispatcher applies line leveling to an image.""" + img = np.arange(16).reshape(4, 4).astype(float) + leveled = apply_level_weighted(img, 1, 1, method="line") + assert leveled.shape == img.shape + + +def test_apply_level_weighted_dispatch_med_line(): + """Test dispatcher applies med_line leveling to an image.""" + img = np.arange(16).reshape(4, 4).astype(float) + leveled = apply_level_weighted(img, 0, 0, method="med_line") + assert leveled.shape == img.shape + + +def test_apply_level_weighted_dispatch_med_line_y(): + """Test dispatcher applies med_line_y leveling to an image.""" + img = np.arange(16).reshape(4, 4).astype(float) + leveled = apply_level_weighted(img, 0, 0, method="med_line_y") + assert leveled.shape == img.shape + + +def test_apply_level_weighted_dispatch_smed_line(): + """Test dispatcher applies smed_line leveling to an image.""" + img = np.arange(16).reshape(4, 4).astype(float) + leveled = apply_level_weighted(img, 0, 0, method="smed_line", smoothing_window=2) + assert leveled.shape == img.shape + + +def test_apply_level_weighted_with_mask(): + """Test that apply_level_weighted works correctly with a mask.""" + img = np.arange(16).reshape(4, 4).astype(float) + mask = np.zeros_like(img, dtype=bool) + mask[1:3, 1:3] = True + leveled = apply_level_weighted(img, 1, 1, method="plane", mask=mask) + assert leveled.shape == img.shape + + +@pytest.mark.parametrize( + "img,mask_coords,method", + [ + (np.ones((3, 3), dtype=float), [(0, 0)], "plane"), + (np.ones((3, 3), dtype=float), [(0, 0), (1, 1)], "line"), + (np.arange(9, dtype=float).reshape(3, 3), [(0, 1), (2, 2)], "med_line"), + (np.arange(25, dtype=float).reshape(5, 5), [(1, 1), (3, 3)], "med_line_y"), + (np.arange(16, dtype=float).reshape(4, 4), [(0, 0), (0, 1)], "smed_line"), + ], +) +def test_apply_level_weighted_basic_masks(img, mask_coords, method): + """Test apply_level_weighted on small images with simple masks.""" + mask = np.zeros_like(img, dtype=bool) + for i, j in mask_coords: + mask[i, j] = True + + leveled = apply_level_weighted(img, polyx=1, polyy=1, method=method, mask=mask) + + # Assert the output shape matches input + assert leveled.shape == img.shape + + # Assert that masked pixels are influenced but not NaN + assert np.all(np.isfinite(leveled[mask])) + + +@pytest.mark.parametrize( + "method", ["plane", "line", "med_line", "med_line_y", "smed_line"] +) +def test_apply_level_weighted_full_mask(method): + """Fully masked image should either return unchanged or handle gracefully.""" + img = np.ones((3, 3), dtype=float) + mask = np.ones_like(img, dtype=bool) + + leveled = apply_level_weighted(img, polyx=1, polyy=1, method=method, mask=mask) + + # In your implementation, full mask should not crash; output should be finite + assert np.all(np.isfinite(leveled)) + assert leveled.shape == img.shape + + +@pytest.mark.parametrize( + "method", ["plane", "line", "med_line", "med_line_y", "smed_line"] +) +def test_apply_level_weighted_no_mask(method): + """Test that leveling runs without a mask (whole image considered).""" + img = np.arange(9, dtype=float).reshape(3, 3) + leveled = apply_level_weighted(img, polyx=1, polyy=1, method=method) + + assert leveled.shape == img.shape + assert np.all(np.isfinite(leveled))