Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions autoarray/dataset/imaging/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,9 @@ def __init__(
self.psf = psf

if psf is not None:
if psf.mask.shape[0] % 2 == 0 or psf.mask.shape[1] % 2 == 0:
raise exc.KernelException("Kernel2D Kernel2D must be odd")
if not psf.use_fft:
if psf.mask.shape[0] % 2 == 0 or psf.mask.shape[1] % 2 == 0:
raise exc.KernelException("Kernel2D Kernel2D must be odd")

self.grids = GridsDataset(
mask=self.data.mask,
Expand Down
3 changes: 0 additions & 3 deletions autoarray/mask/derive/mask_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,6 @@ def blurring_from(self, kernel_shape_native: Tuple[int, int]) -> Mask2D:

from autoarray.mask.mask_2d import Mask2D

if kernel_shape_native[0] % 2 == 0 or kernel_shape_native[1] % 2 == 0:
raise exc.MaskException("psf_size of exterior region must be odd")

blurring_mask = mask_2d_util.blurring_mask_2d_from(
mask_2d=self.mask,
kernel_shape_native=kernel_shape_native,
Expand Down
74 changes: 44 additions & 30 deletions autoarray/structures/arrays/kernel_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,14 +723,18 @@ def convolved_image_from(
blurred_image_full = xp.fft.irfft2(
fft_psf * fft_image_native, s=fft_shape, axes=(0, 1)
)
ky, kx = self.native.array.shape # (21, 21)
Copy link

Copilot AI Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The inline comment # (21, 21) appears to be an example value for demonstration purposes, but it may be unclear to future readers that this is not a constant value. Consider clarifying this comment to indicate it's an example, e.g., # e.g., (21, 21) or removing it if it's not needed.

Suggested change
ky, kx = self.native.array.shape # (21, 21)
ky, kx = self.native.array.shape # e.g., (21, 21)

Copilot uses AI. Check for mistakes.
off_y = (ky - 1) // 2
off_x = (kx - 1) // 2

# Crop back to mask_shape
start_indices = tuple(
(full_size - out_size) // 2
for full_size, out_size in zip(full_shape, mask_shape)
blurred_image_full = xp.roll(
blurred_image_full, shift=(-off_y, -off_x), axis=(0, 1)
)

start_indices = (off_y, off_x)

blurred_image_native = jax.lax.dynamic_slice(
blurred_image_full, start_indices, mask_shape
blurred_image_full, start_indices, image.mask.shape
)
Comment on lines 736 to 738
Copy link

Copilot AI Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential bug when self.fft_shape is None: After resizing the image to fft_shape at line 686, image.mask.shape becomes fft_shape. At line 737, the code tries to extract a region of size image.mask.shape (which is fft_shape) from blurred_image_full starting at (off_y, off_x). Since blurred_image_full has shape fft_shape and off_y, off_x > 0 for kernels with size > 1, this extraction would exceed the array bounds. The old code used mask_shape (computed at line 675 before the resize) for the extraction size, which was correct. Consider saving mask_shape before the resize and using it for extraction instead of image.mask.shape.

Copilot uses AI. Check for mistakes.

# Return slim form; optionally cast for downstream stability
Expand Down Expand Up @@ -806,6 +810,10 @@ def convolved_mapping_matrix_from(
ndarray of shape (N_pix, N_src)
Convolved mapping matrix in slim form.
"""
# -------------------------------------------------------------------------
# NumPy path unchanged
# -------------------------------------------------------------------------

# -------------------------------------------------------------------------
# NumPy path unchanged
# -------------------------------------------------------------------------
Comment on lines 817 to 819
Copy link

Copilot AI Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate comment section detected. Lines 813-815 contain the same comment "NumPy path unchanged" as lines 817-819. One of these duplicate comment blocks should be removed to improve code clarity.

Suggested change
# -------------------------------------------------------------------------
# NumPy path unchanged
# -------------------------------------------------------------------------

Copilot uses AI. Check for mistakes.
Expand Down Expand Up @@ -835,34 +843,24 @@ def convolved_mapping_matrix_from(
import jax.numpy as jnp

# -------------------------------------------------------------------------
# Validate cached FFT shapes / state
# Cached FFT shapes/state (REQUIRED)
# -------------------------------------------------------------------------
if self.fft_shape is None:
full_shape, fft_shape, mask_shape = self.fft_shape_from(mask=mask)
raise ValueError(
f"FFT convolution requires precomputed padded shapes, but `self.fft_shape` is None.\n"
f"Expected mapping matrix padded to match FFT shape of PSF.\n"
f"PSF fft_shape: {fft_shape}, mask shape: {mask.shape}, "
f"mapping_matrix shape: {getattr(mapping_matrix, 'shape', 'unknown')}."
"FFT convolution requires precomputed FFT shapes on the PSF."
)
else:
fft_shape = self.fft_shape
full_shape = self.full_shape
mask_shape = self.mask_shape
fft_psf_mapping = self.fft_psf_mapping

fft_shape = self.fft_shape
fft_psf_mapping = self.fft_psf_mapping

# -------------------------------------------------------------------------
# Mixed precision dtypes (JAX only)
# Mixed precision handling
# -------------------------------------------------------------------------
fft_complex_dtype = jnp.complex64 if use_mixed_precision else jnp.complex128

# Ensure PSF FFT dtype matches the FFT path
fft_psf_mapping = jnp.asarray(fft_psf_mapping, dtype=fft_complex_dtype)

# -------------------------------------------------------------------------
# Build native cube in the FFT dtype (THIS IS THE KEY)
# This relies on mapping_matrix_native_from honoring the use_mixed_precision
# kwarg when constructing the native mapping matrix.
# Build native cube on the *native mask grid*
# -------------------------------------------------------------------------
mapping_matrix_native = self.mapping_matrix_native_from(
mapping_matrix=mapping_matrix,
Expand All @@ -872,35 +870,51 @@ def convolved_mapping_matrix_from(
use_mixed_precision=use_mixed_precision,
xp=xp,
)
# shape: (ny_native, nx_native, n_src)

# -------------------------------------------------------------------------
# FFT convolution
# -------------------------------------------------------------------------
fft_mapping_matrix_native = xp.fft.rfft2(
mapping_matrix_native, s=fft_shape, axes=(0, 1)
)

blurred_mapping_matrix_full = xp.fft.irfft2(
fft_psf_mapping * fft_mapping_matrix_native,
s=fft_shape,
axes=(0, 1),
)

# -------------------------------------------------------------------------
# Crop back to mask-shape
# APPLY SAME FIX AS convolved_image_from
# -------------------------------------------------------------------------
start_indices = tuple(
(full_size - out_size) // 2
for full_size, out_size in zip(full_shape, mask_shape)
) + (0,)
out_shape_full = mask_shape + (blurred_mapping_matrix_full.shape[2],)
ky, kx = self.native.array.shape
off_y = (ky - 1) // 2
off_x = (kx - 1) // 2

blurred_mapping_matrix_full = xp.roll(
blurred_mapping_matrix_full,
shift=(-off_y, -off_x),
axis=(0, 1),
)

# -------------------------------------------------------------------------
# Extract native grid (same as image path)
# -------------------------------------------------------------------------
native_shape = mask.shape
start_indices = (off_y, off_x, 0)

out_shape = native_shape + (blurred_mapping_matrix_full.shape[2],)

blurred_mapping_matrix_native = jax.lax.dynamic_slice(
blurred_mapping_matrix_full,
start_indices,
out_shape_full,
out_shape,
)

# Return slim form
# -------------------------------------------------------------------------
# Slim using ORIGINAL mask indices (same grid)
# -------------------------------------------------------------------------
blurred_slim = blurred_mapping_matrix_native[mask.slim_to_native_tuple]

return blurred_slim
Expand Down
Loading