diff --git a/linumpy/io/allen.py b/linumpy/io/allen.py index 1b91403a..0e4a1fe1 100644 --- a/linumpy/io/allen.py +++ b/linumpy/io/allen.py @@ -7,12 +7,50 @@ from pathlib import Path import SimpleITK as sitk +import numpy as np import requests from tqdm import tqdm AVAILABLE_RESOLUTIONS = [10, 25, 50, 100] +def numpy_to_sitk_image(volume: np.ndarray, spacing: tuple, cast_dtype=None) -> sitk.Image: + """Convert numpy array (Z, X, Y) to SimpleITK image format. + + Parameters + ---------- + volume : np.ndarray + 3D volume with shape (Z, X, Y) + spacing : tuple + Voxel spacing in mm (res_z, res_x, res_y) + cast_dtype : numpy dtype or None + If provided, cast the volume to this dtype before creating the SITK image + (useful for registration where float32 is expected). If None, preserve + the input numpy dtype. + + Returns + ------- + sitk.Image + SimpleITK image with proper spacing and orientation + """ + # Note: volume is (Z, X, Y), SimpleITK GetImageFromArray interprets as (Z, Y, X) + # So we transpose: (Z, X, Y) -> (Z, Y, X) to match SimpleITK's expectation + vol_for_sitk = np.transpose(volume, (0, 2, 1)) + if cast_dtype is not None: + vol_for_sitk = vol_for_sitk.astype(cast_dtype) + else: + # preserve dtype + vol_for_sitk = vol_for_sitk.copy() + vol_sitk = sitk.GetImageFromArray(vol_for_sitk) + # Spacing: SimpleITK uses (X, Y, Z) = (width, height, depth) + # Our spacing is (res_z, res_x, res_y), so: + # X spacing = res_x, Y spacing = res_y, Z spacing = res_z + vol_sitk.SetSpacing([spacing[1], spacing[2], spacing[0]]) # (x, y, z) in SimpleITK + vol_sitk.SetOrigin([0, 0, 0]) + vol_sitk.SetDirection([1, 0, 0, 0, 1, 0, 0, 0, 1]) + return vol_sitk + + def download_template(resolution: int, cache: bool = True, cache_dir: str = ".data/") -> sitk.Image: """Download a 3D average mouse brain Parameters @@ -55,3 +93,305 @@ def download_template(resolution: int, cache: bool = True, cache_dir: str = ".da nrrd_file.unlink() # Removes the nrrd file return vol + + +def download_template_ras_aligned(resolution: int, cache: bool = True, cache_dir: str = ".data/") -> sitk.Image: + """Download a 3D average mouse brain and align it to RAS+ orientation. + + Parameters + ---------- + resolution + Allen template resolution in micron. Must be 10, 25, 50 or 100. + cache + Keep the downloaded volume in cache + cache_dir + Cache directory + + Returns + ------- + Allen average mouse brain in RAS+ orientation. + """ + vol = download_template(resolution, cache, cache_dir) + + # Preparing the affine to align the template in the RAS+ + r_mm = resolution / 1e3 # Convert the resolution from micron to mm + vol.SetSpacing([r_mm] * 3) # Set the spacing in mm + # Ensure origin/direction are standardized so physical coordinates are stable + vol.SetOrigin([0.0, 0.0, 0.0]) + vol.SetDirection([1, 0, 0, 0, 1, 0, 0, 0, 1]) + + # Apply the transform to RAS + vol = sitk.PermuteAxes(vol, (2, 0, 1)) + vol = sitk.Flip(vol, (False, False, True)) + # After permuting/flipping, also ensure origin/direction are identity/zero + vol.SetOrigin([0.0, 0.0, 0.0]) + vol.SetDirection([1, 0, 0, 0, 1, 0, 0, 0, 1]) + + return vol + + +def register_3d_rigid_to_allen(moving_image: np.ndarray, moving_spacing: tuple, + allen_resolution: int = 100, metric: str = 'MI', + max_iterations: int = 1000, verbose: bool = False, + progress_callback=None, + initial_rotation_deg: tuple = (0.0, 0.0, 0.0)): + """Perform 3D rigid registration of a brain volume to the Allen atlas. + + Parameters + ---------- + moving_image : np.ndarray + 3D brain volume to register (shape: Z, X, Y) + moving_spacing : tuple + Voxel spacing in mm (res_z, res_x, res_y) + allen_resolution : int + Allen template resolution in micron (default: 100) + metric : str + Similarity metric: 'MI' (mutual information), 'MSE', 'CC' (correlation), + or 'AntsCC' (ANTS correlation) + max_iterations : int + Maximum number of iterations + verbose : bool + Print registration progress + progress_callback : callable, optional + Callback function called on each iteration with the registration method. + Function signature: callback(registration_method) + + Returns + ------- + transform : sitk.Euler3DTransform + Rigid transform to align moving_image to Allen atlas + stop_condition : str + Optimizer stopping condition + error : float + Final registration metric value + """ + # Download and prepare Allen atlas in RAS orientation + allen_atlas = download_template_ras_aligned(allen_resolution, cache=True) + + # Convert moving image to SimpleITK format + moving_sitk = numpy_to_sitk_image(moving_image, moving_spacing) + + # Compute a preliminary brain centre BEFORE any resampling. + # This is used as the fallback only when needs_resample=False (images already + # share the same physical space). When resampling IS needed, this value is + # overwritten below with the centroid of the clipped brain within the Allen + # domain, because the full-brain geometric centre can be tens of mm outside + # the Allen atlas extent and would produce a translation that maps every + # Allen voxel outside the resampled moving image buffer. + original_moving_size = moving_sitk.GetSize() + original_moving_center_idx = [s / 2.0 for s in original_moving_size] + original_moving_center = np.array( + moving_sitk.TransformContinuousIndexToPhysicalPoint(original_moving_center_idx) + ) + + # Resample moving image to match Allen atlas spacing and size for better registration. + # NOTE: we deliberately keep the original moving center computed above so that the + # centre-aligned fallback initialisation is always correct even after resampling. + allen_spacing = allen_atlas.GetSpacing() + allen_size = allen_atlas.GetSize() + moving_spacing_sitk = moving_sitk.GetSpacing() + moving_size_sitk = moving_sitk.GetSize() + + # Check if resampling is needed (if spacing differs significantly or sizes are very different) + spacing_ratio = np.array(allen_spacing) / np.array(moving_spacing_sitk) + size_ratio = np.array(allen_size, dtype=float) / np.array(moving_size_sitk, dtype=float) + + # Resample if spacing differs by more than 10% or if volumes are very different sizes + needs_resample = (np.any(np.abs(spacing_ratio - 1.0) > 0.1) or + np.any(size_ratio < 0.5) or np.any(size_ratio > 2.0)) + + if needs_resample: + if verbose: + print(f"Resampling moving image from {moving_spacing_sitk} mm, size {moving_size_sitk} " + f"to {allen_spacing} mm, size {allen_size}") + resampler = sitk.ResampleImageFilter() + resampler.SetReferenceImage(allen_atlas) + resampler.SetInterpolator(sitk.sitkLinear) + resampler.SetDefaultPixelValue(0) + moving_sitk = resampler.Execute(moving_sitk) + + # Recompute the effective brain centre from the RESAMPLED image. + # The pre-resampling centre can lie far outside the Allen domain (e.g. a + # large 25 µm brain whose geometric centre is at ~37 mm, while the Allen + # atlas only spans ~11 mm). Using that centre directly gives a translation + # of +31 mm, which maps every Allen voxel outside the moving image buffer. + # Instead, use the centroid of the non-zero (brain-tissue) voxels that + # survived the clipping into the Allen domain. + moving_arr = sitk.GetArrayFromImage(moving_sitk) # shape (Z, Y, X) in numpy + nonzero_idx = np.argwhere(moving_arr > 0) # rows are (z, y, x) + if len(nonzero_idx) > 0: + centroid_zyx = nonzero_idx.mean(axis=0) + # SITK index order is (x, y, z), reverse of numpy (z, y, x) + centroid_xyz = [float(centroid_zyx[2]), float(centroid_zyx[1]), float(centroid_zyx[0])] + original_moving_center = np.array( + moving_sitk.TransformContinuousIndexToPhysicalPoint(centroid_xyz) + ) + if verbose: + print(f"Resampled brain centroid (physical): {original_moving_center} mm") + # If all voxels are zero (brain entirely outside Allen domain), keep + # the pre-resampling centre and accept a potentially poor initialization. + + # Normalize images for better registration + fixed_image = sitk.Normalize(allen_atlas) + moving_image_sitk = sitk.Normalize(moving_sitk) + + if verbose: + print(f"Fixed (Allen) image: size={fixed_image.GetSize()}, spacing={fixed_image.GetSpacing()}") + print(f"Moving (brain) image: size={moving_image_sitk.GetSize()}, spacing={moving_image_sitk.GetSpacing()}") + + # Initialize registration + registration_method = sitk.ImageRegistrationMethod() + + # Set metric + # Note: For correlation-based metrics, negative values are possible + # The optimizer will maximize MI/CC and minimize MSE + if metric.upper() == 'MI': + registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50) + elif metric.upper() == 'MSE': + registration_method.SetMetricAsMeanSquares() + elif metric.upper() == 'CC': + registration_method.SetMetricAsCorrelation() + elif metric.upper() == 'ANTSCC': + registration_method.SetMetricAsANTSNeighborhoodCorrelation(radius=20) + else: + raise ValueError(f"Unknown metric: {metric}. Choose from: MI, MSE, CC, AntsCC") + + # Set metric sampling - use regular sampling for reproducibility and speed + registration_method.SetMetricSamplingStrategy(registration_method.REGULAR) + registration_method.SetMetricSamplingPercentage(0.25) # 25% of pixels is usually sufficient + + # Set optimizer with conservative parameters + # Use smaller learning rate and steps to prevent overshooting + learning_rate = 0.5 # Smaller learning rate for stability + min_step = 0.0001 + registration_method.SetOptimizerAsRegularStepGradientDescent( + learningRate=learning_rate, + minStep=min_step, + numberOfIterations=max_iterations, + relaxationFactor=0.5, + gradientMagnitudeTolerance=1e-8 + ) + + # Use physical shift for scaling - more appropriate for physical coordinate registration + # This computes scales based on how a 1mm shift affects the metric + registration_method.SetOptimizerScalesFromPhysicalShift() + + # Multi-resolution approach - start coarse, refine progressively + # More levels for robustness + registration_method.SetShrinkFactorsPerLevel([8, 4, 2, 1]) + registration_method.SetSmoothingSigmasPerLevel([4, 2, 1, 0]) + registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn() + + # Initialize rigid transform with guaranteed overlap. + # Use the ORIGINAL moving image centre (before any resampling) so that + # the centre-aligned fallback always produces a meaningful initial translation + # regardless of the resolution/size relationship between the two images. + initial_transform = sitk.Euler3DTransform() + + # Calculate image centres in physical space + fixed_size = fixed_image.GetSize() + fixed_center_idx = [s / 2.0 for s in fixed_size] + fixed_center = np.array(fixed_image.TransformContinuousIndexToPhysicalPoint(fixed_center_idx)) + + # Translation to align brain centre with Allen centre (ensures initial overlap). + # ITK transform maps fixed→moving: T(p) = R(p − c) + c + t + # For identity rotation and c=fixed_center: T(fixed_center) = fixed_center + t + # We need T(fixed_center) = original_moving_center, so t = moving_center − fixed_center. + translation = tuple(original_moving_center - fixed_center) + + # Set center of rotation to fixed image center + initial_transform.SetCenter(fixed_center) + + # Convert initial rotation from degrees to radians + rx_rad = np.deg2rad(initial_rotation_deg[0]) + ry_rad = np.deg2rad(initial_rotation_deg[1]) + rz_rad = np.deg2rad(initial_rotation_deg[2]) + + # Set translation to align centers and apply initial rotation + initial_transform.SetTranslation(translation) + initial_transform.SetRotation(rx_rad, ry_rad, rz_rad) + + if verbose: + print(f"Initial center alignment: fixed={fixed_center}, moving (original)={original_moving_center}") + print(f"Translation to align centers: {translation}") + if any(r != 0 for r in initial_rotation_deg): + print(f"Initial rotation (deg): {initial_rotation_deg}") + + # Only try MOMENTS initialization if no initial rotation was specified + # (user-specified rotation takes precedence) and the image was NOT resampled + # into the Allen domain. After resampling, the brain occupies only a small + # corner of the 640³ Allen image; sitk.Normalize then gives the large + # zero-padded background a uniform negative value that dominates the + # centre-of-mass computation, producing translation ≈ 0 which places every + # sample point outside the brain buffer. + if all(r == 0 for r in initial_rotation_deg) and not needs_resample: + try: + # Use MOMENTS initialization which is more robust + init_transform = sitk.Euler3DTransform() + init_transform = sitk.CenteredTransformInitializer( + fixed_image, + moving_image_sitk, + init_transform, + sitk.CenteredTransformInitializerFilter.MOMENTS + ) + # Verify the initialized transform has reasonable translation + init_params = init_transform.GetParameters() + init_translation = np.array(init_params[3:6]) + + # Check if the initialized transform is reasonable (translation not too large) + # If translation is reasonable, use it; otherwise use our center-aligned one + translation_magnitude = np.linalg.norm(init_translation) + fixed_size_mm = np.array(fixed_image.GetSpacing()) * np.array(fixed_image.GetSize()) + max_reasonable_translation = np.linalg.norm(fixed_size_mm) * 0.5 # Half the image size + + if translation_magnitude < max_reasonable_translation: + initial_transform = init_transform + if verbose: + print(f"Using MOMENTS initialization (translation magnitude: {translation_magnitude:.2f} mm)") + else: + if verbose: + print( + f"MOMENTS initialization translation too large ({translation_magnitude:.2f} mm), using center-aligned") + except Exception as e: + if verbose: + print(f"MOMENTS initialization failed: {e}, using center-aligned translation") + + if verbose: + final_params = initial_transform.GetParameters() + final_center = initial_transform.GetCenter() + print(f"Final initial transform: rotation={final_params[:3]}, translation={final_params[3:]}") + print(f"Transform center: {final_center}") + + registration_method.SetInitialTransform(initial_transform) + registration_method.SetInterpolator(sitk.sitkLinear) + + # Set up iteration callback + if verbose or progress_callback is not None: + def command_iteration(method): + if verbose: + if method.GetOptimizerIteration() == 0: + print(f"Estimated scales: {method.GetOptimizerScales()}") + print(f"Iteration {method.GetOptimizerIteration():3d} = " + f"{method.GetMetricValue():7.5f} : " + f"{method.GetOptimizerPosition()}") + if progress_callback is not None: + progress_callback(method) + + registration_method.AddCommand(sitk.sitkIterationEvent, + lambda: command_iteration(registration_method)) + + # Execute registration + final_transform = registration_method.Execute(fixed_image, moving_image_sitk) + + stop_condition = registration_method.GetOptimizerStopConditionDescription() + error = registration_method.GetMetricValue() + + if verbose: + print(f"Registration complete: {stop_condition}") + print(f"Final metric value: {error:.6f}") + final_params = final_transform.GetParameters() + print(f"Final transform: rotation={final_params[:3]}, translation={final_params[3:]}") + print(f"Fixed image size: {fixed_image.GetSize()}, spacing: {fixed_image.GetSpacing()}") + print(f"Moving image size: {moving_image_sitk.GetSize()}, spacing: {moving_image_sitk.GetSpacing()}") + + return final_transform, stop_condition, error diff --git a/scripts/linum_align_to_ras.py b/scripts/linum_align_to_ras.py new file mode 100755 index 00000000..9367d02e --- /dev/null +++ b/scripts/linum_align_to_ras.py @@ -0,0 +1,978 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +Align a 3D brain volume to RAS orientation using rigid registration to the Allen atlas. + +This script computes a rigid transform from the input brain volume to a RAS-aligned +version by registering it to the Allen Brain Atlas. The transform can be applied +directly to the zarr file (resampling) or stored in OME-Zarr metadata. +""" + +# Configure thread limits before numpy/scipy imports +import linumpy._thread_config # noqa: F401 + +import argparse +import json +from pathlib import Path +from typing import Optional + +import SimpleITK as sitk +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +from tqdm.auto import tqdm + +from linumpy.io import allen +from linumpy.io.zarr import read_omezarr, AnalysisOmeZarrWriter +from linumpy.utils.orientation import ( + parse_orientation_code, + apply_orientation_transform, + reorder_resolution, +) + +matplotlib.use('Agg') # Non-interactive backend + +# Constants +DEFAULT_ALLEN_RESOLUTION = 100 +DEFAULT_MAX_ITERATIONS = 1000 +DEFAULT_METRIC = "MI" + + +def _build_arg_parser(): + """Build the command-line argument parser.""" + p = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawTextHelpFormatter + ) + p.add_argument("input_zarr", help="Input OME-Zarr file from 3D reconstruction pipeline") + p.add_argument("output_zarr", help="Output OME-Zarr file (RAS-aligned)") + p.add_argument( + "--allen-resolution", type=int, default=DEFAULT_ALLEN_RESOLUTION, + choices=allen.AVAILABLE_RESOLUTIONS, + help=f"Allen atlas resolution in micron (default: {DEFAULT_ALLEN_RESOLUTION})" + ) + p.add_argument( + "--metric", type=str, default=DEFAULT_METRIC, + choices=["MI", "MSE", "CC", "AntsCC"], + help=f"Registration metric (default: {DEFAULT_METRIC})" + ) + p.add_argument( + "--max-iterations", type=int, default=DEFAULT_MAX_ITERATIONS, + help=f"Maximum registration iterations (default: {DEFAULT_MAX_ITERATIONS})" + ) + p.add_argument( + "--store-transform-only", action="store_true", + help="Store transform in metadata only (don't resample volume)" + ) + p.add_argument( + "--level", type=int, default=0, + help="Pyramid level to use for registration (default: 0 = full resolution)" + ) + p.add_argument( + "--chunks", type=int, nargs=3, default=None, + help="Chunk size for output zarr (default: use input chunks)" + ) + p.add_argument( + "--n-levels", type=int, default=None, + help="Number of pyramid levels for output (default: use Allen atlas resolutions)" + ) + p.add_argument( + "--pyramid_resolutions", type=float, nargs="+", default=None, + help="Target pyramid resolution levels in µm (e.g. 10 25 50 100).\n" + "If omitted, inherits levels from input zarr metadata or uses Allen resolutions." + ) + p.add_argument("--make_isotropic", action="store_true", default=True, + help="Resample to isotropic voxels at each pyramid level.") + p.add_argument("--no_isotropic", dest="make_isotropic", action="store_false") + p.add_argument("--verbose", action="store_true", help="Print registration progress") + p.add_argument( + "--preview", type=str, default=None, + help="Generate preview image showing alignment comparison" + ) + p.add_argument( + "--input-orientation", type=str, default=None, + help="Input volume orientation code (3 letters: R/L, A/P, S/I)\n" + "Examples: 'RAS' (Allen), 'LPI', 'PIR'" + ) + p.add_argument( + "--initial-rotation", type=float, nargs=3, default=[0.0, 0.0, 0.0], + metavar=("RX", "RY", "RZ"), + help="Initial rotation angles in degrees (Rx, Ry, Rz).\n" + "Use to provide initial orientation hint for registration." + ) + p.add_argument( + "--preview-only", action="store_true", + help="Only generate preview of input volume (no registration)" + ) + p.add_argument( + "--orientation-preview", type=str, default=None, + metavar="PATH", + help="Save a 3-panel preview of the volume after --input-orientation and\n" + "--initial-rotation are applied. Use to verify these parameters\n" + "before committing to a full registration run." + ) + p.add_argument( + "--orientation-preview-only", action="store_true", + help="Generate --orientation-preview and exit without running registration." + ) + return p + + +# ============================================================================= +# Orientation utilities — imported from linumpy.utils.orientation +# (parse_orientation_code, apply_orientation_transform, reorder_resolution) +# ============================================================================= + + +def create_registration_progress_callback( + max_iterations: int, + n_resolution_levels: int = 3, + pbar: Optional[tqdm] = None, + registration_start_step: int = 0, + registration_steps: int = 0 +): + """ + Create a progress callback for registration. + + Parameters + ---------- + max_iterations : int + Maximum iterations per level + n_resolution_levels : int + Number of resolution levels in the registration pyramid + pbar : tqdm, optional + Progress bar to update + registration_start_step : int + Step number where registration starts in progress bar + registration_steps : int + Number of steps allocated for registration + + Returns + ------- + callable + Progress callback function compatible with SimpleITK registration + """ + iteration_history = [] + total_iterations = [0] + estimated_total = max_iterations * n_resolution_levels * 0.6 + + def callback(method): + """Update progress during registration iterations.""" + iteration = method.GetOptimizerIteration() + metric = method.GetMetricValue() + + # Track iterations (reset detection for multi-resolution) + if iteration_history and iteration <= iteration_history[-1]: + pass # New resolution level started + + iteration_history.append(iteration) + total_iterations[0] += 1 + + if pbar is not None: + progress_ratio = min(1.0, total_iterations[0] / estimated_total) + target_step = registration_start_step + int(registration_steps * progress_ratio) + if target_step > pbar.n: + pbar.n = target_step + pbar.set_postfix_str(f"metric={metric:.6f}") + pbar.refresh() + + return callback + + +# ============================================================================= +# Transform utilities +# ============================================================================= + +def sitk_transform_to_affine_matrix(transform: sitk.Transform) -> np.ndarray: + """ + Convert SimpleITK transform to 4x4 affine matrix. + + Parameters + ---------- + transform : sitk.Transform + SimpleITK Euler3DTransform or AffineTransform + + Returns + ------- + np.ndarray + 4x4 affine matrix in (Z, X, Y) coordinate ordering + """ + if isinstance(transform, sitk.Euler3DTransform): + center = np.array(transform.GetCenter()) + params = transform.GetParameters() + rx, ry, rz = params[:3] + translation = np.array(params[3:6]) + + # Build rotation matrix from Euler angles + cx, cy, cz = np.cos([rx, ry, rz]) + sx, sy, sz = np.sin([rx, ry, rz]) + + R = np.array([ + [cz * cy, cz * sy * sx - sz * cx, cz * sy * cx + sz * sx], + [sz * cy, sz * sy * sx + cz * cx, sz * sy * cx - cz * sx], + [-sy, cy * sx, cy * cx] + ]) + + matrix = np.eye(4) + matrix[:3, :3] = R + matrix[:3, 3] = translation + center - R @ center + + elif isinstance(transform, sitk.AffineTransform): + R = np.array(transform.GetMatrix()).reshape(3, 3) + translation = np.array(transform.GetTranslation()) + center = np.array(transform.GetCenter()) + + matrix = np.eye(4) + matrix[:3, :3] = R + matrix[:3, 3] = translation + center - R @ center + else: + raise ValueError(f"Unsupported transform type: {type(transform)}") + + # Permute from SimpleITK (X, Y, Z) to our (Z, X, Y) ordering + permute = np.array([ + [0, 0, 1, 0], + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 0, 1] + ]) + return permute @ matrix @ permute.T + + +def store_transform_in_metadata(zarr_path: str, transform: sitk.Transform): + """Store transform in OME-Zarr metadata as affine coordinate transformation.""" + affine_matrix = sitk_transform_to_affine_matrix(transform) + zattrs_path = Path(zarr_path) / ".zattrs" + + if not zattrs_path.exists(): + raise FileNotFoundError(f".zattrs not found: {zarr_path}") + + with open(zattrs_path, 'r', encoding='utf-8') as f: + metadata = json.load(f) + + affine_transform = { + "type": "affine", + "affine": affine_matrix.flatten().tolist() + } + + multiscales = metadata.get('multiscales', []) + if not multiscales: + raise ValueError("No multiscales entry found in metadata") + + for dataset in multiscales[0].get('datasets', []): + existing = dataset.get('coordinateTransformations', []) + dataset['coordinateTransformations'] = [affine_transform] + existing + + with open(zattrs_path, 'w', encoding='utf-8') as f: + json.dump(metadata, f, indent=2) + + print(f"Stored affine transform in metadata: {zattrs_path}") + + +# ============================================================================= +# Resolution utilities +# ============================================================================= + +def get_pyramid_resolutions_from_zarr(zarr_path: Path) -> Optional[list[float]]: + """ + Extract pyramid resolution levels from OME-Zarr metadata. + + Parameters + ---------- + zarr_path : Path + Path to OME-Zarr file + + Returns + ------- + list of float or None + Target resolutions in microns, or None if not found + """ + for metadata_file in ["zarr.json", ".zattrs"]: + metadata_path = zarr_path / metadata_file + if not metadata_path.exists(): + continue + + try: + with open(metadata_path, 'r', encoding='utf-8') as f: + metadata = json.load(f) + except (json.JSONDecodeError, IOError): + continue + + multiscales = metadata.get('multiscales', []) + if not multiscales: + continue + + resolutions = [] + for dataset in multiscales[0].get('datasets', []): + transforms = dataset.get('coordinateTransformations', []) + for tr in transforms: + if tr.get('type') == 'scale' and 'scale' in tr: + # Get finest spatial dimension, convert mm to µm + scale = tr['scale'][-3:] + res_um = min(float(s) for s in scale) * 1000 + resolutions.append(res_um) + break + + if resolutions: + return resolutions + + return None + + +# ============================================================================= +# Core processing functions +# ============================================================================= + +def compute_centered_reference_and_transform( + moving_sitk: sitk.Image, + transform: sitk.Transform, + output_spacing: Optional[tuple] = None +) -> tuple[sitk.Image, sitk.Transform]: + """ + Compute a reference image and modified transform that centers the output volume. + + This creates an output that is centered in the volume (brain in the middle), + preserving the original resolution. + + Parameters + ---------- + moving_sitk : sitk.Image + The input moving image + transform : sitk.Transform + Transform to apply (moving -> fixed/RAS space) + output_spacing : tuple, optional + Output voxel spacing. If None, uses moving image spacing. + + Returns + ------- + ref : sitk.Image + Reference image for resampling, with origin at 0 + composite_transform : sitk.Transform + Modified transform that maps moving image to centered output + """ + if output_spacing is None: + output_spacing = moving_sitk.GetSpacing() + + # Get corners of the moving image in physical coordinates + size = moving_sitk.GetSize() + corners = [ + (0, 0, 0), (size[0]-1, 0, 0), (0, size[1]-1, 0), (0, 0, size[2]-1), + (size[0]-1, size[1]-1, 0), (size[0]-1, 0, size[2]-1), + (0, size[1]-1, size[2]-1), (size[0]-1, size[1]-1, size[2]-1), + ] + + # Map brain corners to FIXED/RAS space. + # The registration transform maps fixed→moving (ResampleImageFilter convention), + # so we use its inverse (moving→fixed) to find where the brain corners land + # in the fixed (RAS/Allen) coordinate system. + inv_transform = transform.GetInverse() + transformed_pts = [] + for idx in corners: + phys = moving_sitk.TransformContinuousIndexToPhysicalPoint(idx) + transformed_pts.append(inv_transform.TransformPoint(phys)) + + pts = np.array(transformed_pts) + pts_min = pts.min(axis=0) + pts_max = pts.max(axis=0) + + # Compute output size to cover the full transformed brain extent + spacing = np.array(output_spacing) + extent = pts_max - pts_min + new_size = np.ceil(extent / spacing).astype(int) + + # Reference image: origin at (0,0,0), spanning [0, new_size*spacing]. + # Output voxel p maps to fixed-space coordinate (p + pts_min). + ref = sitk.Image([int(s) for s in new_size], moving_sitk.GetPixelIDValue()) + ref.SetSpacing(tuple(spacing)) + ref.SetOrigin((0.0, 0.0, 0.0)) + ref.SetDirection((1, 0, 0, 0, 1, 0, 0, 0, 1)) # Identity direction (RAS) + + # Shift transform: output space → fixed space (translate by pts_min). + # This maps output origin (0,0,0) to the brain's fixed-space bounding box minimum. + shift_transform = sitk.TranslationTransform(3) + shift_transform.SetOffset(tuple(pts_min)) + + # Composite transform for resampling: + # output point → (shift) → fixed space → (T) → moving space + # SimpleITK CompositeTransform applies transforms in the order added (first = first applied). + composite = sitk.CompositeTransform(3) + composite.AddTransform(shift_transform) # output → fixed + composite.AddTransform(transform) # fixed → moving + + return ref, composite + + +def apply_transform_to_zarr( + input_path: str, + output_path: str, + transform: sitk.Transform, + chunks: Optional[tuple] = None, + n_levels: Optional[int] = None, + pyramid_resolutions: Optional[list] = None, + make_isotropic: bool = True, + orientation_permutation: Optional[tuple] = None, + orientation_flips: Optional[tuple] = None, + pbar: Optional[tqdm] = None +): + """ + Apply transform to zarr file by resampling into RAS-aligned space. + + The output is centered on the transformed brain volume, preserving the + original resolution. This corrects any rotation/off-axis alignment without + placing the brain in the Allen atlas coordinate system. + + Parameters + ---------- + input_path : str + Path to input OME-Zarr + output_path : str + Path to output OME-Zarr + transform : sitk.Transform + Transform to apply + chunks : tuple, optional + Chunk size for output + n_levels : int, optional + Number of pyramid levels (if None, use source pyramid or Allen resolutions) + orientation_permutation : tuple, optional + Axis permutation for orientation correction + orientation_flips : tuple, optional + Axis flips for orientation correction + pbar : tqdm, optional + Progress bar + """ + def update_pbar(): + if pbar: + pbar.update(1) + + # Load volume at full resolution (level 0) and capture its actual spacing. + # base_resolution comes from the downsampled registration level, so we must + # read the level-0 spacing from the file to get the correct physical extent. + vol_zarr, level0_resolution = read_omezarr(input_path, level=0) + if chunks is None: + chunks = getattr(vol_zarr, 'chunks', None) + + vol = np.asarray(vol_zarr[:]) + original_dtype = vol.dtype + update_pbar() + + # Apply orientation correction + resolution = level0_resolution + if orientation_permutation is not None: + vol = apply_orientation_transform(vol, orientation_permutation, orientation_flips) + resolution = reorder_resolution(resolution, orientation_permutation) + + # Convert to SimpleITK + vol_sitk = allen.numpy_to_sitk_image(vol, resolution, cast_dtype=np.float32) + del vol # free original volume before resampling + update_pbar() + + # Compute reference image and modified transform that centers the output + reference, centered_transform = compute_centered_reference_and_transform(vol_sitk, transform) + + # Resample + resampler = sitk.ResampleImageFilter() + resampler.SetReferenceImage(reference) + resampler.SetInterpolator(sitk.sitkLinear) + resampler.SetDefaultPixelValue(0) + resampler.SetTransform(centered_transform) + + transformed_sitk = resampler.Execute(vol_sitk) + del vol_sitk # free input before allocating output array + transformed = sitk.GetArrayFromImage(transformed_sitk) + del transformed_sitk # free SimpleITK image after extracting numpy array + update_pbar() + + # Convert from SITK (Z, Y, X) to our (Z, X, Y) ordering + transformed = np.transpose(transformed, (0, 2, 1)) + update_pbar() + + # Convert back to original dtype + if np.issubdtype(original_dtype, np.integer): + info = np.iinfo(original_dtype) + transformed = np.clip(np.rint(transformed), info.min, info.max).astype(original_dtype) + else: + transformed = transformed.astype(original_dtype) + + # Write output + writer = AnalysisOmeZarrWriter( + output_path, + shape=transformed.shape, + chunk_shape=chunks, + dtype=transformed.dtype, + overwrite=True + ) + writer[:] = transformed + + if n_levels is not None: + writer.finalize(list(resolution), n_levels=n_levels) + else: + if pyramid_resolutions is not None: + target_resolutions = pyramid_resolutions + else: + # Fallback: inherit levels from input zarr metadata, or use Allen resolutions + target_resolutions = get_pyramid_resolutions_from_zarr(Path(input_path)) + if target_resolutions is None: + target_resolutions = list(allen.AVAILABLE_RESOLUTIONS) + writer.finalize(list(resolution), target_resolutions_um=target_resolutions, + make_isotropic=make_isotropic) + + update_pbar() + + +# ============================================================================= +# Preview generation +# ============================================================================= + +def create_input_preview(input_path: str, output_path: str, level: int = 0): + """Create preview of input volume to help determine orientation.""" + vol_zarr, resolution = read_omezarr(input_path, level=level) + vol = np.asarray(vol_zarr[:]) + + z_mid = vol.shape[0] // 2 + x_mid = vol.shape[1] // 2 + y_mid = vol.shape[2] // 2 + + vmin, vmax = np.percentile(vol, [1, 99]) + + fig, axes = plt.subplots(2, 2, figsize=(14, 14)) + fig.suptitle(f'Input Volume Preview\nShape: {vol.shape} (Z, X, Y), Resolution: {resolution} mm', + fontsize=14, y=0.98) + + # Axial slice (dim0 midpoint) + axes[0, 0].imshow(vol[z_mid, :, :].T, cmap='gray', origin='lower', vmin=vmin, vmax=vmax) + axes[0, 0].set_title('Slice at dim0 midpoint\nShows: dim1 × dim2') + axes[0, 0].set_xlabel('dim1 →') + axes[0, 0].set_ylabel('dim2 →') + + # Sagittal slice (dim1 midpoint) + axes[0, 1].imshow(vol[::-1, x_mid, :], cmap='gray', origin='lower', vmin=vmin, vmax=vmax) + axes[0, 1].set_title('Slice at dim1 midpoint\nShows: dim2 × dim0') + axes[0, 1].set_xlabel('dim2 →') + axes[0, 1].set_ylabel('dim0 →') + + # Coronal slice (dim2 midpoint) + axes[1, 0].imshow(vol[::-1, :, y_mid], cmap='gray', origin='lower', vmin=vmin, vmax=vmax) + axes[1, 0].set_title('Slice at dim2 midpoint\nShows: dim1 × dim0') + axes[1, 0].set_xlabel('dim1 →') + axes[1, 0].set_ylabel('dim0 →') + + # Help text + axes[1, 1].axis('off') + help_text = """ +ORIENTATION GUIDE (Allen Atlas = RAS+) + +Allen RAS+ convention: + • R (Right): +X direction + • A (Anterior): +Y direction (nose) + • S (Superior): +Z direction (top) + +For each dimension, identify the anatomical direction: + R/L for right/left + A/P for anterior/posterior + S/I for superior/inferior + +Example: + dim0→Superior, dim1→Anterior, dim2→Right + → orientation code = 'SAR' +""" + axes[1, 1].text(0.02, 0.98, help_text, transform=axes[1, 1].transAxes, + fontsize=10, verticalalignment='top', fontfamily='monospace', + bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) + + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches='tight') + plt.close() + print(f"Input preview saved to: {output_path}") + + +def create_alignment_preview( + input_path: str, + output_path: Optional[str], + transform: sitk.Transform, + resolution: tuple, + preview_path: str, + allen_resolution: int = DEFAULT_ALLEN_RESOLUTION, + level: int = 0, + orientation_permutation: Optional[tuple] = None, + orientation_flips: Optional[tuple] = None, + pbar: Optional[tqdm] = None +): + """Create preview comparing original, aligned, and Allen template. + + Shows center slices from each volume in their own coordinate frames. + The Allen template is shown for reference but may not spatially align + with the brain volume since we're not placing it in Allen coordinate space. + """ + def update_pbar(): + if pbar: + pbar.update(1) + + # Load original + vol_original, orig_res = read_omezarr(input_path, level=level) + vol_original = np.asarray(vol_original[:]) + + if orientation_permutation is not None: + vol_original = apply_orientation_transform(vol_original, orientation_permutation, orientation_flips) + orig_res = reorder_resolution(tuple(orig_res), orientation_permutation) + update_pbar() + + # Load aligned volume from output file, or compute it + if output_path and Path(output_path).exists(): + vol_aligned, aligned_res = read_omezarr(output_path, level=level) + vol_aligned = np.asarray(vol_aligned[:]) + else: + # Compute aligned volume using the transform + vol_sitk = allen.numpy_to_sitk_image(vol_original, resolution) + # Create reference and centered transform + reference, centered_transform = compute_centered_reference_and_transform(vol_sitk, transform) + + resampler = sitk.ResampleImageFilter() + resampler.SetReferenceImage(reference) + resampler.SetInterpolator(sitk.sitkLinear) + resampler.SetDefaultPixelValue(0) + resampler.SetTransform(centered_transform) + transformed_sitk = resampler.Execute(vol_sitk) + vol_aligned = np.transpose(sitk.GetArrayFromImage(transformed_sitk), (0, 2, 1)) + aligned_res = resolution + update_pbar() + + # Load Allen template at native resolution for reference + # We'll just show it as a reference, not spatially aligned + allen_sitk = allen.download_template_ras_aligned(allen_resolution, cache=True) + allen_template = sitk.GetArrayFromImage(allen_sitk) + # Convert from SITK (Z, Y, X) to our (Z, X, Y) ordering + allen_template = np.transpose(allen_template, (0, 2, 1)) + update_pbar() + + # Helper functions + def get_center_slices(vol): + """Get center slices in each plane.""" + z, x, y = vol.shape[0] // 2, vol.shape[1] // 2, vol.shape[2] // 2 + return vol[z, :, :], vol[:, x, :], vol[:, :, y] + + def get_display_range(vol): + """Get display range from non-zero values.""" + nonzero = vol[vol > 0] + if len(nonzero) > 0: + return np.percentile(nonzero, [1, 99]) + return 0, 1 + + def find_content_center_slices(vol): + """Find the slice with maximum content independently for each axis. + + Using a shared 3D centroid for all three views fails when the brain is + asymmetric (e.g. cut at 45°): the centroid lands near the cut boundary, + so one or more of the orthogonal slice views passes through the cut plane + and shows a black stripe. Instead, pick each index independently as the + slice with the highest total signal along that axis. + """ + if vol.max() == 0: + return get_center_slices(vol) + z = int(np.argmax(vol.sum(axis=(1, 2)))) + x = int(np.argmax(vol.sum(axis=(0, 2)))) + y = int(np.argmax(vol.sum(axis=(0, 1)))) + return vol[z, :, :], vol[:, x, :], vol[:, :, y] + + # Get slices - use content-centered slices for aligned volume + orig_slices = get_center_slices(vol_original) + aligned_slices = find_content_center_slices(vol_aligned) + allen_slices = get_center_slices(allen_template) + + orig_vmin, orig_vmax = get_display_range(vol_original) + align_vmin, align_vmax = get_display_range(vol_aligned) + allen_vmin, allen_vmax = get_display_range(allen_template) + + # Create figure + fig, axes = plt.subplots(3, 3, figsize=(18, 18)) + fig.suptitle('Alignment Preview: Original vs Aligned vs Allen Template (Reference)', fontsize=16) + + plane_names = ['Axial (XY)', 'Sagittal (XZ)', 'Coronal (YZ)'] + + for row, plane_name in enumerate(plane_names): + # Original - use .T for row 0 (XY plane) to match display convention + data = orig_slices[row].T if row == 0 else orig_slices[row][::-1, :] + axes[row, 0].imshow(data, cmap='gray', origin='lower', vmin=orig_vmin, vmax=orig_vmax) + axes[row, 0].set_title(f'Original - {plane_name}') + axes[row, 0].axis('off') + + # Aligned + data = aligned_slices[row].T if row == 0 else aligned_slices[row][::-1, :] + axes[row, 1].imshow(data, cmap='gray', origin='lower', vmin=align_vmin, vmax=align_vmax) + axes[row, 1].set_title(f'Aligned - {plane_name}') + axes[row, 1].axis('off') + + # Allen (reference) + data = allen_slices[row].T if row == 0 else allen_slices[row][::-1, :] + axes[row, 2].imshow(data, cmap='gray', origin='lower', vmin=allen_vmin, vmax=allen_vmax) + axes[row, 2].set_title(f'Allen {allen_resolution}µm - {plane_name}') + axes[row, 2].axis('off') + + # Add info text + info_text = (f"Original shape: {vol_original.shape}\n" + f"Aligned shape: {vol_aligned.shape}\n" + f"Allen shape: {allen_template.shape}") + fig.text(0.02, 0.02, info_text, fontsize=10, family='monospace', + bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) + + plt.tight_layout() + Path(preview_path).parent.mkdir(parents=True, exist_ok=True) + fig.savefig(preview_path, dpi=150, bbox_inches='tight') + plt.close(fig) + update_pbar() + + print(f"Alignment preview saved to: {preview_path}") + + +# ============================================================================= +# Main entry point +# ============================================================================= + +def create_orientation_preview( + input_path: str, + preview_path: str, + level: int = 0, + orientation_permutation: Optional[tuple] = None, + orientation_flips: Optional[tuple] = None, + initial_rotation_deg: tuple = (0.0, 0.0, 0.0), +): + """ + Save a 3-panel orthogonal preview of the volume after orientation correction + and initial rotation are applied. + + Axes are labelled in RAS space (Z=S, X=R, Y=A) so the result can be + inspected directly against the Allen atlas orientation. + + Parameters + ---------- + input_path : str + Path to input OME-Zarr. + preview_path : str + Output PNG path. + level : int + Pyramid level to load (lower = higher resolution but slower). + orientation_permutation : tuple, optional + Axis permutation from ``parse_orientation_code``. + orientation_flips : tuple, optional + Axis flips from ``parse_orientation_code``. + initial_rotation_deg : tuple of float + (Rx, Ry, Rz) initial rotation angles in degrees applied after orientation. + """ + vol_zarr, resolution = read_omezarr(input_path, level=level) + vol = np.asarray(vol_zarr[:]).astype(np.float32) + + # Apply orientation permutation + flips + if orientation_permutation is not None: + vol = apply_orientation_transform(vol, orientation_permutation, orientation_flips) + resolution = list(reorder_resolution(tuple(resolution), orientation_permutation)) + + # Apply initial rotation via SimpleITK (same path as the registration uses) + if any(r != 0.0 for r in initial_rotation_deg): + vol_sitk = allen.numpy_to_sitk_image(vol, resolution, cast_dtype=np.float32) + center = vol_sitk.TransformContinuousIndexToPhysicalPoint( + [s / 2.0 for s in vol_sitk.GetSize()]) + rx, ry, rz = [np.deg2rad(a) for a in initial_rotation_deg] + t = sitk.Euler3DTransform() + t.SetCenter(center) + t.SetRotation(rx, ry, rz) + resampler = sitk.ResampleImageFilter() + resampler.SetReferenceImage(vol_sitk) + resampler.SetTransform(t.GetInverse()) + resampler.SetInterpolator(sitk.sitkLinear) + vol = sitk.GetArrayFromImage(resampler.Execute(vol_sitk)) + + # Display range from non-zero voxels + nonzero = vol[vol > 0] + vmin, vmax = np.percentile(nonzero if len(nonzero) else vol.ravel(), [1, 99]) + + # Build title + applied = [] + if orientation_permutation is not None: + applied.append(f"orientation") + if any(r != 0.0 for r in initial_rotation_deg): + applied.append(f"rotation {list(initial_rotation_deg)}°") + subtitle = f"({', '.join(applied)} applied)" if applied else "(no corrections applied)" + + z_mid = vol.shape[0] // 2 + x_mid = vol.shape[1] // 2 + y_mid = vol.shape[2] // 2 + + fig, axes = plt.subplots(1, 3, figsize=(18, 6)) + fig.suptitle( + f"Orientation Preview — {subtitle}\n" + f"Shape: {vol.shape} | After corrections: Z=S (Superior), X=R (Right), Y=A (Anterior)", + fontsize=11 + ) + + # Axial (mid-Z): rows=Y (A), cols=X (R) + axes[0].imshow(vol[z_mid, :, :].T, cmap='gray', origin='lower', vmin=vmin, vmax=vmax) + axes[0].set_title(f'Axial (Z={z_mid})') + axes[0].set_xlabel('X (← L R →)') + axes[0].set_ylabel('Y (← P A →)') + + # Sagittal (mid-X): rows=Z (S, flipped), cols=Y (A) + axes[1].imshow(vol[::-1, x_mid, :], cmap='gray', origin='lower', vmin=vmin, vmax=vmax) + axes[1].set_title(f'Sagittal (X={x_mid})') + axes[1].set_xlabel('Y (← P A →)') + axes[1].set_ylabel('Z (← I S →)') + + # Coronal (mid-Y): rows=Z (S, flipped), cols=X (R) + axes[2].imshow(vol[::-1, :, y_mid], cmap='gray', origin='lower', vmin=vmin, vmax=vmax) + axes[2].set_title(f'Coronal (Y={y_mid})') + axes[2].set_xlabel('X (← L R →)') + axes[2].set_ylabel('Z (← I S →)') + + plt.tight_layout() + Path(preview_path).parent.mkdir(parents=True, exist_ok=True) + fig.savefig(preview_path, dpi=150, bbox_inches='tight') + plt.close(fig) + print(f"Orientation preview saved to: {preview_path}") + + +# ============================================================================= +# Main entry point +# ============================================================================= + +def main(): + """Main entry point: parse arguments and run alignment workflow.""" + parser = _build_arg_parser() + args = parser.parse_args() + + input_path = Path(args.input_zarr) + output_path = Path(args.output_zarr) + + if not input_path.exists(): + raise FileNotFoundError(f"Input zarr not found: {input_path}") + + # Preview-only mode + if args.preview_only: + preview_path = args.preview or "input_preview.png" + create_input_preview(str(input_path), preview_path, level=args.level) + return + + # Parse orientation + orientation_permutation = None + orientation_flips = None + if args.input_orientation: + try: + orientation_permutation, orientation_flips = parse_orientation_code(args.input_orientation) + print(f"Input orientation '{args.input_orientation}':") + print(f" Axis permutation: {orientation_permutation}") + print(f" Axis flips: {orientation_flips}") + except ValueError as e: + parser.error(str(e)) + + # Orientation + initial-rotation preview (can exit before registration) + if args.orientation_preview or args.orientation_preview_only: + preview_out = args.orientation_preview or "orientation_preview.png" + create_orientation_preview( + str(input_path), + preview_out, + level=args.level, + orientation_permutation=orientation_permutation, + orientation_flips=orientation_flips, + initial_rotation_deg=tuple(args.initial_rotation), + ) + if args.orientation_preview_only: + return + + # Load input volume + vol_zarr, zarr_resolution = read_omezarr(str(input_path), level=args.level) + resolution = tuple(zarr_resolution) + + # Progress bar - allocate steps for each phase + registration_steps = 3 # Steps allocated for registration progress + base_steps = 2 if args.store_transform_only else 5 # Load + save steps + total_steps = base_steps + registration_steps + if args.preview: + total_steps += 4 + pbar = tqdm(total=total_steps, desc="Aligning to RAS") + + vol = np.asarray(vol_zarr[:]) + pbar.update(1) + + if args.verbose: + print(f"Volume shape: {vol.shape}, Resolution: {resolution} mm") + + # Apply orientation correction for registration + if orientation_permutation is not None: + vol = apply_orientation_transform(vol, orientation_permutation, orientation_flips) + resolution = reorder_resolution(resolution, orientation_permutation) + + # Create progress callback for registration + registration_start_step = pbar.n + progress_callback = create_registration_progress_callback( + max_iterations=args.max_iterations, + n_resolution_levels=3, + pbar=pbar, + registration_start_step=registration_start_step, + registration_steps=registration_steps + ) + + # Register to Allen atlas + pbar.set_postfix_str("registering...") + transform, stop_condition, error = allen.register_3d_rigid_to_allen( + moving_image=vol, + moving_spacing=resolution, + allen_resolution=args.allen_resolution, + metric=args.metric, + max_iterations=args.max_iterations, + verbose=args.verbose, + progress_callback=progress_callback, + initial_rotation_deg=tuple(args.initial_rotation), + ) + # Ensure progress bar reaches end of registration steps + pbar.n = registration_start_step + registration_steps + pbar.refresh() + + print(f"Registration complete: {stop_condition}") + print(f"Final metric value: {error:.6f}") + del vol # free registration-level volume before loading full-resolution data + + # Apply or store transform + if args.store_transform_only: + store_transform_in_metadata(str(input_path), transform) + pbar.update(1) + else: + apply_transform_to_zarr( + str(input_path), + str(output_path), + transform, + chunks=tuple(args.chunks) if args.chunks else None, + n_levels=args.n_levels, + pyramid_resolutions=args.pyramid_resolutions, + make_isotropic=args.make_isotropic, + orientation_permutation=orientation_permutation, + orientation_flips=orientation_flips, + pbar=pbar, + ) + print(f"Aligned volume saved to: {output_path}") + + # Save transform file + # Strip the compound .ome.zarr extension (Path.stem only removes the last suffix) + stem = output_path.with_suffix('').with_suffix('').name + transform_path = output_path.parent / f"{stem}_transform.tfm" + sitk.WriteTransform(transform, str(transform_path)) + print(f"Transform saved to: {transform_path}") + pbar.update(1) + + # Generate preview + if args.preview: + pbar.set_postfix_str("generating preview...") + create_alignment_preview( + str(input_path), + str(output_path) if not args.store_transform_only else None, + transform, + resolution, + args.preview, + allen_resolution=args.allen_resolution, + level=args.level, + orientation_permutation=orientation_permutation, + orientation_flips=orientation_flips, + pbar=pbar, + ) + + pbar.set_postfix_str("complete") + pbar.close() + + +if __name__ == "__main__": + main() diff --git a/scripts/linum_download_allen.py b/scripts/linum_download_allen.py index 6dfb9138..82f96b22 100644 --- a/scripts/linum_download_allen.py +++ b/scripts/linum_download_allen.py @@ -5,6 +5,9 @@ Download the Allen mouse brain template, and setting the correct RAS+ direction and spacing. """ +# Configure thread limits before numpy/scipy imports +import linumpy._thread_config # noqa: F401 + import argparse from pathlib import Path @@ -50,6 +53,11 @@ def main(): vol = sitk.Flip(vol, (False, False, True)) vol.SetDirection([1, 0, 0, 0, 1, 0, 0, 0, 1]) + # Match the pipeline output dtype ([0, 1] float32) so both volumes + # display on the same intensity scale in ITK-SNAP / napari. + vol = sitk.Cast(vol, sitk.sitkFloat32) + vol = sitk.RescaleIntensity(vol, 0.0, 1.0) + # Save the volume sitk.WriteImage(vol, str(output))