From e0c70a7da8a835f149752474246fa5cdce6a502a Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 17 Jun 2025 20:58:20 +0100 Subject: [PATCH 01/18] update to much simpler tests for transformer --- autoarray/operators/transformer.py | 2 +- autoarray/operators/transformer_util.py | 92 +++--- test_autoarray/operators/test_transformer.py | 287 +++++-------------- 3 files changed, 117 insertions(+), 264 deletions(-) diff --git a/autoarray/operators/transformer.py b/autoarray/operators/transformer.py index c5b1d3a50..288648016 100644 --- a/autoarray/operators/transformer.py +++ b/autoarray/operators/transformer.py @@ -85,7 +85,7 @@ def visibilities_from(self, image): ) else: - visibilities = transformer_util.visibilities_jit( + visibilities = transformer_util.visibilities_direct_from( image_1d=np.array(image.slim.array), grid_radians=np.array(self.grid), uv_wavelengths=self.uv_wavelengths, diff --git a/autoarray/operators/transformer_util.py b/autoarray/operators/transformer_util.py index 395034794..1558fb0c7 100644 --- a/autoarray/operators/transformer_util.py +++ b/autoarray/operators/transformer_util.py @@ -88,60 +88,68 @@ def visibilities_via_preload_jit_from(image_1d, preloaded_reals, preloaded_imags return visibilities -@numba_util.jit() -def visibilities_jit(image_1d, grid_radians, uv_wavelengths): - visibilities = 0 + 0j * np.zeros(shape=(uv_wavelengths.shape[0])) +def visibilities_direct_from(image_1d : np.ndarray, grid_radians : np.ndarray, uv_wavelengths : np.ndarray) -> np.ndarray: + """ + Compute complex visibilities from an input sky image using the Fourier transform, + simulating the response of an astronomical radio interferometer. - for image_1d_index in range(image_1d.shape[0]): - for vis_1d_index in range(uv_wavelengths.shape[0]): - vis_real = image_1d[image_1d_index] * np.cos( - -2.0 - * np.pi - * ( - grid_radians[image_1d_index, 1] * uv_wavelengths[vis_1d_index, 0] - + grid_radians[image_1d_index, 0] * uv_wavelengths[vis_1d_index, 1] - ) - ) - vis_imag = image_1d[image_1d_index] * np.sin( - -2.0 - * np.pi - * ( - grid_radians[image_1d_index, 1] * uv_wavelengths[vis_1d_index, 0] - + grid_radians[image_1d_index, 0] * uv_wavelengths[vis_1d_index, 1] - ) - ) - visibilities[vis_1d_index] += vis_real + 1j * vis_imag + This function converts an image defined on a sky coordinate grid into its + visibility-space representation, given a set of (u,v) spatial frequency + coordinates (in wavelengths), as sampled by a radio interferometer. + + Parameters + ---------- + image_1d + The 1D flattened sky brightness values corresponding to each pixel in the grid. + + grid_radians + The angular (y, x) positions of each image pixel in radians, matching image_1d. + + uv_wavelengths + The (u, v) spatial frequencies in units of wavelengths, for each baseline + of the interferometer. + + Returns + ------- + visibilities + The complex visibilities (Fourier components) corresponding to each + (u, v) coordinate, representing the interferometer’s measurement. + """ + # Compute the dot product for each pixel-uv pair + phase = -2.0 * np.pi * ( + np.outer(grid_radians[:, 1], uv_wavelengths[:, 0]) + + np.outer(grid_radians[:, 0], uv_wavelengths[:, 1]) + ) # shape (n_pixels, n_vis) + + # Multiply image values with phase terms + vis_real = image_1d[:, None] * np.cos(phase) + vis_imag = image_1d[:, None] * np.sin(phase) + + # Sum over all pixels for each visibility + visibilities = np.sum(vis_real + 1j * vis_imag, axis=0) return visibilities -@numba_util.jit() def image_via_jit_from(n_pixels, grid_radians, uv_wavelengths, visibilities): - image_1d = np.zeros(n_pixels) - for image_1d_index in range(image_1d.shape[0]): - for vis_1d_index in range(uv_wavelengths.shape[0]): - image_1d[image_1d_index] += visibilities[vis_1d_index, 0] * np.cos( - 2.0 - * np.pi - * ( - grid_radians[image_1d_index, 1] * uv_wavelengths[vis_1d_index, 0] - + grid_radians[image_1d_index, 0] * uv_wavelengths[vis_1d_index, 1] - ) - ) + aaaa - image_1d[image_1d_index] -= visibilities[vis_1d_index, 1] * np.sin( - 2.0 - * np.pi - * ( - grid_radians[image_1d_index, 1] * uv_wavelengths[vis_1d_index, 0] - + grid_radians[image_1d_index, 0] * uv_wavelengths[vis_1d_index, 1] - ) - ) + # Compute the phase term for each (pixel, visibility) pair + phase = 2.0 * np.pi * ( + np.outer(grid_radians[:, 1], uv_wavelengths[:, 0]) + + np.outer(grid_radians[:, 0], uv_wavelengths[:, 1]) + ) + + real_part = np.dot(np.cos(phase), visibilities[:, 0]) + imag_part = np.dot(np.sin(phase), visibilities[:, 1]) + + image_1d = real_part - imag_part return image_1d + @numba_util.jit() def transformed_mapping_matrix_via_preload_jit_from( mapping_matrix, preloaded_reals, preloaded_imags diff --git a/test_autoarray/operators/test_transformer.py b/test_autoarray/operators/test_transformer.py index 33ea1ca42..c2af5b082 100644 --- a/test_autoarray/operators/test_transformer.py +++ b/test_autoarray/operators/test_transformer.py @@ -4,147 +4,68 @@ import pytest -class MockDeriveMask2D: - def __init__(self, grid): - self.mask = grid.derive_mask.all_false - self.grid = grid - - @property - def sub_1(self): - return self - - @property - def derive_grid(self): - return MockDeriveGrid2D( - grid=self.grid, - ) - - -class MockDeriveGrid2D: - def __init__(self, grid): - self.unmasked = MockMaskedGrid(grid=grid) - - -class MockRealSpaceMask: - def __init__(self, grid): - self.grid = grid - self.unmasked = MockMaskedGrid(grid=grid) - - @property - def pixels_in_mask(self): - return self.unmasked.slim.in_radians.shape[0] - - @property - def derive_mask(self): - return MockDeriveMask2D( - grid=self.grid, - ) - - @property - def derive_grid(self): - return MockDeriveGrid2D( - grid=self.grid, - ) - - @property - def pixel_scales(self): - return self.grid.pixel_scales - - @property - def origin(self): - return self.grid.origin - - -class MockMaskedGrid: - def __init__(self, grid): - self.in_radians = grid - self.slim = grid - - -def test__dft__visibilities_from(): - uv_wavelengths = np.ones(shape=(4, 2)) - - grid_radians = aa.Grid2D.no_mask(values=[[[1.0, 1.0]]], pixel_scales=1.0) - - real_space_mask = MockRealSpaceMask(grid=grid_radians) +def test__dft__visibilities_from(visibilities_7, uv_wavelengths_7x2, mask_2d_7x7): transformer = aa.TransformerDFT( - uv_wavelengths=uv_wavelengths, - real_space_mask=real_space_mask, + uv_wavelengths=uv_wavelengths_7x2, + real_space_mask=mask_2d_7x7, preload_transform=False, ) - image = aa.Array2D.ones(shape_native=(1, 1), pixel_scales=1.0) - - visibilities = transformer.visibilities_from(image=image) - - assert visibilities == pytest.approx( - np.array([1.0 + 0.0j, 1.0 + 0.0j, 1.0 + 0.0j, 1.0 + 0.0j]), 1.0e-4 + image = aa.Array2D( + values=[ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.5, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.5, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.5, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + mask=mask_2d_7x7, ) - uv_wavelengths = np.array([[0.2, 1.0], [0.5, 1.1], [0.8, 1.2]]) - - grid_radians = aa.Grid2D.no_mask( - values=[[[0.1, 0.2], [0.3, 0.4]]], pixel_scales=1.0 - ) - - real_space_mask = MockRealSpaceMask(grid=grid_radians) - - transformer = aa.TransformerDFT( - uv_wavelengths=uv_wavelengths, - real_space_mask=real_space_mask, - preload_transform=False, - ) - - image = aa.Array2D.ones(shape_native=(1, 2), pixel_scales=1.0) - visibilities = transformer.visibilities_from(image=image) - assert visibilities == pytest.approx( + print(visibilities) + + assert visibilities[0:3] == pytest.approx( np.array( - [-0.091544 - 1.45506j, -0.73359736 - 0.781201j, -0.613160 - 0.077460j] + [ + -0.06434514 - 0.61763293j, + 1.71143349 - 1.184022j, + 0.90200541 + 0.03726693j, + ] ), 1.0e-4, ) - uv_wavelengths = np.array([[0.2, 1.0], [0.5, 1.1], [0.8, 1.2]]) - grid_radians = aa.Grid2D.no_mask( - values=[[[0.1, 0.2], [0.3, 0.4]]], pixel_scales=1.0 - ) - real_space_mask = MockRealSpaceMask(grid=grid_radians) +def test__dft__image_from(visibilities_7, uv_wavelengths_7x2, mask_2d_7x7): transformer = aa.TransformerDFT( - uv_wavelengths=uv_wavelengths, - real_space_mask=real_space_mask, + uv_wavelengths=uv_wavelengths_7x2, + real_space_mask=mask_2d_7x7, preload_transform=False, ) - image = aa.Array2D.no_mask([[3.0, 6.0]], pixel_scales=1.0) + image = transformer.image_from(visibilities=visibilities_7) - visibilities = transformer.visibilities_from(image=image) + assert image[0:3] == pytest.approx([-1.49022481, -0.22395855, -0.45588535], 1.0e-4) - assert visibilities == pytest.approx( - np.array([-2.46153 - 6.418822j, -5.14765 - 1.78146j, -3.11681 + 2.48210j]), - 1.0e-4, - ) - -def test__dft__visibilities_from__preload_and_non_preload_give_same_answer(): - uv_wavelengths = np.array([[0.2, 1.0], [0.5, 1.1], [0.8, 1.2]]) - grid_radians = aa.Grid2D.no_mask( - values=[[[0.1, 0.2], [0.3, 0.4]]], pixel_scales=1.0 - ) - real_space_mask = MockRealSpaceMask(grid=grid_radians) +def test__dft__visibilities_from__preload_and_non_preload_give_same_answer( + visibilities_7, uv_wavelengths_7x2, mask_2d_7x7 +): transformer_preload = aa.TransformerDFT( - uv_wavelengths=uv_wavelengths, - real_space_mask=real_space_mask, + uv_wavelengths=uv_wavelengths_7x2, + real_space_mask=mask_2d_7x7, preload_transform=True, ) transformer = aa.TransformerDFT( - uv_wavelengths=uv_wavelengths, - real_space_mask=real_space_mask, + uv_wavelengths=uv_wavelengths_7x2, + real_space_mask=mask_2d_7x7, preload_transform=False, ) @@ -156,14 +77,13 @@ def test__dft__visibilities_from__preload_and_non_preload_give_same_answer(): assert (visibilities_via_preload == visibilities).all() -def test__dft__transform_mapping_matrix(): - uv_wavelengths = np.ones(shape=(4, 2)) - grid_radians = aa.Grid2D.no_mask(values=[[[1.0, 1.0]]], pixel_scales=1.0) - real_space_mask = MockRealSpaceMask(grid=grid_radians) +def test__dft__transform_mapping_matrix( + visibilities_7, uv_wavelengths_7x2, mask_2d_7x7 +): transformer = aa.TransformerDFT( - uv_wavelengths=uv_wavelengths, - real_space_mask=real_space_mask, + uv_wavelengths=uv_wavelengths_7x2, + real_space_mask=mask_2d_7x7, preload_transform=False, ) @@ -173,86 +93,31 @@ def test__dft__transform_mapping_matrix(): mapping_matrix=mapping_matrix ) - assert transformed_mapping_matrix == pytest.approx( - np.array([[1.0 + 0.0j], [1.0 + 0.0j], [1.0 + 0.0j], [1.0 + 0.0j]]), 1.0e-4 - ) - - uv_wavelengths = np.array([[0.2, 1.0], [0.5, 1.1], [0.8, 1.2]]) - - grid_radians = aa.Grid2D.no_mask( - values=[[[0.1, 0.2], [0.3, 0.4]]], pixel_scales=1.0 - ) - real_space_mask = MockRealSpaceMask(grid=grid_radians) - - transformer = aa.TransformerDFT( - uv_wavelengths=uv_wavelengths, - real_space_mask=real_space_mask, - preload_transform=False, - ) - - mapping_matrix = np.ones(shape=(2, 2)) - - transformed_mapping_matrix = transformer.transform_mapping_matrix( - mapping_matrix=mapping_matrix - ) - - assert transformed_mapping_matrix == pytest.approx( - np.array( - [ - [-0.091544 - 1.45506j, -0.091544 - 1.45506j], - [-0.733597 - 0.78120j, -0.733597 - 0.78120j], - [-0.61316 - 0.07746j, -0.61316 - 0.07746j], - ] - ), - 1.0e-4, - ) - - grid_radians = aa.Grid2D.no_mask( - [[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]], pixel_scales=1.0 - ) - real_space_mask = MockRealSpaceMask(grid=grid_radians) - - uv_wavelengths = np.array([[0.7, 0.8], [0.9, 1.0]]) - - transformer = aa.TransformerDFT( - uv_wavelengths=uv_wavelengths, - real_space_mask=real_space_mask, - preload_transform=False, - ) - - mapping_matrix = np.array([[0.0, 0.5], [0.0, 0.2], [1.0, 0.0]]) - - transformed_mapping_matrix = transformer.transform_mapping_matrix( - mapping_matrix=mapping_matrix - ) - - assert transformed_mapping_matrix == pytest.approx( + assert transformed_mapping_matrix[0:3, :] == pytest.approx( np.array( [ - [0.42577 + 0.90482j, -0.10473 - 0.46607j], - [0.968583 - 0.24868j, -0.20085 - 0.32227j], + [0.80682556 - 0.59078974j], + [-0.19648896 - 0.98050604j], + [-0.47002763 - 0.8826517j], ] ), 1.0e-4, ) -def test__dft__transformed_mapping_matrix__preload_and_non_preload_give_same_answer(): - uv_wavelengths = np.array([[0.2, 1.0], [0.5, 1.1], [0.8, 1.2]]) - grid_radians = aa.Grid2D.no_mask( - values=[[[0.1, 0.2], [0.3, 0.4]]], pixel_scales=1.0 - ) - real_space_mask = MockRealSpaceMask(grid=grid_radians) +def test__dft__transformed_mapping_matrix__preload_and_non_preload_give_same_answer( + visibilities_7, uv_wavelengths_7x2, mask_2d_7x7 +): transformer_preload = aa.TransformerDFT( - uv_wavelengths=uv_wavelengths, - real_space_mask=real_space_mask, + uv_wavelengths=uv_wavelengths_7x2, + real_space_mask=mask_2d_7x7, preload_transform=True, ) transformer = aa.TransformerDFT( - uv_wavelengths=uv_wavelengths, - real_space_mask=real_space_mask, + uv_wavelengths=uv_wavelengths_7x2, + real_space_mask=mask_2d_7x7, preload_transform=False, ) @@ -270,53 +135,40 @@ def test__dft__transformed_mapping_matrix__preload_and_non_preload_give_same_ans def test__nufft__visibilities_from(): - uv_wavelengths = np.array([[0.2, 1.0], [0.5, 1.1], [0.8, 1.2]]) - grid_radians = aa.Grid2D.uniform(shape_native=(5, 5), pixel_scales=0.005).in_radians - real_space_mask = MockRealSpaceMask(grid=grid_radians) + uv_wavelengths = np.array([[0.2, 1.0], [0.5, 1.1], [0.8, 1.2]]) + real_space_mask = aa.Mask2D.all_false(shape_native=(5, 5), pixel_scales=0.005) image = aa.Array2D.ones( - shape_native=grid_radians.shape_native, - pixel_scales=grid_radians.pixel_scales, + shape_native=(5, 5), + pixel_scales=0.005, ) - transformer_dft = aa.TransformerDFT( - uv_wavelengths=uv_wavelengths, - real_space_mask=real_space_mask, - preload_transform=False, - ) - - visibilities_dft = transformer_dft.visibilities_from(image=image.native) - - real_space_mask = aa.Mask2D.all_false(shape_native=(5, 5), pixel_scales=0.005) - transformer_nufft = aa.TransformerNUFFT( uv_wavelengths=uv_wavelengths, real_space_mask=real_space_mask ) visibilities_nufft = transformer_nufft.visibilities_from(image=image.native) - assert visibilities_dft == pytest.approx(visibilities_nufft, 2.0) assert visibilities_nufft[0] == pytest.approx(25.02317617953263 + 0.0j, 1.0e-7) -def test__nufft__transform_mapping_matrix(): - uv_wavelengths = np.array([[0.2, 1.0], [0.5, 1.1], [0.8, 1.2]]) +def test__nufft__image_from(visibilities_7, uv_wavelengths_7x2, mask_2d_7x7): - grid_radians = aa.Grid2D.uniform(shape_native=(5, 5), pixel_scales=0.005) - real_space_mask = MockRealSpaceMask(grid=grid_radians) + transformer = aa.TransformerNUFFT( + uv_wavelengths=uv_wavelengths_7x2, + real_space_mask=mask_2d_7x7, + ) - mapping_matrix = np.ones(shape=(25, 3)) + image = transformer.image_from(visibilities=visibilities_7) - transformer_dft = aa.TransformerDFT( - uv_wavelengths=uv_wavelengths, - real_space_mask=real_space_mask, - preload_transform=False, - ) + assert image[0:3] == pytest.approx([0.00726546, 0.01149121, 0.01421022], 1.0e-4) - transformed_mapping_matrix_dft = transformer_dft.transform_mapping_matrix( - mapping_matrix=mapping_matrix - ) + +def test__nufft__transform_mapping_matrix(): + uv_wavelengths = np.array([[0.2, 1.0], [0.5, 1.1], [0.8, 1.2]]) + + mapping_matrix = np.ones(shape=(25, 3)) real_space_mask = aa.Mask2D.all_false(shape_native=(5, 5), pixel_scales=0.005) @@ -328,13 +180,6 @@ def test__nufft__transform_mapping_matrix(): mapping_matrix=mapping_matrix ) - assert transformed_mapping_matrix_dft == pytest.approx( - transformed_mapping_matrix_nufft, 2.0 - ) - assert transformed_mapping_matrix_dft == pytest.approx( - transformed_mapping_matrix_nufft, 2.0 - ) - assert transformed_mapping_matrix_nufft[0, 0] == pytest.approx( 25.02317 + 0.0j, 1.0e-4 ) From fd383e1b386e94d5ac6bdf70c7b1d5253f097292 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 17 Jun 2025 21:10:39 +0100 Subject: [PATCH 02/18] remove pylops legact --- autoarray/operators/transformer.py | 63 +------------------- autoarray/operators/transformer_util.py | 51 ++++++++++++---- test_autoarray/operators/test_transformer.py | 4 +- 3 files changed, 44 insertions(+), 74 deletions(-) diff --git a/autoarray/operators/transformer.py b/autoarray/operators/transformer.py index 288648016..7bc302771 100644 --- a/autoarray/operators/transformer.py +++ b/autoarray/operators/transformer.py @@ -94,11 +94,10 @@ def visibilities_from(self, image): return Visibilities(visibilities=visibilities) def image_from(self, visibilities, use_adjoint_scaling: bool = False): - image_slim = transformer_util.image_via_jit_from( - n_pixels=self.grid.shape[0], + image_slim = transformer_util.image_direct_from( + visibilities=visibilities.in_array, grid_radians=np.array(self.grid.array), uv_wavelengths=self.uv_wavelengths, - visibilities=visibilities.in_array, ) image_native = array_2d_util.array_2d_native_from( @@ -249,61 +248,3 @@ def transform_mapping_matrix(self, mapping_matrix): transformed_mapping_matrix[:, source_pixel_1d_index] = visibilities return transformed_mapping_matrix - - def forward_lop(self, x): - """ - Forward NUFFT on CPU - :param x: The input numpy array, with the size of Nd or Nd + (batch,) - :type: numpy array with the dtype of numpy.complex64 - :return: y: The output numpy array, with the size of (M,) or (M, batch) - :rtype: numpy array with the dtype of numpy.complex64 - """ - - warnings.filterwarnings("ignore") - - x2d = array_2d_util.array_2d_native_complex_via_indexes_from( - array_2d_slim=x, - shape_native=self.real_space_mask.shape_native, - native_index_for_slim_index_2d=self.native_index_for_slim_index, - )[::-1, :] - - y = self.k2y(self.xx2k(self.x2xx(x2d))) - return np.concatenate((y.real, y.imag), axis=0) - - def adjoint_lop(self, y): - """ - Adjoint NUFFT on CPU - :param y: The input numpy array, with the size of (M,) or (M, batch) - :type: numpy array with the dtype of numpy.complex64 - :return: x: The output numpy array, - with the size of Nd or Nd + (batch, ) - :rtype: numpy array with the dtype of numpy.complex64 - """ - - warnings.filterwarnings("ignore") - - def a_complex_from(a_real, a_imag): - return a_real + 1j * a_imag - - y = a_complex_from( - a_real=y[: int(self.shape[0] / 2.0)], a_imag=y[int(self.shape[0] / 2.0) :] - ) - - x2d = np.real(self.xx2x(self.k2xx(self.y2k(y)))) - - x = array_2d_util.array_2d_slim_complex_from( - array_2d_native=x2d[::-1, :], - mask=np.array(self.real_space_mask), - ) - x = x.real # NOTE: - - # NOTE: - x *= self.adjoint_scaling - - return x - - def _matvec(self, x): - return self.forward_lop(x) - - def _rmatvec(self, x): - return self.adjoint_lop(x) diff --git a/autoarray/operators/transformer_util.py b/autoarray/operators/transformer_util.py index 1558fb0c7..90b0bd798 100644 --- a/autoarray/operators/transformer_util.py +++ b/autoarray/operators/transformer_util.py @@ -88,7 +88,9 @@ def visibilities_via_preload_jit_from(image_1d, preloaded_reals, preloaded_imags return visibilities -def visibilities_direct_from(image_1d : np.ndarray, grid_radians : np.ndarray, uv_wavelengths : np.ndarray) -> np.ndarray: +def visibilities_direct_from( + image_1d: np.ndarray, grid_radians: np.ndarray, uv_wavelengths: np.ndarray +) -> np.ndarray: """ Compute complex visibilities from an input sky image using the Fourier transform, simulating the response of an astronomical radio interferometer. @@ -116,9 +118,13 @@ def visibilities_direct_from(image_1d : np.ndarray, grid_radians : np.ndarray, u (u, v) coordinate, representing the interferometer’s measurement. """ # Compute the dot product for each pixel-uv pair - phase = -2.0 * np.pi * ( - np.outer(grid_radians[:, 1], uv_wavelengths[:, 0]) + - np.outer(grid_radians[:, 0], uv_wavelengths[:, 1]) + phase = ( + -2.0 + * np.pi + * ( + np.outer(grid_radians[:, 1], uv_wavelengths[:, 0]) + + np.outer(grid_radians[:, 0], uv_wavelengths[:, 1]) + ) ) # shape (n_pixels, n_vis) # Multiply image values with phase terms @@ -131,14 +137,40 @@ def visibilities_direct_from(image_1d : np.ndarray, grid_radians : np.ndarray, u return visibilities -def image_via_jit_from(n_pixels, grid_radians, uv_wavelengths, visibilities): +def image_direct_from( + visibilities: np.ndarray, grid_radians: np.ndarray, uv_wavelengths: np.ndarray +) -> np.ndarray: + """ + Reconstruct a real-valued sky image from complex interferometric visibilities + using an inverse Fourier transform approximation. + + This function simulates the synthesis imaging equation of a radio interferometer + by summing sinusoidal components across all (u, v) spatial frequencies. + + Parameters + ---------- + visibilities + The real and imaginary parts of the complex visibilities for each (u, v) point. + + grid_radians + The angular (y, x) coordinates of each pixel in radians. - aaaa + uv_wavelengths + The (u, v) spatial frequencies in units of wavelengths for each baseline. + Returns + ------- + image_1d + The reconstructed real-valued image in sky coordinates. + """ # Compute the phase term for each (pixel, visibility) pair - phase = 2.0 * np.pi * ( - np.outer(grid_radians[:, 1], uv_wavelengths[:, 0]) + - np.outer(grid_radians[:, 0], uv_wavelengths[:, 1]) + phase = ( + 2.0 + * np.pi + * ( + np.outer(grid_radians[:, 1], uv_wavelengths[:, 0]) + + np.outer(grid_radians[:, 0], uv_wavelengths[:, 1]) + ) ) real_part = np.dot(np.cos(phase), visibilities[:, 0]) @@ -149,7 +181,6 @@ def image_via_jit_from(n_pixels, grid_radians, uv_wavelengths, visibilities): return image_1d - @numba_util.jit() def transformed_mapping_matrix_via_preload_jit_from( mapping_matrix, preloaded_reals, preloaded_imags diff --git a/test_autoarray/operators/test_transformer.py b/test_autoarray/operators/test_transformer.py index c2af5b082..1da779c9e 100644 --- a/test_autoarray/operators/test_transformer.py +++ b/test_autoarray/operators/test_transformer.py @@ -27,8 +27,6 @@ def test__dft__visibilities_from(visibilities_7, uv_wavelengths_7x2, mask_2d_7x7 visibilities = transformer.visibilities_from(image=image) - print(visibilities) - assert visibilities[0:3] == pytest.approx( np.array( [ @@ -162,7 +160,7 @@ def test__nufft__image_from(visibilities_7, uv_wavelengths_7x2, mask_2d_7x7): image = transformer.image_from(visibilities=visibilities_7) - assert image[0:3] == pytest.approx([0.00726546, 0.01149121, 0.01421022], 1.0e-4) + assert image[0:3] == pytest.approx([0.00726546, 0.01149121, 0.01421022], 1.0e-4) def test__nufft__transform_mapping_matrix(): From 7961446d39574005b389a0618eca9749369c47c4 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 17 Jun 2025 21:13:22 +0100 Subject: [PATCH 03/18] fix unitt est with shaping --- autoarray/operators/transformer_util.py | 1 + test_autoarray/operators/test_transformer.py | 13 ++++++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/autoarray/operators/transformer_util.py b/autoarray/operators/transformer_util.py index 90b0bd798..c4dab1124 100644 --- a/autoarray/operators/transformer_util.py +++ b/autoarray/operators/transformer_util.py @@ -117,6 +117,7 @@ def visibilities_direct_from( The complex visibilities (Fourier components) corresponding to each (u, v) coordinate, representing the interferometer’s measurement. """ + # Compute the dot product for each pixel-uv pair phase = ( -2.0 diff --git a/test_autoarray/operators/test_transformer.py b/test_autoarray/operators/test_transformer.py index 1da779c9e..d7c97c2f2 100644 --- a/test_autoarray/operators/test_transformer.py +++ b/test_autoarray/operators/test_transformer.py @@ -67,7 +67,18 @@ def test__dft__visibilities_from__preload_and_non_preload_give_same_answer( preload_transform=False, ) - image = aa.Array2D.no_mask([[2.0, 6.0]], pixel_scales=1.0) + image = aa.Array2D( + values=[ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.5, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.5, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.5, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + mask=mask_2d_7x7, + ) visibilities_via_preload = transformer_preload.visibilities_from(image=image) visibilities = transformer.visibilities_from(image=image) From 25a509612a8d3cfc4515c86613691e611a1708b2 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 18 Jun 2025 14:27:34 +0100 Subject: [PATCH 04/18] preload_real_transforms converted to numpy --- autoarray/operators/transformer_util.py | 35 +++++++++---------------- 1 file changed, 12 insertions(+), 23 deletions(-) diff --git a/autoarray/operators/transformer_util.py b/autoarray/operators/transformer_util.py index c4dab1124..26bb8325d 100644 --- a/autoarray/operators/transformer_util.py +++ b/autoarray/operators/transformer_util.py @@ -3,18 +3,16 @@ from autoarray import numba_util -@numba_util.jit() -def preload_real_transforms( - grid_radians: np.ndarray, uv_wavelengths: np.ndarray -) -> np.ndarray: + +def preload_real_transforms(grid_radians: np.ndarray, uv_wavelengths: np.ndarray) -> np.ndarray: """ - Sets up the real preloaded values used by the direct fourier transform (`TransformerDFT`) to speed up + Sets up the real preloaded values used by the direct Fourier transform (`TransformerDFT`) to speed up the Fourier transform calculations. The preloaded values are the cosine terms of every (y,x) radian coordinate on the real-space grid multiplied by - everu `uv_wavelength` value. + every `uv_wavelength` value. - For large numbers of visibilities (> 100000) this array requires large amounts of memory ( > 1 GB) and it is + For large numbers of visibilities (> 100000) this array requires large amounts of memory (> 1 GB) and it is recommended this preloading is not used. Parameters @@ -28,25 +26,16 @@ def preload_real_transforms( Returns ------- - np.ndarray - The preloaded values of the cosine terms in the calculation of real entries of the direct Fourier transform. - + The preloaded values of the cosine terms in the calculation of real entries of the direct Fourier transform. """ - - preloaded_real_transforms = np.zeros( - shape=(grid_radians.shape[0], uv_wavelengths.shape[0]) + # Compute the phase matrix: shape (n_pixels, n_visibilities) + phase = -2.0 * np.pi * ( + np.outer(grid_radians[:, 1], uv_wavelengths[:, 0]) + # y * u + np.outer(grid_radians[:, 0], uv_wavelengths[:, 1]) # x * v ) - for image_1d_index in range(grid_radians.shape[0]): - for vis_1d_index in range(uv_wavelengths.shape[0]): - preloaded_real_transforms[image_1d_index, vis_1d_index] += np.cos( - -2.0 - * np.pi - * ( - grid_radians[image_1d_index, 1] * uv_wavelengths[vis_1d_index, 0] - + grid_radians[image_1d_index, 0] * uv_wavelengths[vis_1d_index, 1] - ) - ) + # Compute cosine of the phase matrix + preloaded_real_transforms = np.cos(phase) return preloaded_real_transforms From 88748fd7d279226a8e4e8ead2c8684df79fae14b Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 18 Jun 2025 14:28:32 +0100 Subject: [PATCH 05/18] preload_imag_transforms --- autoarray/operators/transformer_util.py | 43 +++++++++++++++++-------- 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/autoarray/operators/transformer_util.py b/autoarray/operators/transformer_util.py index 26bb8325d..42f58d7fe 100644 --- a/autoarray/operators/transformer_util.py +++ b/autoarray/operators/transformer_util.py @@ -40,26 +40,41 @@ def preload_real_transforms(grid_radians: np.ndarray, uv_wavelengths: np.ndarray return preloaded_real_transforms -@numba_util.jit() -def preload_imag_transforms(grid_radians, uv_wavelengths): - preloaded_imag_transforms = np.zeros( - shape=(grid_radians.shape[0], uv_wavelengths.shape[0]) +def preload_imag_transforms(grid_radians: np.ndarray, uv_wavelengths: np.ndarray) -> np.ndarray: + """ + Sets up the imaginary preloaded values used by the direct Fourier transform (`TransformerDFT`) to speed up + the Fourier transform calculations in interferometric imaging. + + The preloaded values are the sine terms of every (y,x) radian coordinate on the real-space grid multiplied by + every `uv_wavelength` value. These are used to compute the imaginary components of visibilities. + + For large numbers of visibilities (> 100000), this array can require significant memory (> 1 GB), so preloading + should be used with care. + + Parameters + ---------- + grid_radians + The grid in radians corresponding to the (y,x) coordinates in real space. + uv_wavelengths + The (u,v) coordinates in the Fourier plane (in units of wavelengths). + + Returns + ------- + The sine term preloads used in imaginary-part DFT calculations. + """ + # Compute the phase matrix: shape (n_pixels, n_visibilities) + phase = -2.0 * np.pi * ( + np.outer(grid_radians[:, 1], uv_wavelengths[:, 0]) + # y * u + np.outer(grid_radians[:, 0], uv_wavelengths[:, 1]) # x * v ) - for image_1d_index in range(grid_radians.shape[0]): - for vis_1d_index in range(uv_wavelengths.shape[0]): - preloaded_imag_transforms[image_1d_index, vis_1d_index] += np.sin( - -2.0 - * np.pi - * ( - grid_radians[image_1d_index, 1] * uv_wavelengths[vis_1d_index, 0] - + grid_radians[image_1d_index, 0] * uv_wavelengths[vis_1d_index, 1] - ) - ) + # Compute sine of the phase matrix + preloaded_imag_transforms = np.sin(phase) return preloaded_imag_transforms + @numba_util.jit() def visibilities_via_preload_jit_from(image_1d, preloaded_reals, preloaded_imags): visibilities = 0 + 0j * np.zeros(shape=(preloaded_reals.shape[1])) From 4415121f49a7ffcb8ea217babaac722aec332c52 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 18 Jun 2025 14:39:31 +0100 Subject: [PATCH 06/18] rename functions --- autoarray/operators/transformer.py | 16 ++--- autoarray/operators/transformer_util.py | 74 +++++++++++++------- test_autoarray/operators/test_transformer.py | 2 +- 3 files changed, 58 insertions(+), 34 deletions(-) diff --git a/autoarray/operators/transformer.py b/autoarray/operators/transformer.py index 7bc302771..b668dd641 100644 --- a/autoarray/operators/transformer.py +++ b/autoarray/operators/transformer.py @@ -47,12 +47,12 @@ def __init__(self, uv_wavelengths, real_space_mask, preload_transform=True): self.preload_transform = preload_transform if preload_transform: - self.preload_real_transforms = transformer_util.preload_real_transforms( + self.preload_real_transforms = transformer_util.preload_real_transforms_from( grid_radians=np.array(self.grid.array), uv_wavelengths=self.uv_wavelengths, ) - self.preload_imag_transforms = transformer_util.preload_imag_transforms( + self.preload_imag_transforms = transformer_util.preload_imag_transforms_from( grid_radians=np.array(self.grid.array), uv_wavelengths=self.uv_wavelengths, ) @@ -78,14 +78,14 @@ def __init__(self, uv_wavelengths, real_space_mask, preload_transform=True): def visibilities_from(self, image): if self.preload_transform: - visibilities = transformer_util.visibilities_via_preload_jit_from( + visibilities = transformer_util.visibilities_via_preload_from( image_1d=np.array(image.array), - preloaded_reals=self.preload_real_transforms, - preloaded_imags=self.preload_imag_transforms, + preloaded_reals=self.preload_real_transforms_from, + preloaded_imags=self.preload_imag_transforms_from, ) else: - visibilities = transformer_util.visibilities_direct_from( + visibilities = transformer_util.visibilities_from( image_1d=np.array(image.slim.array), grid_radians=np.array(self.grid), uv_wavelengths=self.uv_wavelengths, @@ -111,8 +111,8 @@ def transform_mapping_matrix(self, mapping_matrix): if self.preload_transform: return transformer_util.transformed_mapping_matrix_via_preload_jit_from( mapping_matrix=mapping_matrix, - preloaded_reals=self.preload_real_transforms, - preloaded_imags=self.preload_imag_transforms, + preloaded_reals=self.preload_real_transforms_from, + preloaded_imags=self.preload_imag_transforms_from, ) else: diff --git a/autoarray/operators/transformer_util.py b/autoarray/operators/transformer_util.py index 42f58d7fe..26f29faeb 100644 --- a/autoarray/operators/transformer_util.py +++ b/autoarray/operators/transformer_util.py @@ -3,8 +3,9 @@ from autoarray import numba_util - -def preload_real_transforms(grid_radians: np.ndarray, uv_wavelengths: np.ndarray) -> np.ndarray: +def preload_real_transforms_from( + grid_radians: np.ndarray, uv_wavelengths: np.ndarray +) -> np.ndarray: """ Sets up the real preloaded values used by the direct Fourier transform (`TransformerDFT`) to speed up the Fourier transform calculations. @@ -29,9 +30,13 @@ def preload_real_transforms(grid_radians: np.ndarray, uv_wavelengths: np.ndarray The preloaded values of the cosine terms in the calculation of real entries of the direct Fourier transform. """ # Compute the phase matrix: shape (n_pixels, n_visibilities) - phase = -2.0 * np.pi * ( - np.outer(grid_radians[:, 1], uv_wavelengths[:, 0]) + # y * u - np.outer(grid_radians[:, 0], uv_wavelengths[:, 1]) # x * v + phase = ( + -2.0 + * np.pi + * ( + np.outer(grid_radians[:, 1], uv_wavelengths[:, 0]) # y * u + + np.outer(grid_radians[:, 0], uv_wavelengths[:, 1]) # x * v + ) ) # Compute cosine of the phase matrix @@ -40,7 +45,9 @@ def preload_real_transforms(grid_radians: np.ndarray, uv_wavelengths: np.ndarray return preloaded_real_transforms -def preload_imag_transforms(grid_radians: np.ndarray, uv_wavelengths: np.ndarray) -> np.ndarray: +def preload_imag_transforms_from( + grid_radians: np.ndarray, uv_wavelengths: np.ndarray +) -> np.ndarray: """ Sets up the imaginary preloaded values used by the direct Fourier transform (`TransformerDFT`) to speed up the Fourier transform calculations in interferometric imaging. @@ -63,9 +70,13 @@ def preload_imag_transforms(grid_radians: np.ndarray, uv_wavelengths: np.ndarray The sine term preloads used in imaginary-part DFT calculations. """ # Compute the phase matrix: shape (n_pixels, n_visibilities) - phase = -2.0 * np.pi * ( - np.outer(grid_radians[:, 1], uv_wavelengths[:, 0]) + # y * u - np.outer(grid_radians[:, 0], uv_wavelengths[:, 1]) # x * v + phase = ( + -2.0 + * np.pi + * ( + np.outer(grid_radians[:, 1], uv_wavelengths[:, 0]) # y * u + + np.outer(grid_radians[:, 0], uv_wavelengths[:, 1]) # x * v + ) ) # Compute sine of the phase matrix @@ -74,25 +85,40 @@ def preload_imag_transforms(grid_radians: np.ndarray, uv_wavelengths: np.ndarray return preloaded_imag_transforms +def visibilities_via_preload_from( + image_1d: np.ndarray, preloaded_reals: np.ndarray, preloaded_imags: np.ndarray +) -> np.ndarray: + """ + Computes interferometric visibilities using preloaded real and imaginary DFT transform components. -@numba_util.jit() -def visibilities_via_preload_jit_from(image_1d, preloaded_reals, preloaded_imags): - visibilities = 0 + 0j * np.zeros(shape=(preloaded_reals.shape[1])) - - for image_1d_index in range(image_1d.shape[0]): - for vis_1d_index in range(preloaded_reals.shape[1]): - vis_real = ( - image_1d[image_1d_index] * preloaded_reals[image_1d_index, vis_1d_index] - ) - vis_imag = ( - image_1d[image_1d_index] * preloaded_imags[image_1d_index, vis_1d_index] - ) - visibilities[vis_1d_index] += vis_real + 1j * vis_imag + This function performs a direct Fourier transform (DFT) using precomputed cosine (real) and sine (imaginary) + terms. It is used in radio astronomy to compute visibilities from an image for a given interferometric + observation setup. + + Parameters + ---------- + image_1d : ndarray of shape (n_pixels,) + The 1D image vector (real-space brightness values). + preloaded_reals : ndarray of shape (n_pixels, n_visibilities) + The preloaded cosine terms (real part of DFT matrix). + preloaded_imags : ndarray of shape (n_pixels, n_visibilities) + The preloaded sine terms (imaginary part of DFT matrix). + + Returns + ------- + visibilities : ndarray of shape (n_visibilities,) + The complex visibilities computed by summing over all pixels. + """ + # Perform the dot product between the image and preloaded transform matrices + vis_real = np.dot(image_1d, preloaded_reals) # shape (n_visibilities,) + vis_imag = np.dot(image_1d, preloaded_imags) # shape (n_visibilities,) + + visibilities = vis_real + 1j * vis_imag return visibilities -def visibilities_direct_from( +def visibilities_from( image_1d: np.ndarray, grid_radians: np.ndarray, uv_wavelengths: np.ndarray ) -> np.ndarray: """ @@ -107,10 +133,8 @@ def visibilities_direct_from( ---------- image_1d The 1D flattened sky brightness values corresponding to each pixel in the grid. - grid_radians The angular (y, x) positions of each image pixel in radians, matching image_1d. - uv_wavelengths The (u, v) spatial frequencies in units of wavelengths, for each baseline of the interferometer. diff --git a/test_autoarray/operators/test_transformer.py b/test_autoarray/operators/test_transformer.py index d7c97c2f2..8ac7f53a6 100644 --- a/test_autoarray/operators/test_transformer.py +++ b/test_autoarray/operators/test_transformer.py @@ -83,7 +83,7 @@ def test__dft__visibilities_from__preload_and_non_preload_give_same_answer( visibilities_via_preload = transformer_preload.visibilities_from(image=image) visibilities = transformer.visibilities_from(image=image) - assert (visibilities_via_preload == visibilities).all() + assert visibilities_via_preload == pytest.approx(visibilities, 1.0e-4) def test__dft__transform_mapping_matrix( From ca034e7123b7b17c1430b01fac68511d58fcda97 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 18 Jun 2025 14:48:14 +0100 Subject: [PATCH 07/18] transformed_mapping_matrix_from --- autoarray/operators/transformer.py | 6 +- autoarray/operators/transformer_util.py | 69 ++++++++++---------- test_autoarray/operators/test_transformer.py | 10 +-- 3 files changed, 43 insertions(+), 42 deletions(-) diff --git a/autoarray/operators/transformer.py b/autoarray/operators/transformer.py index b668dd641..956483510 100644 --- a/autoarray/operators/transformer.py +++ b/autoarray/operators/transformer.py @@ -111,12 +111,12 @@ def transform_mapping_matrix(self, mapping_matrix): if self.preload_transform: return transformer_util.transformed_mapping_matrix_via_preload_jit_from( mapping_matrix=mapping_matrix, - preloaded_reals=self.preload_real_transforms_from, - preloaded_imags=self.preload_imag_transforms_from, + preloaded_reals=self.preload_real_transforms, + preloaded_imags=self.preload_imag_transforms, ) else: - return transformer_util.transformed_mapping_matrix_jit( + return transformer_util.transformed_mapping_matrix_from( mapping_matrix=mapping_matrix, grid_radians=np.array(self.grid), uv_wavelengths=self.uv_wavelengths, diff --git a/autoarray/operators/transformer_util.py b/autoarray/operators/transformer_util.py index 26f29faeb..9591e7863 100644 --- a/autoarray/operators/transformer_util.py +++ b/autoarray/operators/transformer_util.py @@ -233,42 +233,43 @@ def transformed_mapping_matrix_via_preload_jit_from( return transfomed_mapping_matrix -@numba_util.jit() -def transformed_mapping_matrix_jit(mapping_matrix, grid_radians, uv_wavelengths): - transfomed_mapping_matrix = 0 + 0j * np.zeros( - (uv_wavelengths.shape[0], mapping_matrix.shape[1]) - ) +def transformed_mapping_matrix_from(mapping_matrix: np.ndarray, grid_radians: np.ndarray, uv_wavelengths: np.ndarray) -> np.ndarray: + """ + Computes the Fourier-transformed mapping matrix used in radio interferometric imaging. - for pixel_1d_index in range(mapping_matrix.shape[1]): - for image_1d_index in range(mapping_matrix.shape[0]): - value = mapping_matrix[image_1d_index, pixel_1d_index] + This function applies a direct Fourier transform to each pixel column of the mapping matrix using the + uv-wavelength coordinates. The result is a matrix that maps source pixel intensities to complex visibilities, + which represent how a model image would appear to an interferometer. - if value > 0: - for vis_1d_index in range(uv_wavelengths.shape[0]): - vis_real = value * np.cos( - -2.0 - * np.pi - * ( - grid_radians[image_1d_index, 1] - * uv_wavelengths[vis_1d_index, 0] - + grid_radians[image_1d_index, 0] - * uv_wavelengths[vis_1d_index, 1] - ) - ) + Parameters + ---------- + mapping_matrix : ndarray of shape (n_image_pixels, n_source_pixels) + The mapping matrix from image-plane pixels to source-plane pixels. + grid_radians : ndarray of shape (n_image_pixels, 2) + The (y,x) positions of each image pixel in radians. + uv_wavelengths : ndarray of shape (n_visibilities, 2) + The (u,v) coordinates of the sampled Fourier modes in units of wavelength. - vis_imag = value * np.sin( - -2.0 - * np.pi - * ( - grid_radians[image_1d_index, 1] - * uv_wavelengths[vis_1d_index, 0] - + grid_radians[image_1d_index, 0] - * uv_wavelengths[vis_1d_index, 1] - ) - ) + Returns + ------- + transformed_matrix : ndarray of shape (n_visibilities, n_source_pixels) + The transformed mapping matrix in the visibility domain (complex-valued). + """ + # Compute phase term: (n_image_pixels, n_visibilities) + phase = -2.0 * np.pi * ( + np.outer(grid_radians[:, 1], uv_wavelengths[:, 0]) + # y * u + np.outer(grid_radians[:, 0], uv_wavelengths[:, 1]) # x * v + ) - transfomed_mapping_matrix[vis_1d_index, pixel_1d_index] += ( - vis_real + 1j * vis_imag - ) + # Compute real and imaginary Fourier matrices + fourier_real = np.cos(phase) + fourier_imag = np.sin(phase) - return transfomed_mapping_matrix + # Only compute contributions from non-zero mapping entries + # This matrix multiplication is: (n_visibilities x n_image_pixels) dot (n_image_pixels x n_source_pixels) + vis_real = fourier_real.T @ mapping_matrix # (n_vis, n_src) + vis_imag = fourier_imag.T @ mapping_matrix # (n_vis, n_src) + + transformed_matrix = vis_real + 1j * vis_imag + + return transformed_matrix diff --git a/test_autoarray/operators/test_transformer.py b/test_autoarray/operators/test_transformer.py index 8ac7f53a6..e678772a2 100644 --- a/test_autoarray/operators/test_transformer.py +++ b/test_autoarray/operators/test_transformer.py @@ -96,7 +96,7 @@ def test__dft__transform_mapping_matrix( preload_transform=False, ) - mapping_matrix = np.ones(shape=(1, 1)) + mapping_matrix = np.ones(shape=(9, 1)) transformed_mapping_matrix = transformer.transform_mapping_matrix( mapping_matrix=mapping_matrix @@ -105,12 +105,12 @@ def test__dft__transform_mapping_matrix( assert transformed_mapping_matrix[0:3, :] == pytest.approx( np.array( [ - [0.80682556 - 0.59078974j], - [-0.19648896 - 0.98050604j], - [-0.47002763 - 0.8826517j], + [1.48496084+0.00000000e+00j], + [3.02988906+4.44089210e-16], + [0.86395556+0.00000000e+00], ] ), - 1.0e-4, + abs=1.0e-4, ) From b49e92022d15906089b5237e851938eba059c961 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 18 Jun 2025 14:52:04 +0100 Subject: [PATCH 08/18] transformed_mapping_matrix_via_preload_from --- autoarray/operators/transformer.py | 22 ++++--- autoarray/operators/transformer_util.py | 62 +++++++++++++------- test_autoarray/operators/test_transformer.py | 8 +-- 3 files changed, 57 insertions(+), 35 deletions(-) diff --git a/autoarray/operators/transformer.py b/autoarray/operators/transformer.py index 956483510..25565544c 100644 --- a/autoarray/operators/transformer.py +++ b/autoarray/operators/transformer.py @@ -47,14 +47,18 @@ def __init__(self, uv_wavelengths, real_space_mask, preload_transform=True): self.preload_transform = preload_transform if preload_transform: - self.preload_real_transforms = transformer_util.preload_real_transforms_from( - grid_radians=np.array(self.grid.array), - uv_wavelengths=self.uv_wavelengths, + self.preload_real_transforms = ( + transformer_util.preload_real_transforms_from( + grid_radians=np.array(self.grid.array), + uv_wavelengths=self.uv_wavelengths, + ) ) - self.preload_imag_transforms = transformer_util.preload_imag_transforms_from( - grid_radians=np.array(self.grid.array), - uv_wavelengths=self.uv_wavelengths, + self.preload_imag_transforms = ( + transformer_util.preload_imag_transforms_from( + grid_radians=np.array(self.grid.array), + uv_wavelengths=self.uv_wavelengths, + ) ) self.real_space_pixels = self.real_space_mask.pixels_in_mask @@ -80,8 +84,8 @@ def visibilities_from(self, image): if self.preload_transform: visibilities = transformer_util.visibilities_via_preload_from( image_1d=np.array(image.array), - preloaded_reals=self.preload_real_transforms_from, - preloaded_imags=self.preload_imag_transforms_from, + preloaded_reals=self.preload_real_transforms, + preloaded_imags=self.preload_imag_transforms, ) else: @@ -109,7 +113,7 @@ def image_from(self, visibilities, use_adjoint_scaling: bool = False): def transform_mapping_matrix(self, mapping_matrix): if self.preload_transform: - return transformer_util.transformed_mapping_matrix_via_preload_jit_from( + return transformer_util.transformed_mapping_matrix_via_preload_from( mapping_matrix=mapping_matrix, preloaded_reals=self.preload_real_transforms, preloaded_imags=self.preload_imag_transforms, diff --git a/autoarray/operators/transformer_util.py b/autoarray/operators/transformer_util.py index 9591e7863..009f44617 100644 --- a/autoarray/operators/transformer_util.py +++ b/autoarray/operators/transformer_util.py @@ -210,30 +210,44 @@ def image_direct_from( return image_1d -@numba_util.jit() -def transformed_mapping_matrix_via_preload_jit_from( - mapping_matrix, preloaded_reals, preloaded_imags -): - transfomed_mapping_matrix = 0 + 0j * np.zeros( - (preloaded_reals.shape[1], mapping_matrix.shape[1]) - ) +def transformed_mapping_matrix_via_preload_from( + mapping_matrix: np.ndarray, preloaded_reals: np.ndarray, preloaded_imags: np.ndarray +) -> np.ndarray: + """ + Computes the Fourier-transformed mapping matrix using preloaded sine and cosine terms for efficiency. + + This function transforms each source pixel's mapping to visibilities by using precomputed + real (cosine) and imaginary (sine) terms from the direct Fourier transform. + It is used in radio interferometric imaging where source-to-image mappings are projected + into the visibility space. + + Parameters + ---------- + mapping_matrix + The mapping matrix from image-plane pixels to source-plane pixels. + preloaded_reals + Precomputed cosine terms for each pixel-vis pair: cos(-2π(yu + xv)). + preloaded_imags + Precomputed sine terms for each pixel-vis pair: sin(-2π(yu + xv)). - for pixel_1d_index in range(mapping_matrix.shape[1]): - for image_1d_index in range(mapping_matrix.shape[0]): - value = mapping_matrix[image_1d_index, pixel_1d_index] + Returns + ------- + Complex-valued matrix mapping source pixels to visibilities. + """ - if value > 0: - for vis_1d_index in range(preloaded_reals.shape[1]): - vis_real = value * preloaded_reals[image_1d_index, vis_1d_index] - vis_imag = value * preloaded_imags[image_1d_index, vis_1d_index] - transfomed_mapping_matrix[vis_1d_index, pixel_1d_index] += ( - vis_real + 1j * vis_imag - ) + # Broadcasted multiplication and matrix multiplication over non-zero entries - return transfomed_mapping_matrix + vis_real = preloaded_reals.T @ mapping_matrix # (n_visibilities, n_source_pixels) + vis_imag = preloaded_imags.T @ mapping_matrix + transformed_matrix = vis_real + 1j * vis_imag -def transformed_mapping_matrix_from(mapping_matrix: np.ndarray, grid_radians: np.ndarray, uv_wavelengths: np.ndarray) -> np.ndarray: + return transformed_matrix + + +def transformed_mapping_matrix_from( + mapping_matrix: np.ndarray, grid_radians: np.ndarray, uv_wavelengths: np.ndarray +) -> np.ndarray: """ Computes the Fourier-transformed mapping matrix used in radio interferometric imaging. @@ -256,9 +270,13 @@ def transformed_mapping_matrix_from(mapping_matrix: np.ndarray, grid_radians: np The transformed mapping matrix in the visibility domain (complex-valued). """ # Compute phase term: (n_image_pixels, n_visibilities) - phase = -2.0 * np.pi * ( - np.outer(grid_radians[:, 1], uv_wavelengths[:, 0]) + # y * u - np.outer(grid_radians[:, 0], uv_wavelengths[:, 1]) # x * v + phase = ( + -2.0 + * np.pi + * ( + np.outer(grid_radians[:, 1], uv_wavelengths[:, 0]) # y * u + + np.outer(grid_radians[:, 0], uv_wavelengths[:, 1]) # x * v + ) ) # Compute real and imaginary Fourier matrices diff --git a/test_autoarray/operators/test_transformer.py b/test_autoarray/operators/test_transformer.py index e678772a2..c4209e9fb 100644 --- a/test_autoarray/operators/test_transformer.py +++ b/test_autoarray/operators/test_transformer.py @@ -105,9 +105,9 @@ def test__dft__transform_mapping_matrix( assert transformed_mapping_matrix[0:3, :] == pytest.approx( np.array( [ - [1.48496084+0.00000000e+00j], - [3.02988906+4.44089210e-16], - [0.86395556+0.00000000e+00], + [1.48496084 + 0.00000000e00j], + [3.02988906 + 4.44089210e-16], + [0.86395556 + 0.00000000e00], ] ), abs=1.0e-4, @@ -130,7 +130,7 @@ def test__dft__transformed_mapping_matrix__preload_and_non_preload_give_same_ans preload_transform=False, ) - mapping_matrix = np.array([[3.0, 5.0], [1.0, 2.0]]) + mapping_matrix = np.ones(shape=(9, 1)) transformed_mapping_matrix_preload = transformer_preload.transform_mapping_matrix( mapping_matrix=mapping_matrix From 19c14a4d2096dbf82bab093c16ad6e5a4330b29e Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 18 Jun 2025 14:53:09 +0100 Subject: [PATCH 09/18] simplify Transformer --- autoarray/operators/transformer.py | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/autoarray/operators/transformer.py b/autoarray/operators/transformer.py index 25565544c..ca4030367 100644 --- a/autoarray/operators/transformer.py +++ b/autoarray/operators/transformer.py @@ -63,23 +63,11 @@ def __init__(self, uv_wavelengths, real_space_mask, preload_transform=True): self.real_space_pixels = self.real_space_mask.pixels_in_mask - self.shape = ( - int(np.prod(self.total_visibilities)), - int(np.prod(self.real_space_pixels)), - ) - self.dtype = "complex128" - self.explicit = False - # NOTE: This is the scaling factor that needs to be applied to the adjoint operator self.adjoint_scaling = (2.0 * self.grid.shape_native[0]) * ( 2.0 * self.grid.shape_native[1] ) - self.matvec_count = 0 - self.rmatvec_count = 0 - self.matmat_count = 0 - self.rmatmat_count = 0 - def visibilities_from(self, image): if self.preload_transform: visibilities = transformer_util.visibilities_via_preload_from( @@ -167,26 +155,11 @@ def __init__(self, uv_wavelengths, real_space_mask): # NOTE: If reshaped the shape of the operator is (2 x Nvis, Np) else it is (Nvis, Np) self.total_visibilities = int(uv_wavelengths.shape[0] * uv_wavelengths.shape[1]) - self.shape = ( - int(np.prod(self.total_visibilities)), - int(np.prod(self.real_space_pixels)), - ) - - # NOTE: If the operator is reshaped then the output is real. - self.dtype = "float64" - - self.explicit = False - # NOTE: This is the scaling factor that needs to be applied to the adjoint operator self.adjoint_scaling = (2.0 * self.grid.shape_native[0]) * ( 2.0 * self.grid.shape_native[1] ) - self.matvec_count = 0 - self.rmatvec_count = 0 - self.matmat_count = 0 - self.rmatmat_count = 0 - def initialize_plan(self, ratio=2, interp_kernel=(6, 6)): if not isinstance(ratio, int): ratio = int(ratio) From 4c41eb1426474d5c1a6339af257bed441202959d Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 18 Jun 2025 14:56:28 +0100 Subject: [PATCH 10/18] TransformerDFT docstring --- autoarray/operators/transformer.py | 47 ++++++++++++++++++++++++++---- 1 file changed, 42 insertions(+), 5 deletions(-) diff --git a/autoarray/operators/transformer.py b/autoarray/operators/transformer.py index ca4030367..46ebb004f 100644 --- a/autoarray/operators/transformer.py +++ b/autoarray/operators/transformer.py @@ -14,6 +14,7 @@ class NUFFTPlaceholder: NUFFT_cpu = NUFFTPlaceholder +from autoarray.mask.mask_2d import Mask2D from autoarray.structures.arrays.uniform_2d import Array2D from autoarray.structures.grids.uniform_2d import Grid2D from autoarray.structures.visibilities import Visibilities @@ -33,8 +34,44 @@ def pynufft_exception(): class TransformerDFT: - def __init__(self, uv_wavelengths, real_space_mask, preload_transform=True): - + def __init__(self, uv_wavelengths : np.ndarray, real_space_mask : Mask2D, preload_transform : bool = True): + """ + A direct Fourier transform (DFT) operator for radio interferometric imaging. + + This class performs the forward and inverse mapping between real-space images and + complex visibilities measured by an interferometer. It uses a direct implementation + of the Fourier transform (not FFT-based), making it suitable for irregular uv-coverage. + + Optionally, it precomputes and stores the sine and cosine terms used in the transform, + which can significantly improve performance for repeated operations but at the cost of memory. + + Parameters + ---------- + uv_wavelengths + The (u, v) coordinates in wavelengths of the measured visibilities. + real_space_mask + The real-space mask that defines the image grid and which pixels are valid. + preload_transform + If True, precomputes and stores the cosine and sine terms for the Fourier transform. + This accelerates repeated transforms but consumes additional memory (~1GB+ for large datasets). + + Attributes + ---------- + grid : ndarray + The unmasked real-space grid in radians. + total_visibilities : int + The number of measured visibilities. + total_image_pixels : int + The number of unmasked pixels in the real-space image grid. + preload_real_transforms : ndarray, optional + The precomputed cosine terms used in the real part of the DFT. + preload_imag_transforms : ndarray, optional + The precomputed sine terms used in the imaginary part of the DFT. + real_space_pixels : int + Alias for `total_image_pixels`. + adjoint_scaling : float + Scaling factor applied to the adjoint operator to normalize the inverse transform. + """ super().__init__() self.uv_wavelengths = uv_wavelengths.astype("float") @@ -68,7 +105,7 @@ def __init__(self, uv_wavelengths, real_space_mask, preload_transform=True): 2.0 * self.grid.shape_native[1] ) - def visibilities_from(self, image): + def visibilities_from(self, image : Array2D) -> Visibilities: if self.preload_transform: visibilities = transformer_util.visibilities_via_preload_from( image_1d=np.array(image.array), @@ -85,7 +122,7 @@ def visibilities_from(self, image): return Visibilities(visibilities=visibilities) - def image_from(self, visibilities, use_adjoint_scaling: bool = False): + def image_from(self, visibilities : Visibilities, use_adjoint_scaling: bool = False) -> Array2D: image_slim = transformer_util.image_direct_from( visibilities=visibilities.in_array, grid_radians=np.array(self.grid.array), @@ -99,7 +136,7 @@ def image_from(self, visibilities, use_adjoint_scaling: bool = False): return Array2D(values=image_native, mask=self.real_space_mask) - def transform_mapping_matrix(self, mapping_matrix): + def transform_mapping_matrix(self, mapping_matrix : np.ndarray) -> np.ndarray: if self.preload_transform: return transformer_util.transformed_mapping_matrix_via_preload_from( mapping_matrix=mapping_matrix, From 5500b225f9a01bddfd55b47b900586b30410cb62 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 18 Jun 2025 15:17:20 +0100 Subject: [PATCH 11/18] transformer now uses jax arrays in DFT --- autoarray/operators/transformer.py | 286 ++++++++++++++++++++---- autoarray/operators/transformer_util.py | 3 +- 2 files changed, 239 insertions(+), 50 deletions(-) diff --git a/autoarray/operators/transformer.py b/autoarray/operators/transformer.py index 46ebb004f..3c5eeae33 100644 --- a/autoarray/operators/transformer.py +++ b/autoarray/operators/transformer.py @@ -2,6 +2,7 @@ import copy import numpy as np import warnings +from typing import Tuple class NUFFTPlaceholder: @@ -34,44 +35,49 @@ def pynufft_exception(): class TransformerDFT: - def __init__(self, uv_wavelengths : np.ndarray, real_space_mask : Mask2D, preload_transform : bool = True): + def __init__( + self, + uv_wavelengths: np.ndarray, + real_space_mask: Mask2D, + preload_transform: bool = True, + ): + """ + A direct Fourier transform (DFT) operator for radio interferometric imaging. + + This class performs the forward and inverse mapping between real-space images and + complex visibilities measured by an interferometer. It uses a direct implementation + of the Fourier transform (not FFT-based), making it suitable for irregular uv-coverage. + + Optionally, it precomputes and stores the sine and cosine terms used in the transform, + which can significantly improve performance for repeated operations but at the cost of memory. + + Parameters + ---------- + uv_wavelengths + The (u, v) coordinates in wavelengths of the measured visibilities. + real_space_mask + The real-space mask that defines the image grid and which pixels are valid. + preload_transform + If True, precomputes and stores the cosine and sine terms for the Fourier transform. + This accelerates repeated transforms but consumes additional memory (~1GB+ for large datasets). + + Attributes + ---------- + grid : ndarray + The unmasked real-space grid in radians. + total_visibilities : int + The number of measured visibilities. + total_image_pixels : int + The number of unmasked pixels in the real-space image grid. + preload_real_transforms : ndarray, optional + The precomputed cosine terms used in the real part of the DFT. + preload_imag_transforms : ndarray, optional + The precomputed sine terms used in the imaginary part of the DFT. + real_space_pixels : int + Alias for `total_image_pixels`. + adjoint_scaling : float + Scaling factor applied to the adjoint operator to normalize the inverse transform. """ - A direct Fourier transform (DFT) operator for radio interferometric imaging. - - This class performs the forward and inverse mapping between real-space images and - complex visibilities measured by an interferometer. It uses a direct implementation - of the Fourier transform (not FFT-based), making it suitable for irregular uv-coverage. - - Optionally, it precomputes and stores the sine and cosine terms used in the transform, - which can significantly improve performance for repeated operations but at the cost of memory. - - Parameters - ---------- - uv_wavelengths - The (u, v) coordinates in wavelengths of the measured visibilities. - real_space_mask - The real-space mask that defines the image grid and which pixels are valid. - preload_transform - If True, precomputes and stores the cosine and sine terms for the Fourier transform. - This accelerates repeated transforms but consumes additional memory (~1GB+ for large datasets). - - Attributes - ---------- - grid : ndarray - The unmasked real-space grid in radians. - total_visibilities : int - The number of measured visibilities. - total_image_pixels : int - The number of unmasked pixels in the real-space image grid. - preload_real_transforms : ndarray, optional - The precomputed cosine terms used in the real part of the DFT. - preload_imag_transforms : ndarray, optional - The precomputed sine terms used in the imaginary part of the DFT. - real_space_pixels : int - Alias for `total_image_pixels`. - adjoint_scaling : float - Scaling factor applied to the adjoint operator to normalize the inverse transform. - """ super().__init__() self.uv_wavelengths = uv_wavelengths.astype("float") @@ -105,24 +111,63 @@ def __init__(self, uv_wavelengths : np.ndarray, real_space_mask : Mask2D, preloa 2.0 * self.grid.shape_native[1] ) - def visibilities_from(self, image : Array2D) -> Visibilities: + def visibilities_from(self, image: Array2D) -> Visibilities: + """ + Computes the visibilities from a real-space image using the direct Fourier transform (DFT). + + This method transforms the input image into the uv-plane (Fourier space), simulating the + measurements made by an interferometer at specified uv-wavelengths. + + If `preload_transform` is True, it uses precomputed sine and cosine terms to accelerate the computation. + + Parameters + ---------- + image + The real-space image to be transformed to the uv-plane. Must be defined on the + same grid and mask as this transformer's `real_space_mask`. + + Returns + ------- + The complex visibilities resulting from the Fourier transform of the input image. + """ if self.preload_transform: visibilities = transformer_util.visibilities_via_preload_from( - image_1d=np.array(image.array), + image_1d=image.array, preloaded_reals=self.preload_real_transforms, preloaded_imags=self.preload_imag_transforms, ) else: visibilities = transformer_util.visibilities_from( - image_1d=np.array(image.slim.array), + image_1d=image.slim.array, grid_radians=np.array(self.grid), uv_wavelengths=self.uv_wavelengths, ) return Visibilities(visibilities=visibilities) - def image_from(self, visibilities : Visibilities, use_adjoint_scaling: bool = False) -> Array2D: + def image_from( + self, visibilities: Visibilities, use_adjoint_scaling: bool = False + ) -> Array2D: + """ + Computes the real-space image from a set of visibilities using the adjoint of the DFT. + + This is not a true inverse Fourier transform, but rather the adjoint operation, which maps + complex visibilities back into image space. This is typically used as the first step + in inverse imaging algorithms like CLEAN or regularized reconstruction. + + Parameters + ---------- + visibilities + The complex visibilities to be transformed into a real-space image. + use_adjoint_scaling + If True, the result is scaled by a normalization factor. Currently unused. + + Returns + ------- + The real-space image resulting from the adjoint DFT operation, defined on the same + mask as this transformer's `real_space_mask`. + """ image_slim = transformer_util.image_direct_from( visibilities=visibilities.in_array, grid_radians=np.array(self.grid.array), @@ -136,7 +181,26 @@ def image_from(self, visibilities : Visibilities, use_adjoint_scaling: bool = Fa return Array2D(values=image_native, mask=self.real_space_mask) - def transform_mapping_matrix(self, mapping_matrix : np.ndarray) -> np.ndarray: + def transform_mapping_matrix(self, mapping_matrix: np.ndarray) -> np.ndarray: + """ + Applies the DFT to a mapping matrix that maps source pixels to image pixels. + + This is used in linear inversion frameworks, where the transform of each source basis function + (represented by a column of the mapping matrix) is computed individually. The result is a matrix + mapping source pixels directly to visibilities. + + If `preload_transform` is True, the computation is accelerated using precomputed sine and cosine terms. + + Parameters + ---------- + mapping_matrix + A 2D array of shape (n_image_pixels, n_source_pixels) that maps source pixels to image-plane pixels. + + Returns + ------- + A 2D complex-valued array of shape (n_visibilities, n_source_pixels) that maps source-plane basis + functions directly to the visibilities. + """ if self.preload_transform: return transformer_util.transformed_mapping_matrix_via_preload_from( mapping_matrix=mapping_matrix, @@ -153,7 +217,51 @@ def transform_mapping_matrix(self, mapping_matrix : np.ndarray) -> np.ndarray: class TransformerNUFFT(NUFFT_cpu): - def __init__(self, uv_wavelengths, real_space_mask): + def __init__(self, uv_wavelengths: np.ndarray, real_space_mask: Mask2D): + """ + Performs the Non-Uniform Fast Fourier Transform (NUFFT) for interferometric image reconstruction. + + This transformer uses the PyNUFFT library to efficiently compute the Fourier transform + of an image defined on a regular real-space grid to a set of non-uniform uv-plane (Fourier space) + coordinates, as is typical in radio interferometry. + + It is initialized with the interferometer uv-wavelengths and a real-space mask, which defines + the pixelized image domain. + + Parameters + ---------- + uv_wavelengths + The uv-coordinates (Fourier-space sampling points) corresponding to the measured visibilities. + Should be an array of shape (n_vis, 2), where the two columns represent u and v coordinates in wavelengths. + + real_space_mask + The 2D mask defining the real-space pixel grid on which the image is defined. Used to create the + unmasked grid required for NUFFT planning. + + Notes + ----- + - The `initialize_plan()` method builds the internal NUFFT plan based on the input grid and uv sampling. + - A complex exponential `shift` factor is applied to align the center of the Fourier transform correctly, + accounting for the pixel-center offset in the real-space grid. + - The adjoint operation (used in inverse imaging) must be scaled by `adjoint_scaling` to normalize its output. + - This transformer inherits directly from PyNUFFT's `NUFFT_cpu` base class. + - If `NUFFTPlaceholder` is detected (indicating PyNUFFT is not available), an exception is raised. + + Attributes + ---------- + grid : Grid2D + The real-space pixel grid derived from the mask, in radians. + native_index_for_slim_index : np.ndarray + Index map converting from slim (1D) grid to native (2D) indexing, for image reshaping. + shift : np.ndarray + Complex exponential phase shift applied to account for real-space pixel centering. + real_space_pixels : int + Total number of valid real-space pixels defined by the mask. + total_visibilities : int + Total number of visibilities across all uv-wavelength components. + adjoint_scaling : float + Scaling factor for adjoint operations to normalize reconstructed images. + """ if isinstance(self, NUFFTPlaceholder): pynufft_exception() @@ -197,7 +305,35 @@ def __init__(self, uv_wavelengths, real_space_mask): 2.0 * self.grid.shape_native[1] ) - def initialize_plan(self, ratio=2, interp_kernel=(6, 6)): + def initialize_plan(self, ratio: int = 2, interp_kernel: Tuple[int, int] = (6, 6)): + """ + Initializes the PyNUFFT plan for performing the NUFFT operation. + + This method precomputes the interpolation structure and gridding + needed by the NUFFT algorithm to map between the regular real-space + image grid and the non-uniform uv-plane sampling defined by the + interferometric visibilities. + + Parameters + ---------- + ratio + The oversampling ratio used to pad the Fourier grid before interpolation. + A higher value improves accuracy at the cost of increased memory and computation. + Default is 2 (i.e., the Fourier grid is twice the size of the image grid). + + interp_kernel + The interpolation kernel size along each axis, given as (Jy, Jx). + This determines how many neighboring Fourier grid points are used + to interpolate each uv-point. + Default is (6, 6), a good trade-off between accuracy and performance. + + Notes + ----- + - The uv-coordinates are normalized and rescaled into the range expected by PyNUFFT + using the real-space grid’s pixel scale and the Nyquist frequency limit. + - The plan must be initialized before performing any NUFFT operations (e.g., forward or adjoint). + - This method modifies the internal state of the NUFFT object by calling `self.plan(...)`. + """ if not isinstance(ratio, int): ratio = int(ratio) @@ -221,9 +357,23 @@ def initialize_plan(self, ratio=2, interp_kernel=(6, 6)): Jd=interp_kernel, ) - def visibilities_from(self, image): + def visibilities_from(self, image: Array2D) -> Visibilities: """ - ... + Computes visibilities from a real-space image using the NUFFT forward transform. + + Parameters + ---------- + image + The input image in real space, represented as a 2D array object. + + Returns + ------- + The complex visibilities in the uv-plane computed via the NUFFT forward operation. + + Notes + ----- + - The image is flipped vertically before transformation to account for PyNUFFT’s internal data layout. + - Warnings during the NUFFT computation are suppressed for cleaner output. """ warnings.filterwarnings("ignore") @@ -234,7 +384,29 @@ def visibilities_from(self, image): ) # flip due to PyNUFFT internal flip ) - def image_from(self, visibilities, use_adjoint_scaling: bool = False): + def image_from( + self, visibilities: Visibilities, use_adjoint_scaling: bool = False + ) -> Array2D: + """ + Reconstructs a real-space image from visibilities using the NUFFT adjoint transform. + + Parameters + ---------- + visibilities + The complex visibilities in the uv-plane to be inverted. + use_adjoint_scaling + If True, apply a scaling factor to the adjoint result to improve accuracy. + Default is False. + + Returns + ------- + The reconstructed real-space image after applying the NUFFT adjoint transform. + + Notes + ----- + - The output image is flipped vertically to align with the input image orientation. + - Warnings during the adjoint operation are suppressed. + """ with warnings.catch_warnings(): warnings.simplefilter("ignore") image = np.real(self.adjoint(visibilities))[::-1, :] @@ -244,7 +416,25 @@ def image_from(self, visibilities, use_adjoint_scaling: bool = False): return Array2D(values=image, mask=self.real_space_mask) - def transform_mapping_matrix(self, mapping_matrix): + def transform_mapping_matrix(self, mapping_matrix: np.ndarray) -> np.ndarray: + """ + Applies the NUFFT forward transform to each column of a mapping matrix, producing transformed visibilities. + + Parameters + ---------- + mapping_matrix : np.ndarray + A 2D array where each column corresponds to a source-plane pixel intensity distribution flattened into image space. + + Returns + ------- + np.ndarray + A complex-valued 2D array where each column contains the visibilities corresponding to the respective column in the input mapping matrix. + + Notes + ----- + - Each column of the input mapping matrix is reshaped into the native 2D image grid before transformation. + - This method repeatedly calls `visibilities_from` for each column, which may be computationally intensive. + """ transformed_mapping_matrix = 0 + 0j * np.zeros( (self.uv_wavelengths.shape[0], mapping_matrix.shape[1]) ) diff --git a/autoarray/operators/transformer_util.py b/autoarray/operators/transformer_util.py index 009f44617..7bd30b7da 100644 --- a/autoarray/operators/transformer_util.py +++ b/autoarray/operators/transformer_util.py @@ -1,7 +1,6 @@ +import jax.numpy as jnp import numpy as np -from autoarray import numba_util - def preload_real_transforms_from( grid_radians: np.ndarray, uv_wavelengths: np.ndarray From 93af09fec4b1c142327d92593ea909bfce8e1330 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 18 Jun 2025 15:24:45 +0100 Subject: [PATCH 12/18] remove ordered_1d --- autoarray/operators/transformer.py | 3 ++- autoarray/operators/transformer_util.py | 4 ++-- autoarray/structures/visibilities.py | 24 ------------------- .../structures/test_visibilities.py | 6 ----- 4 files changed, 4 insertions(+), 33 deletions(-) diff --git a/autoarray/operators/transformer.py b/autoarray/operators/transformer.py index 3c5eeae33..dbb01a58d 100644 --- a/autoarray/operators/transformer.py +++ b/autoarray/operators/transformer.py @@ -1,5 +1,6 @@ from astropy import units import copy +import jax.numpy as jnp import numpy as np import warnings from typing import Tuple @@ -144,7 +145,7 @@ def visibilities_from(self, image: Array2D) -> Visibilities: uv_wavelengths=self.uv_wavelengths, ) - return Visibilities(visibilities=visibilities) + return Visibilities(visibilities=jnp.array(visibilities)) def image_from( self, visibilities: Visibilities, use_adjoint_scaling: bool = False diff --git a/autoarray/operators/transformer_util.py b/autoarray/operators/transformer_util.py index 7bd30b7da..34659510a 100644 --- a/autoarray/operators/transformer_util.py +++ b/autoarray/operators/transformer_util.py @@ -109,8 +109,8 @@ def visibilities_via_preload_from( The complex visibilities computed by summing over all pixels. """ # Perform the dot product between the image and preloaded transform matrices - vis_real = np.dot(image_1d, preloaded_reals) # shape (n_visibilities,) - vis_imag = np.dot(image_1d, preloaded_imags) # shape (n_visibilities,) + vis_real = jnp.dot(image_1d, preloaded_reals) # shape (n_visibilities,) + vis_imag = jnp.dot(image_1d, preloaded_imags) # shape (n_visibilities,) visibilities = vis_real + 1j * vis_imag diff --git a/autoarray/structures/visibilities.py b/autoarray/structures/visibilities.py index 8cc94dca5..7557d39c4 100644 --- a/autoarray/structures/visibilities.py +++ b/autoarray/structures/visibilities.py @@ -50,16 +50,8 @@ def __init__(self, visibilities: Union[np.ndarray, List[complex]]): .ravel() ) - self.ordered_1d = np.concatenate( - (np.real(visibilities), np.imag(visibilities)), axis=0 - ) - super().__init__(array=visibilities) - def __array_finalize__(self, obj): - if hasattr(obj, "ordered_1d"): - self.ordered_1d = obj.ordered_1d - @property def slim(self) -> "AbstractVisibilities": return self @@ -232,20 +224,4 @@ def __init__(self, visibilities: Union[np.ndarray, List[complex]], *args, **kwar .ravel() ) - self.ordered_1d = np.concatenate( - (np.real(visibilities), np.imag(visibilities)), axis=0 - ) super().__init__(visibilities=visibilities) - - weight_list = 1.0 / self.in_array**2.0 - - self.weight_list_ordered_1d = np.concatenate( - (weight_list[:, 0], weight_list[:, 1]), axis=0 - ) - - def __array_finalize__(self, obj): - if hasattr(obj, "ordered_1d"): - self.ordered_1d = obj.ordered_1d - - if hasattr(obj, "weight_list_ordered_1d"): - self.weight_list_ordered_1d = obj.weight_list_ordered_1d diff --git a/test_autoarray/structures/test_visibilities.py b/test_autoarray/structures/test_visibilities.py index f029ef3ed..82c6bddf0 100644 --- a/test_autoarray/structures/test_visibilities.py +++ b/test_autoarray/structures/test_visibilities.py @@ -16,7 +16,6 @@ def test__manual__makes_visibilities_without_other_inputs(): assert type(visibilities) == vis.Visibilities assert (visibilities.slim == np.array([1.0 + 2.0j, 3.0 + 4.0j])).all() assert (visibilities.in_array == np.array([[1.0, 2.0], [3.0, 4.0]])).all() - assert (visibilities.ordered_1d == np.array([1.0, 3.0, 2.0, 4.0])).all() assert (visibilities.amplitudes == np.array([np.sqrt(5), 5.0])).all() assert visibilities.phases == pytest.approx( np.array([1.10714872, 0.92729522]), 1.0e-4 @@ -29,7 +28,6 @@ def test__manual__makes_visibilities_without_other_inputs(): assert ( visibilities.in_array == np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) ).all() - assert (visibilities.ordered_1d == np.array([1.0, 3.0, 5.0, 2.0, 4.0, 6.0])).all() def test__manual__makes_visibilities_with_converted_input_as_list(): @@ -121,7 +119,3 @@ def test__visibilities_noise_has_attributes(): assert (noise_map.slim == np.array([1.0 + 2.0j, 3.0 + 4.0j])).all() assert (noise_map.amplitudes == np.array([np.sqrt(5), 5.0])).all() assert noise_map.phases == pytest.approx(np.array([1.10714872, 0.92729522]), 1.0e-4) - assert (noise_map.ordered_1d == np.array([1.0, 3.0, 2.0, 4.0])).all() - assert ( - noise_map.weight_list_ordered_1d == np.array([1.0, 1.0 / 9.0, 0.25, 0.0625]) - ).all() From 6526dde40ae46dc230fdfae14d0cd3cd8e86cc9e Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 18 Jun 2025 15:29:03 +0100 Subject: [PATCH 13/18] JAX intereferometer grad works --- autoarray/fit/fit_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/autoarray/fit/fit_util.py b/autoarray/fit/fit_util.py index c61262c1a..8d983bf1f 100644 --- a/autoarray/fit/fit_util.py +++ b/autoarray/fit/fit_util.py @@ -158,8 +158,8 @@ def chi_squared_complex_from(*, chi_squared_map: jnp.ndarray) -> float: chi_squared_map The chi-squared-map of values of the model-data fit to the dataset. """ - chi_squared_real = jnp.sum(np.array(chi_squared_map.real)) - chi_squared_imag = jnp.sum(np.array(chi_squared_map.imag)) + chi_squared_real = jnp.sum(chi_squared_map.array.real) + chi_squared_imag = jnp.sum(chi_squared_map.array.imag) return chi_squared_real + chi_squared_imag From ea9d772d60076f71ae4ff6252717aa7a95cf22d1 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 18 Jun 2025 15:34:35 +0100 Subject: [PATCH 14/18] dft_preload_transform added to Interferometer inputs --- autoarray/dataset/interferometer/dataset.py | 14 +++++++++++--- autoarray/operators/transformer.py | 2 +- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/autoarray/dataset/interferometer/dataset.py b/autoarray/dataset/interferometer/dataset.py index 01b5d84bd..f6f0ad22f 100644 --- a/autoarray/dataset/interferometer/dataset.py +++ b/autoarray/dataset/interferometer/dataset.py @@ -10,7 +10,7 @@ from autoarray.dataset.interferometer.w_tilde import WTildeInterferometer from autoarray.dataset.grids import GridsDataset from autoarray.operators.transformer import TransformerNUFFT - +from autoarray.mask.mask_2d import Mask2D from autoarray.structures.visibilities import Visibilities from autoarray.structures.visibilities import VisibilitiesNoiseMap @@ -25,8 +25,9 @@ def __init__( data: Visibilities, noise_map: VisibilitiesNoiseMap, uv_wavelengths: np.ndarray, - real_space_mask, + real_space_mask : Mask2D, transformer_class=TransformerNUFFT, + dft_preload_transform : bool = True, preprocessing_directory=None, ): """ @@ -73,6 +74,9 @@ def __init__( transformer_class The class of the Fourier Transform which maps images from real space to Fourier space visibilities and the uv-plane. + dft_preload_transform + If True, precomputes and stores the cosine and sine terms for the Fourier transform. + This accelerates repeated transforms but consumes additional memory (~1GB+ for large datasets). """ self.real_space_mask = real_space_mask @@ -86,7 +90,9 @@ def __init__( self.uv_wavelengths = uv_wavelengths self.transformer = transformer_class( - uv_wavelengths=uv_wavelengths, real_space_mask=real_space_mask + uv_wavelengths=uv_wavelengths, + real_space_mask=real_space_mask, + dft_preload_transform=dft_preload_transform, ) self.preprocessing_directory = ( @@ -114,6 +120,7 @@ def from_fits( noise_map_hdu=0, uv_wavelengths_hdu=0, transformer_class=TransformerNUFFT, + dft_preload_transform: bool = True, ): """ Factory for loading the interferometer data_type from .fits files, as well as computing properties like the @@ -139,6 +146,7 @@ def from_fits( noise_map=noise_map, uv_wavelengths=uv_wavelengths, transformer_class=transformer_class, + dft_preload_transform=dft_preload_transform, ) def w_tilde_preprocessing(self): diff --git a/autoarray/operators/transformer.py b/autoarray/operators/transformer.py index dbb01a58d..0e6f78f79 100644 --- a/autoarray/operators/transformer.py +++ b/autoarray/operators/transformer.py @@ -218,7 +218,7 @@ def transform_mapping_matrix(self, mapping_matrix: np.ndarray) -> np.ndarray: class TransformerNUFFT(NUFFT_cpu): - def __init__(self, uv_wavelengths: np.ndarray, real_space_mask: Mask2D): + def __init__(self, uv_wavelengths: np.ndarray, real_space_mask: Mask2D, **kwargs): """ Performs the Non-Uniform Fast Fourier Transform (NUFFT) for interferometric image reconstruction. From 3a6a5dcba5c9bf97e777f0da09c8ceeb0d878b85 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 18 Jun 2025 15:39:30 +0100 Subject: [PATCH 15/18] black --- autoarray/dataset/interferometer/dataset.py | 6 +++--- autoarray/operators/transformer.py | 13 ++++++------- test_autoarray/operators/test_transformer.py | 2 +- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/autoarray/dataset/interferometer/dataset.py b/autoarray/dataset/interferometer/dataset.py index f6f0ad22f..0a2d5bbdb 100644 --- a/autoarray/dataset/interferometer/dataset.py +++ b/autoarray/dataset/interferometer/dataset.py @@ -25,9 +25,9 @@ def __init__( data: Visibilities, noise_map: VisibilitiesNoiseMap, uv_wavelengths: np.ndarray, - real_space_mask : Mask2D, + real_space_mask: Mask2D, transformer_class=TransformerNUFFT, - dft_preload_transform : bool = True, + dft_preload_transform: bool = True, preprocessing_directory=None, ): """ @@ -92,7 +92,7 @@ def __init__( self.transformer = transformer_class( uv_wavelengths=uv_wavelengths, real_space_mask=real_space_mask, - dft_preload_transform=dft_preload_transform, + preload_transform=dft_preload_transform, ) self.preprocessing_directory = ( diff --git a/autoarray/operators/transformer.py b/autoarray/operators/transformer.py index 0e6f78f79..c3d94f686 100644 --- a/autoarray/operators/transformer.py +++ b/autoarray/operators/transformer.py @@ -91,6 +91,7 @@ def __init__( self.preload_transform = preload_transform if preload_transform: + self.preload_real_transforms = ( transformer_util.preload_real_transforms_from( grid_radians=np.array(self.grid.array), @@ -137,7 +138,6 @@ def visibilities_from(self, image: Array2D) -> Visibilities: preloaded_reals=self.preload_real_transforms, preloaded_imags=self.preload_imag_transforms, ) - else: visibilities = transformer_util.visibilities_from( image_1d=image.slim.array, @@ -209,12 +209,11 @@ def transform_mapping_matrix(self, mapping_matrix: np.ndarray) -> np.ndarray: preloaded_imags=self.preload_imag_transforms, ) - else: - return transformer_util.transformed_mapping_matrix_from( - mapping_matrix=mapping_matrix, - grid_radians=np.array(self.grid), - uv_wavelengths=self.uv_wavelengths, - ) + return transformer_util.transformed_mapping_matrix_from( + mapping_matrix=mapping_matrix, + grid_radians=np.array(self.grid), + uv_wavelengths=self.uv_wavelengths, + ) class TransformerNUFFT(NUFFT_cpu): diff --git a/test_autoarray/operators/test_transformer.py b/test_autoarray/operators/test_transformer.py index c4209e9fb..4d432cf6b 100644 --- a/test_autoarray/operators/test_transformer.py +++ b/test_autoarray/operators/test_transformer.py @@ -83,7 +83,7 @@ def test__dft__visibilities_from__preload_and_non_preload_give_same_answer( visibilities_via_preload = transformer_preload.visibilities_from(image=image) visibilities = transformer.visibilities_from(image=image) - assert visibilities_via_preload == pytest.approx(visibilities, 1.0e-4) + assert visibilities_via_preload == pytest.approx(visibilities.array, 1.0e-4) def test__dft__transform_mapping_matrix( From 5ad6997851f66a45bf7c10b8f3fa557f4309f59a Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 18 Jun 2025 17:01:06 +0100 Subject: [PATCH 16/18] refactor use of dirty image in inversion --- autoarray/dataset/interferometer/dataset.py | 22 ++++-- autoarray/dataset/interferometer/w_tilde.py | 2 - .../inversion/interferometer/abstract.py | 19 +++++ .../inversion_interferometer_util.py | 46 ------------- .../inversion/interferometer/mapping.py | 22 ------ .../inversion/interferometer/w_tilde.py | 19 ----- .../test_inversion_interferometer_util.py | 69 +------------------ 7 files changed, 36 insertions(+), 163 deletions(-) diff --git a/autoarray/dataset/interferometer/dataset.py b/autoarray/dataset/interferometer/dataset.py index 0a2d5bbdb..483260c1f 100644 --- a/autoarray/dataset/interferometer/dataset.py +++ b/autoarray/dataset/interferometer/dataset.py @@ -166,6 +166,21 @@ def w_tilde_preprocessing(self): fits.writeto(filename, data=curvature_preload) + @cached_property + def dirty_image_for_inversion(self) -> np.ndarray: + """ + Returns a dirty image with scaling applied to the visibilities, which is used in the inversion + linear algebra. + + In particular, it enables fast computation of the `data_vector` in the linear algebra equations. + """ + + return self.transformer.image_from( + visibilities=self.data.real * self.noise_map.real**-2.0 + + 1j * self.data.imag * self.noise_map.imag**-2.0, + use_adjoint_scaling=True, + ) + @cached_property def w_tilde(self): """ @@ -206,16 +221,9 @@ def w_tilde(self): ).astype("int"), ) - dirty_image = self.transformer.image_from( - visibilities=self.data.real * self.noise_map.real**-2.0 - + 1j * self.data.imag * self.noise_map.imag**-2.0, - use_adjoint_scaling=True, - ) - return WTildeInterferometer( w_matrix=w_matrix, curvature_preload=curvature_preload, - dirty_image=np.array(dirty_image.array), real_space_mask=self.real_space_mask, noise_map_value=self.noise_map[0], ) diff --git a/autoarray/dataset/interferometer/w_tilde.py b/autoarray/dataset/interferometer/w_tilde.py index dbd27247d..c5275b895 100644 --- a/autoarray/dataset/interferometer/w_tilde.py +++ b/autoarray/dataset/interferometer/w_tilde.py @@ -9,7 +9,6 @@ def __init__( self, w_matrix: np.ndarray, curvature_preload: np.ndarray, - dirty_image: np.ndarray, real_space_mask: Mask2D, noise_map_value: float, ): @@ -43,7 +42,6 @@ def __init__( curvature_preload=curvature_preload, noise_map_value=noise_map_value ) - self.dirty_image = dirty_image self.real_space_mask = real_space_mask self.w_matrix = w_matrix diff --git a/autoarray/inversion/inversion/interferometer/abstract.py b/autoarray/inversion/inversion/interferometer/abstract.py index 47e1c84bf..2fe7b3f84 100644 --- a/autoarray/inversion/inversion/interferometer/abstract.py +++ b/autoarray/inversion/inversion/interferometer/abstract.py @@ -1,6 +1,8 @@ import numpy as np from typing import Dict, List, Optional, Union +from autoconf import cached_property + from autoarray.dataset.interferometer.dataset import Interferometer from autoarray.inversion.inversion.dataset_interface import DatasetInterface from autoarray.inversion.inversion.abstract import AbstractInversion @@ -77,6 +79,23 @@ def operated_mapping_matrix_list(self) -> List[np.ndarray]: for linear_obj in self.linear_obj_list ] + @cached_property + @profile_func + def data_vector(self) -> np.ndarray: + """ + The `data_vector` is a 1D vector whose values are solved for by the simultaneous linear equations constructed + by this object. + + The linear algebra is described in the paper https://arxiv.org/pdf/astro-ph/0302587.pdf), where the + data vector is given by equation (4) and the letter D. + + If there are multiple linear objects the `data_vectors` are concatenated ensuring their values are solved + for simultaneously. + + The calculation is described in more detail in `inversion_util.w_tilde_data_interferometer_from`. + """ + return self.dataset.dirty_image_for_inversion.array @ self.mapping_matrix + @property @profile_func def mapped_reconstructed_image_dict( diff --git a/autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py b/autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py index 29580c06d..9ddcb996e 100644 --- a/autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py +++ b/autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py @@ -387,52 +387,6 @@ def w_tilde_via_preload_from(w_tilde_preload, native_index_for_slim_index): return w_tilde_via_preload -@numba_util.jit() -def data_vector_via_transformed_mapping_matrix_from( - transformed_mapping_matrix: np.ndarray, - visibilities: np.ndarray, - noise_map: np.ndarray, -) -> np.ndarray: - """ - Returns the data vector `D` from a transformed mapping matrix `f` and the 1D image `d` and 1D noise-map `sigma` - (see Warren & Dye 2003). - - Parameters - ---------- - transformed_mapping_matrix - The matrix representing the transformed mappings between sub-grid pixels and pixelization pixels. - image - Flattened 1D array of the observed image the inversion is fitting. - noise_map - Flattened 1D array of the noise-map used by the inversion during the fit. - """ - - data_vector = np.zeros(transformed_mapping_matrix.shape[1]) - - visibilities_real = visibilities.real - visibilities_imag = visibilities.imag - transformed_mapping_matrix_real = transformed_mapping_matrix.real - transformed_mapping_matrix_imag = transformed_mapping_matrix.imag - noise_map_real = noise_map.real - noise_map_imag = noise_map.imag - - for vis_1d_index in range(transformed_mapping_matrix.shape[0]): - for pix_1d_index in range(transformed_mapping_matrix.shape[1]): - real_value = ( - visibilities_real[vis_1d_index] - * transformed_mapping_matrix_real[vis_1d_index, pix_1d_index] - / (noise_map_real[vis_1d_index] ** 2.0) - ) - imag_value = ( - visibilities_imag[vis_1d_index] - * transformed_mapping_matrix_imag[vis_1d_index, pix_1d_index] - / (noise_map_imag[vis_1d_index] ** 2.0) - ) - data_vector[pix_1d_index] += real_value + imag_value - - return data_vector - - @numba_util.jit() def curvature_matrix_via_w_tilde_curvature_preload_interferometer_from( curvature_preload: np.ndarray, diff --git a/autoarray/inversion/inversion/interferometer/mapping.py b/autoarray/inversion/inversion/interferometer/mapping.py index 2b3219a2e..feee87c68 100644 --- a/autoarray/inversion/inversion/interferometer/mapping.py +++ b/autoarray/inversion/inversion/interferometer/mapping.py @@ -59,28 +59,6 @@ def __init__( run_time_dict=run_time_dict, ) - @cached_property - @profile_func - def data_vector(self) -> np.ndarray: - """ - The `data_vector` is a 1D vector whose values are solved for by the simultaneous linear equations constructed - by this object. - - The linear algebra is described in the paper https://arxiv.org/pdf/astro-ph/0302587.pdf), where the - data vector is given by equation (4) and the letter D. - - If there are multiple linear objects their `operated_mapping_matrix` properties will have already been - concatenated ensuring their `data_vector` values are solved for simultaneously. - - The calculation is described in more detail in `inversion_util.data_vector_via_transformed_mapping_matrix_from`. - """ - - return inversion_interferometer_util.data_vector_via_transformed_mapping_matrix_from( - transformed_mapping_matrix=np.array(self.operated_mapping_matrix), - visibilities=np.array(self.data), - noise_map=np.array(self.noise_map), - ) - @cached_property @profile_func def curvature_matrix(self) -> np.ndarray: diff --git a/autoarray/inversion/inversion/interferometer/w_tilde.py b/autoarray/inversion/inversion/interferometer/w_tilde.py index 83febb864..82652a6bd 100644 --- a/autoarray/inversion/inversion/interferometer/w_tilde.py +++ b/autoarray/inversion/inversion/interferometer/w_tilde.py @@ -70,25 +70,6 @@ def __init__( self.settings = settings - @cached_property - @profile_func - def data_vector(self) -> np.ndarray: - """ - The `data_vector` is a 1D vector whose values are solved for by the simultaneous linear equations constructed - by this object. - - The linear algebra is described in the paper https://arxiv.org/pdf/astro-ph/0302587.pdf), where the - data vector is given by equation (4) and the letter D. - - If there are multiple linear objects the `data_vectors` are concatenated ensuring their values are solved - for simultaneously. - - The calculation is described in more detail in `inversion_util.w_tilde_data_interferometer_from`. - """ - return np.dot( - self.linear_obj_list[0].mapping_matrix.T, self.w_tilde.dirty_image - ) - @cached_property @profile_func def curvature_matrix(self) -> np.ndarray: diff --git a/test_autoarray/inversion/inversion/interferometer/test_inversion_interferometer_util.py b/test_autoarray/inversion/inversion/interferometer/test_inversion_interferometer_util.py index c7dc03221..de52b3301 100644 --- a/test_autoarray/inversion/inversion/interferometer/test_inversion_interferometer_util.py +++ b/test_autoarray/inversion/inversion/interferometer/test_inversion_interferometer_util.py @@ -3,71 +3,6 @@ import pytest -def test__data_vector_via_transformed_mapping_matrix_from(): - mapping_matrix = np.array( - [ - [1.0, 1.0, 0.0], - [1.0, 0.0, 0.0], - [0.0, 1.0, 0.0], - [0.0, 1.0, 1.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - ] - ) - - data_real = np.array([4.0, 1.0, 1.0, 16.0, 1.0, 1.0]) - noise_map_real = np.array([2.0, 1.0, 1.0, 4.0, 1.0, 1.0]) - - data_vector_real_via_blurred = ( - aa.util.inversion_imaging.data_vector_via_blurred_mapping_matrix_from( - blurred_mapping_matrix=mapping_matrix, - image=data_real, - noise_map=noise_map_real, - ) - ) - - data_imag = np.array([4.0, 1.0, 1.0, 16.0, 1.0, 1.0]) - noise_map_imag = np.array([2.0, 1.0, 1.0, 4.0, 1.0, 1.0]) - - data_vector_imag_via_blurred = ( - aa.util.inversion_imaging.data_vector_via_blurred_mapping_matrix_from( - blurred_mapping_matrix=mapping_matrix, - image=data_imag, - noise_map=noise_map_imag, - ) - ) - - data_vector_complex_via_blurred = ( - data_vector_real_via_blurred + data_vector_imag_via_blurred - ) - - transformed_mapping_matrix = np.array( - [ - [1.0 + 1.0j, 1.0 + 1.0j, 0.0 + 0.0j], - [1.0 + 1.0j, 0.0 + 0.0j, 0.0 + 0.0j], - [0.0 + 0.0j, 1.0 + 1.0j, 0.0 + 0.0j], - [0.0 + 0.0j, 1.0 + 1.0j, 1.0 + 1.0j], - [0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j], - [0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j], - ] - ) - - data = np.array( - [4.0 + 4.0j, 1.0 + 1.0j, 1.0 + 1.0j, 16.0 + 16.0j, 1.0 + 1.0j, 1.0 + 1.0j] - ) - noise_map = np.array( - [2.0 + 2.0j, 1.0 + 1.0j, 1.0 + 1.0j, 4.0 + 4.0j, 1.0 + 1.0j, 1.0 + 1.0j] - ) - - data_vector_via_transformed = aa.util.inversion_interferometer.data_vector_via_transformed_mapping_matrix_from( - transformed_mapping_matrix=transformed_mapping_matrix, - visibilities=data, - noise_map=noise_map, - ) - - assert (data_vector_complex_via_blurred == data_vector_via_transformed).all() - - def test__w_tilde_curvature_interferometer_from(): noise_map = np.array([1.0, 2.0, 3.0]) uv_wavelengths = np.array([[0.0001, 2.0, 3000.0], [3000.0, 2.0, 0.0001]]) @@ -387,6 +322,6 @@ def test__identical_inversion_source_and_image_loops(): assert inversion_image_loop.mapped_reconstructed_image.array == pytest.approx( inversion_source_loop.mapped_reconstructed_image.array, 1.0e-2 ) - assert inversion_image_loop.mapped_reconstructed_data == pytest.approx( - inversion_source_loop.mapped_reconstructed_data, 1.0e-2 + assert inversion_image_loop.mapped_reconstructed_data.array == pytest.approx( + inversion_source_loop.mapped_reconstructed_data.array, 1.0e-2 ) From c798d1775b3e599932f72ae7b31b745a22b41261 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 18 Jun 2025 19:24:43 +0100 Subject: [PATCH 17/18] JAX on transforfed mapping matrtix --- autoarray/dataset/interferometer/dataset.py | 22 ++---- autoarray/dataset/interferometer/w_tilde.py | 2 + autoarray/fit/fit_interferometer.py | 2 +- autoarray/fit/fit_util.py | 4 +- autoarray/inversion/inversion/abstract.py | 2 +- .../inversion/inversion/dataset_interface.py | 6 +- .../inversion/interferometer/abstract.py | 19 ----- .../inversion_interferometer_util.py | 54 +++++++++++---- .../inversion/interferometer/mapping.py | 33 +++++++-- .../inversion/interferometer/w_tilde.py | 17 +++++ autoarray/structures/visibilities.py | 2 +- .../dataset/interferometer/test_simulator.py | 2 +- .../test_inversion_interferometer_util.py | 69 ++++++++++++++++++- 13 files changed, 168 insertions(+), 66 deletions(-) diff --git a/autoarray/dataset/interferometer/dataset.py b/autoarray/dataset/interferometer/dataset.py index 483260c1f..0a2d5bbdb 100644 --- a/autoarray/dataset/interferometer/dataset.py +++ b/autoarray/dataset/interferometer/dataset.py @@ -166,21 +166,6 @@ def w_tilde_preprocessing(self): fits.writeto(filename, data=curvature_preload) - @cached_property - def dirty_image_for_inversion(self) -> np.ndarray: - """ - Returns a dirty image with scaling applied to the visibilities, which is used in the inversion - linear algebra. - - In particular, it enables fast computation of the `data_vector` in the linear algebra equations. - """ - - return self.transformer.image_from( - visibilities=self.data.real * self.noise_map.real**-2.0 - + 1j * self.data.imag * self.noise_map.imag**-2.0, - use_adjoint_scaling=True, - ) - @cached_property def w_tilde(self): """ @@ -221,9 +206,16 @@ def w_tilde(self): ).astype("int"), ) + dirty_image = self.transformer.image_from( + visibilities=self.data.real * self.noise_map.real**-2.0 + + 1j * self.data.imag * self.noise_map.imag**-2.0, + use_adjoint_scaling=True, + ) + return WTildeInterferometer( w_matrix=w_matrix, curvature_preload=curvature_preload, + dirty_image=np.array(dirty_image.array), real_space_mask=self.real_space_mask, noise_map_value=self.noise_map[0], ) diff --git a/autoarray/dataset/interferometer/w_tilde.py b/autoarray/dataset/interferometer/w_tilde.py index c5275b895..dbd27247d 100644 --- a/autoarray/dataset/interferometer/w_tilde.py +++ b/autoarray/dataset/interferometer/w_tilde.py @@ -9,6 +9,7 @@ def __init__( self, w_matrix: np.ndarray, curvature_preload: np.ndarray, + dirty_image: np.ndarray, real_space_mask: Mask2D, noise_map_value: float, ): @@ -42,6 +43,7 @@ def __init__( curvature_preload=curvature_preload, noise_map_value=noise_map_value ) + self.dirty_image = dirty_image self.real_space_mask = real_space_mask self.w_matrix = w_matrix diff --git a/autoarray/fit/fit_interferometer.py b/autoarray/fit/fit_interferometer.py index ec4c1d99d..40a713bc4 100644 --- a/autoarray/fit/fit_interferometer.py +++ b/autoarray/fit/fit_interferometer.py @@ -113,7 +113,7 @@ def chi_squared(self) -> float: Returns the chi-squared terms of the model data's fit to an dataset, by summing the chi-squared-map. """ return fit_util.chi_squared_complex_from( - chi_squared_map=self.chi_squared_map, + chi_squared_map=self.chi_squared_map.array, ) @property diff --git a/autoarray/fit/fit_util.py b/autoarray/fit/fit_util.py index 8d983bf1f..7da5cd310 100644 --- a/autoarray/fit/fit_util.py +++ b/autoarray/fit/fit_util.py @@ -158,8 +158,8 @@ def chi_squared_complex_from(*, chi_squared_map: jnp.ndarray) -> float: chi_squared_map The chi-squared-map of values of the model-data fit to the dataset. """ - chi_squared_real = jnp.sum(chi_squared_map.array.real) - chi_squared_imag = jnp.sum(chi_squared_map.array.imag) + chi_squared_real = jnp.sum(chi_squared_map.real) + chi_squared_imag = jnp.sum(chi_squared_map.imag) return chi_squared_real + chi_squared_imag diff --git a/autoarray/inversion/inversion/abstract.py b/autoarray/inversion/inversion/abstract.py index ebbb9927f..bec0b1ce2 100644 --- a/autoarray/inversion/inversion/abstract.py +++ b/autoarray/inversion/inversion/abstract.py @@ -287,7 +287,7 @@ def mapping_matrix(self) -> np.ndarray: If there are multiple linear objects, the mapping matrices are stacked such that their simultaneous linear equations are solved simultaneously. This property returns the stacked mapping matrix. """ - return np.hstack( + return jnp.hstack( [linear_obj.mapping_matrix for linear_obj in self.linear_obj_list] ) diff --git a/autoarray/inversion/inversion/dataset_interface.py b/autoarray/inversion/inversion/dataset_interface.py index c1cf2db9a..f37efef36 100644 --- a/autoarray/inversion/inversion/dataset_interface.py +++ b/autoarray/inversion/inversion/dataset_interface.py @@ -36,6 +36,9 @@ def __init__( noise_map An array describing the RMS standard deviation error in each pixel used for computing quantities like the chi-squared in a fit (in PyAutoGalaxy and PyAutoLens the recommended units are electrons per second). + grids + The grids of (y,x) Cartesian coordinates that the image data is paired with, which are used for evaluting + light profiles and calculations associated with a pixelization. over_sampler Performs over-sampling whereby the masked image pixels are split into sub-pixels, which are all mapped via the mapper with sub-fractional values of flux. @@ -50,9 +53,6 @@ def __init__( w_tilde The w_tilde matrix used by the w-tilde formalism to construct the data vector and curvature matrix during an inversion efficiently.. - grids - The grids of (y,x) Cartesian coordinates that the image data is paired with, which are used for evaluting - light profiles and calculations associated with a pixelization. noise_covariance_matrix A noise-map covariance matrix representing the covariance between noise in every `data` value, which can be used via a bespoke fit to account for correlated noise in the data. diff --git a/autoarray/inversion/inversion/interferometer/abstract.py b/autoarray/inversion/inversion/interferometer/abstract.py index 2fe7b3f84..47e1c84bf 100644 --- a/autoarray/inversion/inversion/interferometer/abstract.py +++ b/autoarray/inversion/inversion/interferometer/abstract.py @@ -1,8 +1,6 @@ import numpy as np from typing import Dict, List, Optional, Union -from autoconf import cached_property - from autoarray.dataset.interferometer.dataset import Interferometer from autoarray.inversion.inversion.dataset_interface import DatasetInterface from autoarray.inversion.inversion.abstract import AbstractInversion @@ -79,23 +77,6 @@ def operated_mapping_matrix_list(self) -> List[np.ndarray]: for linear_obj in self.linear_obj_list ] - @cached_property - @profile_func - def data_vector(self) -> np.ndarray: - """ - The `data_vector` is a 1D vector whose values are solved for by the simultaneous linear equations constructed - by this object. - - The linear algebra is described in the paper https://arxiv.org/pdf/astro-ph/0302587.pdf), where the - data vector is given by equation (4) and the letter D. - - If there are multiple linear objects the `data_vectors` are concatenated ensuring their values are solved - for simultaneously. - - The calculation is described in more detail in `inversion_util.w_tilde_data_interferometer_from`. - """ - return self.dataset.dirty_image_for_inversion.array @ self.mapping_matrix - @property @profile_func def mapped_reconstructed_image_dict( diff --git a/autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py b/autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py index 9ddcb996e..1a3e647dc 100644 --- a/autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py +++ b/autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py @@ -387,6 +387,44 @@ def w_tilde_via_preload_from(w_tilde_preload, native_index_for_slim_index): return w_tilde_via_preload +def data_vector_via_transformed_mapping_matrix_from( + transformed_mapping_matrix: np.ndarray, + visibilities: np.ndarray, + noise_map: np.ndarray, +) -> np.ndarray: + """ + Returns the data vector `D` from a transformed mapping matrix `f` and the 1D image `d` and 1D noise-map `sigma` + (see Warren & Dye 2003). + + Parameters + ---------- + transformed_mapping_matrix + The matrix representing the transformed mappings between sub-grid pixels and pixelization pixels. + image + Flattened 1D array of the observed image the inversion is fitting. + noise_map + Flattened 1D array of the noise-map used by the inversion during the fit. + """ + # Extract components + vis_real = visibilities.real + vis_imag = visibilities.imag + f_real = transformed_mapping_matrix.real + f_imag = transformed_mapping_matrix.imag + noise_real = noise_map.real + noise_imag = noise_map.imag + + # Square noise components + inv_var_real = 1.0 / (noise_real**2) + inv_var_imag = 1.0 / (noise_imag**2) + + # Real and imaginary contributions + weighted_real = (vis_real * inv_var_real)[:, None] * f_real + weighted_imag = (vis_imag * inv_var_imag)[:, None] * f_imag + + # Sum over visibilities + return np.sum(weighted_real + weighted_imag, axis=0) + + @numba_util.jit() def curvature_matrix_via_w_tilde_curvature_preload_interferometer_from( curvature_preload: np.ndarray, @@ -466,7 +504,6 @@ def curvature_matrix_via_w_tilde_curvature_preload_interferometer_from( return curvature_matrix -@numba_util.jit() def mapped_reconstructed_visibilities_from( transformed_mapping_matrix: np.ndarray, reconstruction: np.ndarray ) -> np.ndarray: @@ -479,20 +516,7 @@ def mapped_reconstructed_visibilities_from( The matrix representing the blurred mappings between sub-grid pixels and pixelization pixels. """ - mapped_reconstructed_visibilities = (0.0 + 0.0j) * np.zeros( - transformed_mapping_matrix.shape[0] - ) - - transformed_mapping_matrix_real = transformed_mapping_matrix.real - transformed_mapping_matrix_imag = transformed_mapping_matrix.imag - - for i in range(transformed_mapping_matrix.shape[0]): - for j in range(reconstruction.shape[0]): - mapped_reconstructed_visibilities[i] += ( - reconstruction[j] * transformed_mapping_matrix_real[i, j] - ) + 1.0j * (reconstruction[j] * transformed_mapping_matrix_imag[i, j]) - - return mapped_reconstructed_visibilities + return transformed_mapping_matrix @ reconstruction """ diff --git a/autoarray/inversion/inversion/interferometer/mapping.py b/autoarray/inversion/inversion/interferometer/mapping.py index feee87c68..77ec2576c 100644 --- a/autoarray/inversion/inversion/interferometer/mapping.py +++ b/autoarray/inversion/inversion/interferometer/mapping.py @@ -1,3 +1,4 @@ +import jax.numpy as jnp import numpy as np from typing import Dict, List, Optional, Union @@ -59,6 +60,28 @@ def __init__( run_time_dict=run_time_dict, ) + @cached_property + @profile_func + def data_vector(self) -> np.ndarray: + """ + The `data_vector` is a 1D vector whose values are solved for by the simultaneous linear equations constructed + by this object. + + The linear algebra is described in the paper https://arxiv.org/pdf/astro-ph/0302587.pdf), where the + data vector is given by equation (4) and the letter D. + + If there are multiple linear objects their `operated_mapping_matrix` properties will have already been + concatenated ensuring their `data_vector` values are solved for simultaneously. + + The calculation is described in more detail in `inversion_util.data_vector_via_transformed_mapping_matrix_from`. + """ + + return inversion_interferometer_util.data_vector_via_transformed_mapping_matrix_from( + transformed_mapping_matrix=self.operated_mapping_matrix, + visibilities=self.data, + noise_map=np.array(self.noise_map), + ) + @cached_property @profile_func def curvature_matrix(self) -> np.ndarray: @@ -84,13 +107,13 @@ def curvature_matrix(self) -> np.ndarray: noise_map=self.noise_map.imag, ) - curvature_matrix = np.add(real_curvature_matrix, imag_curvature_matrix) + curvature_matrix = jnp.add(real_curvature_matrix, imag_curvature_matrix) if len(self.no_regularization_index_list) > 0: curvature_matrix = inversion_util.curvature_matrix_with_added_to_diag_from( curvature_matrix=curvature_matrix, - no_regularization_index_list=self.no_regularization_index_list, value=self.settings.no_regularization_add_to_curvature_diag_value, + no_regularization_index_list=self.no_regularization_index_list, ) return curvature_matrix @@ -130,10 +153,8 @@ def mapped_reconstructed_data_dict( visibilities = ( inversion_interferometer_util.mapped_reconstructed_visibilities_from( - transformed_mapping_matrix=np.array( - operated_mapping_matrix_list[index] - ), - reconstruction=np.array(reconstruction), + transformed_mapping_matrix=operated_mapping_matrix_list[index], + reconstruction=reconstruction, ) ) diff --git a/autoarray/inversion/inversion/interferometer/w_tilde.py b/autoarray/inversion/inversion/interferometer/w_tilde.py index 82652a6bd..e538e8779 100644 --- a/autoarray/inversion/inversion/interferometer/w_tilde.py +++ b/autoarray/inversion/inversion/interferometer/w_tilde.py @@ -70,6 +70,23 @@ def __init__( self.settings = settings + @cached_property + @profile_func + def data_vector(self) -> np.ndarray: + """ + The `data_vector` is a 1D vector whose values are solved for by the simultaneous linear equations constructed + by this object. + + The linear algebra is described in the paper https://arxiv.org/pdf/astro-ph/0302587.pdf), where the + data vector is given by equation (4) and the letter D. + + If there are multiple linear objects the `data_vectors` are concatenated ensuring their values are solved + for simultaneously. + + The calculation is described in more detail in `inversion_util.w_tilde_data_interferometer_from`. + """ + return np.dot(self.mapping_matrix.T, self.w_tilde.dirty_image) + @cached_property @profile_func def curvature_matrix(self) -> np.ndarray: diff --git a/autoarray/structures/visibilities.py b/autoarray/structures/visibilities.py index 7557d39c4..abfb9713e 100644 --- a/autoarray/structures/visibilities.py +++ b/autoarray/structures/visibilities.py @@ -66,7 +66,7 @@ def in_array(self) -> np.ndarray: Returns the 1D complex NumPy array of values with shape [total_visibilities] as a NumPy float array of shape [total_visibilities, 2]. """ - return np.stack((np.real(self), np.imag(self)), axis=-1) + return np.stack((np.real(self.array), np.imag(self.array)), axis=-1) @property def in_grid(self) -> Grid2DIrregular: diff --git a/test_autoarray/dataset/interferometer/test_simulator.py b/test_autoarray/dataset/interferometer/test_simulator.py index e45908cba..ea434c163 100644 --- a/test_autoarray/dataset/interferometer/test_simulator.py +++ b/test_autoarray/dataset/interferometer/test_simulator.py @@ -30,7 +30,7 @@ def test__from_image__setup_with_all_features_off( visibilities = transformer.visibilities_from(image=image) - assert dataset.data == pytest.approx(visibilities, 1.0e-4) + assert dataset.data == pytest.approx(visibilities.array, 1.0e-4) def test__setup_with_noise(uv_wavelengths_7x2, transformer_7x7_7): diff --git a/test_autoarray/inversion/inversion/interferometer/test_inversion_interferometer_util.py b/test_autoarray/inversion/inversion/interferometer/test_inversion_interferometer_util.py index de52b3301..96cad9eed 100644 --- a/test_autoarray/inversion/inversion/interferometer/test_inversion_interferometer_util.py +++ b/test_autoarray/inversion/inversion/interferometer/test_inversion_interferometer_util.py @@ -3,6 +3,71 @@ import pytest +def test__data_vector_via_transformed_mapping_matrix_from(): + mapping_matrix = np.array( + [ + [1.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 1.0, 1.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ] + ) + + data_real = np.array([4.0, 1.0, 1.0, 16.0, 1.0, 1.0]) + noise_map_real = np.array([2.0, 1.0, 1.0, 4.0, 1.0, 1.0]) + + data_vector_real_via_blurred = ( + aa.util.inversion_imaging.data_vector_via_blurred_mapping_matrix_from( + blurred_mapping_matrix=mapping_matrix, + image=data_real, + noise_map=noise_map_real, + ) + ) + + data_imag = np.array([4.0, 1.0, 1.0, 16.0, 1.0, 1.0]) + noise_map_imag = np.array([2.0, 1.0, 1.0, 4.0, 1.0, 1.0]) + + data_vector_imag_via_blurred = ( + aa.util.inversion_imaging.data_vector_via_blurred_mapping_matrix_from( + blurred_mapping_matrix=mapping_matrix, + image=data_imag, + noise_map=noise_map_imag, + ) + ) + + data_vector_complex_via_blurred = ( + data_vector_real_via_blurred + data_vector_imag_via_blurred + ) + + transformed_mapping_matrix = np.array( + [ + [1.0 + 1.0j, 1.0 + 1.0j, 0.0 + 0.0j], + [1.0 + 1.0j, 0.0 + 0.0j, 0.0 + 0.0j], + [0.0 + 0.0j, 1.0 + 1.0j, 0.0 + 0.0j], + [0.0 + 0.0j, 1.0 + 1.0j, 1.0 + 1.0j], + [0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j], + [0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j], + ] + ) + + data = np.array( + [4.0 + 4.0j, 1.0 + 1.0j, 1.0 + 1.0j, 16.0 + 16.0j, 1.0 + 1.0j, 1.0 + 1.0j] + ) + noise_map = np.array( + [2.0 + 2.0j, 1.0 + 1.0j, 1.0 + 1.0j, 4.0 + 4.0j, 1.0 + 1.0j, 1.0 + 1.0j] + ) + + data_vector_via_transformed = aa.util.inversion_interferometer.data_vector_via_transformed_mapping_matrix_from( + transformed_mapping_matrix=transformed_mapping_matrix, + visibilities=data, + noise_map=noise_map, + ) + + assert (data_vector_complex_via_blurred == data_vector_via_transformed).all() + + def test__w_tilde_curvature_interferometer_from(): noise_map = np.array([1.0, 2.0, 3.0]) uv_wavelengths = np.array([[0.0001, 2.0, 3000.0], [3000.0, 2.0, 0.0001]]) @@ -223,8 +288,8 @@ def test__identical_inversion_values_for_two_methods(): assert inversion_w_tilde.mapped_reconstructed_image.array == pytest.approx( inversion_mapping_matrices.mapped_reconstructed_image.array, abs=1.0e-1 ) - assert inversion_w_tilde.mapped_reconstructed_data == pytest.approx( - inversion_mapping_matrices.mapped_reconstructed_data, abs=1.0e-1 + assert inversion_w_tilde.mapped_reconstructed_data.array == pytest.approx( + inversion_mapping_matrices.mapped_reconstructed_data.array, abs=1.0e-1 ) From ff0a76a7107acc621ea71172863c06b35795778e Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 18 Jun 2025 19:30:24 +0100 Subject: [PATCH 18/18] update plot --- autoarray/fit/plot/fit_interferometer_plotters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autoarray/fit/plot/fit_interferometer_plotters.py b/autoarray/fit/plot/fit_interferometer_plotters.py index 93e7b212e..5247d67e5 100644 --- a/autoarray/fit/plot/fit_interferometer_plotters.py +++ b/autoarray/fit/plot/fit_interferometer_plotters.py @@ -183,7 +183,7 @@ def figures_2d( auto_labels=AutoLabels( title="Model Visibilities", filename="model_data" ), - color_array=np.real(self.fit.model_data), + color_array=np.real(self.fit.model_data.array), ) if residual_map_real: