Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions autoarray/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from autoconf import jax_wrapper
from autoconf.dictable import register_parser
from autoconf import conf

Expand Down
49 changes: 27 additions & 22 deletions autoarray/abstract_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

from abc import ABC
from abc import abstractmethod
import jax.numpy as jnp
from jax._src.tree_util import register_pytree_node

import numpy as np

Expand Down Expand Up @@ -75,20 +73,20 @@ def __init__(self, array, xp=np):
while isinstance(array, AbstractNDArray):
array = array.array
self._array = array
try:
register_pytree_node(
type(self),
self.instance_flatten,
self.instance_unflatten,
)
except ValueError:
pass
# try:
# register_pytree_node(
# type(self),
# self.instance_flatten,
# self.instance_unflatten,
# )
# except ValueError:
# pass

self._xp = xp

def invert(self):
new = self.copy()
new._array = jnp.invert(new._array)
new._array = self._xp.invert(new._array)
return new

@classmethod
Expand Down Expand Up @@ -117,7 +115,7 @@ def instance_unflatten(cls, aux_data, children):
setattr(instance, key, value)
return instance

def with_new_array(self, array: jnp.ndarray) -> "AbstractNDArray":
def with_new_array(self, array: np.ndarray) -> "AbstractNDArray":
"""
Copy this object but give it a new array.

Expand All @@ -137,10 +135,9 @@ def with_new_array(self, array: jnp.ndarray) -> "AbstractNDArray":
new_array._array = array
return new_array

@staticmethod
def flip_hdu_for_ds9(values):
def flip_hdu_for_ds9(self, values):
if conf.instance["general"]["fits"]["flip_for_ds9"]:
return jnp.flipud(values)
return self._xp.flipud(values)
return values

def copy(self):
Expand Down Expand Up @@ -170,7 +167,7 @@ def __iter__(self):

@to_new_array
def sqrt(self):
return jnp.sqrt(self._array)
return self._xp.sqrt(self._array)

@property
def array(self):
Expand Down Expand Up @@ -333,20 +330,28 @@ def __getattr__(self, item):
)

def __getitem__(self, item):

result = self._array[item]

if isinstance(item, slice):
result = self.with_new_array(result)
if isinstance(result, jnp.ndarray):
result = self.with_new_array(result)

try:
import jax.numpy as jnp
if isinstance(result, jnp.ndarray):
result = self.with_new_array(result)
except ImportError:
pass

return result

def __setitem__(self, key, value):
from jax import Array

if isinstance(key, (jnp.ndarray, AbstractNDArray, Array)):
self._array = jnp.where(key, value, self._array)
else:
if isinstance(self._array, np.ndarray):
self._array[key] = value
else:
import jax.numpy as jnp
self._array = jnp.where(key, value, self._array)

def __repr__(self):
return repr(self._array).replace(
Expand Down
2 changes: 0 additions & 2 deletions autoarray/config/general.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
jax:
use_jax: true # If True, uses JAX internally, whereas False uses normal Numpy.
fits:
flip_for_ds9: false # If True, the image is flipped before output to a .fits file, which is useful for viewing in DS9.
psf:
Expand Down
2 changes: 1 addition & 1 deletion autoarray/fit/fit_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def subtracted_from(grid, offset):
if grid is None:
return None

return grid.subtracted_from(offset=offset)
return grid.subtracted_from(offset=offset, xp=self._xp)

lp = subtracted_from(
grid=self.dataset.grids.lp, offset=self.dataset_model.grid_offset
Expand Down
2 changes: 1 addition & 1 deletion autoarray/inversion/inversion/interferometer/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def operated_mapping_matrix_list(self) -> List[np.ndarray]:
"""
return [
self.transformer.transform_mapping_matrix(
mapping_matrix=linear_obj.mapping_matrix
mapping_matrix=linear_obj.mapping_matrix, xp=self._xp
)
for linear_obj in self.linear_obj_list
]
Expand Down
2 changes: 0 additions & 2 deletions autoarray/mask/derive/indexes_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import logging
import numpy as np

from jax._src.tree_util import register_pytree_node_class
from typing import TYPE_CHECKING

if TYPE_CHECKING:
Expand All @@ -14,7 +13,6 @@
logger = logging.getLogger(__name__)


@register_pytree_node_class
class DeriveIndexes2D:

def __init__(self, mask: Mask2D, xp=np):
Expand Down
31 changes: 21 additions & 10 deletions autoarray/operators/over_sampling/over_sampler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np

from jax._src.tree_util import register_pytree_node_class
from typing import Union

from autoconf import conf
Expand All @@ -11,7 +10,6 @@
from autoarray.operators.over_sampling import over_sample_util


@register_pytree_node_class
class OverSampler:
def __init__(self, mask: Mask2D, sub_size: Union[int, Array2D]):
"""
Expand Down Expand Up @@ -229,6 +227,7 @@ def binned_array_2d_from(self, array: Array2D, xp=np) -> "Array2D":
Sub-pixels that are part of the same mask array pixel are indexed next to one another, such that the second
sub-pixel in the first pixel has index 1, its next sub-pixel has index 2, and so forth.
"""

if conf.instance["general"]["structures"]["native_binned_only"]:
return self

Expand All @@ -245,16 +244,28 @@ def binned_array_2d_from(self, array: Array2D, xp=np) -> "Array2D":

else:

import jax
if xp.__name__.startswith("jax"):

# Compute the group means
import jax

sums = jax.ops.segment_sum(
array, self.segment_ids, self.mask.pixels_in_mask
)
counts = jax.ops.segment_sum(
xp.ones_like(array), self.segment_ids, self.mask.pixels_in_mask
)

else:

# Sum values per segment
sums = np.bincount(self.segment_ids, weights=array, minlength=self.mask.pixels_in_mask)

# Count number of items per segment
counts = np.bincount(self.segment_ids, minlength=self.mask.pixels_in_mask)

# Avoid division by zero
counts[counts == 0] = 1

sums = jax.ops.segment_sum(
array, self.segment_ids, self.mask.pixels_in_mask
)
counts = jax.ops.segment_sum(
xp.ones_like(array), self.segment_ids, self.mask.pixels_in_mask
)
binned_array_2d = sums / counts

return Array2D(
Expand Down
48 changes: 22 additions & 26 deletions autoarray/operators/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def __init__(
uv_wavelengths: np.ndarray,
real_space_mask: Mask2D,
preload_transform: bool = True,
xp=np,
):
"""
A direct Fourier transform (DFT) operator for radio interferometric imaging.
Expand Down Expand Up @@ -112,9 +111,7 @@ def __init__(
2.0 * self.grid.shape_native[1]
)

self._xp = xp

def visibilities_from(self, image: Array2D) -> Visibilities:
def visibilities_from(self, image: Array2D, xp=np) -> Visibilities:
"""
Computes the visibilities from a real-space image using the direct Fourier transform (DFT).

Expand All @@ -138,19 +135,20 @@ def visibilities_from(self, image: Array2D) -> Visibilities:
image_1d=image.array,
preloaded_reals=self.preload_real_transforms,
preloaded_imags=self.preload_imag_transforms,
xp=self._xp,
xp=xp,
)
else:
visibilities = transformer_util.visibilities_from(
image_1d=image.slim.array,
grid_radians=self.grid.array,
uv_wavelengths=self.uv_wavelengths,
xp=xp
)

return Visibilities(visibilities=self._xp.array(visibilities))
return Visibilities(visibilities=xp.array(visibilities))

def image_from(
self, visibilities: Visibilities, use_adjoint_scaling: bool = False
self, visibilities: Visibilities, use_adjoint_scaling: bool = False, xp=np
) -> Array2D:
"""
Computes the real-space image from a set of visibilities using the adjoint of the DFT.
Expand Down Expand Up @@ -178,12 +176,12 @@ def image_from(
)

image_native = array_2d_util.array_2d_native_from(
array_2d_slim=image_slim, mask_2d=self.real_space_mask, xp=self._xp
array_2d_slim=image_slim, mask_2d=self.real_space_mask, xp=xp
)

return Array2D(values=image_native, mask=self.real_space_mask)

def transform_mapping_matrix(self, mapping_matrix: np.ndarray) -> np.ndarray:
def transform_mapping_matrix(self, mapping_matrix: np.ndarray, xp=np) -> np.ndarray:
"""
Applies the DFT to a mapping matrix that maps source pixels to image pixels.

Expand Down Expand Up @@ -310,8 +308,6 @@ def __init__(
2.0 * self.grid.shape_native[1]
)

self._xp = xp

def initialize_plan(self, ratio: int = 2, interp_kernel: Tuple[int, int] = (6, 6)):
"""
Initializes the PyNUFFT plan for performing the NUFFT operation.
Expand Down Expand Up @@ -394,7 +390,7 @@ def visibilities_from(self, image: Array2D) -> Visibilities:
)

def image_from(
self, visibilities: Visibilities, use_adjoint_scaling: bool = False
self, visibilities: Visibilities, use_adjoint_scaling: bool = False, xp=np
) -> Array2D:
"""
Reconstructs a real-space image from visibilities using the NUFFT adjoint transform.
Expand Down Expand Up @@ -425,24 +421,24 @@ def image_from(

return Array2D(values=image, mask=self.real_space_mask)

def transform_mapping_matrix(self, mapping_matrix: np.ndarray) -> np.ndarray:
def transform_mapping_matrix(self, mapping_matrix: np.ndarray, xp=np) -> np.ndarray:
"""
Applies the NUFFT forward transform to each column of a mapping matrix, producing transformed visibilities.
Applies the NUFFT forward transform to each column of a mapping matrix, producing transformed visibilities.

Parameters
----------
mapping_matrix
A 2D array where each column corresponds to a source-plane pixel intensity distribution flattened into image space.
Parameters
----------
mapping_matrix
A 2D array where each column corresponds to a source-plane pixel intensity distribution flattened into image space.

Returns
Returns
-------
A complex-valued 2D array where each column contains the visibilities corresponding to the respective column
in the input mapping matrix.
A complex-valued 2D array where each column contains the visibilities corresponding to the respective column
in the input mapping matrix.

Notes
-----
- Each column of the input mapping matrix is reshaped into the native 2D image grid before transformation.
- This method repeatedly calls `visibilities_from` for each column, which may be computationally intensive.
Notes
-----
- Each column of the input mapping matrix is reshaped into the native 2D image grid before transformation.
- This method repeatedly calls `visibilities_from` for each column, which may be computationally intensive.
"""
transformed_mapping_matrix = 0 + 0j * np.zeros(
(self.uv_wavelengths.shape[0], mapping_matrix.shape[1])
Expand All @@ -452,7 +448,7 @@ def transform_mapping_matrix(self, mapping_matrix: np.ndarray) -> np.ndarray:
image_2d = array_2d_util.array_2d_native_from(
array_2d_slim=mapping_matrix[:, source_pixel_1d_index],
mask_2d=self.grid.mask,
xp=self._xp,
xp=xp,
)

image = Array2D(values=image_2d, mask=self.grid.mask)
Expand Down
26 changes: 13 additions & 13 deletions autoarray/operators/transformer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def visibilities_via_preload_from(


def visibilities_from(
image_1d: np.ndarray, grid_radians: np.ndarray, uv_wavelengths: np.ndarray
image_1d: np.ndarray, grid_radians: np.ndarray, uv_wavelengths: np.ndarray, xp=np
) -> np.ndarray:
"""
Compute complex visibilities from an input sky image using the Fourier transform,
Expand Down Expand Up @@ -150,19 +150,19 @@ def visibilities_from(
# Compute the dot product for each pixel-uv pair
phase = (
-2.0
* np.pi
* xp.pi
* (
np.outer(grid_radians[:, 1], uv_wavelengths[:, 0])
+ np.outer(grid_radians[:, 0], uv_wavelengths[:, 1])
xp.outer(grid_radians[:, 1], uv_wavelengths[:, 0])
+ xp.outer(grid_radians[:, 0], uv_wavelengths[:, 1])
)
) # shape (n_pixels, n_vis)

# Multiply image values with phase terms
vis_real = image_1d[:, None] * np.cos(phase)
vis_imag = image_1d[:, None] * np.sin(phase)
vis_real = image_1d[:, None] * xp.cos(phase)
vis_imag = image_1d[:, None] * xp.sin(phase)

# Sum over all pixels for each visibility
visibilities = np.sum(vis_real + 1j * vis_imag, axis=0)
visibilities = xp.sum(vis_real + 1j * vis_imag, axis=0)

return visibilities

Expand Down Expand Up @@ -247,7 +247,7 @@ def transformed_mapping_matrix_via_preload_from(


def transformed_mapping_matrix_from(
mapping_matrix: np.ndarray, grid_radians: np.ndarray, uv_wavelengths: np.ndarray
mapping_matrix: np.ndarray, grid_radians: np.ndarray, uv_wavelengths: np.ndarray, xp=np
) -> np.ndarray:
"""
Computes the Fourier-transformed mapping matrix used in radio interferometric imaging.
Expand All @@ -273,16 +273,16 @@ def transformed_mapping_matrix_from(
# Compute phase term: (n_image_pixels, n_visibilities)
phase = (
-2.0
* np.pi
* xp.pi
* (
np.outer(grid_radians[:, 1], uv_wavelengths[:, 0]) # y * u
+ np.outer(grid_radians[:, 0], uv_wavelengths[:, 1]) # x * v
xp.outer(grid_radians[:, 1], uv_wavelengths[:, 0]) # y * u
+ xp.outer(grid_radians[:, 0], uv_wavelengths[:, 1]) # x * v
)
)

# Compute real and imaginary Fourier matrices
fourier_real = np.cos(phase)
fourier_imag = np.sin(phase)
fourier_real = xp.cos(phase)
fourier_imag = xp.sin(phase)

# Only compute contributions from non-zero mapping entries
# This matrix multiplication is: (n_visibilities x n_image_pixels) dot (n_image_pixels x n_source_pixels)
Expand Down
Loading
Loading