diff --git a/autoarray/abstract_ndarray.py b/autoarray/abstract_ndarray.py index 9075dea03..6d71f0983 100644 --- a/autoarray/abstract_ndarray.py +++ b/autoarray/abstract_ndarray.py @@ -342,6 +342,7 @@ def __getitem__(self, item): def __setitem__(self, key, value): from jax import Array + if isinstance(key, (jnp.ndarray, AbstractNDArray, Array)): self._array = jnp.where(key, value, self._array) else: diff --git a/autoarray/dataset/imaging/simulator.py b/autoarray/dataset/imaging/simulator.py index 932416fa6..cefe89fda 100644 --- a/autoarray/dataset/imaging/simulator.py +++ b/autoarray/dataset/imaging/simulator.py @@ -100,7 +100,10 @@ def __init__( self.noise_seed = noise_seed def via_image_from( - self, image: Array2D, over_sample_size: Optional[Union[int, np.ndarray]] = None, xp=np + self, + image: Array2D, + over_sample_size: Optional[Union[int, np.ndarray]] = None, + xp=np, ) -> Imaging: """ Simulate an `Imaging` dataset from an input image. diff --git a/autoarray/fit/fit_dataset.py b/autoarray/fit/fit_dataset.py index bc24295cc..745ae5a45 100644 --- a/autoarray/fit/fit_dataset.py +++ b/autoarray/fit/fit_dataset.py @@ -83,7 +83,9 @@ def chi_squared(self) -> float: """ Returns the chi-squared terms of the model data's fit to an dataset, by summing the chi-squared-map. """ - return fit_util.chi_squared_from(chi_squared_map=self.chi_squared_map.array, xp=self._xp) + return fit_util.chi_squared_from( + chi_squared_map=self.chi_squared_map.array, xp=self._xp + ) @property def noise_normalization(self) -> float: @@ -92,7 +94,9 @@ def noise_normalization(self) -> float: [Noise_Term] = sum(log(2*pi*[Noise]**2.0)) """ - return fit_util.noise_normalization_from(noise_map=self.noise_map.array, xp=self._xp) + return fit_util.noise_normalization_from( + noise_map=self.noise_map.array, xp=self._xp + ) @property def log_likelihood(self) -> float: @@ -113,7 +117,7 @@ def __init__( dataset, use_mask_in_fit: bool = False, dataset_model: DatasetModel = None, - xp=np + xp=np, ): """Class to fit a masked dataset where the dataset's data structures are any dimension. @@ -209,7 +213,10 @@ def normalized_residual_map(self) -> ty.DataLike: """ if self.use_mask_in_fit: return fit_util.normalized_residual_map_with_mask_from( - residual_map=self.residual_map, noise_map=self.noise_map, mask=self.mask, xp=self._xp + residual_map=self.residual_map, + noise_map=self.noise_map, + mask=self.mask, + xp=self._xp, ) return super().normalized_residual_map @@ -222,7 +229,10 @@ def chi_squared_map(self) -> ty.DataLike: """ if self.use_mask_in_fit: return fit_util.chi_squared_map_with_mask_from( - residual_map=self.residual_map, noise_map=self.noise_map, mask=self.mask, xp=self._xp + residual_map=self.residual_map, + noise_map=self.noise_map, + mask=self.mask, + xp=self._xp, ) return super().chi_squared_map diff --git a/autoarray/fit/fit_imaging.py b/autoarray/fit/fit_imaging.py index 130f4225c..98ae98547 100644 --- a/autoarray/fit/fit_imaging.py +++ b/autoarray/fit/fit_imaging.py @@ -14,7 +14,7 @@ def __init__( dataset: Imaging, use_mask_in_fit: bool = False, dataset_model: DatasetModel = None, - xp=np + xp=np, ): """ Class to fit a masked imaging dataset. @@ -50,7 +50,7 @@ def __init__( dataset=dataset, use_mask_in_fit=use_mask_in_fit, dataset_model=dataset_model, - xp=xp + xp=xp, ) @property diff --git a/autoarray/fit/fit_interferometer.py b/autoarray/fit/fit_interferometer.py index 5934382b4..b37aaa234 100644 --- a/autoarray/fit/fit_interferometer.py +++ b/autoarray/fit/fit_interferometer.py @@ -17,7 +17,7 @@ def __init__( dataset: Interferometer, dataset_model: DatasetModel = None, use_mask_in_fit: bool = False, - xp=np + xp=np, ): """ Class to fit a masked interferometer dataset. @@ -58,7 +58,7 @@ def __init__( dataset=dataset, dataset_model=dataset_model, use_mask_in_fit=use_mask_in_fit, - xp=xp + xp=xp, ) @property diff --git a/autoarray/fit/fit_util.py b/autoarray/fit/fit_util.py index d8c4541d7..3f7363fca 100644 --- a/autoarray/fit/fit_util.py +++ b/autoarray/fit/fit_util.py @@ -247,7 +247,9 @@ def chi_squared_map_with_mask_from( return xp.where(xp.asarray(mask) == 0, xp.square(residual_map / noise_map), 0) -def chi_squared_with_mask_from(*, chi_squared_map: ty.DataLike, mask: Mask, xp=np) -> float: +def chi_squared_with_mask_from( + *, chi_squared_map: ty.DataLike, mask: Mask, xp=np +) -> float: """ Returns the chi-squared terms of each model data's fit to a masked dataset, by summing the masked chi-squared-map of the fit. @@ -265,7 +267,12 @@ def chi_squared_with_mask_from(*, chi_squared_map: ty.DataLike, mask: Mask, xp=n def chi_squared_with_mask_fast_from( - *, data: ty.DataLike, mask: Mask, model_data: ty.DataLike, noise_map: ty.DataLike, xp=np + *, + data: ty.DataLike, + mask: Mask, + model_data: ty.DataLike, + noise_map: ty.DataLike, + xp=np, ) -> float: """ Returns the chi-squared terms of each model data's fit to a masked dataset, by summing the masked @@ -302,7 +309,9 @@ def chi_squared_with_mask_fast_from( ) -def noise_normalization_with_mask_from(*, noise_map: ty.DataLike, mask: Mask, xp=np) -> float: +def noise_normalization_with_mask_from( + *, noise_map: ty.DataLike, mask: Mask, xp=np +) -> float: """ Returns the noise-map normalization terms of masked noise-map, summing the noise_map value in every pixel as: @@ -317,9 +326,7 @@ def noise_normalization_with_mask_from(*, noise_map: ty.DataLike, mask: Mask, xp mask The mask applied to the noise-map, where `False` entries are included in the calculation. """ - return float( - xp.sum(xp.log(2 * xp.pi * noise_map[xp.asarray(mask) == 0] ** 2.0)) - ) + return float(xp.sum(xp.log(2 * xp.pi * noise_map[xp.asarray(mask) == 0] ** 2.0))) def chi_squared_with_noise_covariance_from( diff --git a/autoarray/inversion/inversion/abstract.py b/autoarray/inversion/inversion/abstract.py index d7b708691..43a93df21 100644 --- a/autoarray/inversion/inversion/abstract.py +++ b/autoarray/inversion/inversion/abstract.py @@ -18,6 +18,7 @@ from autoarray.util import misc_util from autoarray.inversion.inversion import inversion_util + class AbstractInversion: def __init__( self, @@ -25,7 +26,7 @@ def __init__( linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), preloads: Preloads = None, - xp=np + xp=np, ): """ An `Inversion` reconstructs an input dataset using a list of linear objects (e.g. a list of analytic functions @@ -75,8 +76,6 @@ def __init__( self._xp = xp - - @property def data(self): return self.dataset.data @@ -333,10 +332,15 @@ def regularization_matrix(self) -> Optional[np.ndarray]: """ if self._xp.__name__.startswith("jax"): from jax.scipy.linalg import block_diag + return block_diag( - *[linear_obj.regularization_matrix for linear_obj in self.linear_obj_list] + *[ + linear_obj.regularization_matrix + for linear_obj in self.linear_obj_list + ] ) from scipy.linalg import block_diag + return block_diag( *[linear_obj.regularization_matrix for linear_obj in self.linear_obj_list] ) @@ -448,7 +452,7 @@ def reconstruction(self) -> np.ndarray: data_vector=data_vector, curvature_reg_matrix=curvature_reg_matrix, settings=self.settings, - xp=self._xp + xp=self._xp, ) ) @@ -471,13 +475,13 @@ def reconstruction(self) -> np.ndarray: data_vector=self.data_vector, curvature_reg_matrix=self.curvature_reg_matrix, settings=self.settings, - xp=self._xp + xp=self._xp, ) return inversion_util.reconstruction_positive_negative_from( data_vector=self.data_vector, curvature_reg_matrix=self.curvature_reg_matrix, - xp=self._xp + xp=self._xp, ) @property @@ -640,7 +644,9 @@ def regularization_term(self) -> float: return self._xp.matmul( self.reconstruction_reduced.T, - self._xp.matmul(self.regularization_matrix_reduced, self.reconstruction_reduced), + self._xp.matmul( + self.regularization_matrix_reduced, self.reconstruction_reduced + ), ) @property @@ -654,7 +660,11 @@ def log_det_curvature_reg_matrix_term(self) -> float: return 0.0 return 2.0 * self._xp.sum( - self._xp.log(self._xp.diag(self._xp.linalg.cholesky(self.curvature_reg_matrix_reduced))) + self._xp.log( + self._xp.diag( + self._xp.linalg.cholesky(self.curvature_reg_matrix_reduced) + ) + ) ) @property @@ -675,7 +685,11 @@ def log_det_regularization_matrix_term(self) -> float: return 0.0 return 2.0 * self._xp.sum( - self._xp.log(self._xp.diag(self._xp.linalg.cholesky(self.regularization_matrix_reduced))) + self._xp.log( + self._xp.diag( + self._xp.linalg.cholesky(self.regularization_matrix_reduced) + ) + ) ) @property @@ -738,7 +752,9 @@ def regularization_weights_from(self, index: int) -> np.ndarray: return np.zeros((pixels,)) - return regularization.regularization_weights_from(linear_obj=linear_obj, xp=self._xp) + return regularization.regularization_weights_from( + linear_obj=linear_obj, xp=self._xp + ) @property def regularization_weights_mapper_dict(self) -> Dict[LinearObj, np.ndarray]: diff --git a/autoarray/inversion/inversion/factory.py b/autoarray/inversion/inversion/factory.py index bbfebc861..ae0a06172 100644 --- a/autoarray/inversion/inversion/factory.py +++ b/autoarray/inversion/inversion/factory.py @@ -24,7 +24,7 @@ def inversion_from( linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), preloads: Preloads = None, - xp=np + xp=np, ): """ Factory which given an input dataset and list of linear objects, creates an `Inversion`. @@ -60,14 +60,11 @@ def inversion_from( linear_obj_list=linear_obj_list, settings=settings, preloads=preloads, - xp=xp + xp=xp, ) return inversion_interferometer_from( - dataset=dataset, - linear_obj_list=linear_obj_list, - settings=settings, - xp=xp + dataset=dataset, linear_obj_list=linear_obj_list, settings=settings, xp=xp ) @@ -76,7 +73,7 @@ def inversion_imaging_from( linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), preloads: Preloads = None, - xp=np + xp=np, ): """ Factory which given an input `Imaging` dataset and list of linear objects, creates an `InversionImaging`. @@ -129,7 +126,7 @@ def inversion_imaging_from( w_tilde=w_tilde, linear_obj_list=linear_obj_list, settings=settings, - xp=xp + xp=xp, ) return InversionImagingMapping( @@ -137,7 +134,7 @@ def inversion_imaging_from( linear_obj_list=linear_obj_list, settings=settings, preloads=preloads, - xp=xp + xp=xp, ) @@ -145,7 +142,7 @@ def inversion_interferometer_from( dataset: Union[Interferometer, DatasetInterface], linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), - xp=np + xp=np, ): """ Factory which given an input `Interferometer` dataset and list of linear objects, creates @@ -199,7 +196,7 @@ def inversion_interferometer_from( w_tilde=w_tilde, linear_obj_list=linear_obj_list, settings=settings, - xp=xp + xp=xp, ) else: @@ -207,5 +204,5 @@ def inversion_interferometer_from( dataset=dataset, linear_obj_list=linear_obj_list, settings=settings, - xp=xp + xp=xp, ) diff --git a/autoarray/inversion/inversion/imaging/abstract.py b/autoarray/inversion/inversion/imaging/abstract.py index 812b90672..ed001c158 100644 --- a/autoarray/inversion/inversion/imaging/abstract.py +++ b/autoarray/inversion/inversion/imaging/abstract.py @@ -20,7 +20,7 @@ def __init__( linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), preloads: Preloads = None, - xp=np + xp=np, ): """ An `Inversion` reconstructs an input dataset using a list of linear objects (e.g. a list of analytic functions @@ -93,7 +93,9 @@ def operated_mapping_matrix_list(self) -> List[np.ndarray]: return [ ( self.psf.convolved_mapping_matrix_from( - mapping_matrix=linear_obj.mapping_matrix, mask=self.mask, xp=self._xp + mapping_matrix=linear_obj.mapping_matrix, + mask=self.mask, + xp=self._xp, ) if linear_obj.operated_mapping_matrix_override is None else self.linear_func_operated_mapping_matrix_dict[linear_obj] @@ -137,7 +139,7 @@ def linear_func_operated_mapping_matrix_dict(self) -> Dict: operated_mapping_matrix = self.psf.convolved_mapping_matrix_from( mapping_matrix=linear_func.mapping_matrix, mask=self.mask, - xp=self._xp + xp=self._xp, ) linear_func_operated_mapping_matrix_dict[linear_func] = ( @@ -217,9 +219,7 @@ def mapper_operated_mapping_matrix_dict(self) -> Dict: for mapper in self.cls_list_from(cls=AbstractMapper): operated_mapping_matrix = self.psf.convolved_mapping_matrix_from( - mapping_matrix=mapper.mapping_matrix, - mask=self.mask, - xp=self._xp + mapping_matrix=mapper.mapping_matrix, mask=self.mask, xp=self._xp ) mapper_operated_mapping_matrix_dict[mapper] = operated_mapping_matrix diff --git a/autoarray/inversion/inversion/imaging/inversion_imaging_util.py b/autoarray/inversion/inversion/imaging/inversion_imaging_util.py index 19d8cbd5a..a9298835a 100644 --- a/autoarray/inversion/inversion/imaging/inversion_imaging_util.py +++ b/autoarray/inversion/inversion/imaging/inversion_imaging_util.py @@ -61,7 +61,7 @@ def w_tilde_data_imaging_from( noise_map_native: np.ndarray, kernel_native: np.ndarray, native_index_for_slim_index, - xp=np + xp=np, ) -> np.ndarray: """ The matrix w_tilde is a matrix of dimensions [image_pixels, image_pixels] that encodes the PSF convolution of diff --git a/autoarray/inversion/inversion/imaging/mapping.py b/autoarray/inversion/inversion/imaging/mapping.py index d02481cdc..f72e27b04 100644 --- a/autoarray/inversion/inversion/imaging/mapping.py +++ b/autoarray/inversion/inversion/imaging/mapping.py @@ -21,7 +21,7 @@ def __init__( linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), preloads: Preloads = None, - xp=np + xp=np, ): """ Constructs linear equations (via vectors and matrices) which allow for sets of simultaneous linear equations @@ -48,7 +48,7 @@ def __init__( linear_obj_list=linear_obj_list, settings=settings, preloads=preloads, - xp=xp + xp=xp, ) @property @@ -142,7 +142,7 @@ def _curvature_matrix_mapper_diag(self) -> Optional[np.ndarray]: settings=self.settings, add_to_curvature_diag=True, no_regularization_index_list=self.no_regularization_index_list, - xp=self._xp + xp=self._xp, ) curvature_matrix[ @@ -181,7 +181,7 @@ def curvature_matrix(self): settings=self.settings, add_to_curvature_diag=True, no_regularization_index_list=self.no_regularization_index_list, - xp=self._xp + xp=self._xp, ) @property @@ -224,7 +224,7 @@ def mapped_reconstructed_data_dict(self) -> Dict[LinearObj, Array2D]: inversion_util.mapped_reconstructed_data_via_mapping_matrix_from( mapping_matrix=operated_mapping_matrix_list[index], reconstruction=reconstruction, - xp=self._xp + xp=self._xp, ) ) diff --git a/autoarray/inversion/inversion/imaging/w_tilde.py b/autoarray/inversion/inversion/imaging/w_tilde.py index e6346b088..a5400e299 100644 --- a/autoarray/inversion/inversion/imaging/w_tilde.py +++ b/autoarray/inversion/inversion/imaging/w_tilde.py @@ -22,7 +22,7 @@ def __init__( w_tilde: WTildeImaging, linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), - xp=np + xp=np, ): """ Constructs linear equations (via vectors and matrices) which allow for sets of simultaneous linear equations @@ -59,10 +59,7 @@ def __init__( ) super().__init__( - dataset=dataset, - linear_obj_list=linear_obj_list, - settings=settings, - xp=xp + dataset=dataset, linear_obj_list=linear_obj_list, settings=settings, xp=xp ) if self.settings.use_w_tilde: @@ -522,9 +519,7 @@ def mapped_reconstructed_data_dict(self) -> Dict[LinearObj, Array2D]: ) mapped_reconstructed_image = self.psf.convolved_image_from( - image=mapped_reconstructed_image, - blurring_image=None, - xp=self._xp + image=mapped_reconstructed_image, blurring_image=None, xp=self._xp ).array mapped_reconstructed_image = Array2D( diff --git a/autoarray/inversion/inversion/interferometer/abstract.py b/autoarray/inversion/inversion/interferometer/abstract.py index d8d51fd9c..dd37952d9 100644 --- a/autoarray/inversion/inversion/interferometer/abstract.py +++ b/autoarray/inversion/inversion/interferometer/abstract.py @@ -18,7 +18,7 @@ def __init__( dataset: Union[Interferometer, DatasetInterface], linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), - xp=np + xp=np, ): """ Constructs linear equations (via vectors and matrices) which allow for sets of simultaneous linear equations @@ -42,10 +42,7 @@ def __init__( """ super().__init__( - dataset=dataset, - linear_obj_list=linear_obj_list, - settings=settings, - xp=xp + dataset=dataset, linear_obj_list=linear_obj_list, settings=settings, xp=xp ) @property @@ -113,7 +110,7 @@ def mapped_reconstructed_image_dict( inversion_util.mapped_reconstructed_data_via_mapping_matrix_from( mapping_matrix=linear_obj.mapping_matrix, reconstruction=reconstruction, - xp=self._xp + xp=self._xp, ) ) diff --git a/autoarray/inversion/inversion/interferometer/mapping.py b/autoarray/inversion/inversion/interferometer/mapping.py index 06d1c5dbd..948b9c36c 100644 --- a/autoarray/inversion/inversion/interferometer/mapping.py +++ b/autoarray/inversion/inversion/interferometer/mapping.py @@ -20,7 +20,7 @@ def __init__( dataset: Union[Interferometer, DatasetInterface], linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), - xp=np + xp=np, ): """ Constructs linear equations (via vectors and matrices) which allow for sets of simultaneous linear equations @@ -47,10 +47,7 @@ def __init__( """ super().__init__( - dataset=dataset, - linear_obj_list=linear_obj_list, - settings=settings, - xp=xp + dataset=dataset, linear_obj_list=linear_obj_list, settings=settings, xp=xp ) @property @@ -91,13 +88,13 @@ def curvature_matrix(self) -> np.ndarray: real_curvature_matrix = inversion_util.curvature_matrix_via_mapping_matrix_from( mapping_matrix=self.operated_mapping_matrix.real, noise_map=self.noise_map.real, - xp=self._xp + xp=self._xp, ) imag_curvature_matrix = inversion_util.curvature_matrix_via_mapping_matrix_from( mapping_matrix=self.operated_mapping_matrix.imag, noise_map=self.noise_map.imag, - xp=self._xp + xp=self._xp, ) curvature_matrix = self._xp.add(real_curvature_matrix, imag_curvature_matrix) @@ -107,7 +104,7 @@ def curvature_matrix(self) -> np.ndarray: curvature_matrix=curvature_matrix, value=self.settings.no_regularization_add_to_curvature_diag_value, no_regularization_index_list=self.no_regularization_index_list, - xp=self._xp + xp=self._xp, ) return curvature_matrix diff --git a/autoarray/inversion/inversion/interferometer/w_tilde.py b/autoarray/inversion/inversion/interferometer/w_tilde.py index afe9d54be..888241081 100644 --- a/autoarray/inversion/inversion/interferometer/w_tilde.py +++ b/autoarray/inversion/inversion/interferometer/w_tilde.py @@ -25,7 +25,7 @@ def __init__( w_tilde: WTildeInterferometer, linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), - xp=np + xp=np, ): """ Constructs linear equations (via vectors and matrices) which allow for sets of simultaneous linear equations @@ -122,7 +122,9 @@ def curvature_matrix_diag(self) -> np.ndarray: if self.settings.use_w_tilde_numpy: return inversion_util.curvature_matrix_via_w_tilde_from( - w_tilde=self.w_tilde.w_matrix, mapping_matrix=self.mapping_matrix, xp=self._xp + w_tilde=self.w_tilde.w_matrix, + mapping_matrix=self.mapping_matrix, + xp=self._xp, ) mapper = self.cls_list_from(cls=AbstractMapper)[0] diff --git a/autoarray/inversion/inversion/inversion_util.py b/autoarray/inversion/inversion/inversion_util.py index f2eb65ca7..2190bf81c 100644 --- a/autoarray/inversion/inversion/inversion_util.py +++ b/autoarray/inversion/inversion/inversion_util.py @@ -7,6 +7,7 @@ from autoarray import exc from autoarray.util.fnnls import fnnls_cholesky + def curvature_matrix_via_w_tilde_from( w_tilde: np.ndarray, mapping_matrix: np.ndarray, xp=np ) -> np.ndarray: @@ -38,7 +39,7 @@ def curvature_matrix_with_added_to_diag_from( curvature_matrix: np.ndarray, value: float, no_regularization_index_list: Optional[List] = None, - xp=np + xp=np, ) -> np.ndarray: """ It is common for the `curvature_matrix` computed to not be positive-definite, leading for the inversion @@ -58,13 +59,13 @@ def curvature_matrix_with_added_to_diag_from( return curvature_matrix.at[ no_regularization_index_list, no_regularization_index_list ].add(value) - curvature_matrix[no_regularization_index_list, no_regularization_index_list] += value + curvature_matrix[ + no_regularization_index_list, no_regularization_index_list + ] += value return curvature_matrix -def curvature_matrix_mirrored_from( - curvature_matrix: np.ndarray, xp=np -) -> np.ndarray: +def curvature_matrix_mirrored_from(curvature_matrix: np.ndarray, xp=np) -> np.ndarray: # Copy the original matrix and its transpose m1 = curvature_matrix @@ -82,7 +83,7 @@ def curvature_matrix_via_mapping_matrix_from( add_to_curvature_diag: bool = False, no_regularization_index_list: Optional[List] = None, settings: SettingsInversion = SettingsInversion(), - xp=np + xp=np, ) -> np.ndarray: """ Returns the curvature matrix `F` from a blurred mapping matrix `f` and the 1D noise-map $\sigma$ @@ -104,7 +105,7 @@ def curvature_matrix_via_mapping_matrix_from( curvature_matrix=curvature_matrix, value=settings.no_regularization_add_to_curvature_diag_value, no_regularization_index_list=no_regularization_index_list, - xp=xp + xp=xp, ) return curvature_matrix @@ -247,6 +248,7 @@ def reconstruction_positive_only_from( if xp.__name__.startswith("jax"): import jaxnnls + return jaxnnls.solve_nnls_primal(curvature_reg_matrix, data_vector) try: @@ -265,8 +267,6 @@ def reconstruction_positive_only_from( raise exc.InversionException() from e - - def preconditioner_matrix_via_mapping_matrix_from( mapping_matrix: np.ndarray, regularization_matrix: np.ndarray, diff --git a/autoarray/inversion/inversion/mapper_valued.py b/autoarray/inversion/inversion/mapper_valued.py index 774080768..2415d52d2 100644 --- a/autoarray/inversion/inversion/mapper_valued.py +++ b/autoarray/inversion/inversion/mapper_valued.py @@ -200,7 +200,7 @@ def mapped_reconstructed_image_from( values=inversion_util.mapped_reconstructed_data_via_mapping_matrix_from( mapping_matrix=mapping_matrix, reconstruction=self.values_masked, - xp=self.mapper._xp + xp=self.mapper._xp, ), mask=self.mapper.mapper_grids.mask, ) diff --git a/autoarray/inversion/linear_obj/func_list.py b/autoarray/inversion/linear_obj/func_list.py index ef251191d..ea7a583e4 100644 --- a/autoarray/inversion/linear_obj/func_list.py +++ b/autoarray/inversion/linear_obj/func_list.py @@ -13,7 +13,7 @@ def __init__( self, grid: Grid1D2DLike, regularization: Optional[AbstractRegularization], - xp=np + xp=np, ): """ A linear object which reconstructs a dataset based on mapping between the data points of that dataset and diff --git a/autoarray/inversion/linear_obj/linear_obj.py b/autoarray/inversion/linear_obj/linear_obj.py index bc498fddb..08b846d4f 100644 --- a/autoarray/inversion/linear_obj/linear_obj.py +++ b/autoarray/inversion/linear_obj/linear_obj.py @@ -6,11 +6,7 @@ class LinearObj: - def __init__( - self, - regularization: Optional[AbstractRegularization], - xp=np - ): + def __init__(self, regularization: Optional[AbstractRegularization], xp=np): """ A linear object which reconstructs a dataset based on mapping between the data points of that dataset and the parameters of the linear object. For example, the linear obj could map to the data via analytic functions @@ -152,4 +148,6 @@ def regularization_matrix(self) -> np.ndarray: if self.regularization is None: return self._xp.zeros((self.params, self.params)) - return self.regularization.regularization_matrix_from(linear_obj=self, xp=self._xp) + return self.regularization.regularization_matrix_from( + linear_obj=self, xp=self._xp + ) diff --git a/autoarray/inversion/pixelization/border_relocator.py b/autoarray/inversion/pixelization/border_relocator.py index 4703d6da1..4856bfd10 100644 --- a/autoarray/inversion/pixelization/border_relocator.py +++ b/autoarray/inversion/pixelization/border_relocator.py @@ -351,15 +351,13 @@ def relocated_grid_from(self, grid: Grid2D, xp=np) -> Grid2D: return grid values = relocated_grid_from( - grid=grid.array, - border_grid=grid.array[self.border_slim], - xp=xp + grid=grid.array, border_grid=grid.array[self.border_slim], xp=xp ) over_sampled = relocated_grid_from( grid=grid.over_sampled.array, border_grid=grid.over_sampled.array[self.sub_border_slim], - xp=xp + xp=xp, ) return Grid2D( @@ -367,7 +365,7 @@ def relocated_grid_from(self, grid: Grid2D, xp=np) -> Grid2D: mask=grid.mask, over_sample_size=self.sub_size, over_sampled=over_sampled, - xp=xp + xp=xp, ) def relocated_mesh_grid_from( @@ -388,9 +386,7 @@ def relocated_mesh_grid_from( return Grid2DIrregular( values=relocated_grid_from( - grid=mesh_grid.array, - border_grid=grid[self.sub_border_slim], - xp=xp + grid=mesh_grid.array, border_grid=grid[self.sub_border_slim], xp=xp ), - xp=xp + xp=xp, ) diff --git a/autoarray/inversion/pixelization/mappers/abstract.py b/autoarray/inversion/pixelization/mappers/abstract.py index cbe051744..38800f4df 100644 --- a/autoarray/inversion/pixelization/mappers/abstract.py +++ b/autoarray/inversion/pixelization/mappers/abstract.py @@ -24,7 +24,7 @@ def __init__( mapper_grids: MapperGrids, regularization: Optional[AbstractRegularization], border_relocator: BorderRelocator, - xp=np + xp=np, ): """ To understand a `Mapper` one must be familiar `Mesh` objects and the `mesh` and `pixelization` packages, where @@ -267,7 +267,7 @@ def mapping_matrix(self) -> np.ndarray: total_mask_pixels=self.over_sampler.mask.pixels_in_mask, slim_index_for_sub_slim_index=self.slim_index_for_sub_slim_index, sub_fraction=self.over_sampler.sub_fraction.array, - xp=self._xp + xp=self._xp, ) def pixel_signals_from(self, signal_scale: float, xp=np) -> np.ndarray: @@ -292,7 +292,7 @@ def pixel_signals_from(self, signal_scale: float, xp=np) -> np.ndarray: pix_size_for_sub_slim_index=self.pix_sizes_for_sub_slim_index, slim_index_for_sub_slim_index=self.over_sampler.slim_for_sub_slim, adapt_data=self.adapt_data.array, - xp=xp + xp=xp, ) def slim_indexes_for_pix_indexes(self, pix_indexes: List) -> List[List]: diff --git a/autoarray/inversion/pixelization/mappers/factory.py b/autoarray/inversion/pixelization/mappers/factory.py index 99bb25190..71ed472af 100644 --- a/autoarray/inversion/pixelization/mappers/factory.py +++ b/autoarray/inversion/pixelization/mappers/factory.py @@ -14,7 +14,7 @@ def mapper_from( mapper_grids: MapperGrids, regularization: Optional[AbstractRegularization], border_relocator: Optional[BorderRelocator] = None, - xp=np + xp=np, ): """ Factory which given input `MapperGrids` and `Regularization` objects creates a `Mapper`. @@ -54,26 +54,26 @@ def mapper_from( mapper_grids=mapper_grids, border_relocator=border_relocator, regularization=regularization, - xp=xp + xp=xp, ) elif isinstance(mapper_grids.source_plane_mesh_grid, Mesh2DRectangular): return MapperRectangular( mapper_grids=mapper_grids, border_relocator=border_relocator, regularization=regularization, - xp=xp + xp=xp, ) elif isinstance(mapper_grids.source_plane_mesh_grid, Mesh2DDelaunay): return MapperDelaunay( mapper_grids=mapper_grids, border_relocator=border_relocator, regularization=regularization, - xp=xp + xp=xp, ) elif isinstance(mapper_grids.source_plane_mesh_grid, Mesh2DVoronoi): return MapperVoronoi( mapper_grids=mapper_grids, border_relocator=border_relocator, regularization=regularization, - xp=xp + xp=xp, ) diff --git a/autoarray/inversion/pixelization/mappers/mapper_grids.py b/autoarray/inversion/pixelization/mappers/mapper_grids.py index 12f98d4ff..59903d03c 100644 --- a/autoarray/inversion/pixelization/mappers/mapper_grids.py +++ b/autoarray/inversion/pixelization/mappers/mapper_grids.py @@ -19,7 +19,7 @@ def __init__( source_plane_mesh_grid: Optional[Abstract2DMesh] = None, image_plane_mesh_grid: Optional[Grid2DIrregular] = None, adapt_data: Optional[np.ndarray] = None, - mesh_weight_map : Optional[Array2D] = None, + mesh_weight_map: Optional[Array2D] = None, ): """ Groups the different grids used by `Mesh` objects, the `mesh` package and the `pixelization` package, which diff --git a/autoarray/inversion/pixelization/mappers/mapper_util.py b/autoarray/inversion/pixelization/mappers/mapper_util.py index 49c51f52d..06857dc69 100644 --- a/autoarray/inversion/pixelization/mappers/mapper_util.py +++ b/autoarray/inversion/pixelization/mappers/mapper_util.py @@ -7,13 +7,18 @@ def forward_interp(xp, yp, x): import jax import jax.numpy as jnp - return jax.vmap(jnp.interp, in_axes=(1, 1, None, None, None))(x, xp, yp, 0, 1).T + + return jax.vmap(jnp.interp, in_axes=(1, 1, 1, None, None), out_axes=(1))( + x, xp, yp, 0, 1 + ) def reverse_interp(xp, yp, x): import jax import jax.numpy as jnp - return jax.vmap(jnp.interp, in_axes=(1, None, 1))(x, xp, yp).T + + return jax.vmap(jnp.interp, in_axes=(1, 1, 1), out_axes=(1))(x, xp, yp) + def forward_interp_np(xp, yp, x): """ @@ -34,6 +39,7 @@ def forward_interp_np(xp, yp, x): return out + def reverse_interp_np(xp, yp, x): """ xp : (N,) or (N, M) @@ -42,8 +48,8 @@ def reverse_interp_np(xp, yp, x): """ # Ensure xp is 2D: (N, M) - if xp.ndim == 1 and yp.ndim == 2: # (N, 1) - xp = np.broadcast_to(xp[:, None] , yp.shape) + if xp.ndim == 1 and yp.ndim == 2: # (N, 1) + xp = np.broadcast_to(xp[:, None], yp.shape) # Shapes K, M = x.shape @@ -57,38 +63,43 @@ def reverse_interp_np(xp, yp, x): return out -def create_transforms(traced_points, mesh_weight_map = None, xp=np): - # make functions that takes a set of traced points - # stored in a (N, 2) array and return functions that - # take in (N, 2) arrays and transform the values into - # the range (0, 1) and the inverse transform + +def create_transforms(traced_points, mesh_weight_map=None, xp=np): + N = traced_points.shape[0] # // 2 if mesh_weight_map is None: t = xp.arange(1, N + 1) / (N + 1) + t = xp.stack([t, t], axis=1) + sort_points = xp.sort(traced_points, axis=0) # [::2] else: - t = xp.cumsum(mesh_weight_map) - - sort_points = xp.sort(traced_points, axis=0) # [::2] + sdx = xp.argsort(traced_points, axis=0) + sort_points = xp.take_along_axis(traced_points, sdx, axis=0) + t = xp.stack([mesh_weight_map, mesh_weight_map], axis=1) + t = xp.take_along_axis(t, sdx, axis=0) + t = xp.cumsum(t, axis=0) if xp.__name__.startswith("jax"): transform = partial(forward_interp, sort_points, t) inv_transform = partial(reverse_interp, t, sort_points) return transform, inv_transform - else: - transform = partial(forward_interp_np, sort_points, t) - inv_transform = partial(reverse_interp_np, t, sort_points) - return transform, inv_transform + transform = partial(forward_interp_np, sort_points, t) + inv_transform = partial(reverse_interp_np, t, sort_points) + return transform, inv_transform -def adaptive_rectangular_transformed_grid_from(source_plane_data_grid, grid, xp=np): +def adaptive_rectangular_transformed_grid_from( + source_plane_data_grid, grid, mesh_weight_map=None, xp=np +): mu = source_plane_data_grid.mean(axis=0) scale = source_plane_data_grid.std(axis=0).min() source_grid_scaled = (source_plane_data_grid - mu) / scale - transform, inv_transform = create_transforms(source_grid_scaled, xp=xp) + transform, inv_transform = create_transforms( + source_grid_scaled, mesh_weight_map=mesh_weight_map, xp=xp + ) def inv_full(U): return inv_transform(U) * scale + mu @@ -96,33 +107,38 @@ def inv_full(U): return inv_full(grid) -def adaptive_rectangular_areas_from(source_grid_size, source_plane_data_grid, xp=np): +def adaptive_rectangular_areas_from( + source_grid_shape, source_plane_data_grid, mesh_weight_map=None, xp=np +): - pixel_edges_1d = xp.linspace(0, 1, source_grid_size + 1) + edges_y = xp.linspace(1, 0, source_grid_shape[0] + 1) + edges_x = xp.linspace(0, 1, source_grid_shape[1] + 1) mu = source_plane_data_grid.mean(axis=0) scale = source_plane_data_grid.std(axis=0).min() source_grid_scaled = (source_plane_data_grid - mu) / scale - transform, inv_transform = create_transforms(source_grid_scaled, xp=xp) + transform, inv_transform = create_transforms( + source_grid_scaled, mesh_weight_map=mesh_weight_map, xp=xp + ) def inv_full(U): return inv_transform(U) * scale + mu - pixel_edges = inv_full(xp.stack([pixel_edges_1d, pixel_edges_1d]).T) + pixel_edges = inv_full(xp.stack([edges_y, edges_x]).T) pixel_lengths = xp.diff(pixel_edges, axis=0).squeeze() # shape (N_source, 2) dy = pixel_lengths[:, 0] dx = pixel_lengths[:, 1] - return xp.outer(dy, dx).flatten() + return xp.abs(xp.outer(dy, dx).flatten()) def adaptive_rectangular_mappings_weights_via_interpolation_from( source_grid_size: int, source_plane_data_grid, source_plane_data_grid_over_sampled, - mesh_weight_map = None, + mesh_weight_map=None, xp=np, ): """ @@ -178,13 +194,16 @@ def adaptive_rectangular_mappings_weights_via_interpolation_from( The bilinear interpolation weights for each of the four neighboring pixels. Order: [w_bl, w_br, w_tl, w_tr]. """ + # --- Step 1. Normalize grid --- mu = source_plane_data_grid.mean(axis=0) scale = source_plane_data_grid.std(axis=0).min() source_grid_scaled = (source_plane_data_grid - mu) / scale # --- Step 2. Build transforms --- - transform, inv_transform = create_transforms(source_grid_scaled, mesh_weight_map=mesh_weight_map, xp=xp) + transform, inv_transform = create_transforms( + source_grid_scaled, mesh_weight_map=mesh_weight_map, xp=xp + ) # --- Step 3. Transform oversampled grid into index space --- grid_over_sampled_scaled = (source_plane_data_grid_over_sampled - mu) / scale @@ -236,7 +255,7 @@ def rectangular_mappings_weights_via_interpolation_from( shape_native: Tuple[int, int], source_plane_data_grid: np.ndarray, source_plane_mesh_grid: np.ndarray, - xp=np + xp=np, ): """ Compute bilinear interpolation weights and corresponding rectangular mesh indices for an irregular grid. @@ -347,7 +366,7 @@ def adaptive_pixel_signals_from( pix_size_for_sub_slim_index: np.ndarray, slim_index_for_sub_slim_index: np.ndarray, adapt_data: np.ndarray, - xp=np + xp=np, ) -> np.ndarray: """ Returns the signal in each pixel, where the signal is the sum of its mapped data values. diff --git a/autoarray/inversion/pixelization/mappers/rectangular.py b/autoarray/inversion/pixelization/mappers/rectangular.py index fa3540961..f8c227c4d 100644 --- a/autoarray/inversion/pixelization/mappers/rectangular.py +++ b/autoarray/inversion/pixelization/mappers/rectangular.py @@ -103,7 +103,7 @@ def pix_sub_weights(self) -> PixSubWeights: self.source_plane_data_grid.over_sampled ), mesh_weight_map=self.mapper_grids.mesh_weight_map, - xp=self._xp + xp=self._xp, ) ) @@ -123,9 +123,10 @@ def areas_transformed(self): rectangular grid, as described in the method `mesh_util.rectangular_neighbors_from`. """ return mapper_util.adaptive_rectangular_areas_from( - source_grid_size=self.shape_native[0], + source_grid_shape=self.shape_native, source_plane_data_grid=self.source_plane_data_grid.array, - xp=self._xp + mesh_weight_map=self.mapper_grids.mesh_weight_map, + xp=self._xp, ) @property @@ -139,11 +140,14 @@ def edges_transformed(self): """ # edges defined in 0 -> 1 space, there is one more edge than pixel centers on each side - edges = self._xp.linspace(0, 1, self.shape_native[0] + 1) - edges_reshaped = self._xp.stack([edges, edges]).T + edges_y = self._xp.linspace(1, 0, self.shape_native[0] + 1) + edges_x = self._xp.linspace(0, 1, self.shape_native[1] + 1) + + edges_reshaped = self._xp.stack([edges_y, edges_x]).T return mapper_util.adaptive_rectangular_transformed_grid_from( source_plane_data_grid=self.source_plane_data_grid.array, grid=edges_reshaped, - xp=self._xp + mesh_weight_map=self.mapper_grids.mesh_weight_map, + xp=self._xp, ) diff --git a/autoarray/inversion/pixelization/mappers/rectangular_uniform.py b/autoarray/inversion/pixelization/mappers/rectangular_uniform.py index 507dd0e5a..a8dd78346 100644 --- a/autoarray/inversion/pixelization/mappers/rectangular_uniform.py +++ b/autoarray/inversion/pixelization/mappers/rectangular_uniform.py @@ -92,7 +92,7 @@ def pix_sub_weights(self) -> PixSubWeights: shape_native=self.shape_native, source_plane_mesh_grid=self.source_plane_mesh_grid.array, source_plane_data_grid=self.source_plane_data_grid.over_sampled, - xp=self._xp + xp=self._xp, ) ) diff --git a/autoarray/inversion/pixelization/mesh/abstract.py b/autoarray/inversion/pixelization/mesh/abstract.py index 5b61bf80c..2dceeb590 100644 --- a/autoarray/inversion/pixelization/mesh/abstract.py +++ b/autoarray/inversion/pixelization/mesh/abstract.py @@ -12,10 +12,7 @@ def __eq__(self, other): return self.__dict__ == other.__dict__ and self.__class__ is other.__class__ def relocated_grid_from( - self, - border_relocator: BorderRelocator, - source_plane_data_grid: Grid2D, - xp=np + self, border_relocator: BorderRelocator, source_plane_data_grid: Grid2D, xp=np ) -> Grid2D: """ Relocates all coordinates of the input `source_plane_data_grid` that are outside of a @@ -41,14 +38,16 @@ def relocated_grid_from( A 2D (y,x) grid of coordinates, whose coordinates outside the border are relocated to its edge. """ if border_relocator is not None: - return border_relocator.relocated_grid_from(grid=source_plane_data_grid, xp=xp) + return border_relocator.relocated_grid_from( + grid=source_plane_data_grid, xp=xp + ) return Grid2D( values=source_plane_data_grid.array, mask=source_plane_data_grid.mask, over_sample_size=source_plane_data_grid.over_sampler.sub_size, over_sampled=source_plane_data_grid.over_sampled.array, - xp=xp + xp=xp, ) def relocated_mesh_grid_from( @@ -56,7 +55,7 @@ def relocated_mesh_grid_from( border_relocator: Optional[BorderRelocator], source_plane_data_grid: Grid2D, source_plane_mesh_grid: Grid2DIrregular, - xp=np + xp=np, ): """ Relocates all coordinates of the input `source_plane_mesh_grid` that are outside of a border (which diff --git a/autoarray/inversion/pixelization/mesh/mesh_util.py b/autoarray/inversion/pixelization/mesh/mesh_util.py index b5aa21b99..8ee5ba685 100644 --- a/autoarray/inversion/pixelization/mesh/mesh_util.py +++ b/autoarray/inversion/pixelization/mesh/mesh_util.py @@ -333,9 +333,7 @@ def rectangular_edges_from(shape_native, pixel_scales, xp=np): ) # xa is the "right" boundary in your convention # Edge order to match your pytest: [(xa,y0)->(xa,y1), (xa,y1)->(xb,y1), (xb,y1)->(xb,y0), (xb,y0)->(xa,y0)] - e0 = xp.array( - [[xa, y0], [xa, y1]] - ) # "top" in your test (vertical at x=xa) + e0 = xp.array([[xa, y0], [xa, y1]]) # "top" in your test (vertical at x=xa) e1 = xp.array( [[xa, y1], [xb, y1]] ) # "right" in your test (horizontal at y=y1) diff --git a/autoarray/inversion/pixelization/mesh/rectangular.py b/autoarray/inversion/pixelization/mesh/rectangular.py index b2a8701dd..c7d4bcfc6 100644 --- a/autoarray/inversion/pixelization/mesh/rectangular.py +++ b/autoarray/inversion/pixelization/mesh/rectangular.py @@ -110,13 +110,10 @@ def mapper_grids_from( relocated_grid = self.relocated_grid_from( border_relocator=border_relocator, source_plane_data_grid=source_plane_data_grid, - xp=xp + xp=xp, ) - mesh_grid = self.mesh_grid_from( - source_plane_data_grid=relocated_grid, - xp=xp - ) + mesh_grid = self.mesh_grid_from(source_plane_data_grid=relocated_grid, xp=xp) mesh_weight_map = self.mesh_weight_map_from(adapt_data=adapt_data, xp=xp) @@ -126,7 +123,7 @@ def mapper_grids_from( source_plane_mesh_grid=mesh_grid, image_plane_mesh_grid=image_plane_mesh_grid, adapt_data=adapt_data, - mesh_weight_map=mesh_weight_map + mesh_weight_map=mesh_weight_map, ) def mesh_grid_from( @@ -151,8 +148,7 @@ def mesh_grid_from( return Mesh2DRectangular.overlay_grid( shape_native=self.shape, grid=Grid2DIrregular(source_plane_data_grid.over_sampled), - xp=xp - + xp=xp, ) @property @@ -162,7 +158,7 @@ def requires_image_mesh(self): class RectangularSource(RectangularMagnification): - def __init__(self, shape: Tuple[int, int] = (3, 3), weight_power : float = 1.0): + def __init__(self, shape: Tuple[int, int] = (3, 3), weight_power: float = 1.0): """ A uniform mesh of rectangular pixels, which without interpolation are paired with a 2D grid of (y,x) coordinates. @@ -208,6 +204,6 @@ def mesh_weight_map_from(self, adapt_data, xp=np) -> np.ndarray: """ mesh_weight_map = xp.asarray(adapt_data.array) mesh_weight_map = xp.clip(mesh_weight_map, 1e-12, None) - mesh_weight_map = mesh_weight_map ** self.weight_power + mesh_weight_map = mesh_weight_map**self.weight_power mesh_weight_map /= xp.sum(mesh_weight_map) return mesh_weight_map diff --git a/autoarray/inversion/pixelization/mesh/rectangular_uniform.py b/autoarray/inversion/pixelization/mesh/rectangular_uniform.py index 6c22b390b..de076828f 100644 --- a/autoarray/inversion/pixelization/mesh/rectangular_uniform.py +++ b/autoarray/inversion/pixelization/mesh/rectangular_uniform.py @@ -32,5 +32,5 @@ def mesh_grid_from( return Mesh2DRectangularUniform.overlay_grid( shape_native=self.shape, grid=Grid2DIrregular(source_plane_data_grid.over_sampled), - xp=xp + xp=xp, ) diff --git a/autoarray/inversion/regularization/adaptive_brightness.py b/autoarray/inversion/regularization/adaptive_brightness.py index b9e3c7a20..04c2ae488 100644 --- a/autoarray/inversion/regularization/adaptive_brightness.py +++ b/autoarray/inversion/regularization/adaptive_brightness.py @@ -116,7 +116,7 @@ def weighted_regularization_matrix_from( mat = mat.at[I, J].add(-w_ij) mat = mat.at[J, I].add(-w_ij) else: - np.add.at(mat, np.diag_indices(S+1), diag_updates_i) + np.add.at(mat, np.diag_indices(S + 1), diag_updates_i) xp.add.at(mat, (I, I), w_ij) xp.add.at(mat, (J, J), w_ij) @@ -207,7 +207,9 @@ def regularization_weights_from(self, linear_obj: LinearObj, xp=np) -> np.ndarra ------- The regularization weights. """ - pixel_signals = linear_obj.pixel_signals_from(signal_scale=self.signal_scale, xp=xp) + pixel_signals = linear_obj.pixel_signals_from( + signal_scale=self.signal_scale, xp=xp + ) return adaptive_regularization_weights_from( inner_coefficient=self.inner_coefficient, @@ -228,10 +230,12 @@ def regularization_matrix_from(self, linear_obj: LinearObj, xp=np) -> np.ndarray ------- The regularization matrix. """ - regularization_weights = self.regularization_weights_from(linear_obj=linear_obj, xp=xp) + regularization_weights = self.regularization_weights_from( + linear_obj=linear_obj, xp=xp + ) return weighted_regularization_matrix_from( regularization_weights=regularization_weights, neighbors=linear_obj.source_plane_mesh_grid.neighbors, - xp=xp + xp=xp, ) diff --git a/autoarray/inversion/regularization/adaptive_brightness_split.py b/autoarray/inversion/regularization/adaptive_brightness_split.py index 2e12ea9c5..a126d0060 100644 --- a/autoarray/inversion/regularization/adaptive_brightness_split.py +++ b/autoarray/inversion/regularization/adaptive_brightness_split.py @@ -90,7 +90,9 @@ def regularization_matrix_from(self, linear_obj: LinearObj, xp=np) -> np.ndarray ------- The regularization matrix. """ - regularization_weights = self.regularization_weights_from(linear_obj=linear_obj, xp=xp) + regularization_weights = self.regularization_weights_from( + linear_obj=linear_obj, xp=xp + ) pix_sub_weights_split_cross = linear_obj.pix_sub_weights_split_cross diff --git a/autoarray/inversion/regularization/adaptive_brightness_split_zeroth.py b/autoarray/inversion/regularization/adaptive_brightness_split_zeroth.py index 0d1720bf1..fc6cb16e4 100644 --- a/autoarray/inversion/regularization/adaptive_brightness_split_zeroth.py +++ b/autoarray/inversion/regularization/adaptive_brightness_split_zeroth.py @@ -92,7 +92,9 @@ def regularization_matrix_from(self, linear_obj: LinearObj, xp=np) -> np.ndarray ------- The regularization matrix. """ - regularization_weights = self.regularization_weights_from(linear_obj=linear_obj, xp=xp) + regularization_weights = self.regularization_weights_from( + linear_obj=linear_obj, xp=xp + ) pix_sub_weights_split_cross = linear_obj.pix_sub_weights_split_cross @@ -120,8 +122,7 @@ def regularization_matrix_from(self, linear_obj: LinearObj, xp=np) -> np.ndarray ) regularization_matrix_zeroth = brightness_zeroth.regularization_matrix_from( - linear_obj=linear_obj, - xp=xp + linear_obj=linear_obj, xp=xp ) return regularization_matrix + regularization_matrix_zeroth diff --git a/autoarray/inversion/regularization/brightness_zeroth.py b/autoarray/inversion/regularization/brightness_zeroth.py index 42cd57c3c..6c177a9c4 100644 --- a/autoarray/inversion/regularization/brightness_zeroth.py +++ b/autoarray/inversion/regularization/brightness_zeroth.py @@ -134,7 +134,9 @@ def regularization_matrix_from(self, linear_obj: LinearObj, xp=np) -> np.ndarray ------- The regularization matrix. """ - regularization_weights = self.regularization_weights_from(linear_obj=linear_obj, xp=xp) + regularization_weights = self.regularization_weights_from( + linear_obj=linear_obj, xp=xp + ) return brightness_zeroth_regularization_matrix_from( regularization_weights=regularization_weights, xp=xp diff --git a/autoarray/inversion/regularization/constant.py b/autoarray/inversion/regularization/constant.py index 828dcefe0..7685e41f0 100644 --- a/autoarray/inversion/regularization/constant.py +++ b/autoarray/inversion/regularization/constant.py @@ -9,10 +9,7 @@ def constant_regularization_matrix_from( - coefficient: float, - neighbors: np.ndarray, - neighbors_sizes: np.ndarray, - xp=np + coefficient: float, neighbors: np.ndarray, neighbors_sizes: np.ndarray, xp=np ) -> np.ndarray: """ From the pixel-neighbors array, setup the regularization matrix using the instance regularization scheme. @@ -56,9 +53,9 @@ class in the module `autoarray.inversion.regularization`. if xp.__name__.startswith("jax"): return ( - xp.diag(diag_vals).at[ - I_IDX, neighbors - ].add(-regularization_coefficient, mode="drop", unique_indices=True) + xp.diag(diag_vals) + .at[I_IDX, neighbors] + .add(-regularization_coefficient, mode="drop", unique_indices=True) ) else: mat = xp.diag(diag_vals).copy() @@ -70,6 +67,7 @@ class in the module `autoarray.inversion.regularization`. xp.add.at(mat, (I_valid, neigh_valid), -regularization_coefficient) return mat + class Constant(AbstractRegularization): def __init__(self, coefficient: float = 1.0): """ @@ -137,5 +135,5 @@ def regularization_matrix_from(self, linear_obj: LinearObj, xp=np) -> np.ndarray coefficient=self.coefficient, neighbors=linear_obj.neighbors, neighbors_sizes=linear_obj.neighbors.sizes, - xp=xp + xp=xp, ) diff --git a/autoarray/inversion/regularization/constant_zeroth.py b/autoarray/inversion/regularization/constant_zeroth.py index 38d9dd6f9..886f90ae1 100644 --- a/autoarray/inversion/regularization/constant_zeroth.py +++ b/autoarray/inversion/regularization/constant_zeroth.py @@ -13,7 +13,7 @@ def constant_zeroth_regularization_matrix_from( coefficient_zeroth: float, neighbors: np.ndarray, neighbors_sizes, - xp=np + xp=np, ) -> np.ndarray: """ From the pixel-neighbors array, setup the regularization matrix using the instance regularization scheme. @@ -54,9 +54,7 @@ class in the module ``autoarray.inversion.regularization``. if xp.__name__.startswith("jax"): const = ( - xp.diag(diag_vals).at[ - I_IDX, neighbors - ] + xp.diag(diag_vals).at[I_IDX, neighbors] # unique indices should be guranteed by neighbors-spec .add(-regularization_coefficient, mode="drop", unique_indices=True) ) @@ -119,5 +117,5 @@ def regularization_matrix_from(self, linear_obj: LinearObj, xp=np) -> np.ndarray coefficient=self.coefficient_neighbor, coefficient_zeroth=self.coefficient_zeroth, neighbors=linear_obj.neighbors, - xp=xp + xp=xp, ) diff --git a/autoarray/inversion/regularization/gaussian_kernel.py b/autoarray/inversion/regularization/gaussian_kernel.py index 1e00f4551..0e10ae99e 100644 --- a/autoarray/inversion/regularization/gaussian_kernel.py +++ b/autoarray/inversion/regularization/gaussian_kernel.py @@ -9,9 +9,7 @@ def gauss_cov_matrix_from( - scale: float, - pixel_points: np.ndarray, # shape (N, 2) - xp=np + scale: float, pixel_points: np.ndarray, xp=np # shape (N, 2) ) -> np.ndarray: """ Construct the source‐pixel Gaussian covariance matrix for regularization. @@ -112,7 +110,9 @@ def regularization_matrix_from(self, linear_obj: LinearObj, xp=np) -> np.ndarray The regularization matrix. """ covariance_matrix = gauss_cov_matrix_from( - scale=self.scale, pixel_points=linear_obj.source_plane_mesh_grid.array, xp=xp + scale=self.scale, + pixel_points=linear_obj.source_plane_mesh_grid.array, + xp=xp, ) return self.coefficient * xp.linalg.inv(covariance_matrix) diff --git a/autoarray/inversion/regularization/zeroth.py b/autoarray/inversion/regularization/zeroth.py index 38b1060e9..73e73d7a7 100644 --- a/autoarray/inversion/regularization/zeroth.py +++ b/autoarray/inversion/regularization/zeroth.py @@ -8,7 +8,9 @@ from autoarray.inversion.regularization.abstract import AbstractRegularization -def zeroth_regularization_matrix_from(coefficient: float, pixels: int, xp=np) -> np.ndarray: +def zeroth_regularization_matrix_from( + coefficient: float, pixels: int, xp=np +) -> np.ndarray: """ Apply zeroth order regularization which penalizes every pixel's deviation from zero by addiing non-zero terms to the regularization matrix. diff --git a/autoarray/mask/derive/grid_2d.py b/autoarray/mask/derive/grid_2d.py index ef8fc9f85..9ddde558e 100644 --- a/autoarray/mask/derive/grid_2d.py +++ b/autoarray/mask/derive/grid_2d.py @@ -171,7 +171,7 @@ def unmasked(self) -> Grid2D: mask_2d=self.mask, pixel_scales=self.mask.pixel_scales, origin=self.mask.origin, - xp=self._xp + xp=self._xp, ) return Grid2D(values=grid_2d, mask=self.mask) diff --git a/autoarray/mask/derive/indexes_2d.py b/autoarray/mask/derive/indexes_2d.py index 4dabc0661..13bdec3b9 100644 --- a/autoarray/mask/derive/indexes_2d.py +++ b/autoarray/mask/derive/indexes_2d.py @@ -409,6 +409,5 @@ def native_for_slim(self) -> np.ndarray: print(derive_indexes_2d.native_for_slim) """ return mask_2d_util.native_index_for_slim_index_2d_from( - mask_2d=self.mask, - xp=self._xp + mask_2d=self.mask, xp=self._xp ).astype("int") diff --git a/autoarray/mask/mask_1d_util.py b/autoarray/mask/mask_1d_util.py index d73813a71..12be23996 100644 --- a/autoarray/mask/mask_1d_util.py +++ b/autoarray/mask/mask_1d_util.py @@ -1,9 +1,7 @@ import numpy as np -def native_index_for_slim_index_1d_from( - mask_1d: np.ndarray, - xp=np -) -> np.ndarray: + +def native_index_for_slim_index_1d_from(mask_1d: np.ndarray, xp=np) -> np.ndarray: """ Returns an array of shape [total_unmasked_pixels] that maps every unmasked pixel to its corresponding native 2D pixel using its (y,x) pixel indexes. diff --git a/autoarray/mask/mask_2d_util.py b/autoarray/mask/mask_2d_util.py index 26af6c0f0..2e8472b3e 100644 --- a/autoarray/mask/mask_2d_util.py +++ b/autoarray/mask/mask_2d_util.py @@ -4,10 +4,8 @@ from autoarray import exc -def native_index_for_slim_index_2d_from( - mask_2d: np.ndarray, - xp=np -) -> np.ndarray: + +def native_index_for_slim_index_2d_from(mask_2d: np.ndarray, xp=np) -> np.ndarray: """ Returns an array of shape [total_unmasked_pixels] that maps every unmasked pixel to its corresponding native 2D pixel using its (y,x) pixel indexes. diff --git a/autoarray/numba_util.py b/autoarray/numba_util.py index b9027d76c..75cdea0c4 100644 --- a/autoarray/numba_util.py +++ b/autoarray/numba_util.py @@ -14,6 +14,7 @@ cache = True parallel = False + def jit(nopython=nopython, cache=cache, parallel=parallel): def wrapper(func): diff --git a/autoarray/operators/mock/mock_psf.py b/autoarray/operators/mock/mock_psf.py index 404fa3022..c52314699 100644 --- a/autoarray/operators/mock/mock_psf.py +++ b/autoarray/operators/mock/mock_psf.py @@ -1,5 +1,6 @@ import numpy as np + class MockPSF: def __init__(self, operated_mapping_matrix=None): self.operated_mapping_matrix = operated_mapping_matrix diff --git a/autoarray/operators/over_sampling/over_sampler.py b/autoarray/operators/over_sampling/over_sampler.py index 29d235674..62a3b8ba6 100644 --- a/autoarray/operators/over_sampling/over_sampler.py +++ b/autoarray/operators/over_sampling/over_sampler.py @@ -11,7 +11,6 @@ from autoarray.operators.over_sampling import over_sample_util - @register_pytree_node_class class OverSampler: def __init__(self, mask: Mask2D, sub_size: Union[int, Array2D]): @@ -148,9 +147,7 @@ def __init__(self, mask: Mask2D, sub_size: Union[int, Array2D]): self.sub_total = int(np.sum(self.sub_size**2)) self.sub_length = self.sub_size**self.mask.dimensions - self.sub_fraction = Array2D( - values=1.0 / self.sub_length.array, mask=self.mask - ) + self.sub_fraction = Array2D(values=1.0 / self.sub_length.array, mask=self.mask) # Used for JAX based adaptive over sampling. diff --git a/autoarray/operators/transformer.py b/autoarray/operators/transformer.py index 87cde4779..b2d7fc916 100644 --- a/autoarray/operators/transformer.py +++ b/autoarray/operators/transformer.py @@ -39,7 +39,7 @@ def __init__( uv_wavelengths: np.ndarray, real_space_mask: Mask2D, preload_transform: bool = True, - xp=np + xp=np, ): """ A direct Fourier transform (DFT) operator for radio interferometric imaging. @@ -138,7 +138,7 @@ def visibilities_from(self, image: Array2D) -> Visibilities: image_1d=image.array, preloaded_reals=self.preload_real_transforms, preloaded_imags=self.preload_imag_transforms, - xp=self._xp + xp=self._xp, ) else: visibilities = transformer_util.visibilities_from( @@ -178,9 +178,7 @@ def image_from( ) image_native = array_2d_util.array_2d_native_from( - array_2d_slim=image_slim, - mask_2d=self.real_space_mask, - xp=self._xp + array_2d_slim=image_slim, mask_2d=self.real_space_mask, xp=self._xp ) return Array2D(values=image_native, mask=self.real_space_mask) @@ -220,7 +218,9 @@ def transform_mapping_matrix(self, mapping_matrix: np.ndarray) -> np.ndarray: class TransformerNUFFT(NUFFT_cpu): - def __init__(self, uv_wavelengths: np.ndarray, real_space_mask: Mask2D, xp=np, **kwargs): + def __init__( + self, uv_wavelengths: np.ndarray, real_space_mask: Mask2D, xp=np, **kwargs + ): """ Performs the Non-Uniform Fast Fourier Transform (NUFFT) for interferometric image reconstruction. @@ -452,7 +452,7 @@ def transform_mapping_matrix(self, mapping_matrix: np.ndarray) -> np.ndarray: image_2d = array_2d_util.array_2d_native_from( array_2d_slim=mapping_matrix[:, source_pixel_1d_index], mask_2d=self.grid.mask, - xp=self._xp + xp=self._xp, ) image = Array2D(values=image_2d, mask=self.grid.mask) diff --git a/autoarray/operators/transformer_util.py b/autoarray/operators/transformer_util.py index 4a3b358e2..3ff0cf868 100644 --- a/autoarray/operators/transformer_util.py +++ b/autoarray/operators/transformer_util.py @@ -84,7 +84,10 @@ def preload_imag_transforms_from( def visibilities_via_preload_from( - image_1d: np.ndarray, preloaded_reals: np.ndarray, preloaded_imags: np.ndarray, xp=np + image_1d: np.ndarray, + preloaded_reals: np.ndarray, + preloaded_imags: np.ndarray, + xp=np, ) -> np.ndarray: """ Computes interferometric visibilities using preloaded real and imaginary DFT transform components. diff --git a/autoarray/plot/mat_plot/two_d.py b/autoarray/plot/mat_plot/two_d.py index c5fd895a6..27ba81b5f 100644 --- a/autoarray/plot/mat_plot/two_d.py +++ b/autoarray/plot/mat_plot/two_d.py @@ -601,7 +601,7 @@ def _plot_rectangular_mapper( plt.pcolormesh( X, # x-grid Y, # y-grid - np.flipud(pixel_values.array.reshape(shape_native)), # (ny, nx) + pixel_values.array.reshape(shape_native), # (ny, nx) shading="flat", norm=norm, cmap=self.cmap.cmap, diff --git a/autoarray/structures/arrays/array_1d_util.py b/autoarray/structures/arrays/array_1d_util.py index 8c2c9868a..071c1ab14 100644 --- a/autoarray/structures/arrays/array_1d_util.py +++ b/autoarray/structures/arrays/array_1d_util.py @@ -13,7 +13,7 @@ def convert_array_1d( array_1d: Union[np.ndarray, List], mask_1d: Mask1D, store_native: bool = False, - xp=np + xp=np, ) -> np.ndarray: """ The `manual` classmethods in the `Array2D` object take as input a list or ndarray which is returned as an @@ -49,11 +49,7 @@ def convert_array_1d( array_1d_native=array_1d, mask_1d=mask_1d, ) - return array_1d_native_from( - array_1d_slim=array_1d, - mask_1d=mask_1d, - xp=xp - ) + return array_1d_native_from(array_1d_slim=array_1d, mask_1d=mask_1d, xp=xp) def array_1d_slim_from( @@ -124,7 +120,7 @@ def array_1d_native_from( array_1d_slim=array_1d_slim, shape=shape, native_index_for_slim_index_1d=native_index_for_slim_index_1d, - xp=xp + xp=xp, ) @@ -174,4 +170,4 @@ def array_1d_via_indexes_1d_from( else: array[native_index_for_slim_index_1d] = array_1d_slim - return array \ No newline at end of file + return array diff --git a/autoarray/structures/arrays/array_2d_util.py b/autoarray/structures/arrays/array_2d_util.py index bd4c402dd..f186faeca 100644 --- a/autoarray/structures/arrays/array_2d_util.py +++ b/autoarray/structures/arrays/array_2d_util.py @@ -93,7 +93,7 @@ def convert_array_2d( mask_2d: Mask2D, store_native: bool = False, skip_mask: bool = False, - xp=np + xp=np, ) -> np.ndarray: """ The `manual` classmethods in the `Array2D` object take as input a list or ndarray which is returned as an @@ -134,11 +134,7 @@ def convert_array_2d( array_2d_native=array_2d, mask_2d=mask_2d, ) - return array_2d_native_from( - array_2d_slim=array_2d, - mask_2d=mask_2d, - xp=xp - ) + return array_2d_native_from(array_2d_slim=array_2d, mask_2d=mask_2d, xp=xp) def convert_array_2d_to_slim(array_2d: np.ndarray, mask_2d: Mask2D) -> np.ndarray: @@ -169,7 +165,9 @@ def convert_array_2d_to_slim(array_2d: np.ndarray, mask_2d: Mask2D) -> np.ndarra ) -def convert_array_2d_to_native(array_2d: np.ndarray, mask_2d: Mask2D, xp=np) -> np.ndarray: +def convert_array_2d_to_native( + array_2d: np.ndarray, mask_2d: Mask2D, xp=np +) -> np.ndarray: """ The `manual` classmethods in the `Array2D` object take as input a list or ndarray which is returned as an Array2D. @@ -203,11 +201,7 @@ def convert_array_2d_to_native(array_2d: np.ndarray, mask_2d: Mask2D, xp=np) -> "the mask." ) - return array_2d_native_from( - array_2d_slim=array_2d, - mask_2d=mask_2d, - xp=xp - ) + return array_2d_native_from(array_2d_slim=array_2d, mask_2d=mask_2d, xp=xp) def extracted_array_2d_from( @@ -468,9 +462,7 @@ def array_2d_slim_from( def array_2d_native_from( - array_2d_slim: np.ndarray, - mask_2d: np.ndarray, - xp=np + array_2d_slim: np.ndarray, mask_2d: np.ndarray, xp=np ) -> np.ndarray: """ For a slimmed 2D array that was computed by mapping unmasked values from a native 2D array of shape @@ -510,17 +502,17 @@ def array_2d_native_from( shape = (mask_2d.shape[0], mask_2d.shape[1]) native_index_for_slim_index_2d = mask_2d_util.native_index_for_slim_index_2d_from( - mask_2d=mask_2d, - xp=xp + mask_2d=mask_2d, xp=xp ).astype("int") return array_2d_via_indexes_from( array_2d_slim=array_2d_slim, shape=shape, native_index_for_slim_index_2d=native_index_for_slim_index_2d, - xp=xp + xp=xp, ) + def array_2d_via_indexes_from( array_2d_slim: np.ndarray, shape: Tuple[int, int], diff --git a/autoarray/structures/arrays/kernel_2d.py b/autoarray/structures/arrays/kernel_2d.py index 247eb07fe..d053d98b1 100644 --- a/autoarray/structures/arrays/kernel_2d.py +++ b/autoarray/structures/arrays/kernel_2d.py @@ -590,7 +590,9 @@ def mapping_matrix_native_from( mask_flat = xp.logical_not(mask.array) if xp.__name__.startswith("jax"): - slim_to_native_tuple = xp.nonzero(mask_flat, size=mapping_matrix.shape[0]) + slim_to_native_tuple = xp.nonzero( + mask_flat, size=mapping_matrix.shape[0] + ) else: slim_to_native = mask.derive_indexes.native_for_slim.astype("int32") slim_to_native_tuple = (slim_to_native[:, 0], slim_to_native[:, 1]) @@ -640,7 +642,9 @@ def mapping_matrix_native_from( slim_to_native_blurring_tuple ].set(blurring_mapping_matrix) else: - mapping_matrix_native[slim_to_native_blurring_tuple] = blurring_mapping_matrix + mapping_matrix_native[slim_to_native_blurring_tuple] = ( + blurring_mapping_matrix + ) return mapping_matrix_native @@ -701,7 +705,7 @@ def convolved_image_from(self, image, blurring_image, jax_method="direct", xp=np if self.fft_shape is None: full_shape, fft_shape, mask_shape = self.fft_shape_from(mask=image.mask) - fft_psf = xp.fft.rfft2(self.stored_native.array, s=fft_shape, axes=(0,1)) + fft_psf = xp.fft.rfft2(self.stored_native.array, s=fft_shape, axes=(0, 1)) image_shape_original = image.shape_native @@ -738,7 +742,9 @@ def convolved_image_from(self, image, blurring_image, jax_method="direct", xp=np if slim_to_native_blurring_tuple is None: mask_flat = xp.logical_not(blurring_image.mask.array) - slim_to_native_blurring_tuple = xp.nonzero(mask_flat, size=blurring_image.shape[0]) + slim_to_native_blurring_tuple = xp.nonzero( + mask_flat, size=blurring_image.shape[0] + ) image_both_native = image_both_native.at[slim_to_native_blurring_tuple].set( xp.asarray(blurring_image.array) @@ -766,7 +772,6 @@ def convolved_image_from(self, image, blurring_image, jax_method="direct", xp=np for full_size, out_size in zip(full_shape, mask_shape) ) - blurred_image_native = jax.lax.dynamic_slice( blurred_image_full, start_indices, out_shape_full ) @@ -790,7 +795,7 @@ def convolved_mapping_matrix_from( blurring_mapping_matrix=None, blurring_mask: Optional[Mask2D] = None, jax_method="direct", - xp=np + xp=np, ): """ Convolve a source-plane mapping matrix with this PSF. @@ -842,10 +847,9 @@ def convolved_mapping_matrix_from( mask=mask, blurring_mapping_matrix=blurring_mapping_matrix, blurring_mask=blurring_mask, - xp=xp + xp=xp, ) - if not self.use_fft: return self.convolved_mapping_matrix_via_real_space_from( mapping_matrix=mapping_matrix, @@ -853,7 +857,7 @@ def convolved_mapping_matrix_from( blurring_mapping_matrix=blurring_mapping_matrix, blurring_mask=blurring_mask, jax_method=jax_method, - xp=xp + xp=xp, ) import jax @@ -888,7 +892,7 @@ def convolved_mapping_matrix_from( mask=mask, blurring_mapping_matrix=blurring_mapping_matrix, blurring_mask=blurring_mask, - xp=xp + xp=xp, ) # FFT convolution @@ -1016,7 +1020,7 @@ def convolved_image_via_real_space_from( image: np.ndarray, blurring_image: Optional[np.ndarray] = None, jax_method: str = "direct", - xp=np + xp=np, ): """ Convolve an input masked image with this PSF in real space. @@ -1058,7 +1062,6 @@ def convolved_image_via_real_space_from( mask_flat = xp.logical_not(image.mask.array) slim_to_native_tuple = xp.nonzero(mask_flat, size=image.shape[0]) - # start with native array padded with zeros image_native = xp.zeros(image.mask.shape, dtype=xp.asarray(image.array).dtype) @@ -1102,7 +1105,7 @@ def convolved_mapping_matrix_via_real_space_from( blurring_mapping_matrix: Optional[np.ndarray] = None, blurring_mask: Optional[Mask2D] = None, jax_method: str = "direct", - xp=np + xp=np, ): """ Convolve a source-plane mapping matrix with this PSF in real space. @@ -1138,7 +1141,7 @@ def convolved_mapping_matrix_via_real_space_from( mask=mask, blurring_mapping_matrix=blurring_mapping_matrix, blurring_mask=blurring_mask, - xp=xp + xp=xp, ) import jax @@ -1158,7 +1161,7 @@ def convolved_mapping_matrix_via_real_space_from( mask=mask, blurring_mapping_matrix=blurring_mapping_matrix, blurring_mask=blurring_mask, - xp=xp + xp=xp, ) # 6) Real-space convolution, broadcast kernel over source axis kernel = self.stored_native.array @@ -1174,10 +1177,7 @@ def convolved_mapping_matrix_via_real_space_from( return blurred_mapping_matrix_native[slim_to_native_tuple] def convolved_image_via_real_space_np_from( - self, - image: np.ndarray, - blurring_image: Optional[np.ndarray] = None, - xp=np + self, image: np.ndarray, blurring_image: Optional[np.ndarray] = None, xp=np ): """ Convolve an input masked image with this PSF in real space. @@ -1258,7 +1258,7 @@ def convolved_mapping_matrix_via_real_space_np_from( mask, blurring_mapping_matrix: Optional[np.ndarray] = None, blurring_mask: Optional[Mask2D] = None, - xp=np + xp=np, ): """ Convolve a source-plane mapping matrix with this PSF in real space. @@ -1302,12 +1302,11 @@ def convolved_mapping_matrix_via_real_space_np_from( mask=mask, blurring_mapping_matrix=blurring_mapping_matrix, blurring_mask=blurring_mask, - xp=xp + xp=xp, ) # 6) Real-space convolution, broadcast kernel over source axis kernel = self.stored_native.array - blurred_mapping_matrix_native = scipy_convolve( mapping_matrix_native, kernel[..., None], @@ -1315,4 +1314,4 @@ def convolved_mapping_matrix_via_real_space_np_from( ) # return slim form - return blurred_mapping_matrix_native[slim_to_native_tuple] \ No newline at end of file + return blurred_mapping_matrix_native[slim_to_native_tuple] diff --git a/autoarray/structures/arrays/uniform_1d.py b/autoarray/structures/arrays/uniform_1d.py index 296cdd631..f9a48b83c 100644 --- a/autoarray/structures/arrays/uniform_1d.py +++ b/autoarray/structures/arrays/uniform_1d.py @@ -23,14 +23,11 @@ def __init__( mask: Mask1D, header: Optional[Header] = None, store_native: bool = False, - xp=np + xp=np, ): values = array_1d_util.convert_array_1d( - array_1d=values, - mask_1d=mask, - store_native=store_native, - xp=xp + array_1d=values, mask_1d=mask, store_native=store_native, xp=xp ) self.mask = mask diff --git a/autoarray/structures/arrays/uniform_2d.py b/autoarray/structures/arrays/uniform_2d.py index b634f6cb8..c92a18f3b 100644 --- a/autoarray/structures/arrays/uniform_2d.py +++ b/autoarray/structures/arrays/uniform_2d.py @@ -239,7 +239,7 @@ def __init__( mask_2d=mask, store_native=store_native, skip_mask=skip_mask, - xp=xp + xp=xp, ) super().__init__(values, xp=xp) diff --git a/autoarray/structures/decorators/to_vector_yx.py b/autoarray/structures/decorators/to_vector_yx.py index 9a82567e7..13ca2b21e 100644 --- a/autoarray/structures/decorators/to_vector_yx.py +++ b/autoarray/structures/decorators/to_vector_yx.py @@ -107,6 +107,8 @@ def wrapper( The function values evaluated on the grid with the same structure as the input grid_like object. """ - return VectorYXMaker(func=func, obj=obj, grid=grid, xp=xp, *args, **kwargs).result + return VectorYXMaker( + func=func, obj=obj, grid=grid, xp=xp, *args, **kwargs + ).result return wrapper diff --git a/autoarray/structures/grids/grid_1d_util.py b/autoarray/structures/grids/grid_1d_util.py index de577f2ec..6fd756d1f 100644 --- a/autoarray/structures/grids/grid_1d_util.py +++ b/autoarray/structures/grids/grid_1d_util.py @@ -49,18 +49,14 @@ def convert_grid_1d( grid_1d_native=grid_1d, mask_1d=mask_1d, ) - return grid_1d_native_from( - grid_1d_slim=grid_1d, - mask_1d=mask_1d, - xp=xp - ) + return grid_1d_native_from(grid_1d_slim=grid_1d, mask_1d=mask_1d, xp=xp) def grid_1d_slim_via_shape_slim_from( shape_slim: Tuple[int], pixel_scales: ty.PixelScales, origin: Tuple[float] = (0.0,), - xp=np + xp=np, ) -> np.ndarray: """ This routine computes the (x) scaled coordinates at the centre of every pixel defined by a 1D shape of the @@ -93,7 +89,7 @@ def grid_1d_slim_via_shape_slim_from( mask_1d=np.full(fill_value=False, shape=shape_slim), pixel_scales=pixel_scales, origin=origin, - xp=xp + xp=xp, ) @@ -101,7 +97,7 @@ def grid_1d_slim_via_mask_from( mask_1d: np.ndarray, pixel_scales: ty.PixelScales, origin: Tuple[float] = (0.0,), - xp=np + xp=np, ) -> np.ndarray: """ For a grid, every unmasked pixel of its 1D mask with shape (total_pixels,) is divided into a finer uniform @@ -207,7 +203,5 @@ def grid_1d_native_from( mapped from the slimmed grid. """ return array_1d_util.array_1d_native_from( - array_1d_slim=grid_1d_slim, - mask_1d=mask_1d, - xp=xp + array_1d_slim=grid_1d_slim, mask_1d=mask_1d, xp=xp ) diff --git a/autoarray/structures/grids/grid_2d_util.py b/autoarray/structures/grids/grid_2d_util.py index f391bb225..6dbe85979 100644 --- a/autoarray/structures/grids/grid_2d_util.py +++ b/autoarray/structures/grids/grid_2d_util.py @@ -121,16 +121,8 @@ def convert_grid_2d( if is_native == store_native: return grid_2d elif not store_native: - return grid_2d_slim_from( - grid_2d_native=grid_2d, - mask=mask_2d, - xp=xp - ) - return grid_2d_native_from( - grid_2d_slim=grid_2d, - mask_2d=mask_2d,# - xp=xp - ) + return grid_2d_slim_from(grid_2d_native=grid_2d, mask=mask_2d, xp=xp) + return grid_2d_native_from(grid_2d_slim=grid_2d, mask_2d=mask_2d, xp=xp) # def convert_grid_2d_to_slim( @@ -152,15 +144,12 @@ def convert_grid_2d_to_slim( """ if len(grid_2d.shape) == 2: return grid_2d - return grid_2d_slim_from( - grid_2d_native=grid_2d, - mask=mask_2d, - xp=xp - ) + return grid_2d_slim_from(grid_2d_native=grid_2d, mask=mask_2d, xp=xp) def convert_grid_2d_to_native( - grid_2d: Union[np.ndarray, List], mask_2d: Mask2D, + grid_2d: Union[np.ndarray, List], + mask_2d: Mask2D, ) -> np.ndarray: """ he `manual` classmethods in the Grid2D object take as input a list or ndarray which is returned as a Grid2D. @@ -302,11 +291,7 @@ def grid_2d_via_mask_from( mask_2d=mask_2d, pixel_scales=pixel_scales, origin=origin, xp=xp ) - return grid_2d_native_from( - grid_2d_slim=grid_2d_slim, - mask_2d=mask_2d, - xp=xp - ) + return grid_2d_native_from(grid_2d_slim=grid_2d_slim, mask_2d=mask_2d, xp=xp) def grid_2d_slim_via_shape_native_from( @@ -350,7 +335,7 @@ def grid_2d_slim_via_shape_native_from( mask_2d=xp.full(fill_value=False, shape=shape_native), pixel_scales=pixel_scales, origin=origin, - xp=xp + xp=xp, ) @@ -568,9 +553,7 @@ def grid_scaled_2d_slim_radial_projected_from( def grid_2d_slim_from( - grid_2d_native: np.ndarray, - mask: np.ndarray, - xp=np + grid_2d_native: np.ndarray, mask: np.ndarray, xp=np ) -> np.ndarray: """ For a native 2D grid and mask of shape [total_y_pixels, total_x_pixels, 2], map the values of all unmasked @@ -609,9 +592,7 @@ def grid_2d_slim_from( def grid_2d_native_from( - grid_2d_slim: np.ndarray, - mask_2d: np.ndarray, - xp=np + grid_2d_slim: np.ndarray, mask_2d: np.ndarray, xp=np ) -> np.ndarray: """ For a slimmed 2D grid of shape [total_unmasked_pixels, 2], that was computed by extracting the unmasked values @@ -640,15 +621,11 @@ def grid_2d_native_from( """ grid_2d_native_y = array_2d_util.array_2d_native_from( - array_2d_slim=grid_2d_slim[:, 0], - mask_2d=mask_2d, - xp=xp + array_2d_slim=grid_2d_slim[:, 0], mask_2d=mask_2d, xp=xp ) grid_2d_native_x = array_2d_util.array_2d_native_from( - array_2d_slim=grid_2d_slim[:, 1], - mask_2d=mask_2d, - xp=xp + array_2d_slim=grid_2d_slim[:, 1], mask_2d=mask_2d, xp=xp ) return xp.stack((grid_2d_native_y, grid_2d_native_x), axis=-1) @@ -721,7 +698,7 @@ def grid_2d_slim_via_shape_native_not_mask_from( shape_native: Tuple[int, int], pixel_scales: Tuple[float, float], origin: Tuple[float, float] = (0.0, 0.0), - xp=np + xp=np, ) -> np.ndarray: """ Build the slim (flattened) grid of all (y, x) pixel centres for a rectangular grid diff --git a/autoarray/structures/grids/irregular_2d.py b/autoarray/structures/grids/irregular_2d.py index b11f6f87c..bd07006c7 100644 --- a/autoarray/structures/grids/irregular_2d.py +++ b/autoarray/structures/grids/irregular_2d.py @@ -185,9 +185,9 @@ def squared_distances_to_coordinate_from( coordinate The (y,x) coordinate from which the squared distance of every *Coordinate* is computed. """ - squared_distances = self._xp.square(self.array[:, 0] - coordinate[0]) + self._xp.square( - self.array[:, 1] - coordinate[1] - ) + squared_distances = self._xp.square( + self.array[:, 0] - coordinate[0] + ) + self._xp.square(self.array[:, 1] - coordinate[1]) return ArrayIrregular(values=squared_distances) def distances_to_coordinate_from( @@ -267,4 +267,4 @@ def grid_of_closest_from(self, grid_pair: "Grid2DIrregular") -> "Grid2DIrregular # select closest points: shape (N2, 2) closest_points = self.array[closest_idx] - return Grid2DIrregular(closest_points) \ No newline at end of file + return Grid2DIrregular(closest_points) diff --git a/autoarray/structures/grids/uniform_2d.py b/autoarray/structures/grids/uniform_2d.py index 08ae60d6a..6292549b3 100644 --- a/autoarray/structures/grids/uniform_2d.py +++ b/autoarray/structures/grids/uniform_2d.py @@ -160,10 +160,7 @@ def __init__( """ values = grid_2d_util.convert_grid_2d( - grid_2d=values, - mask_2d=mask, - store_native=store_native, - xp=xp + grid_2d=values, mask_2d=mask, store_native=store_native, xp=xp ) super().__init__(values, xp=xp) @@ -556,14 +553,11 @@ def from_mask( mask_2d=mask.array, pixel_scales=mask.pixel_scales, origin=mask.origin, - xp=xp + xp=xp, ) return Grid2D( - values=grid_2d, - mask=mask, - over_sample_size=over_sample_size, - xp=xp + values=grid_2d, mask=mask, over_sample_size=over_sample_size, xp=xp ) @classmethod @@ -691,7 +685,9 @@ def blurring_grid_from( over_sample_size=over_sample_size, ) - def subtracted_from(self, offset: Tuple[(float, float), np.ndarray], xp=np) -> "Grid2D": + def subtracted_from( + self, offset: Tuple[(float, float), np.ndarray], xp=np + ) -> "Grid2D": mask = Mask2D( mask=self.mask, @@ -848,9 +844,9 @@ def squared_distances_to_coordinate_from( coordinate The (y,x) coordinate from which the squared distance of every grid (y,x) coordinate is computed. """ - squared_distances = self._xp.square(self.array[:, 0] - coordinate[0]) + self._xp.square( - self.array[:, 1] - coordinate[1] - ) + squared_distances = self._xp.square( + self.array[:, 0] - coordinate[0] + ) + self._xp.square(self.array[:, 1] - coordinate[1]) return Array2D(values=squared_distances, mask=self.mask) diff --git a/autoarray/structures/mesh/rectangular_2d.py b/autoarray/structures/mesh/rectangular_2d.py index 84fc9884e..2fa855443 100644 --- a/autoarray/structures/mesh/rectangular_2d.py +++ b/autoarray/structures/mesh/rectangular_2d.py @@ -62,7 +62,11 @@ def __init__( @classmethod def overlay_grid( - cls, shape_native: Tuple[int, int], grid: np.ndarray, buffer: float = 1e-8, xp=np + cls, + shape_native: Tuple[int, int], + grid: np.ndarray, + buffer: float = 1e-8, + xp=np, ) -> "Mesh2DRectangular": """ Creates a `Grid2DRecntagular` by overlaying the rectangular pixelization over an input grid of (y,x) @@ -100,10 +104,7 @@ def overlay_grid( origin = xp.array(((y_max + y_min) / 2.0, (x_max + x_min) / 2.0)) grid_slim = grid_2d_util.grid_2d_slim_via_shape_native_not_mask_from( - shape_native=shape_native, - pixel_scales=pixel_scales, - origin=origin, - xp=xp + shape_native=shape_native, pixel_scales=pixel_scales, origin=origin, xp=xp ) return cls( diff --git a/autoarray/structures/mock/mock_decorators.py b/autoarray/structures/mock/mock_decorators.py index 6ad4dfaf8..8ca7790d1 100644 --- a/autoarray/structures/mock/mock_decorators.py +++ b/autoarray/structures/mock/mock_decorators.py @@ -86,6 +86,7 @@ def __init__(self, centre=(0.0, 0.0), angle=0.0): self.centre = centre self.angle = angle + class MockGrid2DLikeObj: def __init__(self): self.centre = (0.0, 0.0) diff --git a/autoarray/structures/triangles/array.py b/autoarray/structures/triangles/array.py index 00dbae2d9..3f7d50049 100644 --- a/autoarray/structures/triangles/array.py +++ b/autoarray/structures/triangles/array.py @@ -58,7 +58,7 @@ def for_limits_and_scale( ) -> "AbstractTriangles": import jax.numpy as jnp - + height = scale * HEIGHT_FACTOR vertices = [] @@ -152,6 +152,7 @@ def means(self) -> np.ndarray: The mean of each triangle. """ import jax.numpy as jnp + return jnp.mean(self.triangles, axis=1) def containing_indices(self, shape: Shape) -> np.ndarray: @@ -168,6 +169,7 @@ def containing_indices(self, shape: Shape) -> np.ndarray: The triangles that intersect the shape. """ import jax.numpy as jnp + inside = shape.mask(self.triangles) return jnp.where( @@ -191,6 +193,7 @@ def for_indexes(self, indexes: np.ndarray) -> "ArrayTriangles": The new ArrayTriangles instance. """ import jax.numpy as jnp + selected_indices = select_and_handle_invalid( data=self.indices, indices=indexes, @@ -238,6 +241,7 @@ def for_indexes(self, indexes: np.ndarray) -> "ArrayTriangles": def _up_sample_triangle(self): import jax.numpy as jnp + triangles = self.triangles m01 = (triangles[:, 0] + triangles[:, 1]) / 2 @@ -270,6 +274,7 @@ def up_sample(self) -> "ArrayTriangles": def _neighborhood_triangles(self): import jax.numpy as jnp + triangles = self.triangles new_v0 = triangles[:, 1] + triangles[:, 2] - triangles[:, 0] @@ -380,6 +385,7 @@ def select_and_handle_invalid( An array with selected data, where invalid indices are replaced with `invalid_replacement`. """ import jax.numpy as jnp + invalid_mask = indices == invalid_value safe_indices = jnp.where(invalid_mask, 0, indices) selected_data = data[safe_indices] @@ -394,6 +400,7 @@ def select_and_handle_invalid( def remove_duplicates(new_triangles): import jax.numpy as jnp + unique_vertices, inverse_indices = jnp.unique( new_triangles.reshape(-1, 2), axis=0, diff --git a/autoarray/structures/triangles/coordinate_array.py b/autoarray/structures/triangles/coordinate_array.py index d7b1d6518..64c3717e2 100644 --- a/autoarray/structures/triangles/coordinate_array.py +++ b/autoarray/structures/triangles/coordinate_array.py @@ -35,6 +35,7 @@ def __init__( An y_offset to apply to the y coordinates so that up-sampled triangles align. """ import jax.numpy as jnp + self.coordinates = coordinates self.side_length = side_length self.flipped = flipped @@ -56,6 +57,7 @@ def for_limits_and_scale( **_, ): import jax.numpy as jnp + x_shift = int(2 * x_min / scale) y_shift = int(y_min / (HEIGHT_FACTOR * scale)) @@ -99,6 +101,7 @@ def tree_unflatten(cls, aux_data, children): def __len__(self): import jax.numpy as jnp + return jnp.count_nonzero(~jnp.isnan(self.coordinates).any(axis=1)) def __iter__(self): @@ -110,6 +113,7 @@ def centres(self) -> np.ndarray: The centres of the triangles. """ import jax.numpy as jnp + centres = self.scaling_factors * self.coordinates + jnp.array( [self.x_offset, self.y_offset] ) @@ -121,6 +125,7 @@ def vertex_coordinates(self) -> np.ndarray: The vertices of the triangles as an Nx3x2 array. """ import jax.numpy as jnp + coordinates = self.coordinates return jnp.concatenate( [ @@ -137,6 +142,7 @@ def triangles(self) -> np.ndarray: The vertices of the triangles as an Nx3x2 array. """ import jax.numpy as jnp + centres = self.centres return jnp.stack( ( @@ -177,6 +183,7 @@ def flip_array(self) -> np.ndarray: An array of 1s and -1s to flip the triangles. """ import jax.numpy as jnp + array = jnp.where(self.flip_mask, -1, 1) return array[:, None] @@ -185,6 +192,7 @@ def up_sample(self) -> "CoordinateArrayTriangles": Up-sample the triangles by adding a new vertex at the midpoint of each edge. """ import jax.numpy as jnp + coordinates = self.coordinates flip_mask = self.flip_mask @@ -217,6 +225,7 @@ def neighborhood(self) -> "CoordinateArrayTriangles": Ensures that the new triangles are unique and adjusts the mask accordingly. """ import jax.numpy as jnp + coordinates = self.coordinates flip_mask = self.flip_mask @@ -255,6 +264,7 @@ def neighborhood(self) -> "CoordinateArrayTriangles": @property def _vertices_and_indices(self): import jax.numpy as jnp + flat_triangles = self.triangles.reshape(-1, 2) vertices, inverse_indices = jnp.unique( flat_triangles, @@ -303,6 +313,7 @@ def for_indexes(self, indexes: np.ndarray) -> "CoordinateArrayTriangles": The new CoordinateArrayTriangles instance. """ import jax.numpy as jnp + mask = indexes == -1 safe_indexes = jnp.where(mask, 0, indexes) coordinates = jnp.take(self.coordinates, safe_indexes, axis=0) @@ -333,6 +344,7 @@ def indices(self) -> np.ndarray: @property def means(self): import jax.numpy as jnp + return jnp.mean(self.triangles, axis=1) @property diff --git a/test_autoarray/conftest.py b/test_autoarray/conftest.py index 571caaccc..bbfb6ea11 100644 --- a/test_autoarray/conftest.py +++ b/test_autoarray/conftest.py @@ -1,8 +1,10 @@ import jax.numpy as jnp + def pytest_configure(): _ = jnp.sum(jnp.array([0.0])) # Force backend init + import os from os import path import pytest diff --git a/test_autoarray/dataset/imaging/test_dataset.py b/test_autoarray/dataset/imaging/test_dataset.py index fc500aa24..98283298b 100644 --- a/test_autoarray/dataset/imaging/test_dataset.py +++ b/test_autoarray/dataset/imaging/test_dataset.py @@ -33,7 +33,6 @@ def make_test_data_path(): return test_data_path - def test__grid__uses_mask_and_settings( image_7x7, noise_map_7x7, diff --git a/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py b/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py index 62b05d945..82ea0d6d8 100644 --- a/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py +++ b/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py @@ -219,7 +219,9 @@ def test__data_vector_via_w_tilde_data_two_methods_agree(): image_native=image.native.array, noise_map_native=noise_map.native.array, kernel_native=kernel.native.array, - native_index_for_slim_index=mask.derive_indexes.native_for_slim.astype("int"), + native_index_for_slim_index=mask.derive_indexes.native_for_slim.astype( + "int" + ), ) ( @@ -229,7 +231,9 @@ def test__data_vector_via_w_tilde_data_two_methods_agree(): ) = aa.util.mapper_numba.data_slim_to_pixelization_unique_from( data_pixels=w_tilde_data.shape[0], pix_indexes_for_sub_slim_index=mapper.pix_indexes_for_sub_slim_index, - pix_sizes_for_sub_slim_index=mapper.pix_sizes_for_sub_slim_index.astype("int"), + pix_sizes_for_sub_slim_index=mapper.pix_sizes_for_sub_slim_index.astype( + "int" + ), pix_weights_for_sub_slim_index=mapper.pix_weights_for_sub_slim_index, pix_pixels=mapper.params, sub_size=grid.over_sample_size.array, @@ -330,7 +334,9 @@ def test__curvature_matrix_via_w_tilde_preload_two_methods_agree(): ) = aa.util.inversion_imaging_numba.w_tilde_curvature_preload_imaging_from( noise_map_native=noise_map.native.array, kernel_native=kernel.native.array, - native_index_for_slim_index=mask.derive_indexes.native_for_slim.astype("int"), + native_index_for_slim_index=mask.derive_indexes.native_for_slim.astype( + "int" + ), ) ( @@ -366,4 +372,6 @@ def test__curvature_matrix_via_w_tilde_preload_two_methods_agree(): noise_map=np.array(noise_map), ) - assert curvature_matrix_via_w_tilde == pytest.approx(curvature_matrix, abs=1.0e-4) + assert curvature_matrix_via_w_tilde == pytest.approx( + curvature_matrix, abs=1.0e-4 + ) diff --git a/test_autoarray/inversion/pixelization/mappers/test_rectangular.py b/test_autoarray/inversion/pixelization/mappers/test_rectangular.py index 1215b708d..866b25a7b 100644 --- a/test_autoarray/inversion/pixelization/mappers/test_rectangular.py +++ b/test_autoarray/inversion/pixelization/mappers/test_rectangular.py @@ -139,7 +139,7 @@ def test__edges_transformed(mask_2d_7x7): assert mapper.edges_transformed[3] == pytest.approx( np.array( - [1.5, 1.5], # left + [-1.5, 1.5], # left ), abs=1e-8, - ) \ No newline at end of file + ) diff --git a/test_autoarray/structures/arrays/test_kernel_2d.py b/test_autoarray/structures/arrays/test_kernel_2d.py index ca103ac04..ae4beae62 100644 --- a/test_autoarray/structures/arrays/test_kernel_2d.py +++ b/test_autoarray/structures/arrays/test_kernel_2d.py @@ -524,8 +524,7 @@ def test__convolve_imaged_from__via_fft__sizes_not_precomputed__compare_numerica blurring_image = aa.Array2D(values=image.native, mask=blurring_mask) blurred_fft = kernel_fft.convolved_image_from( - image=masked_image, - blurring_image=blurring_image + image=masked_image, blurring_image=blurring_image ) assert blurred_fft.native.array[13, 13] == pytest.approx(249.5, abs=1e-6) diff --git a/test_autoarray/structures/triangles/test_coordinate.py b/test_autoarray/structures/triangles/test_coordinate.py index 97c7389d7..bfd677874 100644 --- a/test_autoarray/structures/triangles/test_coordinate.py +++ b/test_autoarray/structures/triangles/test_coordinate.py @@ -1,4 +1,3 @@ - import numpy as np import pytest @@ -316,9 +315,6 @@ def one_triangle(): ) - - - def test_neighborhood(one_triangle): import jax