Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions autoarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
from .layout.layout import Layout2D
from .structures.arrays.uniform_1d import Array1D
from .structures.arrays.uniform_2d import Array2D
from .structures.arrays.rgb import Array2DRGB
from .structures.arrays.irregular import ArrayIrregular
from .structures.grids.uniform_1d import Grid1D
from .structures.grids.uniform_2d import Grid2D
Expand Down
4 changes: 4 additions & 0 deletions autoarray/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def make_array_2d_7x7():
return aa.Array2D.ones(shape_native=(7, 7), pixel_scales=(1.0, 1.0))


def make_array_2d_rgb_7x7():
return aa.Array2DRGB(values=np.ones((7, 7, 3)), mask=make_mask_2d_7x7())


def make_layout_2d_7x7():
return aa.Layout2D(
shape_2d=(7, 7),
Expand Down
56 changes: 56 additions & 0 deletions autoarray/mask/derive/zoom_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

if TYPE_CHECKING:
from autoarray.structures.arrays.uniform_2d import Array2D
from autoarray.structures.arrays.rgb import Array2DRGB
from autoarray.mask.mask_2d import Mask2D


from autoarray.structures.arrays import array_2d_util
from autoarray.structures.grids import grid_2d_util

Expand Down Expand Up @@ -242,8 +244,12 @@ def array_2d_from(self, array: Array2D, buffer: int = 1) -> Array2D:
The number pixels around the extracted array used as a buffer.
"""
from autoarray.structures.arrays.uniform_2d import Array2D
from autoarray.structures.arrays.rgb import Array2DRGB
from autoarray.mask.mask_2d import Mask2D

if isinstance(array, Array2DRGB):
return self.array_2d_rgb_from(array=array, buffer=buffer)

extracted_array_2d = array_2d_util.extracted_array_2d_from(
array_2d=np.array(array.native),
y0=self.region[0] - buffer,
Expand All @@ -269,3 +275,53 @@ def array_2d_from(self, array: Array2D, buffer: int = 1) -> Array2D:
arr = array_2d_util.convert_array_2d(array_2d=extracted_array_2d, mask_2d=mask)

return Array2D(values=arr, mask=mask, header=array.header)

def array_2d_rgb_from(self, array: Array2DRGB, buffer: int = 1) -> Array2DRGB:
"""
Extract the 2D region of an RGB array corresponding to the rectangle encompassing all unmasked values.

This works the same as the `array_2d_from` method, but for RGB arrays, meaning that it iterates over the three
channels of the RGB array and extracts the region for each channel separately.

This is used to extract and visualize only the region of an RGB image that is used in an analysis.

Parameters
----------
buffer
The number pixels around the extracted array used as a buffer.
"""
from autoarray.structures.arrays.rgb import Array2DRGB
from autoarray.mask.mask_2d import Mask2D

for i in range(3):

extracted_array_2d = array_2d_util.extracted_array_2d_from(
array_2d=np.array(array.native[:, :, i]),
y0=self.region[0] - buffer,
y1=self.region[1] + buffer,
x0=self.region[2] - buffer,
x1=self.region[3] + buffer,
)

if i == 0:
array_2d_rgb = np.zeros(
(extracted_array_2d.shape[0], extracted_array_2d.shape[1], 3)
)

array_2d_rgb[:, :, i] = extracted_array_2d

extracted_mask_2d = array_2d_util.extracted_array_2d_from(
array_2d=np.array(self.mask),
y0=self.region[0] - buffer,
y1=self.region[1] + buffer,
x0=self.region[2] - buffer,
x1=self.region[3] + buffer,
)

mask = Mask2D(
mask=extracted_mask_2d,
pixel_scales=array.pixel_scales,
origin=array.mask.mask_centre,
)

return Array2DRGB(values=array_2d_rgb.astype("int"), mask=mask)
37 changes: 27 additions & 10 deletions autoarray/plot/mat_plot/two_d.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from autoarray.plot.visuals.two_d import Visuals2D
from autoarray.mask.derive.zoom_2d import Zoom2D
from autoarray.structures.arrays.uniform_2d import Array2D
from autoarray.structures.arrays.rgb import Array2DRGB

from autoarray.structures.arrays import array_2d_util

Expand Down Expand Up @@ -243,7 +244,6 @@ def plot_array(
bypass
If `True`, `plt.close` is omitted and the matplotlib figure remains open. This is used when making subplots.
"""

if array is None or np.all(array == 0):
return

Expand Down Expand Up @@ -280,14 +280,25 @@ def plot_array(

origin = conf.instance["visualize"]["general"]["general"]["imshow_origin"]

plt.imshow(
X=array.native.array,
aspect=aspect,
cmap=self.cmap.cmap,
norm=norm,
extent=extent,
origin=origin,
)
if isinstance(array, Array2DRGB):

plt.imshow(
X=array.native.array,
aspect=aspect,
extent=extent,
origin=origin,
)

else:

plt.imshow(
X=array.native.array,
aspect=aspect,
cmap=self.cmap.cmap,
norm=norm,
extent=extent,
origin=origin,
)

if visuals_2d.array_overlay is not None:
self.array_overlay.overlay_array(
Expand Down Expand Up @@ -317,7 +328,12 @@ def plot_array(
pixels=array.shape_native[1],
)

self.title.set(auto_title=auto_labels.title, use_log10=self.use_log10)
if isinstance(array, Array2DRGB):
title = "RGB"
else:
title = auto_labels.title

self.title.set(auto_title=title, use_log10=self.use_log10)
self.ylabel.set()
self.xlabel.set()

Expand All @@ -332,6 +348,7 @@ def plot_array(
[annotate.set() for annotate in self.annotate]

if self.colorbar is not False:

cb = self.colorbar.set(
units=self.units,
ax=ax,
Expand Down
45 changes: 45 additions & 0 deletions autoarray/structures/arrays/rgb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from autoarray.abstract_ndarray import AbstractNDArray
from autoarray.structures.arrays.uniform_2d import Array2D


class Array2DRGB(Array2D):

def __init__(self, values, mask):
"""
A container for RGB images which have a final dimension of 3, which allows them to be visualized using
the same functionality as `Array2D` objects.

By passing an RGB image to this class, the following visualization functionality is used when the RGB
image is used in `Plotter` objects:

- The RGB image is plotted using the `imshow` function of Matplotlib.
- Functionality which sets the scale of the axis, zooms the image, and sets the axis limits is used.
- The colorbar is set to the RGB image, which is a 3D array with a final dimension of 3.
- The formatting of the image is identical to that of `Array2D` objects, which means the image is plotted
with the same aspect ratio as the original image making for easy subplot formatting.

This class always assumes the array is in its `native` representation, but with a final dimension of 3.

Parameters
----------
values
The values of the RGB image, which is a 3D array with a final dimension of 3.
mask
The 2D mask associated with the array, defining the pixels each array value in its ``slim`` representation
is paired with.
"""

array = values

while isinstance(array, AbstractNDArray):
array = array.array

self._array = array
self.mask = mask

@property
def native(self) -> "Array2D":
"""
Returns the RGB ndarray of shape [total_y_pixels, total_x_pixels, 3] in its `native` representation.
"""
return self
5 changes: 5 additions & 0 deletions test_autoarray/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ def make_array_2d_7x7():
return fixtures.make_array_2d_7x7()


@pytest.fixture(name="array_2d_rgb_7x7")
def make_array_2d_rgb_7x7():
return fixtures.make_array_2d_rgb_7x7()


@pytest.fixture(name="layout_2d_7x7")
def make_layout_2d_7x7():
return fixtures.make_layout_2d_7x7()
Expand Down
17 changes: 17 additions & 0 deletions test_autoarray/structures/plot/test_structure_plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,20 @@ def test__grid(
grid_2d_plotter.figure_2d(color_array=color_array)

assert path.join(plot_path, "grid3.png") in plot_patch.paths


def test__array_rgb(
array_2d_rgb_7x7,
plot_path,
plot_patch,
):
array_plotter = aplt.Array2DPlotter(
array=array_2d_rgb_7x7,
mat_plot_2d=aplt.MatPlot2D(
output=aplt.Output(path=plot_path, filename="array_rgb", format="png")
),
)

array_plotter.figure_2d()

assert path.join(plot_path, "array_rgb.png") in plot_patch.paths
Loading