diff --git a/linumpy/stitching/interpolation.py b/linumpy/stitching/interpolation.py new file mode 100644 index 00000000..9fcaf2ef --- /dev/null +++ b/linumpy/stitching/interpolation.py @@ -0,0 +1,397 @@ +# -*- coding: utf-8 -*- +""" +Slice interpolation utilities for missing or degraded serial sections. + +Consolidated from linum_interpolate_missing_slice.py. +""" +import numpy as np +import SimpleITK as sitk +from scipy.ndimage import distance_transform_edt, gaussian_filter + +from linumpy.stitching.registration import register_2d_images_sitk, apply_transform +from linumpy.utils.image_quality import ( + compute_ssim_3d, compute_edge_score, compute_variance_score +) + + +def compute_half_affine_transform(transform): + """Compute a transform that is 'halfway' to the given transform. + + For affine transforms: decomposes the transform matrix via eigendecomposition + and applies the matrix square root (half rotation + half translation). + + Parameters + ---------- + transform : sitk.Transform + Full transform from image A to image B. + + Returns + ------- + sitk.AffineTransform + Transform representing half the transformation. + """ + if isinstance(transform, sitk.CompositeTransform): + transform = sitk.AffineTransform(transform.GetNthTransform(0)) + + dim = transform.GetDimension() + + if dim == 2: + half_transform = sitk.AffineTransform(2) + matrix = np.array(transform.GetMatrix()).reshape(2, 2) + translation = np.array(transform.GetTranslation()) + center = np.array(transform.GetCenter()) + + eigenvalues, eigenvectors = np.linalg.eig(matrix) + sqrt_eigenvalues = np.sqrt(eigenvalues.astype(complex)) + half_matrix = (eigenvectors @ np.diag(sqrt_eigenvalues) @ + np.linalg.inv(eigenvectors)).real + + # Correct half-translation: h(h(x)) = T(x) requires + # (H_m + I) * h_t = t => h_t = (H_m + I)^{-1} * t + half_translation = np.linalg.solve( + half_matrix + np.eye(2), translation + ) + + half_transform.SetMatrix(half_matrix.flatten().tolist()) + half_transform.SetTranslation(half_translation.tolist()) + half_transform.SetCenter(center.tolist()) + + elif dim == 3: + half_transform = sitk.AffineTransform(3) + matrix = np.array(transform.GetMatrix()).reshape(3, 3) + translation = np.array(transform.GetTranslation()) + center = np.array(transform.GetCenter()) + + eigenvalues, eigenvectors = np.linalg.eig(matrix) + sqrt_eigenvalues = np.sqrt(eigenvalues.astype(complex)) + half_matrix = (eigenvectors @ np.diag(sqrt_eigenvalues) @ + np.linalg.inv(eigenvectors)).real + + half_translation = np.linalg.solve( + half_matrix + np.eye(3), translation + ) + + half_transform.SetMatrix(half_matrix.flatten().tolist()) + half_transform.SetTranslation(half_translation.tolist()) + half_transform.SetCenter(center.tolist()) + else: + raise ValueError(f"Unsupported transform dimension: {dim}") + + return half_transform + + +def interpolate_average(vol_before: np.ndarray, vol_after: np.ndarray) -> np.ndarray: + """Simple 50/50 average of two adjacent volumes. + + Parameters + ---------- + vol_before : np.ndarray + Volume before missing slice (Z, X, Y). + vol_after : np.ndarray + Volume after missing slice (Z, X, Y). + + Returns + ------- + np.ndarray + Average volume. + """ + return 0.5 * vol_before.astype(np.float32) + 0.5 * vol_after.astype(np.float32) + + +def interpolate_weighted(vol_before: np.ndarray, + vol_after: np.ndarray, + sigma: float = 2.0) -> np.ndarray: + """Weighted average with Gaussian smoothing along Z. + + Parameters + ---------- + vol_before : np.ndarray + Volume before missing slice. + vol_after : np.ndarray + Volume after missing slice. + sigma : float + Gaussian smoothing sigma along Z-axis. + + Returns + ------- + np.ndarray + Weighted average. + """ + avg = 0.5 * vol_before.astype(np.float32) + 0.5 * vol_after.astype(np.float32) + return gaussian_filter(avg, sigma=(sigma, 0, 0)) + + +def find_best_overlap_planes(vol_before: np.ndarray, + vol_after: np.ndarray, + search_window: int = 5): + """Find the best-correlated plane pair at the boundary between two volumes. + + In serial sectioning, each OCT volume images the tissue surface that + remains after removing a physical slice. The physically adjacent tissue + is therefore near the **bottom** of *vol_before* and the **top** of + *vol_after*. Because the exact cut depth can vary slightly, this function + searches the last *search_window* z-planes of *vol_before* against the + first *search_window* z-planes of *vol_after* using normalized + cross-correlation on the central ROI, and returns the pair with the + highest correlation together with that correlation score. + + The returned correlation also serves as a quality gate: a low score (e.g. + below ~0.1) indicates that no reliable structural match was found and the + caller should fall back to a simpler interpolation strategy. + + Parameters + ---------- + vol_before : np.ndarray + 3D volume (Z, X, Y) before the missing slice. + vol_after : np.ndarray + 3D volume (Z, X, Y) after the missing slice. + search_window : int + Number of z-planes to search at each boundary. Default 5. + + Returns + ------- + ref_before : int + Best z-index in *vol_before*. + ref_after : int + Best z-index in *vol_after*. + best_corr : float + Normalized cross-correlation at the best pair (range approximately + [-1, 1]; higher is better). + """ + nz_before = vol_before.shape[0] + nz_after = vol_after.shape[0] + h, w = vol_before.shape[1], vol_before.shape[2] + margin = min(h, w) // 4 + roi = (slice(margin, h - margin), slice(margin, w - margin)) + + def _norm_roi(plane): + crop = plane[roi].astype(np.float32) + valid = crop > 0 + if valid.any(): + pmin = float(np.percentile(crop[valid], 5)) + pmax = float(np.percentile(crop[valid], 95)) + crop = np.clip((crop - pmin) / max(pmax - pmin, 1e-8), 0, 1) + return (crop - crop.mean()) / (crop.std() + 1e-8) + + before_zs = range(max(0, nz_before - search_window), nz_before) + after_zs = range(0, min(search_window, nz_after)) + + before_norms = {z: _norm_roi(vol_before[z]) for z in before_zs} + after_norms = {z: _norm_roi(vol_after[z]) for z in after_zs} + + best_corr = -np.inf + ref_before = nz_before - 1 + ref_after = 0 + + for zb in before_zs: + for za in after_zs: + corr = float(np.mean(before_norms[zb] * after_norms[za])) + if corr > best_corr: + best_corr = corr + ref_before = zb + ref_after = za + + return ref_before, ref_after, best_corr + + +def interpolate_registration_based(vol_before: np.ndarray, + vol_after: np.ndarray, + metric: str = 'MSE', + max_iterations: int = 1000, + reference_slice: int | None = None, + blend_method: str = 'gaussian', + overlap_search_window: int = 5, + min_overlap_correlation: float = 0.1) -> np.ndarray: + """Interpolate a missing slice using registration-based morphing. + + 1. Finds the best-correlated plane pair at the volume boundary using + ``find_best_overlap_planes`` (quality gate + best reference selection) + 2. Registers that pair of 2D planes to obtain the XY alignment transform + 3. Computes the half-transform representing the midpoint transformation + 4. Warps both volumes toward the midpoint + 5. Blends the results using linear or feathered (Gaussian) blending + + If the best overlap correlation is below *min_overlap_correlation*, the + volumes cannot be reliably aligned and a simple average is returned instead. + + Parameters + ---------- + vol_before : np.ndarray + 3D volume (Z, X, Y) before the missing slice. + vol_after : np.ndarray + 3D volume (Z, X, Y) after the missing slice. + metric : str + Registration metric: 'MSE', 'CC', or 'MI'. + max_iterations : int + Maximum registration iterations. + reference_slice : int or None + When provided, overrides automatic plane selection and uses this + z-index (clamped to each volume's bounds) as the registration + reference in both volumes. When *None* (default), + ``find_best_overlap_planes`` selects the best plane pair within + *overlap_search_window* planes from each boundary. + blend_method : str + 'linear' (50/50) or 'gaussian' (feathered distance-transform blend). + overlap_search_window : int + Number of z-planes to search at each boundary when selecting the + reference plane pair automatically. Ignored when *reference_slice* + is set. Default 5. + min_overlap_correlation : float + Minimum normalized cross-correlation required to proceed with + registration. Below this threshold the volumes are considered + mismatched and a plain average is returned. Default 0.1. + + Returns + ------- + np.ndarray + Interpolated 3D volume. + """ + nz_before, nx, ny = vol_before.shape + nz_after = vol_after.shape[0] + nz_out = min(nz_before, nz_after) + + if reference_slice is None: + ref_before, ref_after, best_corr = find_best_overlap_planes( + vol_before, vol_after, search_window=overlap_search_window + ) + if best_corr < min_overlap_correlation: + print(f" [interpolation] Overlap correlation {best_corr:.3f} is below threshold " + f"{min_overlap_correlation:.3f} — falling back to simple average.") + return interpolate_average(vol_before[:nz_out], vol_after[:nz_out]) + print(f" [interpolation] Best overlap: before[{ref_before}] ↔ after[{ref_after}] " + f"(corr={best_corr:.3f})") + else: + ref_before = min(reference_slice, nz_before - 1) + ref_after = min(reference_slice, nz_after - 1) + + fixed_2d = vol_after[ref_after].astype(np.float32) + moving_2d = vol_before[ref_before].astype(np.float32) + + mn, mx = fixed_2d.min(), fixed_2d.max() + if mx > mn: + fixed_2d = (fixed_2d - mn) / (mx - mn) + mn, mx = moving_2d.min(), moving_2d.max() + if mx > mn: + moving_2d = (moving_2d - mn) / (mx - mn) + + transform_2d, _, _ = register_2d_images_sitk( + fixed_2d, moving_2d, + method='affine', + metric=metric, + max_iterations=max_iterations, + return_3d_transform=False, + verbose=False + ) + + half_transform = compute_half_affine_transform(transform_2d) + inv_half_transform = half_transform.GetInverse() + + warped_before = np.zeros((nz_out, nx, ny), dtype=np.float32) + warped_after = np.zeros((nz_out, nx, ny), dtype=np.float32) + + for z in range(nz_out): + warped_before[z] = apply_transform(vol_before[z].astype(np.float32), half_transform) + warped_after[z] = apply_transform(vol_after[z].astype(np.float32), inv_half_transform) + + if blend_method == 'linear': + return 0.5 * warped_before + 0.5 * warped_after + + elif blend_method == 'gaussian': + mask_before = warped_before > 0 + mask_after = warped_after > 0 + + dist_before = np.zeros((nz_out, nx, ny), dtype=np.float32) + dist_after = np.zeros((nz_out, nx, ny), dtype=np.float32) + + for z in range(nz_out): + if np.any(mask_before[z]): + dist_before[z] = distance_transform_edt(mask_before[z]) + if np.any(mask_after[z]): + dist_after[z] = distance_transform_edt(mask_after[z]) + + dist_before = gaussian_filter(dist_before, sigma=(0, 2, 2)) + dist_after = gaussian_filter(dist_after, sigma=(0, 2, 2)) + + total_dist = dist_before + dist_after + 1e-10 + w_before = dist_before / total_dist + w_after = dist_after / total_dist + + only_before = mask_before & ~mask_after + only_after = mask_after & ~mask_before + w_before[only_before] = 1.0 + w_after[only_before] = 0.0 + w_before[only_after] = 0.0 + w_after[only_after] = 1.0 + + return w_before * warped_before + w_after * warped_after + + raise ValueError(f"Unknown blend_method: {blend_method}") + + +def assess_degraded_slice_quality(vol_degraded: np.ndarray, + vol_before: np.ndarray, + vol_after: np.ndarray): + """Automatically assess the quality of a degraded slice. + + Uses SSIM (weight 0.5), edge preservation (0.3), and variance (0.2). + + Parameters + ---------- + vol_degraded : np.ndarray + The degraded slice volume. + vol_before : np.ndarray + Volume before the degraded slice. + vol_after : np.ndarray + Volume after the degraded slice. + + Returns + ------- + quality_score : float + Score from 0 (unusable) to 1 (perfect). + metrics : dict + Individual metric scores. + """ + reference = 0.5 * vol_before.astype(np.float32) + 0.5 * vol_after.astype(np.float32) + + ssim_before = compute_ssim_3d(vol_degraded, vol_before) + ssim_after = compute_ssim_3d(vol_degraded, vol_after) + ssim_score = (ssim_before + ssim_after) / 2 + + edge_score = compute_edge_score(vol_degraded, reference) + variance_score = compute_variance_score(vol_degraded, reference) + + quality_score = 0.5 * ssim_score + 0.3 * edge_score + 0.2 * variance_score + + metrics = { + 'ssim_before': ssim_before, + 'ssim_after': ssim_after, + 'ssim_mean': ssim_score, + 'edge_preservation': edge_score, + 'variance_ratio': variance_score, + 'overall': quality_score + } + + return quality_score, metrics + + +def blend_with_degraded(interpolated: np.ndarray, + degraded: np.ndarray, + quality_weight: float) -> np.ndarray: + """Blend an interpolated result with a degraded slice weighted by quality. + + Parameters + ---------- + interpolated : np.ndarray + Pure interpolated volume. + degraded : np.ndarray + Degraded slice volume. + quality_weight : float + Weight for degraded slice (0 = use interpolated, 1 = use degraded). + + Returns + ------- + np.ndarray + Blended result. + """ + w = quality_weight + return w * degraded.astype(np.float32) + (1 - w) * interpolated.astype(np.float32) diff --git a/linumpy/tests/test_stitching_interpolation.py b/linumpy/tests/test_stitching_interpolation.py new file mode 100644 index 00000000..0179f714 --- /dev/null +++ b/linumpy/tests/test_stitching_interpolation.py @@ -0,0 +1,149 @@ +# -*- coding: utf-8 -*- +"""Tests for linumpy/stitching/interpolation.py""" +import numpy as np +import pytest + +from linumpy.stitching.interpolation import ( + assess_degraded_slice_quality, + blend_with_degraded, + interpolate_average, + interpolate_weighted, +) + + +def _vol(shape=(8, 16, 16), seed=0): + rng = np.random.default_rng(seed) + return (rng.random(shape) * 100.0).astype(np.float32) + + +# --------------------------------------------------------------------------- +# interpolate_average +# --------------------------------------------------------------------------- + +def test_interpolate_average_shape(): + before = _vol() + after = _vol(seed=1) + result = interpolate_average(before, after) + assert result.shape == before.shape + + +def test_interpolate_average_midpoint(): + """Result should be the exact arithmetic mean.""" + before = np.zeros((4, 8, 8), dtype=np.float32) + after = np.full((4, 8, 8), 2.0, dtype=np.float32) + result = interpolate_average(before, after) + np.testing.assert_allclose(result, 1.0) + + +def test_interpolate_average_identical(): + """Averaging a volume with itself preserves values.""" + vol = _vol() + result = interpolate_average(vol, vol) + np.testing.assert_allclose(result, vol, rtol=1e-5) + + +def test_interpolate_average_dtype_float32(): + """Output dtype should be float32.""" + before = np.ones((4, 8, 8), dtype=np.uint8) + after = np.ones((4, 8, 8), dtype=np.uint8) + result = interpolate_average(before, after) + assert result.dtype == np.float32 + + +# --------------------------------------------------------------------------- +# interpolate_weighted +# --------------------------------------------------------------------------- + +def test_interpolate_weighted_shape(): + before = _vol() + after = _vol(seed=1) + result = interpolate_weighted(before, after, sigma=1.0) + assert result.shape == before.shape + + +def test_interpolate_weighted_same_result_as_average_when_sigma_zero(): + """With sigma ≈ 0, weighted should be close to simple average.""" + before = np.zeros((6, 8, 8), dtype=np.float32) + after = np.full((6, 8, 8), 4.0, dtype=np.float32) + avg = interpolate_average(before, after) + weighted = interpolate_weighted(before, after, sigma=0.01) + np.testing.assert_allclose(weighted, avg, rtol=0.05) + + +def test_interpolate_weighted_smoothing_reduces_variance(): + """Larger sigma should produce smoother output (lower std dev along Z).""" + rng = np.random.default_rng(42) + before = rng.random((20, 8, 8)).astype(np.float32) + after = rng.random((20, 8, 8)).astype(np.float32) + std_low_sigma = interpolate_weighted(before, after, sigma=0.1).std() + std_high_sigma = interpolate_weighted(before, after, sigma=3.0).std() + assert std_high_sigma < std_low_sigma + + +# --------------------------------------------------------------------------- +# blend_with_degraded +# --------------------------------------------------------------------------- + +def test_blend_with_degraded_pure_interpolated(): + """quality_weight=0 → output equals interpolated.""" + interp = np.ones((4, 8, 8), dtype=np.float32) + degraded = np.full((4, 8, 8), 10.0, dtype=np.float32) + result = blend_with_degraded(interp, degraded, quality_weight=0.0) + np.testing.assert_allclose(result, interp) + + +def test_blend_with_degraded_pure_degraded(): + """quality_weight=1 → output equals degraded.""" + interp = np.ones((4, 8, 8), dtype=np.float32) + degraded = np.full((4, 8, 8), 10.0, dtype=np.float32) + result = blend_with_degraded(interp, degraded, quality_weight=1.0) + np.testing.assert_allclose(result, degraded) + + +def test_blend_with_degraded_half_weight(): + """quality_weight=0.5 → average of interpolated and degraded.""" + interp = np.zeros((4, 8, 8), dtype=np.float32) + degraded = np.full((4, 8, 8), 4.0, dtype=np.float32) + result = blend_with_degraded(interp, degraded, quality_weight=0.5) + np.testing.assert_allclose(result, 2.0) + + +def test_blend_with_degraded_shape_preserved(): + interp = _vol() + degraded = _vol(seed=1) + result = blend_with_degraded(interp, degraded, quality_weight=0.3) + assert result.shape == interp.shape + + +# --------------------------------------------------------------------------- +# assess_degraded_slice_quality +# --------------------------------------------------------------------------- + +def test_assess_degraded_slice_quality_perfect_quality(): + """If degraded == reference, quality score should be near 1.""" + rng = np.random.default_rng(10) + vol = (rng.random((8, 16, 16)) * 100.0).astype(np.float32) + score, metrics = assess_degraded_slice_quality(vol, vol, vol) + assert 0.0 <= score <= 1.0 + # Perfect match → quality near 1 + assert score > 0.8 + + +def test_assess_degraded_slice_quality_zeros_degrade_score(): + """Zero-filled degraded slice should have low quality score.""" + rng = np.random.default_rng(11) + before = (rng.random((8, 16, 16)) * 100.0 + 1.0).astype(np.float32) + after = (rng.random((8, 16, 16)) * 100.0 + 1.0).astype(np.float32) + degraded = np.zeros_like(before) + score, metrics = assess_degraded_slice_quality(degraded, before, after) + assert 0.0 <= score <= 1.0 + assert score < 0.5 + + +def test_assess_degraded_slice_quality_returns_metrics_dict(): + rng = np.random.default_rng(12) + vol = rng.random((6, 12, 12)).astype(np.float32) + _, metrics = assess_degraded_slice_quality(vol, vol, vol) + expected_keys = {'ssim_before', 'ssim_after', 'ssim_mean', + 'edge_preservation', 'variance_ratio', 'overall'} + assert expected_keys.issubset(set(metrics.keys())) diff --git a/scripts/linum_interpolate_missing_slice.py b/scripts/linum_interpolate_missing_slice.py new file mode 100644 index 00000000..75061f12 --- /dev/null +++ b/scripts/linum_interpolate_missing_slice.py @@ -0,0 +1,546 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Interpolate a missing or degraded slice using information from adjacent slices. + +This script implements registration-based morphing interpolation to reconstruct +a missing slice in a serial sectioning dataset. The method: + +1. Registers the slice before the gap to the slice after +2. Computes a half-transform (midpoint transformation) +3. Warps both adjacent slices toward the midpoint +4. Blends the warped slices to create the interpolated result +5. Optionally blends with a degraded slice if available + +If a degraded slice is provided, its quality is automatically assessed using +structural similarity (SSIM) and edge preservation metrics. The degraded slice +is blended with the interpolated result based on its quality score. + +This approach is based on motion-compensated interpolation techniques commonly +used in video frame interpolation, adapted for 3D microscopy volumes. + +References: +- Lee et al. (1991) "Shape-based interpolation of multidimensional grey-level + images", IEEE Trans. Medical Imaging +- Bao et al. (2019) "Depth-Aware Video Frame Interpolation", CVPR +- Penney et al. (2004) "A comparison of similarity measures for use in 2-D-3-D + medical image registration", IEEE Trans. Medical Imaging + +Note: This method is only suitable for interpolating a SINGLE missing slice. +When two or more consecutive slices are missing, there is insufficient +information for accurate reconstruction. + +Example usage: + # Without degraded slice (pure interpolation) + linum_interpolate_missing_slice.py slice_z00.ome.zarr slice_z02.ome.zarr \\ + slice_z01_interpolated.ome.zarr --method registration + + # With degraded slice (quality-weighted blend) + linum_interpolate_missing_slice.py slice_z00.ome.zarr slice_z02.ome.zarr \\ + slice_z01_interpolated.ome.zarr --degraded_slice slice_z01_bad.ome.zarr +""" +# Configure thread limits before numpy/scipy imports +import linumpy._thread_config # noqa: F401 + +import argparse +import numpy as np +import SimpleITK as sitk +import matplotlib.pyplot as plt +from pathlib import Path + +from linumpy.io.zarr import read_omezarr, save_omezarr +from linumpy.stitching.interpolation import ( + compute_half_affine_transform, + interpolate_average, + interpolate_weighted, + interpolate_registration_based, + assess_degraded_slice_quality, + blend_with_degraded, +) +from linumpy.stitching.registration import register_2d_images_sitk, apply_transform +from linumpy.utils.io import add_overwrite_arg, assert_output_exists +from linumpy.utils.image_quality import ( + compute_ssim_3d, + compute_edge_score, + compute_variance_score, +) +import dask.array as da + +# Configure all libraries (especially SimpleITK) to respect thread limits +from linumpy._thread_config import configure_all_libraries +configure_all_libraries() + + +def _build_arg_parser(): + p = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawTextHelpFormatter) + p.add_argument("slice_before", + help="Path to the slice BEFORE the missing slice (*.ome.zarr)") + p.add_argument("slice_after", + help="Path to the slice AFTER the missing slice (*.ome.zarr)") + p.add_argument("output", + help="Output path for the interpolated slice (*.ome.zarr)") + p.add_argument("--method", choices=['registration', 'average', 'weighted'], + default='registration', + help="Interpolation method:\n" + " registration - Registration-based morphing (recommended)\n" + " average - Simple average of adjacent slices\n" + " weighted - Weighted average with distance falloff\n" + "[default: %(default)s]") + p.add_argument("--blend_method", choices=['linear', 'gaussian'], + default='gaussian', + help="Blending method for combining warped slices:\n" + " linear - Equal 50/50 blend (may show edges)\n" + " gaussian - Feathered blend using distance transform (recommended)\n" + "[default: %(default)s]") + p.add_argument("--registration_metric", choices=['MSE', 'CC', 'MI'], + default='MSE', + help="Metric for registration [default: %(default)s]") + p.add_argument("--max_iterations", type=int, default=1000, + help="Maximum iterations for registration [default: %(default)s]") + p.add_argument("--reference_slice", type=int, default=None, + help="Z-index (in each volume, clamped to that volume's bounds)\n" + "to use as the registration reference plane.\n" + "If not specified, the best-correlated plane pair within\n" + "--overlap_search_window planes from each boundary is used\n" + "automatically (recommended).") + p.add_argument("--overlap_search_window", type=int, default=5, + help="Number of z-planes to search at each volume boundary\n" + "when selecting the registration reference pair automatically.\n" + "Ignored when --reference_slice is set. [default: %(default)s]") + p.add_argument("--min_overlap_correlation", type=float, default=0.1, + help="Minimum normalized cross-correlation required between the\n" + "boundary planes to proceed with registration. Below this\n" + "threshold the method falls back to a simple average.\n" + "[default: %(default)s]") + + # Degraded slice options + degraded_group = p.add_argument_group('Degraded Slice Options', + 'Use a degraded/damaged slice to improve interpolation') + degraded_group.add_argument("--degraded_slice", type=str, default=None, + help="Path to a degraded slice that has usable data (*.ome.zarr).\n" + "Quality is automatically assessed and blended accordingly.") + degraded_group.add_argument("--degraded_weight", type=float, default=None, + help="Manual override for degraded slice weight (0.0-1.0).\n" + "If not specified, weight is automatically computed from quality.") + degraded_group.add_argument("--min_quality_threshold", type=float, default=0.2, + help="Minimum quality score to use degraded slice.\n" + "Below this threshold, degraded slice is ignored.\n" + "[default: %(default)s]") + + # Preview/debug options + preview_group = p.add_argument_group('Preview Options', + 'Generate visual previews for quality checking') + preview_group.add_argument("--preview", type=str, default=None, + help="Path to save a preview image (PNG) showing:\n" + "- Slice before, slice after\n" + "- Interpolated result\n" + "- Degraded slice (if provided)\n" + "Useful for verifying interpolation quality.") + preview_group.add_argument("--preview_slice", type=int, default=None, + help="Z-index to use for preview. Default: middle slice.") + preview_group.add_argument("--preview_dpi", type=int, default=150, + help="DPI for preview image [default: %(default)s]") + + add_overwrite_arg(p) + return p + + +def invert_transform(transform): + """ + Compute the inverse of a transform. + + Parameters + ---------- + transform : sitk.Transform + Transform to invert. + + Returns + ------- + sitk.Transform + Inverted transform. + """ + return transform.GetInverse() + + +def register_slices_2d(fixed_slice, moving_slice, metric='MSE', max_iterations=1000): + """ + Register a 2D slice from moving volume to fixed volume. + + Parameters + ---------- + fixed_slice : np.ndarray + 2D fixed image. + moving_slice : np.ndarray + 2D moving image. + metric : str + Registration metric. + max_iterations : int + Maximum iterations. + + Returns + ------- + sitk.Transform + The computed transform. + float + Registration error metric. + """ + # Normalize images for registration + fixed_norm = fixed_slice.astype(np.float32) + moving_norm = moving_slice.astype(np.float32) + + # Normalize to [0, 1] + if fixed_norm.max() > fixed_norm.min(): + fixed_norm = (fixed_norm - fixed_norm.min()) / (fixed_norm.max() - fixed_norm.min()) + if moving_norm.max() > moving_norm.min(): + moving_norm = (moving_norm - moving_norm.min()) / (moving_norm.max() - moving_norm.min()) + + transform, _, error = register_2d_images_sitk( + fixed_norm, moving_norm, + method='affine', + metric=metric, + max_iterations=max_iterations, + return_3d_transform=False, + verbose=False + ) + + return transform, error + + +def generate_preview(vol_before, vol_after, interpolated, output_path, + vol_degraded=None, final_result=None, + preview_slice=None, dpi=150, + degraded_weight=None, quality_threshold=None): + """ + Generate a preview image showing the interpolation results. + + Parameters + ---------- + vol_before : np.ndarray + Volume before the missing slice. + vol_after : np.ndarray + Volume after the missing slice. + interpolated : np.ndarray + Pure interpolated result. + output_path : str or Path + Path to save the preview image. + vol_degraded : np.ndarray, optional + Degraded slice volume if provided. + final_result : np.ndarray, optional + Final result after blending with degraded (if different from interpolated). + preview_slice : int, optional + Z-index to use for preview. Default: middle slice. + dpi : int + DPI for the output image. + degraded_weight : float, optional + Weight used for degraded slice blending. + quality_threshold : float, optional + Quality threshold used. + """ + # Determine slice index + if preview_slice is None: + preview_slice = vol_before.shape[0] // 2 + preview_slice = max(0, min(preview_slice, vol_before.shape[0] - 1)) + + # Normalize function for display + def normalize_for_display(img): + img = img.astype(np.float32) + p1, p99 = np.percentile(img[img > 0], [1, 99]) if np.any(img > 0) else (0, 1) + if p99 > p1: + img = (img - p1) / (p99 - p1) + return np.clip(img, 0, 1) + + # Extract slices + before_slice = normalize_for_display(vol_before[preview_slice]) + after_slice = normalize_for_display(vol_after[preview_slice]) + interp_slice = normalize_for_display(interpolated[preview_slice]) + + # Determine layout based on whether we have degraded slice + has_degraded = vol_degraded is not None + has_final = final_result is not None and not np.allclose(final_result, interpolated) + + if has_degraded and has_final: + # 2x3 layout: before, after, degraded, interpolated, final, difference + fig, axes = plt.subplots(2, 3, figsize=(15, 10)) + axes = axes.flatten() + + degraded_slice = normalize_for_display(vol_degraded[preview_slice]) + final_slice = normalize_for_display(final_result[preview_slice]) + + # Row 1: inputs + axes[0].imshow(before_slice, cmap='gray') + axes[0].set_title('Slice Before (input)') + axes[0].axis('off') + + axes[1].imshow(after_slice, cmap='gray') + axes[1].set_title('Slice After (input)') + axes[1].axis('off') + + axes[2].imshow(degraded_slice, cmap='gray') + title = f'Degraded Slice' + if degraded_weight is not None: + title += f'\n(quality={degraded_weight:.2f})' + axes[2].set_title(title) + axes[2].axis('off') + + # Row 2: outputs + axes[3].imshow(interp_slice, cmap='gray') + axes[3].set_title('Pure Interpolation') + axes[3].axis('off') + + axes[4].imshow(final_slice, cmap='gray') + title = 'Final Result (blended)' + if degraded_weight is not None: + title += f'\n(w={degraded_weight:.2f})' + axes[4].set_title(title) + axes[4].axis('off') + + # Difference image + diff = np.abs(interp_slice - degraded_slice) + axes[5].imshow(diff, cmap='hot') + axes[5].set_title('|Interpolated - Degraded|') + axes[5].axis('off') + + elif has_degraded: + # 2x2 layout: before, after, interpolated, degraded + fig, axes = plt.subplots(2, 2, figsize=(12, 10)) + axes = axes.flatten() + + degraded_slice = normalize_for_display(vol_degraded[preview_slice]) + + axes[0].imshow(before_slice, cmap='gray') + axes[0].set_title('Slice Before (input)') + axes[0].axis('off') + + axes[1].imshow(after_slice, cmap='gray') + axes[1].set_title('Slice After (input)') + axes[1].axis('off') + + axes[2].imshow(interp_slice, cmap='gray') + axes[2].set_title('Interpolated (output)') + axes[2].axis('off') + + axes[3].imshow(degraded_slice, cmap='gray') + title = f'Degraded (not used)' + if degraded_weight is not None: + title = f'Degraded (q={degraded_weight:.2f} < {quality_threshold})' + axes[3].set_title(title) + axes[3].axis('off') + + else: + # 2x2 layout: before, after, interpolated, XZ view + fig, axes = plt.subplots(2, 2, figsize=(12, 10)) + axes = axes.flatten() + + axes[0].imshow(before_slice, cmap='gray') + axes[0].set_title('Slice Before (input)') + axes[0].axis('off') + + axes[1].imshow(after_slice, cmap='gray') + axes[1].set_title('Slice After (input)') + axes[1].axis('off') + + axes[2].imshow(interp_slice, cmap='gray') + axes[2].set_title('Interpolated (output)') + axes[2].axis('off') + + # Show XZ cross-section to visualize z-continuity + y_mid = vol_before.shape[1] // 2 + xz_before = normalize_for_display(vol_before[:, y_mid, :]) + xz_interp = normalize_for_display(interpolated[:, y_mid, :]) + xz_after = normalize_for_display(vol_after[:, y_mid, :]) + + # Stack them for comparison + xz_combined = np.vstack([xz_before, xz_interp, xz_after]) + axes[3].imshow(xz_combined, cmap='gray', aspect='auto') + axes[3].set_title('XZ View: Before | Interp | After') + axes[3].axhline(y=xz_before.shape[0], color='cyan', linestyle='--', linewidth=0.5) + axes[3].axhline(y=xz_before.shape[0] + xz_interp.shape[0], color='cyan', linestyle='--', linewidth=0.5) + axes[3].axis('off') + + fig.suptitle(f'Slice Interpolation Preview (z={preview_slice})', fontsize=14) + fig.tight_layout() + fig.savefig(output_path, dpi=dpi, bbox_inches='tight') + plt.close(fig) + print(f"Preview saved to: {output_path}") + + +def main(): + p = _build_arg_parser() + args = p.parse_args() + + # Validate inputs + slice_before_path = Path(args.slice_before) + slice_after_path = Path(args.slice_after) + output_path = Path(args.output) + + if not slice_before_path.exists(): + p.error(f"Slice before not found: {slice_before_path}") + if not slice_after_path.exists(): + p.error(f"Slice after not found: {slice_after_path}") + + assert_output_exists(output_path, p, args) + + print(f"Loading slice before: {slice_before_path}") + vol_before, res_before = read_omezarr(slice_before_path) + vol_before = np.array(vol_before) + + print(f"Loading slice after: {slice_after_path}") + vol_after, res_after = read_omezarr(slice_after_path) + vol_after = np.array(vol_after) + + # Handle shape mismatches + if vol_before.shape != vol_after.shape: + print(f"Shape mismatch detected: {vol_before.shape} vs {vol_after.shape}") + + # Handle z-dimension mismatch by truncating to minimum + min_z = min(vol_before.shape[0], vol_after.shape[0]) + if vol_before.shape[0] != vol_after.shape[0]: + print(f" Truncating z-dimension to minimum: {min_z}") + vol_before = vol_before[:min_z] + vol_after = vol_after[:min_z] + + # Handle X/Y dimension mismatch by using maximum and zero-padding + if vol_before.shape[1:] != vol_after.shape[1:]: + max_x = max(vol_before.shape[1], vol_after.shape[1]) + max_y = max(vol_before.shape[2], vol_after.shape[2]) + print(f" Adjusting X/Y dimensions to: ({max_x}, {max_y})") + + # Pad vol_before if needed + if vol_before.shape[1] < max_x or vol_before.shape[2] < max_y: + padded = np.zeros((min_z, max_x, max_y), dtype=vol_before.dtype) + padded[:, :vol_before.shape[1], :vol_before.shape[2]] = vol_before + vol_before = padded + + # Pad vol_after if needed + if vol_after.shape[1] < max_x or vol_after.shape[2] < max_y: + padded = np.zeros((min_z, max_x, max_y), dtype=vol_after.dtype) + padded[:, :vol_after.shape[1], :vol_after.shape[2]] = vol_after + vol_after = padded + + print(f" Adjusted shapes: {vol_before.shape}") + + # Store original z-depth for output (use the target z-depth, which is average of neighbors) + output_z_depth = vol_before.shape[0] + + # Validate resolutions match + if res_before != res_after: + print(f"Warning: Resolution mismatch: {res_before} vs {res_after}") + + # Load degraded slice if provided + vol_degraded = None + if args.degraded_slice is not None: + degraded_path = Path(args.degraded_slice) + if degraded_path.exists(): + print(f"Loading degraded slice: {degraded_path}") + vol_degraded, res_degraded = read_omezarr(degraded_path) + vol_degraded = np.array(vol_degraded) + + if vol_degraded.shape != vol_before.shape: + print(f"Degraded slice shape mismatch: {vol_degraded.shape} vs {vol_before.shape}") + # Try to adjust degraded slice to match + target_shape = vol_before.shape + try: + # Truncate z if needed + if vol_degraded.shape[0] > target_shape[0]: + vol_degraded = vol_degraded[:target_shape[0]] + elif vol_degraded.shape[0] < target_shape[0]: + # Pad z with zeros + padded = np.zeros(target_shape, dtype=vol_degraded.dtype) + padded[:vol_degraded.shape[0]] = vol_degraded + vol_degraded = padded + + # Handle X/Y mismatch + if vol_degraded.shape[1:] != target_shape[1:]: + padded = np.zeros(target_shape, dtype=vol_degraded.dtype) + min_x = min(vol_degraded.shape[1], target_shape[1]) + min_y = min(vol_degraded.shape[2], target_shape[2]) + padded[:, :min_x, :min_y] = vol_degraded[:, :min_x, :min_y] + vol_degraded = padded + + print(f" Adjusted degraded slice shape to: {vol_degraded.shape}") + except Exception as e: + print(f" Could not adjust degraded slice shape, ignoring: {e}") + vol_degraded = None + else: + print(f"Warning: Degraded slice not found, proceeding without it: {degraded_path}") + + print(f"Volume shape: {vol_before.shape}") + print(f"Resolution: {res_before}") + print(f"Method: {args.method}") + + # Perform interpolation + if args.method == 'registration': + print("Performing registration-based interpolation...") + interpolated = interpolate_registration_based( + vol_before, vol_after, + metric=args.registration_metric, + max_iterations=args.max_iterations, + reference_slice=args.reference_slice, + blend_method=args.blend_method, + overlap_search_window=args.overlap_search_window, + min_overlap_correlation=args.min_overlap_correlation, + ) + elif args.method == 'average': + print("Performing simple average interpolation...") + interpolated = interpolate_average(vol_before, vol_after) + elif args.method == 'weighted': + print("Performing weighted average interpolation...") + interpolated = interpolate_weighted(vol_before, vol_after) + else: + p.error(f"Unknown method: {args.method}") + + # Blend with degraded slice if available + final_result = interpolated + quality_weight = None + used_degraded = False + if vol_degraded is not None: + # Determine quality weight + if args.degraded_weight is not None: + # Manual override + quality_weight = args.degraded_weight + print(f"Using manual degraded weight: {quality_weight:.3f}") + else: + # Automatic quality assessment + quality_weight, metrics = assess_degraded_slice_quality( + vol_degraded, vol_before, vol_after + ) + + # Check if quality is above threshold + if quality_weight >= args.min_quality_threshold: + print(f"Blending with degraded slice (weight={quality_weight:.3f})") + final_result = blend_with_degraded(interpolated, vol_degraded, quality_weight) + used_degraded = True + else: + print(f"Degraded slice quality ({quality_weight:.3f}) below threshold " + f"({args.min_quality_threshold}), using pure interpolation") + + # Generate preview if requested + if args.preview is not None: + print(f"Generating preview...") + generate_preview( + vol_before, vol_after, interpolated, + output_path=args.preview, + vol_degraded=vol_degraded, + final_result=final_result if used_degraded else None, + preview_slice=args.preview_slice, + dpi=args.preview_dpi, + degraded_weight=quality_weight, + quality_threshold=args.min_quality_threshold + ) + + # Convert to original dtype if needed + original_dtype = vol_before.dtype + if np.issubdtype(original_dtype, np.integer): + final_result = np.clip(final_result, 0, np.iinfo(original_dtype).max) + final_result = final_result.astype(original_dtype) + + # Save result + print(f"Saving interpolated slice to: {output_path}") + save_omezarr(da.from_array(final_result), str(output_path), res_before) + + print("Done!") + + +if __name__ == "__main__": + main() diff --git a/scripts/tests/test_interpolate_missing_slice.py b/scripts/tests/test_interpolate_missing_slice.py new file mode 100644 index 00000000..85f08ee5 --- /dev/null +++ b/scripts/tests/test_interpolate_missing_slice.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- + + +def test_help(script_runner): + ret = script_runner.run(['linum_interpolate_missing_slice.py', '--help']) + assert ret.success + + +def test_average_method(script_runner, tmp_path): + """Test basic interpolation with average method using synthetic data.""" + import numpy as np + from linumpy.io.zarr import save_omezarr + import dask.array as da + + # Create synthetic test volumes + shape = (10, 32, 32) + resolution = (0.001, 0.01, 0.01) # mm/pixel + + # Create two slightly different volumes + vol_before = np.random.rand(*shape).astype(np.float32) * 100 + vol_after = np.random.rand(*shape).astype(np.float32) * 100 + + # Save as ome.zarr + slice_before = tmp_path / 'slice_z00.ome.zarr' + slice_after = tmp_path / 'slice_z02.ome.zarr' + output = tmp_path / 'slice_z01_interpolated.ome.zarr' + + save_omezarr(da.from_array(vol_before), str(slice_before), resolution) + save_omezarr(da.from_array(vol_after), str(slice_after), resolution) + + # Run interpolation with average method (faster for testing) + ret = script_runner.run([ + 'linum_interpolate_missing_slice.py', + str(slice_before), + str(slice_after), + str(output), + '--method', 'average' + ]) + + assert ret.success + assert output.exists() +