Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion autoemxsp/core/EM_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@
import numpy as np
from PIL import Image

#Drift estimation imports
from autoemxsp.utils.drift_correction import estimate_drift_of

# Typing (Python 3.5+)
from typing import Any, List, Optional, Tuple, Union

Expand Down Expand Up @@ -1270,7 +1273,8 @@ def get_image_drift_pixels(self):
# ADD CODE TO EVALUATE drift HERE
# CONSIDER THAT THIS VECTOR WILL BE ADDED TO THE PREVIOUSLY DETERMINED (X,Y) TO ONBTAIN THE NEW POSITION,
# SO ENSURE THAT THE SIGNS ARE CORRECT

drift_vector = estimate_drift_of(self.ref_image, current_img)

return drift_vector


Expand Down
134 changes: 134 additions & 0 deletions autoemxsp/utils/drift_correction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import numpy as np
import cv2
from skimage.registration import phase_cross_correlation
from skimage.filters import window
from typing import Tuple

def preprocess_image(img: np.ndarray) -> np.ndarray:
"""
Normalize and apply a window function to reduce FFT edge effects.
Critical for non-periodic real-world images.
"""
img_float = img.astype(np.float32)

# Range normalization [0, 1]
denom = img_float.max() - img_float.min()
if denom > 1e-6:
img_float = (img_float - img_float.min()) / denom

# Apply Hann window
# window() creates an N-D window matching the image shape
w = window('hann', img_float.shape)
return img_float * w


def shift_image(img: np.ndarray, dx: float, dy: float) -> np.ndarray:
"""Apply an affine shift (translation) to an image."""
M = np.float32([[1, 0, dx], [0, 1, dy]])
return cv2.warpAffine(
img,
M,
(img.shape[1], img.shape[0]),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_REFLECT,
)

def estimate_drift_of(img1: np.ndarray, img2: np.ndarray) -> Tuple[float, float]:
"""
Robust estimate of the translation (dx, dy) of img2 relative to img1.

Uses a 2-stage hierarchical approach:
1. Coarse estimation on downscaled images to handle large displacements.
2. Fine estimation using an ensemble of crops on the full-resolution images,
corrected by the coarse shift.
3. Robust averaging (IQR) to reject outliers.
"""

def get_phase_corr_shift(im1, im2):
if np.std(im1) < 1e-3 or np.std(im2) < 1e-3: return None
try:
detected_shift, _error, _diffphase = phase_cross_correlation(
preprocess_image(im1),
preprocess_image(im2),
upsample_factor=20, # Moderate precision for speed
normalization=None
)
return -detected_shift[1], -detected_shift[0]
except Exception:
return None

# --- Stage 1: Coarse Estimation (Downscaled) ---
# Downscale for robustness against large drifts and noise
scale = 0.25
h, w = img1.shape
small_h, small_w = int(h * scale), int(w * scale)

img1_s = cv2.resize(img1, (small_w, small_h), interpolation=cv2.INTER_AREA)
img2_s = cv2.resize(img2, (small_w, small_h), interpolation=cv2.INTER_AREA)

coarse_shift = get_phase_corr_shift(img1_s, img2_s)

if coarse_shift is None:
coarse_dx, coarse_dy = 0.0, 0.0
else:
coarse_dx = coarse_shift[0] / scale
coarse_dy = coarse_shift[1] / scale

# --- Stage 2: Fine Refinement (Ensemble of Crops) ---
# We apply the inverse coarse shift to img2 so it roughly aligns with img1.
img2_aligned = shift_image(img2, -coarse_dx, -coarse_dy)

shifts = []

# Define crops: 5 overlapping large windows (Center, TL, TR, BL, BR)
crop_h = int(h * 0.75)
crop_w = int(w * 0.75)

crops = [
(h//2 - crop_h//2, h//2 + crop_h//2, w//2 - crop_w//2, w//2 + crop_w//2),
(0, crop_h, 0, crop_w),
(0, crop_h, w - crop_w, w),
(h - crop_h, h, 0, crop_w),
(h - crop_h, h, w - crop_w, w),
]

for (r1, r2, c1, c2) in crops:
r1, r2 = max(0, r1), min(h, r2)
c1, c2 = max(0, c1), min(w, c2)

try:
detected_shift, _error, _diffphase = phase_cross_correlation(
preprocess_image(img1[r1:r2, c1:c2]),
preprocess_image(img2_aligned[r1:r2, c1:c2]),
upsample_factor=100,
normalization=None
)
res_dx = -detected_shift[1]
res_dy = -detected_shift[0]
shifts.append((coarse_dx + res_dx, coarse_dy + res_dy))
except Exception:
pass

if not shifts:
return float(coarse_dx), float(coarse_dy)

# --- Stage 3: Robust Aggregation ---
shifts_arr = np.array(shifts)

def robust_mean(values):
if len(values) < 3: return np.mean(values)
q1 = np.percentile(values, 25)
q3 = np.percentile(values, 75)
iqr = q3 - q1
lower = q1 - 1.5 * iqr
upper = q3 + 1.5 * iqr
filtered = [v for v in values if lower <= v <= upper]
if not filtered: return np.median(values)
return np.mean(filtered)

final_dx = robust_mean(shifts_arr[:, 0])
final_dy = robust_mean(shifts_arr[:, 1])

# returns displacement of img2 relative to img1
# img1(x, y) -> img2(x + dx, y + dy)
return float(final_dx), float(final_dy)
64 changes: 64 additions & 0 deletions tests/Test_Drift_Correction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np

from autoemxsp.utils.drift_correction import estimate_drift_of, shift_image


def load_grayscale_image(image_path: Path) -> np.ndarray:
image = cv2.imread(str(image_path), cv2.IMREAD_GRAYSCALE)
if image is None:
raise FileNotFoundError(f"Could not load image: {image_path}")
return image


def main() -> None:
script_dir = Path(__file__).resolve().parent
inputs_dir = script_dir / "inputs"
outputs_dir = script_dir / "outputs"
outputs_dir.mkdir(exist_ok=True)

original = load_grayscale_image(inputs_dir / "pre_drift.png")
shifted = load_grayscale_image(inputs_dir / "post_drift.png")

estimated_dx, estimated_dy = estimate_drift_of(original, shifted)
# apply minus shift to the shifted image to correct it to original
# for determining the infered x,y during real inference we need to add the shift to the
# estimated pixel because the estimated pixel is in the original image and we want it
# to be in the shifted image
reshifted = shift_image(shifted, -estimated_dx, -estimated_dy)
difference = np.abs(original.astype(np.float32) - reshifted.astype(np.float32))

fig, axes = plt.subplots(2, 2, figsize=(10, 10))
fig.suptitle(
"Image drift correction algorithm\n"
f"estimated shift ({estimated_dx:.2f}, {estimated_dy:.2f})",
fontsize=14,
)

panels = [
(axes[0, 0], original, "Original"),
(axes[0, 1], shifted, "Shifted"),
(axes[1, 0], reshifted, "Reshifted"),
(axes[1, 1], difference, "Pixelwise Error"),
]

for axis, image, title in panels:
cmap = "inferno" if title == "Pixelwise Error" else "gray"
axis.imshow(image, cmap=cmap)
axis.set_title(title)
axis.axis("off")

plt.tight_layout(rect=(0, 0, 1, 0.95))
figure_path = outputs_dir / "drift_correction_visualization.png"
fig.savefig(figure_path, dpi=200, bbox_inches="tight")
plt.show()


if __name__ == "__main__":
main()
Binary file added tests/inputs/post_drift.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/inputs/pre_drift.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.