Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
240 commits
Select commit Hold shift + click to select a range
340c34d
pylops removed
Jammy2211 Apr 4, 2025
39fbf72
unit tests fixed
Jammy2211 Apr 4, 2025
dd7ca45
inversion now uses Convoler again, all numba, factory tests pass
Jammy2211 Apr 5, 2025
119866d
fix inversion plotters
Jammy2211 Apr 5, 2025
580be1c
fix test overlay
Jammy2211 Apr 5, 2025
994e048
fixn interferometer conversion
Jammy2211 Apr 5, 2025
9c5354c
more plot unit tests pass
Jammy2211 Apr 5, 2025
2b6754c
tst coverage complete
Jammy2211 Apr 5, 2025
2049042
small changes to jax array handling
Jammy2211 Apr 5, 2025
0da7719
clean up grid contou
Jammy2211 Apr 5, 2025
5e31ce0
arrays now store maks in jax
Jammy2211 Apr 6, 2025
3357cbb
array 2d always stored as jax
Jammy2211 Apr 6, 2025
b9e4cb5
update array_1d_slim_from
Jammy2211 Apr 6, 2025
d132365
array 1D stuff all updated to support JAX
Jammy2211 Apr 6, 2025
bd88fee
grid_1d_slim_via_mask_from to JAX
Jammy2211 Apr 6, 2025
5f8877b
grids now use JAX
Jammy2211 Apr 6, 2025
6b7d433
fix vectors
Jammy2211 Apr 6, 2025
a6480a9
fix some visualizatoin
Jammy2211 Apr 6, 2025
1fde8b1
risky change to figure open
Jammy2211 Apr 6, 2025
875483c
fix structurre plotters now JAX arrays used
Jammy2211 Apr 6, 2025
1088baf
over sampler fix
Jammy2211 Apr 6, 2025
f72b15e
fix kmesh
Jammy2211 Apr 6, 2025
ac68a08
fix test repr
Jammy2211 Apr 6, 2025
678673e
fix test preprocess
Jammy2211 Apr 6, 2025
c611bb6
fix abstract dataset tests
Jammy2211 Apr 6, 2025
e1899a1
fix test imaigng data
Jammy2211 Apr 6, 2025
7d5ba26
fix simulator
Jammy2211 Apr 6, 2025
1feabff
fix interferometrer
Jammy2211 Apr 6, 2025
2c02e39
dataset tests pass
Jammy2211 Apr 6, 2025
87931b6
fix fit tests
Jammy2211 Apr 6, 2025
0fa9646
mask array 2d choose type based on if input is jax or numpy
Jammy2211 Apr 6, 2025
578bc4c
array 1d now use convert rules
Jammy2211 Apr 6, 2025
32f55f7
grid 2d casting
Jammy2211 Apr 6, 2025
b5601e8
grid 2d follows rules
Jammy2211 Apr 6, 2025
49b8dc5
grid 1d no conversion
Jammy2211 Apr 6, 2025
d3c5bbf
same rules for grids
Jammy2211 Apr 6, 2025
61a72b8
fix dataset unit tests
Jammy2211 Apr 6, 2025
15cd0b5
fix some unit tests
Jammy2211 Apr 6, 2025
d33ab44
mask now also follows strictire typing rules
Jammy2211 Apr 6, 2025
3aef08a
but of typing simplication
Jammy2211 Apr 6, 2025
5d96ba5
check conversion needed
Jammy2211 Apr 6, 2025
ccccf1a
add small non zero numerical value to make grid radials plot correctly
Jammy2211 Apr 7, 2025
7948e90
small shift in projected grid
Jammy2211 Apr 7, 2025
5fc2b63
fix some unit tests due to numerical addition on projeted grid
Jammy2211 Apr 8, 2025
135d5b8
black
Jammy2211 Apr 8, 2025
ca6a934
Merge pull request #170 from Jammy2211/feature/jax_ndarray_casting_rules
Jammy2211 Apr 8, 2025
6451bc9
furthest_distances_to_other_coordinates now supports JAX
Jammy2211 Apr 9, 2025
35604bb
grid_of_closest_from convert to JAX
Jammy2211 Apr 9, 2025
a625672
black
Jammy2211 Apr 9, 2025
107e470
moved over most zoom-y stuff
Apr 21, 2025
5f3edd1
removed more zoom stuf
Apr 21, 2025
ec5af95
more stuff moved over
Apr 21, 2025
ec82687
black
Apr 21, 2025
191e69c
updated to use new zoom API
Apr 21, 2025
92a01e3
zoom refactor done
Apr 29, 2025
ec8423a
remove some tests
Apr 29, 2025
a4ce9db
update extracted_array_2d_from to not use numba
May 4, 2025
80bc2a5
updat eunit test shapes
May 4, 2025
7251e10
unit test fix
May 4, 2025
70c4694
fix visualization output
May 4, 2025
3820704
unit test fixes
May 4, 2025
c3e4ce3
black and minor fixes
Jun 11, 2025
0edcf7d
updated factory test on two mappers to use more stable solver
Jun 12, 2025
0521b1b
fix test__identical_inversion_source_and_image_loops
Jun 13, 2025
506d512
check for nans when raising InversionException
Jun 13, 2025
e184330
test__inversion_imaging__via_linear_obj_func_and_mapper__force_edge_p…
Jun 13, 2025
39fa6fc
fix test__inversion_imaging__linear_obj_func_with_w_tilde
Jun 13, 2025
7eae31e
fix test__inversion_imaging__linear_obj_func_and_non_func_give_same_t…
Jun 13, 2025
c20897b
fix test__inversion_imaging__linear_obj_func_with_w_tilde
Jun 13, 2025
949da4a
test__identical_inversion_values_for_two_methods
Jun 13, 2025
fd2e601
midway through some refactoring
Jun 13, 2025
c085e88
remove force_edge_image_pixels_to_zeros
Jun 15, 2025
7d91158
convert JAX Array
Jun 15, 2025
9c70932
ndarray conversions for JAX arrays
Jun 15, 2025
97abd12
fix all inversion unit tests
Jun 15, 2025
8051e35
black
Jun 15, 2025
c161cda
Merge pull request #175 from Jammy2211/feature/jax_linear_light
Jammy2211 Jun 15, 2025
09784fa
removed numpy based Trianghle and fixed first test
Jun 17, 2025
3f24384
fix one_triangle.triangles
Jun 17, 2025
2db16d9
test__above
Jun 17, 2025
81d1c70
multiple fixes in test_coordinate_implementation.py
Jun 17, 2025
6c35ed2
remove containment test
Jun 17, 2025
075e8e7
test_array_representation deleted
Jun 17, 2025
3933b55
fix test_extended_source
Jun 17, 2025
20ceaca
JAx CoordinateArrayTriangles has explicit JAX use now
Jun 17, 2025
d83f4ba
simplfiied coordinate_array and removed support for numpy
Jun 17, 2025
1094154
refactor arrayu to only use JAX
Jun 17, 2025
b4f1853
simplify unit tests
Jun 17, 2025
33eb087
black
Jun 17, 2025
ad4211e
black
Jun 17, 2025
e0c70a7
update to much simpler tests for transformer
Jun 17, 2025
fd383e1
remove pylops legact
Jun 17, 2025
7961446
fix unitt est with shaping
Jun 17, 2025
25a5096
preload_real_transforms converted to numpy
Jun 18, 2025
88748fd
preload_imag_transforms
Jun 18, 2025
4415121
rename functions
Jun 18, 2025
ca034e7
transformed_mapping_matrix_from
Jun 18, 2025
b49e920
transformed_mapping_matrix_via_preload_from
Jun 18, 2025
19c14a4
simplify Transformer
Jun 18, 2025
4c41eb1
TransformerDFT docstring
Jun 18, 2025
5500b22
transformer now uses jax arrays in DFT
Jun 18, 2025
93af09f
remove ordered_1d
Jun 18, 2025
6526dde
JAX intereferometer grad works
Jun 18, 2025
ea9d772
dft_preload_transform added to Interferometer inputs
Jun 18, 2025
3a6a5dc
black
Jun 18, 2025
aac0fbd
Merge pull request #176 from Jammy2211/feature/jax_only_triangles
Jammy2211 Jun 18, 2025
5ad6997
refactor use of dirty image in inversion
Jun 18, 2025
c798d17
JAX on transforfed mapping matrtix
Jun 18, 2025
ff0a76a
update plot
Jun 18, 2025
a6d7729
Merge pull request #177 from Jammy2211/feature/jax_interferometer
Jammy2211 Jun 18, 2025
af405a6
resized_array_2d_from converted to numpy
Jun 18, 2025
f095be5
remove replace_noise_map_2d_values_where_image_2d_values_are_negative
Jun 18, 2025
648a815
index_2d_for_index_slim_from no longer uses numba
Jun 18, 2025
4a8814a
index_slim_for_index_2d_from does not use numba
Jun 18, 2025
d69e39f
remove array_2d_slim_complex_from
Jun 18, 2025
3832bca
remove array_2d_native_complex_via_indexes_from
Jun 18, 2025
783ff84
remove numba from two more functions ing rid_2d_util
Jun 18, 2025
ab45cf5
remove grid_2d_slim_upscaled_from
Jun 18, 2025
8bb449a
remove native_sub_index_for_slim_sub_index_2d_from
Jun 18, 2025
0f03df7
remove redudnant tests
Jun 18, 2025
5a163af
slim_index_for_sub_slim_index_via_mask_2d_from no longer uses numba
Jun 18, 2025
fbdcf80
removed sub_slim_index_for_sub_native_index_from
Jun 18, 2025
91e8805
remove oversample_mask_2d_from
Jun 18, 2025
c4a6339
sub_size_radial_bins_from no longer uses numba
Jun 18, 2025
7b1c3b2
grid_2d_slim_over_sampled_via_mask_from does not use mask
Jun 19, 2025
d5d28c9
removed final numba functions not used for inversion
Jun 19, 2025
8fca9ba
numba stops pixelization if not installed
Jun 19, 2025
5c41ba1
test
Jun 19, 2025
650c8a0
merge
Jun 19, 2025
fdac408
use simpler grid_2d_slim_over_sampled_via_mask_from which works
Jun 19, 2025
a0e65da
temporary solution
Jun 19, 2025
ae26201
remove profile func
Jun 23, 2025
a8e48c1
remove run_time_dict
Jun 23, 2025
fee7499
remove run time dict
Jun 23, 2025
50d7ae8
removed all profilng and tested
Jun 24, 2025
f0892ba
black
Jun 24, 2025
8c71cae
Merge pull request #179 from Jammy2211/feature/jax_remove_profiling
Jammy2211 Jun 24, 2025
fd11b17
rectangular uses intterpolation with JAX support now
Jun 24, 2025
8a69509
fix visualization unit tests
Jun 24, 2025
0e1f23b
test_autoarray/inversion/pixelization/mappers/test_rectangular.py
Jun 24, 2025
a95464f
test_autoarray/inversion/pixelization/mappers/test_factory.py -> rect…
Jun 24, 2025
ecfe98d
interpolate works but now need to remove convolver
Jun 24, 2025
5b08271
convolver convler mapping matrix replaced for psf, fixes issue
Jun 24, 2025
d7bdb3c
fix test__inversion_matrices__x2_mappers
Jun 24, 2025
4482a24
fix test_autoarray/inversion/inversion/imaging/test_imaging.py
Jun 24, 2025
84aaa3c
fix test__curvature_matrix
Jun 24, 2025
4d7f1c3
fix test__inversion_imaging__via_mapper
Jun 24, 2025
e01c918
all unit tests pass
Jun 24, 2025
2c0edb8
Merge pull request #181 from Jammy2211/feature/jax_remove_convolver
Jammy2211 Jun 24, 2025
18d6a46
clean up imports
Jun 24, 2025
9397291
comment out test
Jun 25, 2025
e8d767a
fix fit _util whilst reudcing imports
Jun 25, 2025
b4c0db4
remove many np.array() conversions
Jun 25, 2025
5f539ce
black
Jun 25, 2025
2736e33
remove cached properties from BorderRelocator for easier JAx
Jun 25, 2025
cd50e74
removed for loop from border function
Jun 25, 2025
ea252a2
converted relocated_grid_from to use JAX
Jun 25, 2025
b7cb3c0
border relocator function converred to JAX
Jun 25, 2025
dec6b51
grid_2d_slim_via_shape_native_not_mask_from
Jun 25, 2025
629e7e7
Rectangular fidxes
Jun 25, 2025
0766edd
fix w tild eiwth some ndarray conversions
Jun 26, 2025
3d0233f
fix over sampling tests
Jun 26, 2025
9673b4c
fix plotting
Jun 26, 2025
7e3509c
fix some more tests due to numba jax
Jun 26, 2025
b35e32d
coment out test to get past it for now, think its just linear lagebra…
Jun 26, 2025
74bfcda
updated _Reducedd matrices to use zeroing
Jun 26, 2025
1d6d517
regularization_matrix_Reduced
Jun 26, 2025
e4219b0
fix test
Jun 26, 2025
1b5b64f
add preloading in order to pass mapper indexes
Jun 26, 2025
93157b8
full JAX success
Jun 26, 2025
0b424c4
adaptive_pixel_signals_from JAX-d
Jun 26, 2025
a92d828
convert mapped_to_source_via_mapping_matrix_from to numpy
Jun 26, 2025
5677589
update data_weight_total_for_pix_from
Jun 26, 2025
3dc49b0
moved sub_slim_indexes_for_pix_index to inversion_interferometer_util
Jun 26, 2025
9a316ef
Merge pull request #182 from Jammy2211/feature/jax_speed_up_general
Jammy2211 Jul 12, 2025
101b704
convert soem regularization util functions from numba to numpy
Jul 13, 2025
2896a0c
remove minus ones in data_weight_total_for_pix_from
Jul 13, 2025
6ebd7a3
mapper index list returns ndarray
Jul 13, 2025
e758555
fix autoarray/inversion/inversion/interferometer/w_tilde.py
Jul 13, 2025
9aa26b5
fixed bug where regularization matrix was returned for curvature_reg_…
Jul 13, 2025
968a994
fix case where border relocator is off
Jul 13, 2025
8af880d
w tilde now default to false
Jul 13, 2025
bbef0e3
remove old source pixel zeroing functionality
Jul 13, 2025
415926f
docuemnet preloasds and mapper_index_list -> mapper_indices
Jul 13, 2025
12b2d72
fix last unit test
Jul 13, 2025
67a5830
black
Jul 13, 2025
741ff6a
Merge pull request #183 from Jammy2211/feature/jax_inversion
Jammy2211 Jul 13, 2025
6e7be6e
removed all reference to include in plotting
Jul 14, 2025
0cb75eb
visual clean up complete
Jul 21, 2025
5d94e9a
Merge pull request #184 from Jammy2211/feature/jax_simplify_visualiza…
Jammy2211 Jul 22, 2025
1d6043b
regularization util JAX conversons, seem to work
Jul 23, 2025
8a6951a
gaussian kernel converted successfully
Jul 23, 2025
a509014
fix constant zeroth
Jul 23, 2025
150a360
move utils to their specific modules
Jul 23, 2025
a63369d
regulsirztion refactor complete and rectangular works
Jul 23, 2025
f0a4a8d
updates for rectangular grid sorted
Jul 24, 2025
9b97bcb
Merge pull request #185 from Jammy2211/feature/jax_rectangular_slam
Jammy2211 Jul 24, 2025
ad33991
fix pos neg
Jul 28, 2025
e46b087
refactor kernel mapping for speed
Aug 11, 2025
3fecb5e
backup
Aug 18, 2025
8f00795
adpative stuff implemented
Sep 15, 2025
e3332b6
edges unit test passes
Sep 15, 2025
972690b
rectangular edges
Sep 16, 2025
c939bd0
edges_transformed
Sep 16, 2025
5c352bb
remove InversionException tests
Sep 18, 2025
c1fff0e
comment out edges test
Sep 18, 2025
dde10c4
data_vector_via_w_tilde_data_imaging_from converted to JAX
Sep 18, 2025
6fd813b
_data_vector_mapper in JAX
Sep 18, 2025
9eea7c2
_curvature_matrix_x1_mapper
Sep 18, 2025
74c8d39
mapped_reconstructed_data_via_w_tilde_from
Sep 18, 2025
c38ed07
lots of hacky stuff
Sep 21, 2025
e3a8341
remove commented code
Sep 22, 2025
9b1c91c
fixed mappings and weights
Sep 23, 2025
216a8fc
test__edges_transformed
Sep 23, 2025
2ecd33b
test__areas_transformed
Sep 23, 2025
9c52d10
all of inversion tests pass
Sep 23, 2025
52ca2ac
all unit tests pass
Sep 23, 2025
e5e1a07
Merge pull request #186 from Jammy2211/feature/jax_wrapper_rectangula…
Jammy2211 Sep 23, 2025
d7d31f6
inversion_util
Sep 25, 2025
42aea76
back to stable
Sep 25, 2025
485e076
inned supeer sampling array uses jax.ops.segment_sum for fast compile
Sep 30, 2025
007b7b2
remove offset subtraction
Oct 1, 2025
b2fb2a6
moved first 3 numba functions
Oct 1, 2025
d9a1582
moved more functions
Oct 1, 2025
e5ee4dd
update dataset w tilde to use numba util
Oct 1, 2025
88a8d93
update w_tilde to use numba utils
Oct 1, 2025
6c8a325
fix _data_vector_x1_mapper
Oct 1, 2025
3126e93
more direct split of numba code
Oct 1, 2025
673f075
fix unitt est due to jax operated mapping matrix
Oct 1, 2025
724af26
test__data_linear_func_matrix_dict
Oct 1, 2025
6f96966
update data_linear_func_matrix_from to not use numba
Oct 1, 2025
6877618
remove use of frames in off diag
Oct 1, 2025
a3ab883
docstring
Oct 1, 2025
aa71fb3
remove convolver
Oct 1, 2025
179ef69
split mappeR_util with numba
Oct 1, 2025
e3a0c7f
split mesh_util into numba
Oct 1, 2025
6d6a3e7
replace more unit tests so they use numba util
Oct 1, 2025
5525b9c
fix mesh numab improt
Oct 1, 2025
ea3067c
JAX complet
Oct 6, 2025
06ba8e3
Merge pull request #187 from Jammy2211/feature/jax_numba_mix
Jammy2211 Oct 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- uses: actions/cache@v2
- uses: actions/cache@v3
id: cache-pip
with:
path: ~/.cache/pip
Expand All @@ -36,9 +36,9 @@ jobs:
pip3 install setuptools
pip3 install wheel
pip3 install pytest coverage pytest-cov
pip3 install -r PyAutoConf/requirements.txt
pip3 install -r PyAutoArray/requirements.txt
pip3 install -r PyAutoArray/optional_requirements.txt
pip install ./PyAutoConf
pip install ./PyAutoArray
pip install ./PyAutoArray[optional]

cd PyAutoArray/autoarray/util/nn/src/nn
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/runner/work/PyAutoArray/PyAutoArray/PyAutoArray/autoarray/util/nn/src/nn
Expand Down
12 changes: 10 additions & 2 deletions autoarray/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from autoconf.dictable import register_parser
from autoconf import conf

conf.instance.register(__file__)

from . import exc
from . import type
from . import util
from . import fixtures
from . import mock as m
from .numba_util import profile_func
from .dataset import preprocess
from .dataset.abstract.dataset import AbstractDataset
from .dataset.abstract.w_tilde import AbstractWTilde
Expand Down Expand Up @@ -38,20 +42,21 @@
from .inversion.pixelization.mappers.rectangular import MapperRectangular
from .inversion.pixelization.mappers.delaunay import MapperDelaunay
from .inversion.pixelization.mappers.voronoi import MapperVoronoi
from .inversion.pixelization.mappers.rectangular_uniform import MapperRectangularUniform
from .inversion.pixelization.image_mesh.abstract import AbstractImageMesh
from .inversion.pixelization.mesh.abstract import AbstractMesh
from .inversion.inversion.imaging.mapping import InversionImagingMapping
from .inversion.inversion.imaging.w_tilde import InversionImagingWTilde
from .inversion.inversion.interferometer.w_tilde import InversionInterferometerWTilde
from .inversion.inversion.interferometer.mapping import InversionInterferometerMapping
from .inversion.inversion.interferometer.lop import InversionInterferometerMappingPyLops
from .inversion.linear_obj.linear_obj import LinearObj
from .inversion.linear_obj.func_list import AbstractLinearObjFuncList
from .mask.derive.indexes_2d import DeriveIndexes2D
from .mask.derive.mask_1d import DeriveMask1D
from .mask.derive.mask_2d import DeriveMask2D
from .mask.derive.grid_1d import DeriveGrid1D
from .mask.derive.grid_2d import DeriveGrid2D
from .mask.derive.zoom_2d import Zoom2D
from .mask.mask_1d import Mask1D
from .mask.mask_2d import Mask2D
from .operators.transformer import TransformerDFT
Expand All @@ -60,14 +65,17 @@
from .operators.contour import Grid2DContour
from .layout.layout import Layout1D
from .layout.layout import Layout2D
from .preloads import Preloads
from .structures.arrays.uniform_1d import Array1D
from .structures.arrays.uniform_2d import Array2D
from .structures.arrays.rgb import Array2DRGB
from .structures.arrays.irregular import ArrayIrregular
from .structures.grids.uniform_1d import Grid1D
from .structures.grids.uniform_2d import Grid2D
from .operators.over_sampling.over_sampler import OverSampler
from .structures.grids.irregular_2d import Grid2DIrregular
from .structures.mesh.rectangular_2d import Mesh2DRectangular
from .structures.mesh.rectangular_2d_uniform import Mesh2DRectangularUniform
from .structures.mesh.voronoi_2d import Mesh2DVoronoi
from .structures.mesh.delaunay_2d import Mesh2DDelaunay
from .structures.arrays.kernel_2d import Kernel2D
Expand Down
2 changes: 1 addition & 1 deletion autoarray/abstract_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def output_to_fits(self, file_path: str, overwrite: bool = False):
If a file already exists at the path, if overwrite=True it is overwritten else an error is raised.
"""
output_to_fits(
values=self.native.array,
values=self.native.array.astype("float"),
file_path=file_path,
overwrite=overwrite,
header_dict=self.mask.header_dict,
Expand Down
2 changes: 2 additions & 0 deletions autoarray/config/general.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
jax:
use_jax: true # If True, uses JAX internally, whereas False uses normal Numpy.
fits:
flip_for_ds9: false # If True, the image is flipped before output to a .fits file, which is useful for viewing in DS9.
inversion:
Expand Down
1 change: 0 additions & 1 deletion autoarray/config/visualize/general.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ general:
log10_min_value: 1.0e-4 # If negative values are being plotted on a log10 scale, values below this value are rounded up to it (e.g. to remove negative values).
log10_max_value: 1.0e99 # If positive values are being plotted on a log10 scale, values above this value are rounded down to it (e.g. to prevent white blobs).
zoom_around_mask: true # If True, plots of data structures with a mask automatically zoom in the masked region.
disable_zoom_for_fits: true # If True, the zoom-in around the masked region is disabled when outputting .fits files, which is useful to retain the same dimensions as the input data.
inversion:
reconstruction_vmax_factor: 0.5
total_mappings_pixels : 8 # The number of source pixels used when plotting the subplot_mappings of a pixelization.
Expand Down
20 changes: 0 additions & 20 deletions autoarray/config/visualize/include.yaml

This file was deleted.

78 changes: 18 additions & 60 deletions autoarray/dataset/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from autoarray.mask.mask_2d import Mask2D
from autoarray.structures.arrays.uniform_2d import Array2D
from autoarray.structures.arrays.kernel_2d import Kernel2D
from autoarray.structures.grids.uniform_1d import Grid1D
from autoarray.structures.grids.uniform_2d import Grid2D

from autoarray.inversion.pixelization.border_relocator import BorderRelocator
from autoconf import cached_property

from autoarray import exc


class GridsDataset:
Expand All @@ -24,7 +24,7 @@ def __init__(

The following grids are contained:

- `uniform`: A grids of (y,x) coordinates which aligns with the centre of every image pixel of the image data,
- `lp`: A grids of (y,x) coordinates which aligns with the centre of every image pixel of the image data,
which is used for most normal calculations (e.g. evaluating the amount of light that falls in an pixel
from a light profile).

Expand Down Expand Up @@ -60,72 +60,30 @@ def __init__(
self.over_sample_size_pixelization = over_sample_size_pixelization
self.psf = psf

@cached_property
def lp(self) -> Union[Grid1D, Grid2D]:
"""
Returns the grid of (y,x) Cartesian coordinates at the centre of every pixel in the masked data, which is used
to perform most normal calculations (e.g. evaluating the amount of light that falls in an pixel from a light
profile).

This grid is computed based on the mask, in particular its pixel-scale and sub-grid size.

Returns
-------
The (y,x) coordinates of every pixel in the data.
"""
return Grid2D.from_mask(
self.lp = Grid2D.from_mask(
mask=self.mask,
over_sample_size=self.over_sample_size_lp,
)
self.lp.over_sampled

@cached_property
def pixelization(self) -> Grid2D:
"""
Returns the grid of (y,x) Cartesian coordinates of every pixel in the masked data which is used
specifically for calculations associated with a pixelization.

The `pixelization` grid is identical to the `uniform` grid but often uses a different over sampling scheme
when performing calculations. For example, the pixelization may benefit from using a a higher `sub_size` than
the `uniform` grid, in order to better prevent aliasing effects.

This grid is computed based on the mask, in particular its pixel-scale and sub-grid size.

Returns
-------
The (y,x) coordinates of every pixel in the data, used for pixelization / inversion calculations.
"""
return Grid2D.from_mask(
self.pixelization = Grid2D.from_mask(
mask=self.mask,
over_sample_size=self.over_sample_size_pixelization,
)

@cached_property
def blurring(self) -> Optional[Grid2D]:
"""
Returns a blurring-grid from a mask and the 2D shape of the PSF kernel.

A blurring grid consists of all pixels that are masked (and therefore have their values set to (0.0, 0.0)),
but are close enough to the unmasked pixels that their values will be convolved into the unmasked those pixels.
This when computing images from light profile objects.

This uses lazy allocation such that the calculation is only performed when the blurring grid is used, ensuring
efficient set up of the `Imaging` class.

Returns
-------
The blurring grid given the mask of the imaging data.
"""
self.pixelization.over_sampled

if self.psf is None:
return None

return self.lp.blurring_grid_via_kernel_shape_from(
kernel_shape_native=self.psf.shape_native,
)

@cached_property
def border_relocator(self) -> BorderRelocator:
return BorderRelocator(
self.blurring = None
else:
try:
self.blurring = self.lp.blurring_grid_via_kernel_shape_from(
kernel_shape_native=self.psf.shape_native,
)
self.blurring.over_sampled
except exc.MaskException:
self.blurring = None

self.border_relocator = BorderRelocator(
mask=self.mask, sub_size=self.over_sample_size_pixelization
)

Expand Down
44 changes: 33 additions & 11 deletions autoarray/dataset/imaging/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from autoarray import exc
from autoarray.operators.over_sampling import over_sample_util
from autoarray.inversion.inversion.imaging import inversion_imaging_util
from autoarray.inversion.inversion.imaging import inversion_imaging_numba_util

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -159,9 +159,26 @@ def __init__(
"""
)

if psf is not None and use_normalized_psf:
if psf is not None:

if not data.mask.is_all_false:

image_mask = data.mask
blurring_mask = data.mask.derive_mask.blurring_from(
kernel_shape_native=psf.shape_native
)

else:

image_mask = None
blurring_mask = None

psf = Kernel2D.no_mask(
values=psf.native._array, pixel_scales=psf.pixel_scales, normalize=True
values=psf.native._array,
pixel_scales=psf.pixel_scales,
normalize=use_normalized_psf,
image_mask=image_mask,
blurring_mask=blurring_mask,
)

self.psf = psf
Expand All @@ -170,9 +187,7 @@ def __init__(
if psf.mask.shape[0] % 2 == 0 or psf.mask.shape[1] % 2 == 0:
raise exc.KernelException("Kernel2D Kernel2D must be odd")

@cached_property
def grids(self):
return GridsDataset(
self.grids = GridsDataset(
mask=self.data.mask,
over_sample_size_lp=self.over_sample_size_lp,
over_sample_size_pixelization=self.over_sample_size_pixelization,
Expand Down Expand Up @@ -203,17 +218,22 @@ def w_tilde(self):
curvature_preload,
indexes,
lengths,
) = inversion_imaging_util.w_tilde_curvature_preload_imaging_from(
) = inversion_imaging_numba_util.w_tilde_curvature_preload_imaging_from(
noise_map_native=np.array(self.noise_map.native.array).astype("float64"),
kernel_native=np.array(self.psf.native.array).astype("float64"),
native_index_for_slim_index=np.array(self.mask.derive_indexes.native_for_slim).astype("int"),
native_index_for_slim_index=np.array(
self.mask.derive_indexes.native_for_slim
).astype("int"),
)

return WTildeImaging(
curvature_preload=curvature_preload,
indexes=indexes.astype("int"),
lengths=lengths.astype("int"),
noise_map_value=self.noise_map[0],
noise_map=self.noise_map,
psf=self.psf,
mask=self.mask,
)

@classmethod
Expand Down Expand Up @@ -409,12 +429,12 @@ def apply_noise_scaling(
"""

if signal_to_noise_value is None:
noise_map = np.array(self.noise_map.native.array)
noise_map = self.noise_map.native
noise_map[mask.array == False] = noise_value
else:
noise_map = np.where(
mask == False,
np.median(self.data.native.array[mask.derive_mask.edge == False])
np.median(self.data.native[mask.derive_mask.edge == False])
/ signal_to_noise_value,
self.noise_map.native.array,
)
Expand Down Expand Up @@ -488,7 +508,7 @@ def apply_over_sampling(
passed into the calculations performed in the `inversion` module.
"""

return Imaging(
dataset = Imaging(
data=self.data,
noise_map=self.noise_map,
psf=self.psf,
Expand All @@ -499,6 +519,8 @@ def apply_over_sampling(
check_noise_map=False,
)

return dataset

def output_to_fits(
self,
data_path: Union[Path, str],
Expand Down
6 changes: 4 additions & 2 deletions autoarray/dataset/imaging/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def via_image_from(
pixel_scales=image.pixel_scales,
)

if np.isnan(noise_map).any():
if np.isnan(noise_map.array).any():
raise exc.DatasetException(
"The noise-map has NaN values in it. This suggests your exposure time and / or"
"background sky levels are too low, creating signal counts at or close to 0.0."
Expand All @@ -161,7 +161,9 @@ def via_image_from(
image = image - background_sky_map

mask = Mask2D.all_false(
shape_native=image.shape_native, pixel_scales=image.pixel_scales
shape_native=image.shape_native,
pixel_scales=image.pixel_scales,
origin=image.origin,
)

image = Array2D(values=image, mask=mask)
Expand Down
Loading
Loading