Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions autoarray/abstract_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def __getitem__(self, item):

try:
import jax.numpy as jnp

if isinstance(result, jnp.ndarray):
result = self.with_new_array(result)
except ImportError:
Expand All @@ -351,6 +352,7 @@ def __setitem__(self, key, value):
self._array[key] = value
else:
import jax.numpy as jnp

self._array = jnp.where(key, value, self._array)

def __repr__(self):
Expand Down
4 changes: 3 additions & 1 deletion autoarray/dataset/imaging/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from pathlib import Path
from typing import Optional, Union

from autoconf import cached_property

from autoarray.dataset.abstract.dataset import AbstractDataset
from autoarray.dataset.grids import GridsDataset
from autoarray.dataset.imaging.w_tilde import WTildeImaging
Expand Down Expand Up @@ -191,7 +193,7 @@ def __init__(
psf=self.psf,
)

@property
@cached_property
def w_tilde(self):
"""
The w_tilde formalism of the linear algebra equations precomputes the convolution of every pair of masked
Expand Down
6 changes: 5 additions & 1 deletion autoarray/dataset/interferometer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import numpy as np
from pathlib import Path

from autoconf import cached_property

from autoconf.fitsable import ndarray_via_fits_from, output_to_fits

from autoarray.dataset.abstract.dataset import AbstractDataset
Expand All @@ -14,6 +16,8 @@

from autoarray.inversion.inversion.interferometer import inversion_interferometer_util

from autoarray import exc

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -165,7 +169,7 @@ def w_tilde_preprocessing(self):

fits.writeto(filename, data=curvature_preload)

@property
@cached_property
def w_tilde(self):
"""
The w_tilde formalism of the linear algebra equations precomputes the Fourier Transform of all the visibilities
Expand Down
14 changes: 7 additions & 7 deletions autoarray/fit/fit_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def residual_map_with_mask_from(
model_data
The model data used to fit the data.
"""
return xp.where(xp.asarray(mask) == 0, xp.subtract(data, model_data), 0)
return xp.where(mask == 0, xp.subtract(data, model_data), 0)


@to_new_array
Expand All @@ -221,7 +221,7 @@ def normalized_residual_map_with_mask_from(
mask
The mask applied to the residual-map, where `False` entries are included in the calculation.
"""
return xp.where(xp.asarray(mask) == 0, xp.divide(residual_map, noise_map), 0)
return xp.where(mask == 0, xp.divide(residual_map, noise_map), 0)


@to_new_array
Expand All @@ -244,7 +244,7 @@ def chi_squared_map_with_mask_from(
mask
The mask applied to the residual-map, where `False` entries are included in the calculation.
"""
return xp.where(xp.asarray(mask) == 0, xp.square(residual_map / noise_map), 0)
return xp.where(mask == 0, xp.square(residual_map / noise_map), 0)


def chi_squared_with_mask_from(
Expand All @@ -263,7 +263,7 @@ def chi_squared_with_mask_from(
mask
The mask applied to the chi-squared-map, where `False` entries are included in the calculation.
"""
return float(xp.sum(chi_squared_map[xp.asarray(mask) == 0]))
return float(xp.sum(chi_squared_map[mask == 0]))


def chi_squared_with_mask_fast_from(
Expand Down Expand Up @@ -301,8 +301,8 @@ def chi_squared_with_mask_fast_from(
xp.subtract(
data,
model_data,
)[xp.asarray(mask) == 0],
noise_map[xp.asarray(mask) == 0],
)[mask == 0],
noise_map[mask == 0],
)
)
)
Expand All @@ -326,7 +326,7 @@ def noise_normalization_with_mask_from(
mask
The mask applied to the noise-map, where `False` entries are included in the calculation.
"""
return float(xp.sum(xp.log(2 * xp.pi * noise_map[xp.asarray(mask) == 0] ** 2.0)))
return float(xp.sum(xp.log(2 * xp.pi * noise_map[mask == 0] ** 2.0)))


def chi_squared_with_noise_covariance_from(
Expand Down
11 changes: 8 additions & 3 deletions autoarray/inversion/pixelization/mesh/rectangular.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,12 @@ def requires_image_mesh(self):

class RectangularSource(RectangularMagnification):

def __init__(self, shape: Tuple[int, int] = (3, 3), weight_power: float = 1.0, weight_floor : float = 0.0):
def __init__(
self,
shape: Tuple[int, int] = (3, 3),
weight_power: float = 1.0,
weight_floor: float = 0.0,
):
"""
A uniform mesh of rectangular pixels, which without interpolation are paired with a 2D grid of (y,x)
coordinates.
Expand Down Expand Up @@ -203,9 +208,9 @@ def mesh_weight_map_from(self, adapt_data, xp=np) -> np.ndarray:
xp
The array library to use.
"""
mesh_weight_map = xp.asarray(adapt_data.array)
mesh_weight_map = adapt_data.array
mesh_weight_map = xp.clip(mesh_weight_map, 1e-12, None)
mesh_weight_map = mesh_weight_map ** self.weight_power
mesh_weight_map = mesh_weight_map**self.weight_power

# Apply floor using xp.where (safe for NumPy and JAX)
mesh_weight_map = xp.where(
Expand Down
2 changes: 1 addition & 1 deletion autoarray/inversion/regularization/gaussian_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def gauss_cov_matrix_from(
The Gaussian covariance matrix.
"""
# Ensure array:
pts = xp.asarray(pixel_points) # (N, 2)
pts = pixel_points # (N, 2)
# Compute squared distances: ||p_i - p_j||^2
diffs = pts[:, None, :] - pts[None, :, :] # (N, N, 2)
d2 = xp.sum(diffs**2, axis=-1) # (N, N)
Expand Down
3 changes: 3 additions & 0 deletions autoarray/mask/mask_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,9 @@ def __init__(
xp=xp,
)

slim_to_native = self.derive_indexes.native_for_slim.astype("int32")
self.slim_to_native_tuple = (slim_to_native[:, 0], slim_to_native[:, 1])

@property
def native_for_slim(self):
return self.derive_indexes.native_for_slim
Expand Down
8 changes: 6 additions & 2 deletions autoarray/operators/over_sampling/over_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,14 @@ def binned_array_2d_from(self, array: Array2D, xp=np) -> "Array2D":
else:

# Sum values per segment
sums = np.bincount(self.segment_ids, weights=array, minlength=self.mask.pixels_in_mask)
sums = np.bincount(
self.segment_ids, weights=array, minlength=self.mask.pixels_in_mask
)

# Count number of items per segment
counts = np.bincount(self.segment_ids, minlength=self.mask.pixels_in_mask)
counts = np.bincount(
self.segment_ids, minlength=self.mask.pixels_in_mask
)

# Avoid division by zero
counts[counts == 0] = 1
Expand Down
2 changes: 1 addition & 1 deletion autoarray/operators/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def visibilities_from(self, image: Array2D, xp=np) -> Visibilities:
image_1d=image.slim.array,
grid_radians=self.grid.array,
uv_wavelengths=self.uv_wavelengths,
xp=xp
xp=xp,
)

return Visibilities(visibilities=xp.array(visibilities))
Expand Down
5 changes: 4 additions & 1 deletion autoarray/operators/transformer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,10 @@ def transformed_mapping_matrix_via_preload_from(


def transformed_mapping_matrix_from(
mapping_matrix: np.ndarray, grid_radians: np.ndarray, uv_wavelengths: np.ndarray, xp=np
mapping_matrix: np.ndarray,
grid_radians: np.ndarray,
uv_wavelengths: np.ndarray,
xp=np,
) -> np.ndarray:
"""
Computes the Fourier-transformed mapping matrix used in radio interferometric imaging.
Expand Down
2 changes: 1 addition & 1 deletion autoarray/preloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(

ids_zeros = np.array(source_pixel_zeroed_indices, dtype=int)

values_to_solve = np.ones(np.max(mapper_indices)+1, dtype=bool)
values_to_solve = np.ones(np.max(mapper_indices) + 1, dtype=bool)
values_to_solve[ids_zeros] = False

self.source_pixel_zeroed_indices_to_keep = np.where(values_to_solve)[0]
Expand Down
Loading
Loading