diff --git a/autoarray/__init__.py b/autoarray/__init__.py index b48a9c2ae..8fce98366 100644 --- a/autoarray/__init__.py +++ b/autoarray/__init__.py @@ -66,6 +66,7 @@ from .layout.layout import Layout2D from .structures.arrays.uniform_1d import Array1D from .structures.arrays.uniform_2d import Array2D +from .structures.arrays.rgb import Array2DRGB from .structures.arrays.irregular import ArrayIrregular from .structures.grids.uniform_1d import Grid1D from .structures.grids.uniform_2d import Grid2D diff --git a/autoarray/fixtures.py b/autoarray/fixtures.py index acfd277df..00ad08607 100644 --- a/autoarray/fixtures.py +++ b/autoarray/fixtures.py @@ -68,6 +68,10 @@ def make_array_2d_7x7(): return aa.Array2D.ones(shape_native=(7, 7), pixel_scales=(1.0, 1.0)) +def make_array_2d_rgb_7x7(): + return aa.Array2DRGB(values=np.ones((7, 7, 3)), mask=make_mask_2d_7x7()) + + def make_layout_2d_7x7(): return aa.Layout2D( shape_2d=(7, 7), diff --git a/autoarray/mask/derive/zoom_2d.py b/autoarray/mask/derive/zoom_2d.py index 69c49b7cd..5cae65df4 100644 --- a/autoarray/mask/derive/zoom_2d.py +++ b/autoarray/mask/derive/zoom_2d.py @@ -4,8 +4,10 @@ if TYPE_CHECKING: from autoarray.structures.arrays.uniform_2d import Array2D + from autoarray.structures.arrays.rgb import Array2DRGB from autoarray.mask.mask_2d import Mask2D + from autoarray.structures.arrays import array_2d_util from autoarray.structures.grids import grid_2d_util @@ -242,8 +244,12 @@ def array_2d_from(self, array: Array2D, buffer: int = 1) -> Array2D: The number pixels around the extracted array used as a buffer. """ from autoarray.structures.arrays.uniform_2d import Array2D + from autoarray.structures.arrays.rgb import Array2DRGB from autoarray.mask.mask_2d import Mask2D + if isinstance(array, Array2DRGB): + return self.array_2d_rgb_from(array=array, buffer=buffer) + extracted_array_2d = array_2d_util.extracted_array_2d_from( array_2d=np.array(array.native), y0=self.region[0] - buffer, @@ -269,3 +275,53 @@ def array_2d_from(self, array: Array2D, buffer: int = 1) -> Array2D: arr = array_2d_util.convert_array_2d(array_2d=extracted_array_2d, mask_2d=mask) return Array2D(values=arr, mask=mask, header=array.header) + + def array_2d_rgb_from(self, array: Array2DRGB, buffer: int = 1) -> Array2DRGB: + """ + Extract the 2D region of an RGB array corresponding to the rectangle encompassing all unmasked values. + + This works the same as the `array_2d_from` method, but for RGB arrays, meaning that it iterates over the three + channels of the RGB array and extracts the region for each channel separately. + + This is used to extract and visualize only the region of an RGB image that is used in an analysis. + + Parameters + ---------- + buffer + The number pixels around the extracted array used as a buffer. + """ + from autoarray.structures.arrays.rgb import Array2DRGB + from autoarray.mask.mask_2d import Mask2D + + for i in range(3): + + extracted_array_2d = array_2d_util.extracted_array_2d_from( + array_2d=np.array(array.native[:, :, i]), + y0=self.region[0] - buffer, + y1=self.region[1] + buffer, + x0=self.region[2] - buffer, + x1=self.region[3] + buffer, + ) + + if i == 0: + array_2d_rgb = np.zeros( + (extracted_array_2d.shape[0], extracted_array_2d.shape[1], 3) + ) + + array_2d_rgb[:, :, i] = extracted_array_2d + + extracted_mask_2d = array_2d_util.extracted_array_2d_from( + array_2d=np.array(self.mask), + y0=self.region[0] - buffer, + y1=self.region[1] + buffer, + x0=self.region[2] - buffer, + x1=self.region[3] + buffer, + ) + + mask = Mask2D( + mask=extracted_mask_2d, + pixel_scales=array.pixel_scales, + origin=array.mask.mask_centre, + ) + + return Array2DRGB(values=array_2d_rgb.astype("int"), mask=mask) diff --git a/autoarray/plot/mat_plot/two_d.py b/autoarray/plot/mat_plot/two_d.py index e000a91a2..b1f77bc15 100644 --- a/autoarray/plot/mat_plot/two_d.py +++ b/autoarray/plot/mat_plot/two_d.py @@ -14,6 +14,7 @@ from autoarray.plot.visuals.two_d import Visuals2D from autoarray.mask.derive.zoom_2d import Zoom2D from autoarray.structures.arrays.uniform_2d import Array2D +from autoarray.structures.arrays.rgb import Array2DRGB from autoarray.structures.arrays import array_2d_util @@ -243,7 +244,6 @@ def plot_array( bypass If `True`, `plt.close` is omitted and the matplotlib figure remains open. This is used when making subplots. """ - if array is None or np.all(array == 0): return @@ -280,14 +280,25 @@ def plot_array( origin = conf.instance["visualize"]["general"]["general"]["imshow_origin"] - plt.imshow( - X=array.native.array, - aspect=aspect, - cmap=self.cmap.cmap, - norm=norm, - extent=extent, - origin=origin, - ) + if isinstance(array, Array2DRGB): + + plt.imshow( + X=array.native.array, + aspect=aspect, + extent=extent, + origin=origin, + ) + + else: + + plt.imshow( + X=array.native.array, + aspect=aspect, + cmap=self.cmap.cmap, + norm=norm, + extent=extent, + origin=origin, + ) if visuals_2d.array_overlay is not None: self.array_overlay.overlay_array( @@ -317,7 +328,12 @@ def plot_array( pixels=array.shape_native[1], ) - self.title.set(auto_title=auto_labels.title, use_log10=self.use_log10) + if isinstance(array, Array2DRGB): + title = "RGB" + else: + title = auto_labels.title + + self.title.set(auto_title=title, use_log10=self.use_log10) self.ylabel.set() self.xlabel.set() @@ -332,6 +348,7 @@ def plot_array( [annotate.set() for annotate in self.annotate] if self.colorbar is not False: + cb = self.colorbar.set( units=self.units, ax=ax, diff --git a/autoarray/structures/arrays/rgb.py b/autoarray/structures/arrays/rgb.py new file mode 100644 index 000000000..8e2171f23 --- /dev/null +++ b/autoarray/structures/arrays/rgb.py @@ -0,0 +1,45 @@ +from autoarray.abstract_ndarray import AbstractNDArray +from autoarray.structures.arrays.uniform_2d import Array2D + + +class Array2DRGB(Array2D): + + def __init__(self, values, mask): + """ + A container for RGB images which have a final dimension of 3, which allows them to be visualized using + the same functionality as `Array2D` objects. + + By passing an RGB image to this class, the following visualization functionality is used when the RGB + image is used in `Plotter` objects: + + - The RGB image is plotted using the `imshow` function of Matplotlib. + - Functionality which sets the scale of the axis, zooms the image, and sets the axis limits is used. + - The colorbar is set to the RGB image, which is a 3D array with a final dimension of 3. + - The formatting of the image is identical to that of `Array2D` objects, which means the image is plotted + with the same aspect ratio as the original image making for easy subplot formatting. + + This class always assumes the array is in its `native` representation, but with a final dimension of 3. + + Parameters + ---------- + values + The values of the RGB image, which is a 3D array with a final dimension of 3. + mask + The 2D mask associated with the array, defining the pixels each array value in its ``slim`` representation + is paired with. + """ + + array = values + + while isinstance(array, AbstractNDArray): + array = array.array + + self._array = array + self.mask = mask + + @property + def native(self) -> "Array2D": + """ + Returns the RGB ndarray of shape [total_y_pixels, total_x_pixels, 3] in its `native` representation. + """ + return self diff --git a/test_autoarray/conftest.py b/test_autoarray/conftest.py index 1dbe19e19..34ba8e48b 100644 --- a/test_autoarray/conftest.py +++ b/test_autoarray/conftest.py @@ -69,6 +69,11 @@ def make_array_2d_7x7(): return fixtures.make_array_2d_7x7() +@pytest.fixture(name="array_2d_rgb_7x7") +def make_array_2d_rgb_7x7(): + return fixtures.make_array_2d_rgb_7x7() + + @pytest.fixture(name="layout_2d_7x7") def make_layout_2d_7x7(): return fixtures.make_layout_2d_7x7() diff --git a/test_autoarray/structures/plot/test_structure_plotters.py b/test_autoarray/structures/plot/test_structure_plotters.py index ad1ca0251..45fff5bf1 100644 --- a/test_autoarray/structures/plot/test_structure_plotters.py +++ b/test_autoarray/structures/plot/test_structure_plotters.py @@ -167,3 +167,20 @@ def test__grid( grid_2d_plotter.figure_2d(color_array=color_array) assert path.join(plot_path, "grid3.png") in plot_patch.paths + + +def test__array_rgb( + array_2d_rgb_7x7, + plot_path, + plot_patch, +): + array_plotter = aplt.Array2DPlotter( + array=array_2d_rgb_7x7, + mat_plot_2d=aplt.MatPlot2D( + output=aplt.Output(path=plot_path, filename="array_rgb", format="png") + ), + ) + + array_plotter.figure_2d() + + assert path.join(plot_path, "array_rgb.png") in plot_patch.paths