diff --git a/autoemxsp/core/EM_controller.py b/autoemxsp/core/EM_controller.py index b1e02ec..f76dfc7 100644 --- a/autoemxsp/core/EM_controller.py +++ b/autoemxsp/core/EM_controller.py @@ -161,6 +161,9 @@ import numpy as np from PIL import Image +#Drift estimation imports +from autoemxsp.utils.drift_correction import estimate_drift_of + # Typing (Python 3.5+) from typing import Any, List, Optional, Tuple, Union @@ -1270,7 +1273,8 @@ def get_image_drift_pixels(self): # ADD CODE TO EVALUATE drift HERE # CONSIDER THAT THIS VECTOR WILL BE ADDED TO THE PREVIOUSLY DETERMINED (X,Y) TO ONBTAIN THE NEW POSITION, # SO ENSURE THAT THE SIGNS ARE CORRECT - + drift_vector = estimate_drift_of(self.ref_image, current_img) + return drift_vector diff --git a/autoemxsp/utils/drift_correction.py b/autoemxsp/utils/drift_correction.py new file mode 100644 index 0000000..98590f4 --- /dev/null +++ b/autoemxsp/utils/drift_correction.py @@ -0,0 +1,134 @@ +import numpy as np +import cv2 +from skimage.registration import phase_cross_correlation +from skimage.filters import window +from typing import Tuple + +def preprocess_image(img: np.ndarray) -> np.ndarray: + """ + Normalize and apply a window function to reduce FFT edge effects. + Critical for non-periodic real-world images. + """ + img_float = img.astype(np.float32) + + # Range normalization [0, 1] + denom = img_float.max() - img_float.min() + if denom > 1e-6: + img_float = (img_float - img_float.min()) / denom + + # Apply Hann window + # window() creates an N-D window matching the image shape + w = window('hann', img_float.shape) + return img_float * w + + +def shift_image(img: np.ndarray, dx: float, dy: float) -> np.ndarray: + """Apply an affine shift (translation) to an image.""" + M = np.float32([[1, 0, dx], [0, 1, dy]]) + return cv2.warpAffine( + img, + M, + (img.shape[1], img.shape[0]), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_REFLECT, + ) + +def estimate_drift_of(img1: np.ndarray, img2: np.ndarray) -> Tuple[float, float]: + """ + Robust estimate of the translation (dx, dy) of img2 relative to img1. + + Uses a 2-stage hierarchical approach: + 1. Coarse estimation on downscaled images to handle large displacements. + 2. Fine estimation using an ensemble of crops on the full-resolution images, + corrected by the coarse shift. + 3. Robust averaging (IQR) to reject outliers. + """ + + def get_phase_corr_shift(im1, im2): + if np.std(im1) < 1e-3 or np.std(im2) < 1e-3: return None + try: + detected_shift, _error, _diffphase = phase_cross_correlation( + preprocess_image(im1), + preprocess_image(im2), + upsample_factor=20, # Moderate precision for speed + normalization=None + ) + return -detected_shift[1], -detected_shift[0] + except Exception: + return None + + # --- Stage 1: Coarse Estimation (Downscaled) --- + # Downscale for robustness against large drifts and noise + scale = 0.25 + h, w = img1.shape + small_h, small_w = int(h * scale), int(w * scale) + + img1_s = cv2.resize(img1, (small_w, small_h), interpolation=cv2.INTER_AREA) + img2_s = cv2.resize(img2, (small_w, small_h), interpolation=cv2.INTER_AREA) + + coarse_shift = get_phase_corr_shift(img1_s, img2_s) + + if coarse_shift is None: + coarse_dx, coarse_dy = 0.0, 0.0 + else: + coarse_dx = coarse_shift[0] / scale + coarse_dy = coarse_shift[1] / scale + + # --- Stage 2: Fine Refinement (Ensemble of Crops) --- + # We apply the inverse coarse shift to img2 so it roughly aligns with img1. + img2_aligned = shift_image(img2, -coarse_dx, -coarse_dy) + + shifts = [] + + # Define crops: 5 overlapping large windows (Center, TL, TR, BL, BR) + crop_h = int(h * 0.75) + crop_w = int(w * 0.75) + + crops = [ + (h//2 - crop_h//2, h//2 + crop_h//2, w//2 - crop_w//2, w//2 + crop_w//2), + (0, crop_h, 0, crop_w), + (0, crop_h, w - crop_w, w), + (h - crop_h, h, 0, crop_w), + (h - crop_h, h, w - crop_w, w), + ] + + for (r1, r2, c1, c2) in crops: + r1, r2 = max(0, r1), min(h, r2) + c1, c2 = max(0, c1), min(w, c2) + + try: + detected_shift, _error, _diffphase = phase_cross_correlation( + preprocess_image(img1[r1:r2, c1:c2]), + preprocess_image(img2_aligned[r1:r2, c1:c2]), + upsample_factor=100, + normalization=None + ) + res_dx = -detected_shift[1] + res_dy = -detected_shift[0] + shifts.append((coarse_dx + res_dx, coarse_dy + res_dy)) + except Exception: + pass + + if not shifts: + return float(coarse_dx), float(coarse_dy) + + # --- Stage 3: Robust Aggregation --- + shifts_arr = np.array(shifts) + + def robust_mean(values): + if len(values) < 3: return np.mean(values) + q1 = np.percentile(values, 25) + q3 = np.percentile(values, 75) + iqr = q3 - q1 + lower = q1 - 1.5 * iqr + upper = q3 + 1.5 * iqr + filtered = [v for v in values if lower <= v <= upper] + if not filtered: return np.median(values) + return np.mean(filtered) + + final_dx = robust_mean(shifts_arr[:, 0]) + final_dy = robust_mean(shifts_arr[:, 1]) + + # returns displacement of img2 relative to img1 + # img1(x, y) -> img2(x + dx, y + dy) + return float(final_dx), float(final_dy) diff --git a/tests/Test_Drift_Correction.py b/tests/Test_Drift_Correction.py new file mode 100644 index 0000000..95706c0 --- /dev/null +++ b/tests/Test_Drift_Correction.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from pathlib import Path + +import cv2 +import matplotlib.pyplot as plt +import numpy as np + +from autoemxsp.utils.drift_correction import estimate_drift_of, shift_image + + +def load_grayscale_image(image_path: Path) -> np.ndarray: + image = cv2.imread(str(image_path), cv2.IMREAD_GRAYSCALE) + if image is None: + raise FileNotFoundError(f"Could not load image: {image_path}") + return image + + +def main() -> None: + script_dir = Path(__file__).resolve().parent + inputs_dir = script_dir / "inputs" + outputs_dir = script_dir / "outputs" + outputs_dir.mkdir(exist_ok=True) + + original = load_grayscale_image(inputs_dir / "pre_drift.png") + shifted = load_grayscale_image(inputs_dir / "post_drift.png") + + estimated_dx, estimated_dy = estimate_drift_of(original, shifted) + # apply minus shift to the shifted image to correct it to original + # for determining the infered x,y during real inference we need to add the shift to the + # estimated pixel because the estimated pixel is in the original image and we want it + # to be in the shifted image + reshifted = shift_image(shifted, -estimated_dx, -estimated_dy) + difference = np.abs(original.astype(np.float32) - reshifted.astype(np.float32)) + + fig, axes = plt.subplots(2, 2, figsize=(10, 10)) + fig.suptitle( + "Image drift correction algorithm\n" + f"estimated shift ({estimated_dx:.2f}, {estimated_dy:.2f})", + fontsize=14, + ) + + panels = [ + (axes[0, 0], original, "Original"), + (axes[0, 1], shifted, "Shifted"), + (axes[1, 0], reshifted, "Reshifted"), + (axes[1, 1], difference, "Pixelwise Error"), + ] + + for axis, image, title in panels: + cmap = "inferno" if title == "Pixelwise Error" else "gray" + axis.imshow(image, cmap=cmap) + axis.set_title(title) + axis.axis("off") + + plt.tight_layout(rect=(0, 0, 1, 0.95)) + figure_path = outputs_dir / "drift_correction_visualization.png" + fig.savefig(figure_path, dpi=200, bbox_inches="tight") + plt.show() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/inputs/post_drift.png b/tests/inputs/post_drift.png new file mode 100644 index 0000000..8780f26 Binary files /dev/null and b/tests/inputs/post_drift.png differ diff --git a/tests/inputs/pre_drift.png b/tests/inputs/pre_drift.png new file mode 100644 index 0000000..74e9248 Binary files /dev/null and b/tests/inputs/pre_drift.png differ