diff --git a/autoarray/config/visualize/include.yaml b/autoarray/config/visualize/include.yaml deleted file mode 100644 index f010d9f8b..000000000 --- a/autoarray/config/visualize/include.yaml +++ /dev/null @@ -1,20 +0,0 @@ -# The `include` settings customize every feature that appears on plotted images by default (e.g. a mask, the -# coordinate system's origin, etc.). - -# For example, if `include_2d -> mask:true`, the mask will not be plotted on any applicable figure by default. - -include_1d: - mask: false # Include a Mask ? - origin: false # Include the (x,) origin of the data's coordinate system ? -include_2d: - border: false # Include the border of the mask (all pixels on the outside of the mask) ? - grid: false # Include the data's 2D grid of (y,x) coordinates ? - mapper_image_plane_mesh_grid: false # For an Inversion, include the pixel centres computed in the image-plane / data frame? - mapper_source_plane_data_grid: false # For an Inversion, include the centres of the image-plane grid mapped to the source-plane / frame in source-plane figures? - mapper_source_plane_mesh_grid: false # For an Inversion, include the centres of the mesh pixels in the source-plane / source-plane? - mask: true # Include a mask ? - origin: false # Include the (y,x) origin of the data's coordinate system ? - positions: true # Include (y,x) coordinates specified via `Visuals2d.positions` ? - parallel_overscan: true - serial_overscan: true - serial_prescan: true \ No newline at end of file diff --git a/autoarray/dataset/grids.py b/autoarray/dataset/grids.py index b46abeb79..d97fd3f4d 100644 --- a/autoarray/dataset/grids.py +++ b/autoarray/dataset/grids.py @@ -3,11 +3,9 @@ from autoarray.mask.mask_2d import Mask2D from autoarray.structures.arrays.uniform_2d import Array2D from autoarray.structures.arrays.kernel_2d import Kernel2D -from autoarray.structures.grids.uniform_1d import Grid1D from autoarray.structures.grids.uniform_2d import Grid2D from autoarray.inversion.pixelization.border_relocator import BorderRelocator -from autoconf import cached_property from autoarray import exc diff --git a/autoarray/dataset/plot/imaging_plotters.py b/autoarray/dataset/plot/imaging_plotters.py index 0adcd6d11..e0c0772e3 100644 --- a/autoarray/dataset/plot/imaging_plotters.py +++ b/autoarray/dataset/plot/imaging_plotters.py @@ -2,21 +2,18 @@ from typing import Callable, Optional from autoarray.plot.visuals.two_d import Visuals2D -from autoarray.plot.include.two_d import Include2D from autoarray.plot.mat_plot.two_d import MatPlot2D from autoarray.plot.auto_labels import AutoLabels -from autoarray.plot.abstract_plotters import Plotter +from autoarray.plot.abstract_plotters import AbstractPlotter from autoarray.dataset.imaging.dataset import Imaging -class ImagingPlotterMeta(Plotter): +class ImagingPlotterMeta(AbstractPlotter): def __init__( self, dataset: Imaging, - get_visuals_2d: Callable, mat_plot_2d: MatPlot2D = None, visuals_2d: Visuals2D = None, - include_2d: Include2D = None, ): """ Plots the attributes of `Imaging` objects using the matplotlib method `imshow()` and many other matplotlib @@ -27,29 +24,21 @@ def __init__( but a user can manually input values into `MatPlot2d` to customize the figure's appearance. Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `Imaging` and plotted via the visuals object, if the corresponding entry is `True` in the `Include2D` - object or the `config/visualize/include.ini` file. + the `Imaging` and plotted via the visuals object. Parameters ---------- dataset The imaging dataset the plotter plots. - get_visuals_2d - A function which extracts from the `Imaging` the 2D visuals which are plotted on figures. mat_plot_2d Contains objects which wrap the matplotlib function calls that make 2D plots. visuals_2d Contains 2D visuals that can be overlaid on 2D plots. - include_2d - Specifies which attributes of the `Imaging` are extracted and plotted as visuals for 2D plots. """ - super().__init__( - mat_plot_2d=mat_plot_2d, include_2d=include_2d, visuals_2d=visuals_2d - ) + super().__init__(mat_plot_2d=mat_plot_2d, visuals_2d=visuals_2d) self.dataset = dataset - self.get_visuals_2d = get_visuals_2d @property def imaging(self): @@ -91,21 +80,21 @@ def figures_2d( if data: self.mat_plot_2d.plot_array( array=self.dataset.data, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels(title=title_str or f" Data", filename="data"), ) if noise_map: self.mat_plot_2d.plot_array( array=self.dataset.noise_map, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels(title_str or f"Noise-Map", filename="noise_map"), ) if psf: self.mat_plot_2d.plot_array( array=self.dataset.psf, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title=title_str or f"Point Spread Function", filename="psf", @@ -116,7 +105,7 @@ def figures_2d( if signal_to_noise_map: self.mat_plot_2d.plot_array( array=self.dataset.signal_to_noise_map, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title=title_str or f"Signal-To-Noise Map", filename="signal_to_noise_map", @@ -127,7 +116,7 @@ def figures_2d( if over_sample_size_lp: self.mat_plot_2d.plot_array( array=self.dataset.grids.over_sample_size_lp, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title=title_str or f"Over Sample Size (Light Profiles)", filename="over_sample_size_lp", @@ -138,7 +127,7 @@ def figures_2d( if over_sample_size_pixelization: self.mat_plot_2d.plot_array( array=self.dataset.grids.over_sample_size_pixelization, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title=title_str or f"Over Sample Size (Pixelization)", filename="over_sample_size_pixelization", @@ -227,13 +216,12 @@ def subplot_dataset(self): self.mat_plot_2d.use_log10 = use_log10_original -class ImagingPlotter(Plotter): +class ImagingPlotter(AbstractPlotter): def __init__( self, dataset: Imaging, mat_plot_2d: MatPlot2D = None, visuals_2d: Visuals2D = None, - include_2d: Include2D = None, ): """ Plots the attributes of `Imaging` objects using the matplotlib method `imshow()` and many other matplotlib @@ -244,8 +232,7 @@ def __init__( but a user can manually input values into `MatPlot2d` to customize the figure's appearance. Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `Imaging` and plotted via the visuals object, if the corresponding entry is `True` in the `Include2D` - object or the `config/visualize/include.ini` file. + the `Imaging` and plotted via the visuals object. Parameters ---------- @@ -255,27 +242,18 @@ def __init__( Contains objects which wrap the matplotlib function calls that make 2D plots. visuals_2d Contains 2D visuals that can be overlaid on 2D plots. - include_2d - Specifies which attributes of the `Imaging` are extracted and plotted as visuals for 2D plots. """ - super().__init__( - mat_plot_2d=mat_plot_2d, include_2d=include_2d, visuals_2d=visuals_2d - ) + super().__init__(mat_plot_2d=mat_plot_2d, visuals_2d=visuals_2d) self.dataset = dataset self._imaging_meta_plotter = ImagingPlotterMeta( dataset=self.dataset, - get_visuals_2d=self.get_visuals_2d, mat_plot_2d=self.mat_plot_2d, - include_2d=self.include_2d, visuals_2d=self.visuals_2d, ) self.figures_2d = self._imaging_meta_plotter.figures_2d self.subplot = self._imaging_meta_plotter.subplot self.subplot_dataset = self._imaging_meta_plotter.subplot_dataset - - def get_visuals_2d(self): - return self.get_2d.via_mask_from(mask=self.dataset.mask) diff --git a/autoarray/dataset/plot/interferometer_plotters.py b/autoarray/dataset/plot/interferometer_plotters.py index a50ef64b9..e69f53d38 100644 --- a/autoarray/dataset/plot/interferometer_plotters.py +++ b/autoarray/dataset/plot/interferometer_plotters.py @@ -1,298 +1,284 @@ -from autoarray.plot.abstract_plotters import Plotter -from autoarray.plot.visuals.one_d import Visuals1D -from autoarray.plot.visuals.two_d import Visuals2D -from autoarray.plot.include.one_d import Include1D -from autoarray.plot.include.two_d import Include2D -from autoarray.plot.mat_plot.one_d import MatPlot1D -from autoarray.plot.mat_plot.two_d import MatPlot2D -from autoarray.plot.auto_labels import AutoLabels -from autoarray.dataset.interferometer.dataset import Interferometer -from autoarray.structures.grids.irregular_2d import Grid2DIrregular - - -class InterferometerPlotter(Plotter): - def __init__( - self, - dataset: Interferometer, - mat_plot_1d: MatPlot1D = None, - visuals_1d: Visuals1D = None, - include_1d: Include1D = None, - mat_plot_2d: MatPlot2D = None, - visuals_2d: Visuals2D = None, - include_2d: Include2D = None, - ): - """ - Plots the attributes of `Interferometer` objects using the matplotlib methods `plot()`, `scatter()` and - `imshow()` and other matplotlib functions which customize the plot's appearance. - - The `mat_plot_1d` and `mat_plot_2d` attributes wrap matplotlib function calls to make the figure. By default, - the settings passed to every matplotlib function called are those specified in - the `config/visualize/mat_wrap/*.ini` files, but a user can manually input values into `MatPlot2d` to - customize the figure's appearance. - - Overlaid on the figure are visuals, contained in the `Visuals1D` and `Visuals2D` objects. Attributes may be - extracted from the `LightProfile` and plotted via the visuals object, if the corresponding entry is `True` in - the `Include1D` or `Include2D` object or the `config/visualize/include.ini` file. - - Parameters - ---------- - dataset - The interferometer dataset the plotter plots. - mat_plot_1d - Contains objects which wrap the matplotlib function calls that make 1D plots. - visuals_1d - Contains 1D visuals that can be overlaid on 1D plots. - include_1d - Specifies which attributes of the `Interferometer` are extracted and plotted as visuals for 1D plots. - mat_plot_2d - Contains objects which wrap the matplotlib function calls that make 2D plots. - visuals_2d - Contains 2D visuals that can be overlaid on 2D plots. - include_2d - Specifies which attributes of the `Interferometer` are extracted and plotted as visuals for 2D plots. - """ - self.dataset = dataset - - super().__init__( - mat_plot_1d=mat_plot_1d, - include_1d=include_1d, - visuals_1d=visuals_1d, - mat_plot_2d=mat_plot_2d, - include_2d=include_2d, - visuals_2d=visuals_2d, - ) - - @property - def interferometer(self): - return self.dataset - - def get_visuals_2d_real_space(self): - return self.get_2d.via_mask_from(mask=self.dataset.real_space_mask) - - def figures_2d( - self, - data: bool = False, - noise_map: bool = False, - u_wavelengths: bool = False, - v_wavelengths: bool = False, - uv_wavelengths: bool = False, - amplitudes_vs_uv_distances: bool = False, - phases_vs_uv_distances: bool = False, - dirty_image: bool = False, - dirty_noise_map: bool = False, - dirty_signal_to_noise_map: bool = False, - ): - """ - Plots the individual attributes of the plotter's `Interferometer` object in 1D and 2D. - - The API is such that every plottable attribute of the `Interferometer` object is an input parameter of type - bool of the function, which if switched to `True` means that it is plotted. - - Parameters - ---------- - data - Whether to make a 2D plot (via `scatter`) of the visibility data. - noise_map - Whether to make a 2D plot (via `scatter`) of the noise-map. - u_wavelengths - Whether to make a 1D plot (via `plot`) of the u-wavelengths. - v_wavelengths - Whether to make a 1D plot (via `plot`) of the v-wavelengths. - amplitudes_vs_uv_distances - Whether to make a 1D plot (via `plot`) of the amplitudes versis the uv distances. - phases_vs_uv_distances - Whether to make a 1D plot (via `plot`) of the phases versis the uv distances. - dirty_image - Whether to make a 2D plot (via `imshow`) of the dirty image. - dirty_noise_map - Whether to make a 2D plot (via `imshow`) of the dirty noise map. - dirty_signal_to_noise_map - Whether to make a 2D plot (via `imshow`) of the dirty signal-to-noise map. - """ - - if data: - self.mat_plot_2d.plot_grid( - grid=self.dataset.data.in_grid, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels(title="Visibilities", filename="data"), - ) - - if noise_map: - self.mat_plot_2d.plot_grid( - grid=self.dataset.data.in_grid, - visuals_2d=self.visuals_2d, - color_array=self.dataset.noise_map.real, - auto_labels=AutoLabels(title="Noise-Map", filename="noise_map"), - ) - - if u_wavelengths: - self.mat_plot_1d.plot_yx( - y=self.dataset.uv_wavelengths[:, 0], - x=None, - visuals_1d=self.visuals_1d, - auto_labels=AutoLabels( - title="U-Wavelengths", - filename="u_wavelengths", - ylabel="Wavelengths", - ), - plot_axis_type_override="linear", - ) - - if v_wavelengths: - self.mat_plot_1d.plot_yx( - y=self.dataset.uv_wavelengths[:, 1], - x=None, - visuals_1d=self.visuals_1d, - auto_labels=AutoLabels( - title="V-Wavelengths", - filename="v_wavelengths", - ylabel="Wavelengths", - ), - plot_axis_type_override="linear", - ) - - if uv_wavelengths: - self.mat_plot_2d.plot_grid( - grid=Grid2DIrregular.from_yx_1d( - y=self.dataset.uv_wavelengths[:, 1] / 10**3.0, - x=self.dataset.uv_wavelengths[:, 0] / 10**3.0, - ), - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title="UV-Wavelengths", filename="uv_wavelengths" - ), - ) - - if amplitudes_vs_uv_distances: - self.mat_plot_1d.plot_yx( - y=self.dataset.amplitudes, - x=self.dataset.uv_distances / 10**3.0, - visuals_1d=self.visuals_1d, - auto_labels=AutoLabels( - title="Amplitudes vs UV-distances", - filename="amplitudes_vs_uv_distances", - yunit="Jy", - xunit="k$\lambda$", - ), - plot_axis_type_override="scatter", - ) - - if phases_vs_uv_distances: - self.mat_plot_1d.plot_yx( - y=self.dataset.phases, - x=self.dataset.uv_distances / 10**3.0, - visuals_1d=self.visuals_1d, - auto_labels=AutoLabels( - title="Phases vs UV-distances", - filename="phases_vs_uv_distances", - yunit="deg", - xunit="k$\lambda$", - ), - plot_axis_type_override="scatter", - ) - - if dirty_image: - self.mat_plot_2d.plot_array( - array=self.dataset.dirty_image, - visuals_2d=self.get_visuals_2d_real_space(), - auto_labels=AutoLabels(title="Dirty Image", filename="dirty_image"), - ) - - if dirty_noise_map: - self.mat_plot_2d.plot_array( - array=self.dataset.dirty_noise_map, - visuals_2d=self.get_visuals_2d_real_space(), - auto_labels=AutoLabels( - title="Dirty Noise Map", filename="dirty_noise_map" - ), - ) - - if dirty_signal_to_noise_map: - self.mat_plot_2d.plot_array( - array=self.dataset.dirty_signal_to_noise_map, - visuals_2d=self.get_visuals_2d_real_space(), - auto_labels=AutoLabels( - title="Dirty Signal-To-Noise Map", - filename="dirty_signal_to_noise_map", - ), - ) - - def subplot( - self, - data: bool = False, - noise_map: bool = False, - u_wavelengths: bool = False, - v_wavelengths: bool = False, - uv_wavelengths: bool = False, - amplitudes_vs_uv_distances: bool = False, - phases_vs_uv_distances: bool = False, - dirty_image: bool = False, - dirty_noise_map: bool = False, - dirty_signal_to_noise_map: bool = False, - auto_filename: str = "subplot_dataset", - ): - """ - Plots the individual attributes of the plotter's `Interferometer` object in 1D and 2D on a subplot. - - The API is such that every plottable attribute of the `Interferometer` object is an input parameter of type - bool of the function, which if switched to `True` means that it is included on the subplot. - - Parameters - ---------- - data - Whether to include a 2D plot (via `scatter`) of the visibility data. - noise_map - Whether to include a 2D plot (via `scatter`) of the noise-map. - u_wavelengths - Whether to include a 1D plot (via `plot`) of the u-wavelengths. - v_wavelengths - Whether to include a 1D plot (via `plot`) of the v-wavelengths. - amplitudes_vs_uv_distances - Whether to include a 1D plot (via `plot`) of the amplitudes versis the uv distances. - phases_vs_uv_distances - Whether to include a 1D plot (via `plot`) of the phases versis the uv distances. - dirty_image - Whether to include a 2D plot (via `imshow`) of the dirty image. - dirty_noise_map - Whether to include a 2D plot (via `imshow`) of the dirty noise map. - dirty_signal_to_noise_map - Whether to include a 2D plot (via `imshow`) of the dirty signal-to-noise map. - """ - self._subplot_custom_plot( - data=data, - noise_map=noise_map, - u_wavelengths=u_wavelengths, - v_wavelengths=v_wavelengths, - uv_wavelengths=uv_wavelengths, - amplitudes_vs_uv_distances=amplitudes_vs_uv_distances, - phases_vs_uv_distances=phases_vs_uv_distances, - dirty_image=dirty_image, - dirty_noise_map=dirty_noise_map, - dirty_signal_to_noise_map=dirty_signal_to_noise_map, - auto_labels=AutoLabels(filename=auto_filename), - ) - - def subplot_dataset(self): - """ - Standard subplot of the attributes of the plotter's `Interferometer` object. - """ - return self.subplot( - data=True, - uv_wavelengths=True, - amplitudes_vs_uv_distances=True, - phases_vs_uv_distances=True, - dirty_image=True, - dirty_signal_to_noise_map=True, - auto_filename="subplot_dataset", - ) - - def subplot_dirty_images(self): - """ - Standard subplot of the dirty attributes of the plotter's `Interferometer` object. - """ - return self.subplot( - dirty_image=True, - dirty_noise_map=True, - dirty_signal_to_noise_map=True, - auto_filename="subplot_dirty_images", - ) +from autoarray.plot.abstract_plotters import AbstractPlotter +from autoarray.plot.visuals.one_d import Visuals1D +from autoarray.plot.visuals.two_d import Visuals2D +from autoarray.plot.mat_plot.one_d import MatPlot1D +from autoarray.plot.mat_plot.two_d import MatPlot2D +from autoarray.plot.auto_labels import AutoLabels +from autoarray.dataset.interferometer.dataset import Interferometer +from autoarray.structures.grids.irregular_2d import Grid2DIrregular + + +class InterferometerPlotter(AbstractPlotter): + def __init__( + self, + dataset: Interferometer, + mat_plot_1d: MatPlot1D = None, + visuals_1d: Visuals1D = None, + mat_plot_2d: MatPlot2D = None, + visuals_2d: Visuals2D = None, + ): + """ + Plots the attributes of `Interferometer` objects using the matplotlib methods `plot()`, `scatter()` and + `imshow()` and other matplotlib functions which customize the plot's appearance. + + The `mat_plot_1d` and `mat_plot_2d` attributes wrap matplotlib function calls to make the figure. By default, + the settings passed to every matplotlib function called are those specified in + the `config/visualize/mat_wrap/*.ini` files, but a user can manually input values into `MatPlot2d` to + customize the figure's appearance. + + Overlaid on the figure are visuals, contained in the `Visuals1D` and `Visuals2D` objects. Attributes may be + extracted from the `LightProfile` and plotted via the visuals object. + + Parameters + ---------- + dataset + The interferometer dataset the plotter plots. + mat_plot_1d + Contains objects which wrap the matplotlib function calls that make 1D plots. + visuals_1d + Contains 1D visuals that can be overlaid on 1D plots. + mat_plot_2d + Contains objects which wrap the matplotlib function calls that make 2D plots. + visuals_2d + Contains 2D visuals that can be overlaid on 2D plots. + """ + self.dataset = dataset + + super().__init__( + mat_plot_1d=mat_plot_1d, + visuals_1d=visuals_1d, + mat_plot_2d=mat_plot_2d, + visuals_2d=visuals_2d, + ) + + @property + def interferometer(self): + return self.dataset + + def figures_2d( + self, + data: bool = False, + noise_map: bool = False, + u_wavelengths: bool = False, + v_wavelengths: bool = False, + uv_wavelengths: bool = False, + amplitudes_vs_uv_distances: bool = False, + phases_vs_uv_distances: bool = False, + dirty_image: bool = False, + dirty_noise_map: bool = False, + dirty_signal_to_noise_map: bool = False, + ): + """ + Plots the individual attributes of the plotter's `Interferometer` object in 1D and 2D. + + The API is such that every plottable attribute of the `Interferometer` object is an input parameter of type + bool of the function, which if switched to `True` means that it is plotted. + + Parameters + ---------- + data + Whether to make a 2D plot (via `scatter`) of the visibility data. + noise_map + Whether to make a 2D plot (via `scatter`) of the noise-map. + u_wavelengths + Whether to make a 1D plot (via `plot`) of the u-wavelengths. + v_wavelengths + Whether to make a 1D plot (via `plot`) of the v-wavelengths. + amplitudes_vs_uv_distances + Whether to make a 1D plot (via `plot`) of the amplitudes versis the uv distances. + phases_vs_uv_distances + Whether to make a 1D plot (via `plot`) of the phases versis the uv distances. + dirty_image + Whether to make a 2D plot (via `imshow`) of the dirty image. + dirty_noise_map + Whether to make a 2D plot (via `imshow`) of the dirty noise map. + dirty_signal_to_noise_map + Whether to make a 2D plot (via `imshow`) of the dirty signal-to-noise map. + """ + + if data: + self.mat_plot_2d.plot_grid( + grid=self.dataset.data.in_grid, + visuals_2d=self.visuals_2d, + auto_labels=AutoLabels(title="Visibilities", filename="data"), + ) + + if noise_map: + self.mat_plot_2d.plot_grid( + grid=self.dataset.data.in_grid, + visuals_2d=self.visuals_2d, + color_array=self.dataset.noise_map.real, + auto_labels=AutoLabels(title="Noise-Map", filename="noise_map"), + ) + + if u_wavelengths: + self.mat_plot_1d.plot_yx( + y=self.dataset.uv_wavelengths[:, 0], + x=None, + visuals_1d=self.visuals_1d, + auto_labels=AutoLabels( + title="U-Wavelengths", + filename="u_wavelengths", + ylabel="Wavelengths", + ), + plot_axis_type_override="linear", + ) + + if v_wavelengths: + self.mat_plot_1d.plot_yx( + y=self.dataset.uv_wavelengths[:, 1], + x=None, + visuals_1d=self.visuals_1d, + auto_labels=AutoLabels( + title="V-Wavelengths", + filename="v_wavelengths", + ylabel="Wavelengths", + ), + plot_axis_type_override="linear", + ) + + if uv_wavelengths: + self.mat_plot_2d.plot_grid( + grid=Grid2DIrregular.from_yx_1d( + y=self.dataset.uv_wavelengths[:, 1] / 10**3.0, + x=self.dataset.uv_wavelengths[:, 0] / 10**3.0, + ), + visuals_2d=self.visuals_2d, + auto_labels=AutoLabels( + title="UV-Wavelengths", filename="uv_wavelengths" + ), + ) + + if amplitudes_vs_uv_distances: + self.mat_plot_1d.plot_yx( + y=self.dataset.amplitudes, + x=self.dataset.uv_distances / 10**3.0, + visuals_1d=self.visuals_1d, + auto_labels=AutoLabels( + title="Amplitudes vs UV-distances", + filename="amplitudes_vs_uv_distances", + yunit="Jy", + xunit="k$\lambda$", + ), + plot_axis_type_override="scatter", + ) + + if phases_vs_uv_distances: + self.mat_plot_1d.plot_yx( + y=self.dataset.phases, + x=self.dataset.uv_distances / 10**3.0, + visuals_1d=self.visuals_1d, + auto_labels=AutoLabels( + title="Phases vs UV-distances", + filename="phases_vs_uv_distances", + yunit="deg", + xunit="k$\lambda$", + ), + plot_axis_type_override="scatter", + ) + + if dirty_image: + self.mat_plot_2d.plot_array( + array=self.dataset.dirty_image, + visuals_2d=self.visuals_2d, + auto_labels=AutoLabels(title="Dirty Image", filename="dirty_image"), + ) + + if dirty_noise_map: + self.mat_plot_2d.plot_array( + array=self.dataset.dirty_noise_map, + visuals_2d=self.visuals_2d, + auto_labels=AutoLabels( + title="Dirty Noise Map", filename="dirty_noise_map" + ), + ) + + if dirty_signal_to_noise_map: + self.mat_plot_2d.plot_array( + array=self.dataset.dirty_signal_to_noise_map, + visuals_2d=self.visuals_2d, + auto_labels=AutoLabels( + title="Dirty Signal-To-Noise Map", + filename="dirty_signal_to_noise_map", + ), + ) + + def subplot( + self, + data: bool = False, + noise_map: bool = False, + u_wavelengths: bool = False, + v_wavelengths: bool = False, + uv_wavelengths: bool = False, + amplitudes_vs_uv_distances: bool = False, + phases_vs_uv_distances: bool = False, + dirty_image: bool = False, + dirty_noise_map: bool = False, + dirty_signal_to_noise_map: bool = False, + auto_filename: str = "subplot_dataset", + ): + """ + Plots the individual attributes of the plotter's `Interferometer` object in 1D and 2D on a subplot. + + The API is such that every plottable attribute of the `Interferometer` object is an input parameter of type + bool of the function, which if switched to `True` means that it is included on the subplot. + + Parameters + ---------- + data + Whether to include a 2D plot (via `scatter`) of the visibility data. + noise_map + Whether to include a 2D plot (via `scatter`) of the noise-map. + u_wavelengths + Whether to include a 1D plot (via `plot`) of the u-wavelengths. + v_wavelengths + Whether to include a 1D plot (via `plot`) of the v-wavelengths. + amplitudes_vs_uv_distances + Whether to include a 1D plot (via `plot`) of the amplitudes versis the uv distances. + phases_vs_uv_distances + Whether to include a 1D plot (via `plot`) of the phases versis the uv distances. + dirty_image + Whether to include a 2D plot (via `imshow`) of the dirty image. + dirty_noise_map + Whether to include a 2D plot (via `imshow`) of the dirty noise map. + dirty_signal_to_noise_map + Whether to include a 2D plot (via `imshow`) of the dirty signal-to-noise map. + """ + self._subplot_custom_plot( + data=data, + noise_map=noise_map, + u_wavelengths=u_wavelengths, + v_wavelengths=v_wavelengths, + uv_wavelengths=uv_wavelengths, + amplitudes_vs_uv_distances=amplitudes_vs_uv_distances, + phases_vs_uv_distances=phases_vs_uv_distances, + dirty_image=dirty_image, + dirty_noise_map=dirty_noise_map, + dirty_signal_to_noise_map=dirty_signal_to_noise_map, + auto_labels=AutoLabels(filename=auto_filename), + ) + + def subplot_dataset(self): + """ + Standard subplot of the attributes of the plotter's `Interferometer` object. + """ + return self.subplot( + data=True, + uv_wavelengths=True, + amplitudes_vs_uv_distances=True, + phases_vs_uv_distances=True, + dirty_image=True, + dirty_signal_to_noise_map=True, + auto_filename="subplot_dataset", + ) + + def subplot_dirty_images(self): + """ + Standard subplot of the dirty attributes of the plotter's `Interferometer` object. + """ + return self.subplot( + dirty_image=True, + dirty_noise_map=True, + dirty_signal_to_noise_map=True, + auto_filename="subplot_dirty_images", + ) diff --git a/autoarray/fit/plot/fit_imaging_plotters.py b/autoarray/fit/plot/fit_imaging_plotters.py index 217e356d1..86aa0d34d 100644 --- a/autoarray/fit/plot/fit_imaging_plotters.py +++ b/autoarray/fit/plot/fit_imaging_plotters.py @@ -1,21 +1,18 @@ from typing import Callable -from autoarray.plot.abstract_plotters import Plotter +from autoarray.plot.abstract_plotters import AbstractPlotter from autoarray.plot.visuals.two_d import Visuals2D -from autoarray.plot.include.two_d import Include2D from autoarray.plot.mat_plot.two_d import MatPlot2D from autoarray.plot.auto_labels import AutoLabels from autoarray.fit.fit_imaging import FitImaging -class FitImagingPlotterMeta(Plotter): +class FitImagingPlotterMeta(AbstractPlotter): def __init__( self, fit, - get_visuals_2d: Callable, mat_plot_2d: MatPlot2D = None, visuals_2d: Visuals2D = None, - include_2d: Include2D = None, residuals_symmetric_cmap: bool = True, ): """ @@ -27,31 +24,23 @@ def __init__( but a user can manually input values into `MatPlot2d` to customize the figure's appearance. Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `FitImaging` and plotted via the visuals object, if the corresponding entry is `True` in the `Include2D` - object or the `config/visualize/include.ini` file. + the `FitImaging` and plotted via the visuals object. Parameters ---------- fit The fit to an imaging dataset the plotter plots. - get_visuals_2d - A function which extracts from the `FitImaging` the 2D visuals which are plotted on figures. mat_plot_2d Contains objects which wrap the matplotlib function calls that make the plot. visuals_2d Contains visuals that can be overlaid on the plot. - include_2d - Specifies which attributes of the `Array2D` are extracted and plotted as visuals. residuals_symmetric_cmap If true, the `residual_map` and `normalized_residual_map` are plotted with a symmetric color map such that `abs(vmin) = abs(vmax)`. """ - super().__init__( - mat_plot_2d=mat_plot_2d, include_2d=include_2d, visuals_2d=visuals_2d - ) + super().__init__(mat_plot_2d=mat_plot_2d, visuals_2d=visuals_2d) self.fit = fit - self.get_visuals_2d = get_visuals_2d self.residuals_symmetric_cmap = residuals_symmetric_cmap def figures_2d( @@ -95,14 +84,14 @@ def figures_2d( if data: self.mat_plot_2d.plot_array( array=self.fit.data, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels(title="Data", filename=f"data{suffix}"), ) if noise_map: self.mat_plot_2d.plot_array( array=self.fit.noise_map, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Noise-Map", filename=f"noise_map{suffix}" ), @@ -111,7 +100,7 @@ def figures_2d( if signal_to_noise_map: self.mat_plot_2d.plot_array( array=self.fit.signal_to_noise_map, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Signal-To-Noise Map", filename=f"signal_to_noise_map{suffix}" ), @@ -120,7 +109,7 @@ def figures_2d( if model_image: self.mat_plot_2d.plot_array( array=self.fit.model_data, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Model Image", filename=f"model_image{suffix}" ), @@ -134,7 +123,7 @@ def figures_2d( if residual_map: self.mat_plot_2d.plot_array( array=self.fit.residual_map, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Residual Map", filename=f"residual_map{suffix}" ), @@ -143,7 +132,7 @@ def figures_2d( if normalized_residual_map: self.mat_plot_2d.plot_array( array=self.fit.normalized_residual_map, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Normalized Residual Map", filename=f"normalized_residual_map{suffix}", @@ -155,7 +144,7 @@ def figures_2d( if chi_squared_map: self.mat_plot_2d.plot_array( array=self.fit.chi_squared_map, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Chi-Squared Map", filename=f"chi_squared_map{suffix}" ), @@ -164,7 +153,7 @@ def figures_2d( if residual_flux_fraction_map: self.mat_plot_2d.plot_array( array=self.fit.residual_map, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Residual Flux Fraction Map", filename=f"residual_flux_fraction_map{suffix}", @@ -238,13 +227,12 @@ def subplot_fit(self): ) -class FitImagingPlotter(Plotter): +class FitImagingPlotter(AbstractPlotter): def __init__( self, fit: FitImaging, mat_plot_2d: MatPlot2D = None, visuals_2d: Visuals2D = None, - include_2d: Include2D = None, ): """ Plots the attributes of `FitImaging` objects using the matplotlib method `imshow()` and many other matplotlib @@ -255,8 +243,7 @@ def __init__( but a user can manually input values into `MatPlot2d` to customize the figure's appearance. Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `FitImaging` and plotted via the visuals object, if the corresponding entry is `True` in the `Include2D` - object or the `config/visualize/include.ini` file. + the `FitImaging` and plotted via the visuals object. Parameters ---------- @@ -266,26 +253,17 @@ def __init__( Contains objects which wrap the matplotlib function calls that make the plot. visuals_2d Contains visuals that can be overlaid on the plot. - include_2d - Specifies which attributes of the `Array2D` are extracted and plotted as visuals. """ - super().__init__( - mat_plot_2d=mat_plot_2d, include_2d=include_2d, visuals_2d=visuals_2d - ) + super().__init__(mat_plot_2d=mat_plot_2d, visuals_2d=visuals_2d) self.fit = fit self._fit_imaging_meta_plotter = FitImagingPlotterMeta( fit=self.fit, - get_visuals_2d=self.get_visuals_2d, mat_plot_2d=self.mat_plot_2d, - include_2d=self.include_2d, visuals_2d=self.visuals_2d, ) self.figures_2d = self._fit_imaging_meta_plotter.figures_2d self.subplot = self._fit_imaging_meta_plotter.subplot self.subplot_fit = self._fit_imaging_meta_plotter.subplot_fit - - def get_visuals_2d(self) -> Visuals2D: - return self.get_2d.via_fit_imaging_from(fit=self.fit) diff --git a/autoarray/fit/plot/fit_interferometer_plotters.py b/autoarray/fit/plot/fit_interferometer_plotters.py index fa908e569..3ab2bd1e6 100644 --- a/autoarray/fit/plot/fit_interferometer_plotters.py +++ b/autoarray/fit/plot/fit_interferometer_plotters.py @@ -1,28 +1,22 @@ import numpy as np -from typing import Callable -from autoarray.plot.abstract_plotters import Plotter +from autoarray.plot.abstract_plotters import AbstractPlotter from autoarray.plot.visuals.one_d import Visuals1D from autoarray.plot.visuals.two_d import Visuals2D -from autoarray.plot.include.one_d import Include1D -from autoarray.plot.include.two_d import Include2D from autoarray.plot.mat_plot.one_d import MatPlot1D from autoarray.plot.mat_plot.two_d import MatPlot2D from autoarray.plot.auto_labels import AutoLabels from autoarray.fit.fit_interferometer import FitInterferometer -class FitInterferometerPlotterMeta(Plotter): +class FitInterferometerPlotterMeta(AbstractPlotter): def __init__( self, fit, - get_visuals_2d_real_space: Callable, mat_plot_1d: MatPlot1D, visuals_1d: Visuals1D, - include_1d: Include1D, mat_plot_2d: MatPlot2D = None, visuals_2d: Visuals2D = None, - include_2d: Include2D = None, residuals_symmetric_cmap: bool = True, ): """ @@ -35,42 +29,32 @@ def __init__( customize the figure's appearance. Overlaid on the figure are visuals, contained in the `Visuals1D` and `Visuals2D` objects. Attributes may be - extracted from the `FitInterferometer` and plotted via the visuals object, if the corresponding entry is `True` in - the `Include1D` or `Include2D` object or the `config/visualize/include.ini` file. + extracted from the `FitInterferometer` and plotted via the visuals object. Parameters ---------- fit The fit to an interferometer dataset the plotter plots. - get_visuals_2d - A function which extracts from the `FitInterferometer` the 2D visuals which are plotted on figures. mat_plot_1d Contains objects which wrap the matplotlib function calls that make 1D plots. visuals_1d Contains 1D visuals that can be overlaid on 1D plots. - include_1d - Specifies which attributes of the `FitInterferometer` are extracted and plotted as visuals for 1D plots. mat_plot_2d Contains objects which wrap the matplotlib function calls that make 2D plots. visuals_2d Contains 2D visuals that can be overlaid on 2D plots. - include_2d - Specifies which attributes of the `FitInterferometer` are extracted and plotted as visuals for 2D plots. residuals_symmetric_cmap If true, the `residual_map` and `normalized_residual_map` are plotted with a symmetric color map such that `abs(vmin) = abs(vmax)`. """ super().__init__( mat_plot_1d=mat_plot_1d, - include_1d=include_1d, visuals_1d=visuals_1d, mat_plot_2d=mat_plot_2d, - include_2d=include_2d, visuals_2d=visuals_2d, ) self.fit = fit - self.get_visuals_2d_real_space = get_visuals_2d_real_space self.residuals_symmetric_cmap = residuals_symmetric_cmap def figures_2d( @@ -268,14 +252,14 @@ def figures_2d( if dirty_image: self.mat_plot_2d.plot_array( array=self.fit.dirty_image, - visuals_2d=self.get_visuals_2d_real_space(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels(title="Dirty Image", filename="dirty_image"), ) if dirty_noise_map: self.mat_plot_2d.plot_array( array=self.fit.dirty_noise_map, - visuals_2d=self.get_visuals_2d_real_space(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Dirty Noise Map", filename="dirty_noise_map" ), @@ -284,7 +268,7 @@ def figures_2d( if dirty_signal_to_noise_map: self.mat_plot_2d.plot_array( array=self.fit.dirty_signal_to_noise_map, - visuals_2d=self.get_visuals_2d_real_space(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Dirty Signal-To-Noise Map", filename="dirty_signal_to_noise_map", @@ -294,7 +278,7 @@ def figures_2d( if dirty_model_image: self.mat_plot_2d.plot_array( array=self.fit.dirty_model_image, - visuals_2d=self.get_visuals_2d_real_space(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Dirty Model Image", filename="dirty_model_image_2d" ), @@ -308,7 +292,7 @@ def figures_2d( if dirty_residual_map: self.mat_plot_2d.plot_array( array=self.fit.dirty_residual_map, - visuals_2d=self.get_visuals_2d_real_space(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Dirty Residual Map", filename="dirty_residual_map_2d" ), @@ -317,7 +301,7 @@ def figures_2d( if dirty_normalized_residual_map: self.mat_plot_2d.plot_array( array=self.fit.dirty_normalized_residual_map, - visuals_2d=self.get_visuals_2d_real_space(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Dirty Normalized Residual Map", filename="dirty_normalized_residual_map_2d", @@ -330,7 +314,7 @@ def figures_2d( if dirty_chi_squared_map: self.mat_plot_2d.plot_array( array=self.fit.dirty_chi_squared_map, - visuals_2d=self.get_visuals_2d_real_space(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Dirty Chi-Squared Map", filename="dirty_chi_squared_map_2d" ), @@ -451,16 +435,14 @@ def subplot_fit_dirty_images(self): ) -class FitInterferometerPlotter(Plotter): +class FitInterferometerPlotter(AbstractPlotter): def __init__( self, fit: FitInterferometer, mat_plot_1d: MatPlot1D = None, visuals_1d: Visuals1D = None, - include_1d: Include1D = None, mat_plot_2d: MatPlot2D = None, visuals_2d: Visuals2D = None, - include_2d: Include2D = None, ): """ Plots the attributes of `FitInterferometer` objects using the matplotlib method `imshow()` and many other @@ -471,8 +453,7 @@ def __init__( but a user can manually input values into `MatPlot2d` to customize the figure's appearance. Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `FitInterferometer` and plotted via the visuals object, if the corresponding entry is `True` in the `Include2D` - object or the `config/visualize/include.ini` file. + the `FitInterferometer` and plotted via the visuals object. Parameters ---------- @@ -482,15 +463,11 @@ def __init__( Contains objects which wrap the matplotlib function calls that make the plot. visuals_2d Contains visuals that can be overlaid on the plot. - include_2d - Specifies which attributes of the `Array2D` are extracted and plotted as visuals. """ super().__init__( mat_plot_1d=mat_plot_1d, - include_1d=include_1d, visuals_1d=visuals_1d, mat_plot_2d=mat_plot_2d, - include_2d=include_2d, visuals_2d=visuals_2d, ) @@ -498,12 +475,9 @@ def __init__( self._fit_interferometer_meta_plotter = FitInterferometerPlotterMeta( fit=self.fit, - get_visuals_2d_real_space=self.get_visuals_2d_real_space, mat_plot_1d=self.mat_plot_1d, - include_1d=self.include_1d, visuals_1d=self.visuals_1d, mat_plot_2d=self.mat_plot_2d, - include_2d=self.include_2d, visuals_2d=self.visuals_2d, ) @@ -513,6 +487,3 @@ def __init__( self.subplot_fit_dirty_images = ( self._fit_interferometer_meta_plotter.subplot_fit_dirty_images ) - - def get_visuals_2d_real_space(self) -> Visuals2D: - return self.get_2d.via_mask_from(mask=self.fit.dataset.real_space_mask) diff --git a/autoarray/fit/plot/fit_vector_yx_plotters.py b/autoarray/fit/plot/fit_vector_yx_plotters.py index 466d5c8d9..9691e5680 100644 --- a/autoarray/fit/plot/fit_vector_yx_plotters.py +++ b/autoarray/fit/plot/fit_vector_yx_plotters.py @@ -1,22 +1,19 @@ from typing import Callable -from autoarray.plot.abstract_plotters import Plotter +from autoarray.plot.abstract_plotters import AbstractPlotter from autoarray.plot.visuals.two_d import Visuals2D -from autoarray.plot.include.two_d import Include2D from autoarray.plot.mat_plot.two_d import MatPlot2D from autoarray.plot.auto_labels import AutoLabels from autoarray.fit.fit_imaging import FitImaging from autoarray.fit.plot.fit_imaging_plotters import FitImagingPlotterMeta -class FitVectorYXPlotterMeta(Plotter): +class FitVectorYXPlotterMeta(AbstractPlotter): def __init__( self, fit, - get_visuals_2d: Callable, mat_plot_2d: MatPlot2D = None, visuals_2d: Visuals2D = None, - include_2d: Include2D = None, ): """ Plots the attributes of `FitImaging` objects using the matplotlib method `imshow()` and many other matplotlib @@ -27,28 +24,20 @@ def __init__( but a user can manually input values into `MatPlot2d` to customize the figure's appearance. Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `FitImaging` and plotted via the visuals object, if the corresponding entry is `True` in the `Include2D` - object or the `config/visualize/include.ini` file. + the `FitImaging` and plotted via the visuals object. Parameters ---------- fit The fit to an imaging dataset the plotter plots. - get_visuals_2d - A function which extracts from the `FitImaging` the 2D visuals which are plotted on figures. mat_plot_2d Contains objects which wrap the matplotlib function calls that make the plot. visuals_2d Contains visuals that can be overlaid on the plot. - include_2d - Specifies which attributes of the `Array2D` are extracted and plotted as visuals. """ - super().__init__( - mat_plot_2d=mat_plot_2d, include_2d=include_2d, visuals_2d=visuals_2d - ) + super().__init__(mat_plot_2d=mat_plot_2d, visuals_2d=visuals_2d) self.fit = fit - self.get_visuals_2d = get_visuals_2d def figures_2d( self, @@ -84,26 +73,24 @@ def figures_2d( Whether to make a 2D plot (via `imshow`) of the chi-squared map. """ - fit_plotter_y = FitImaging(self.fit.data.y_array) - if image: self.mat_plot_2d.plot_array( array=self.fit.data, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels(title="Data", filename="image_2d"), ) if noise_map: self.mat_plot_2d.plot_array( array=self.fit.noise_map, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels(title="Noise-Map", filename="noise_map"), ) if signal_to_noise_map: self.mat_plot_2d.plot_array( array=self.fit.signal_to_noise_map, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Signal-To-Noise Map", filename="signal_to_noise_map" ), @@ -112,21 +99,21 @@ def figures_2d( if model_image: self.mat_plot_2d.plot_array( array=self.fit.model_data, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels(title="Model Image", filename="model_image"), ) if residual_map: self.mat_plot_2d.plot_array( array=self.fit.residual_map, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels(title="Residual Map", filename="residual_map"), ) if normalized_residual_map: self.mat_plot_2d.plot_array( array=self.fit.normalized_residual_map, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Normalized Residual Map", filename="normalized_residual_map" ), @@ -135,7 +122,7 @@ def figures_2d( if chi_squared_map: self.mat_plot_2d.plot_array( array=self.fit.chi_squared_map, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Chi-Squared Map", filename="chi_squared_map" ), @@ -204,13 +191,12 @@ def subplot_fit(self): ) -class FitImagingPlotter(Plotter): +class FitImagingPlotter(AbstractPlotter): def __init__( self, fit: FitImaging, mat_plot_2d: MatPlot2D = None, visuals_2d: Visuals2D = None, - include_2d: Include2D = None, ): """ Plots the attributes of `FitImaging` objects using the matplotlib method `imshow()` and many other matplotlib @@ -221,8 +207,7 @@ def __init__( but a user can manually input values into `MatPlot2d` to customize the figure's appearance. Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `FitImaging` and plotted via the visuals object, if the corresponding entry is `True` in the `Include2D` - object or the `config/visualize/include.ini` file. + the `FitImaging` and plotted via the visuals object. Parameters ---------- @@ -232,26 +217,17 @@ def __init__( Contains objects which wrap the matplotlib function calls that make the plot. visuals_2d Contains visuals that can be overlaid on the plot. - include_2d - Specifies which attributes of the `Array2D` are extracted and plotted as visuals. """ - super().__init__( - mat_plot_2d=mat_plot_2d, include_2d=include_2d, visuals_2d=visuals_2d - ) + super().__init__(mat_plot_2d=mat_plot_2d, visuals_2d=visuals_2d) self.fit = fit self._fit_imaging_meta_plotter = FitImagingPlotterMeta( fit=self.fit, - get_visuals_2d=self.get_visuals_2d, mat_plot_2d=self.mat_plot_2d, - include_2d=self.include_2d, visuals_2d=self.visuals_2d, ) self.figures_2d = self._fit_imaging_meta_plotter.figures_2d self.subplot = self._fit_imaging_meta_plotter.subplot self.subplot_fit = self._fit_imaging_meta_plotter.subplot_fit - - def get_visuals_2d(self) -> Visuals2D: - return self.get_2d.via_fit_imaging_from(fit=self.fit) diff --git a/autoarray/inversion/pixelization/mappers/abstract.py b/autoarray/inversion/pixelization/mappers/abstract.py index f2f6b03b6..091713af1 100644 --- a/autoarray/inversion/pixelization/mappers/abstract.py +++ b/autoarray/inversion/pixelization/mappers/abstract.py @@ -291,7 +291,7 @@ def pixel_signals_from(self, signal_scale: float) -> np.ndarray: adapt_data=np.array(self.adapt_data), ) - def pix_indexes_for_slim_indexes(self, pix_indexes: List) -> List[List]: + def slim_indexes_for_pix_indexes(self, pix_indexes: List) -> List[List]: """ Returns the index mappings between every masked data-point (not subgridded) on the data and the mapper pixels / parameters that it maps too. @@ -299,7 +299,7 @@ def pix_indexes_for_slim_indexes(self, pix_indexes: List) -> List[List]: The `slim_index` refers to the masked data pixels (without subgridding) and `pix_indexes` the pixelization pixel indexes, for example: - - `pix_indexes_for_slim_indexes[0] = [2, 3]`: The data's first (index 0) pixel maps to the + - `slim_indexes_for_pix_indexes[0] = [2, 3]`: The data's first (index 0) pixel maps to the pixelization's third (index 2) and fourth (index 3) pixels. Parameters diff --git a/autoarray/inversion/plot/inversion_plotters.py b/autoarray/inversion/plot/inversion_plotters.py index 755224913..506388609 100644 --- a/autoarray/inversion/plot/inversion_plotters.py +++ b/autoarray/inversion/plot/inversion_plotters.py @@ -3,9 +3,8 @@ from autoconf import conf from autoarray.inversion.pixelization.mappers.abstract import AbstractMapper -from autoarray.plot.abstract_plotters import Plotter +from autoarray.plot.abstract_plotters import AbstractPlotter from autoarray.plot.visuals.two_d import Visuals2D -from autoarray.plot.include.two_d import Include2D from autoarray.plot.mat_plot.two_d import MatPlot2D from autoarray.plot.auto_labels import AutoLabels from autoarray.structures.arrays.uniform_2d import Array2D @@ -14,13 +13,12 @@ from autoarray.inversion.plot.mapper_plotters import MapperPlotter -class InversionPlotter(Plotter): +class InversionPlotter(AbstractPlotter): def __init__( self, inversion: AbstractInversion, mat_plot_2d: MatPlot2D = None, visuals_2d: Visuals2D = None, - include_2d: Include2D = None, residuals_symmetric_cmap: bool = True, ): """ @@ -32,8 +30,7 @@ def __init__( but a user can manually input values into `MatPlot2d` to customize the figure's appearance. Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `Inversion` and plotted via the visuals object, if the corresponding entry is `True` in the `Include2D` - object or the `config/visualize/include.ini` file. + the `Inversion` and plotted via the visuals object. Parameters ---------- @@ -43,35 +40,12 @@ def __init__( Contains objects which wrap the matplotlib function calls that make 2D plots. visuals_2d Contains 2D visuals that can be overlaid on 2D plots. - include_2d - Specifies which attributes of the `Inversion` are extracted and plotted as visuals for 2D plots. - """ - super().__init__( - mat_plot_2d=mat_plot_2d, include_2d=include_2d, visuals_2d=visuals_2d - ) + super().__init__(mat_plot_2d=mat_plot_2d, visuals_2d=visuals_2d) self.inversion = inversion self.residuals_symmetric_cmap = residuals_symmetric_cmap - def get_visuals_2d_for_data(self) -> Visuals2D: - try: - mapper = self.inversion.cls_list_from(cls=AbstractMapper)[0] - - visuals = self.get_2d.via_mapper_for_data_from(mapper=mapper) - - if self.visuals_2d.pix_indexes is not None: - indexes = mapper.pix_indexes_for_slim_indexes( - pix_indexes=self.visuals_2d.pix_indexes - ) - - visuals.indexes = indexes - - return visuals - - except (AttributeError, IndexError): - return self.visuals_2d - def mapper_plotter_from(self, mapper_index: int) -> MapperPlotter: """ Returns a `MapperPlotter` corresponding to the `Mapper` in the `Inversion`'s `linear_obj_list` given an input @@ -91,7 +65,6 @@ def mapper_plotter_from(self, mapper_index: int) -> MapperPlotter: mapper=self.inversion.cls_list_from(cls=AbstractMapper)[mapper_index], mat_plot_2d=self.mat_plot_2d, visuals_2d=self.visuals_2d, - include_2d=self.include_2d, ) def figures_2d(self, reconstructed_image: bool = False): @@ -109,7 +82,7 @@ def figures_2d(self, reconstructed_image: bool = False): if reconstructed_image: self.mat_plot_2d.plot_array( array=self.inversion.mapped_reconstructed_image, - visuals_2d=self.get_visuals_2d_for_data(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Reconstructed Image", filename="reconstructed_image" ), @@ -183,7 +156,7 @@ def figures_2d_of_pixelization( self.mat_plot_2d.plot_array( array=array, - visuals_2d=self.get_visuals_2d_for_data(), + visuals_2d=self.visuals_2d, grid_indexes=mapper_plotter.mapper.over_sampler.uniform_over_sampled, auto_labels=AutoLabels( title="Data Subtracted", filename="data_subtracted" @@ -199,7 +172,7 @@ def figures_2d_of_pixelization( self.mat_plot_2d.plot_array( array=array, - visuals_2d=self.get_visuals_2d_for_data(), + visuals_2d=self.visuals_2d, grid_indexes=mapper_plotter.mapper.over_sampler.uniform_over_sampled, auto_labels=AutoLabels( title="Reconstructed Image", filename="reconstructed_image" @@ -292,7 +265,7 @@ def figures_2d_of_pixelization( self.mat_plot_2d.plot_array( array=sub_size, - visuals_2d=self.get_visuals_2d_for_data(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Sub Pixels Per Image Pixels", filename="sub_pixels_per_image_pixels", @@ -307,7 +280,7 @@ def figures_2d_of_pixelization( self.mat_plot_2d.plot_array( array=mesh_pixels_per_image_pixels, - visuals_2d=self.get_visuals_2d_for_data(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Mesh Pixels Per Image Pixels", filename="mesh_pixels_per_image_pixels", @@ -350,10 +323,6 @@ def subplot_of_mapper( if self.mat_plot_2d.use_log10: self.mat_plot_2d.contour = False - mapper_image_plane_mesh_grid = self.include_2d._mapper_image_plane_mesh_grid - - self.include_2d._mapper_image_plane_mesh_grid = False - self.figures_2d_of_pixelization( pixelization_index=mapper_index, data_subtracted=True ) @@ -371,15 +340,21 @@ def subplot_of_mapper( self.mat_plot_2d.use_log10 = False - self.include_2d._mapper_image_plane_mesh_grid = mapper_image_plane_mesh_grid - self.include_2d._mapper_image_plane_mesh_grid = True + mapper = self.inversion.cls_list_from(cls=AbstractMapper)[mapper_index] + + self.visuals_2d += Visuals2D( + mesh_grid=mapper.mapper_grids.image_plane_mesh_grid + ) + self.set_title(label="Mesh Pixel Grid Overlaid") self.figures_2d_of_pixelization( pixelization_index=mapper_index, reconstructed_image=True ) self.set_title(label=None) - self.include_2d._mapper_image_plane_mesh_grid = False + self.visuals_2d.mesh_grid = None + + # self.include_2d._mapper_image_plane_mesh_grid = False self.figures_2d_of_pixelization( pixelization_index=mapper_index, reconstruction=True @@ -436,8 +411,6 @@ def subplot_mappings( ): self.open_subplot_figure(number_subplots=4) - self.include_2d._mapper_image_plane_mesh_grid = False - self.figures_2d_of_pixelization( pixelization_index=pixelization_index, data_subtracted=True ) @@ -456,9 +429,9 @@ def subplot_mappings( total_pixels=total_pixels, filter_neighbors=True ) - self.visuals_2d.pix_indexes = [ - [index] for index in pix_indexes[pixelization_index] - ] + indexes = mapper.slim_indexes_for_pix_indexes(pix_indexes=pix_indexes) + + self.visuals_2d.indexes = indexes self.figures_2d_of_pixelization( pixelization_index=pixelization_index, reconstructed_image=True diff --git a/autoarray/inversion/plot/mapper_plotters.py b/autoarray/inversion/plot/mapper_plotters.py index abd220f85..08b53a710 100644 --- a/autoarray/inversion/plot/mapper_plotters.py +++ b/autoarray/inversion/plot/mapper_plotters.py @@ -1,9 +1,8 @@ import numpy as np from typing import Union -from autoarray.plot.abstract_plotters import Plotter +from autoarray.plot.abstract_plotters import AbstractPlotter from autoarray.plot.visuals.two_d import Visuals2D -from autoarray.plot.include.two_d import Include2D from autoarray.plot.mat_plot.two_d import MatPlot2D from autoarray.plot.auto_labels import AutoLabels from autoarray.structures.arrays.uniform_2d import Array2D @@ -16,13 +15,12 @@ logger = logging.getLogger(__name__) -class MapperPlotter(Plotter): +class MapperPlotter(AbstractPlotter): def __init__( self, mapper: MapperRectangular, mat_plot_2d: MatPlot2D = None, visuals_2d: Visuals2D = None, - include_2d: Include2D = None, ): """ Plots the attributes of `Mapper` objects using the matplotlib method `imshow()` and many other matplotlib @@ -33,8 +31,7 @@ def __init__( but a user can manually input values into `MatPlot2d` to customize the figure's appearance. Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `Mapper` and plotted via the visuals object, if the corresponding entry is `True` in the `Include2D` - object or the `config/visualize/include.ini` file. + the `Mapper` and plotted via the visuals object. Parameters ---------- @@ -44,23 +41,13 @@ def __init__( Contains objects which wrap the matplotlib function calls that make 2D plots. visuals_2d Contains 2D visuals that can be overlaid on 2D plots. - include_2d - Specifies which attributes of the `Mapper` are extracted and plotted as visuals for 2D plots. """ - super().__init__( - visuals_2d=visuals_2d, include_2d=include_2d, mat_plot_2d=mat_plot_2d - ) + super().__init__(visuals_2d=visuals_2d, mat_plot_2d=mat_plot_2d) self.mapper = mapper - def get_visuals_2d_for_data(self) -> Visuals2D: - return self.get_2d.via_mapper_for_data_from(mapper=self.mapper) - - def get_visuals_2d_for_source(self) -> Visuals2D: - return self.get_2d.via_mapper_for_source_from(mapper=self.mapper) - def figure_2d( - self, interpolate_to_uniform: bool = True, solution_vector: bool = None + self, interpolate_to_uniform: bool = False, solution_vector: bool = None ): """ Plots the plotter's `Mapper` object in 2D. @@ -76,7 +63,7 @@ def figure_2d( """ self.mat_plot_2d.plot_mapper( mapper=self.mapper, - visuals_2d=self.get_2d.via_mapper_for_source_from(mapper=self.mapper), + visuals_2d=self.visuals_2d, interpolate_to_uniform=interpolate_to_uniform, pixel_values=solution_vector, auto_labels=AutoLabels( @@ -84,8 +71,19 @@ def figure_2d( ), ) + def figure_2d_image(self, image): + + self.mat_plot_2d.plot_array( + array=image, + visuals_2d=self.visuals_2d, + grid_indexes=self.mapper.mapper_grids.image_plane_data_grid.over_sampled, + auto_labels=AutoLabels( + title="Image (Image-Plane)", filename="mapper_image" + ), + ) + def subplot_image_and_mapper( - self, image: Array2D, interpolate_to_uniform: bool = True + self, image: Array2D, interpolate_to_uniform: bool = False ): """ Make a subplot of an input image and the `Mapper`'s source-plane reconstruction. @@ -105,22 +103,7 @@ def subplot_image_and_mapper( """ self.open_subplot_figure(number_subplots=2) - self.mat_plot_2d.plot_array( - array=image, - visuals_2d=self.get_visuals_2d_for_data(), - auto_labels=AutoLabels(title="Image (Image-Plane)"), - ) - - if self.visuals_2d.pix_indexes is not None: - indexes = self.mapper.pix_indexes_for_slim_indexes( - pix_indexes=self.visuals_2d.pix_indexes - ) - - self.mat_plot_2d.index_scatter.scatter_grid_indexes( - grid=self.mapper.over_sampler.uniform_over_sampled, - indexes=indexes, - ) - + self.figure_2d_image(image=image) self.figure_2d(interpolate_to_uniform=interpolate_to_uniform) self.mat_plot_2d.output.subplot_to_figure( @@ -154,7 +137,7 @@ def plot_source_from( try: self.mat_plot_2d.plot_mapper( mapper=self.mapper, - visuals_2d=self.get_visuals_2d_for_source(), + visuals_2d=self.visuals_2d, auto_labels=auto_labels, pixel_values=pixel_values, zoom_to_brightest=zoom_to_brightest, diff --git a/autoarray/plot/__init__.py b/autoarray/plot/__init__.py index 71d7abb45..c45d31702 100644 --- a/autoarray/plot/__init__.py +++ b/autoarray/plot/__init__.py @@ -22,6 +22,7 @@ from autoarray.plot.wrap.two_d.array_overlay import ArrayOverlay from autoarray.plot.wrap.two_d.contour import Contour +from autoarray.plot.wrap.two_d.fill import Fill from autoarray.plot.wrap.two_d.grid_scatter import GridScatter from autoarray.plot.wrap.two_d.grid_plot import GridPlot from autoarray.plot.wrap.two_d.grid_errorbar import GridErrorbar @@ -43,12 +44,8 @@ from autoarray.plot.wrap.two_d.serial_prescan_plot import SerialPrescanPlot from autoarray.plot.wrap.two_d.serial_overscan_plot import SerialOverscanPlot -from autoarray.plot.get_visuals.one_d import GetVisuals1D -from autoarray.plot.get_visuals.two_d import GetVisuals2D from autoarray.plot.mat_plot.one_d import MatPlot1D from autoarray.plot.mat_plot.two_d import MatPlot2D -from autoarray.plot.include.one_d import Include1D -from autoarray.plot.include.two_d import Include2D from autoarray.plot.visuals.one_d import Visuals1D from autoarray.plot.visuals.two_d import Visuals2D from autoarray.plot.auto_labels import AutoLabels diff --git a/autoarray/plot/abstract_plotters.py b/autoarray/plot/abstract_plotters.py index 5752155f0..07ec41354 100644 --- a/autoarray/plot/abstract_plotters.py +++ b/autoarray/plot/abstract_plotters.py @@ -8,12 +8,8 @@ from autoarray.plot.visuals.one_d import Visuals1D from autoarray.plot.visuals.two_d import Visuals2D -from autoarray.plot.include.one_d import Include1D -from autoarray.plot.include.two_d import Include2D from autoarray.plot.mat_plot.one_d import MatPlot1D from autoarray.plot.mat_plot.two_d import MatPlot2D -from autoarray.plot.get_visuals.one_d import GetVisuals1D -from autoarray.plot.get_visuals.two_d import GetVisuals2D class AbstractPlotter: @@ -21,17 +17,13 @@ def __init__( self, mat_plot_1d: MatPlot1D = None, visuals_1d: Visuals1D = None, - include_1d: Include1D = None, mat_plot_2d: MatPlot2D = None, visuals_2d: Visuals2D = None, - include_2d: Include2D = None, ): self.visuals_1d = visuals_1d or Visuals1D() - self.include_1d = include_1d or Include1D() self.mat_plot_1d = mat_plot_1d or MatPlot1D() self.visuals_2d = visuals_2d or Visuals2D() - self.include_2d = include_2d or Include2D() self.mat_plot_2d = mat_plot_2d or MatPlot2D() self.subplot_figsize = None @@ -219,13 +211,3 @@ def subplot_of_plotters_figure(self, plotter_list, name): self.mat_plot_2d.output.subplot_to_figure(auto_filename=f"subplot_{name}") self.close_subplot_figure() - - -class Plotter(AbstractPlotter): - @property - def get_1d(self): - return GetVisuals1D(visuals=self.visuals_1d, include=self.include_1d) - - @property - def get_2d(self): - return GetVisuals2D(visuals=self.visuals_2d, include=self.include_2d) diff --git a/autoarray/plot/get_visuals/__init__.py b/autoarray/plot/get_visuals/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/autoarray/plot/get_visuals/abstract.py b/autoarray/plot/get_visuals/abstract.py deleted file mode 100644 index f149fcf19..000000000 --- a/autoarray/plot/get_visuals/abstract.py +++ /dev/null @@ -1,64 +0,0 @@ -from typing import Optional, Union - -from autoarray.plot.include.one_d import Include1D -from autoarray.plot.include.two_d import Include2D -from autoarray.plot.visuals.one_d import Visuals1D -from autoarray.plot.visuals.two_d import Visuals2D - - -class AbstractGetVisuals: - def __init__( - self, include: Union[Include1D, Include2D], visuals: Union[Visuals1D, Visuals2D] - ): - """ - Class which gets attributes and adds them to a `Visuals` objects, such that they are plotted on figures. - - For a visual to be extracted and added for plotting, it must have a `True` value in its corresponding entry in - the `Include` object. If this entry is `False`, the `GetVisuals.get` method returns a None and the attribute - is omitted from the plot. - - The `GetVisuals` class adds new visuals to a pre-existing `Visuals` object that is passed to its `__init__` - method. This only adds a new entry if the visual are not already in this object. - - Parameters - ---------- - include - Sets which visuals are included on the figure that is to be plotted (only entries which are `True` - are extracted via the `GetVisuals` object). - visuals - The pre-existing visuals of the plotter which new visuals are added too via the `GetVisuals` class. - """ - self.include = include - self.visuals = visuals - - def get(self, name: str, value, include_name: Optional[str] = None): - """ - Get an attribute for plotting in a `Visuals1D` object based on the following criteria: - - 1) If `visuals_1d` already has a value for the attribute this is returned, over-riding the input `value` of - that attribute. - - 2) If `visuals_1d` do not contain the attribute, the input `value` is returned provided its corresponding - entry in the `Include1D` class is `True`. - - 3) If the `Include1D` entry is `False` a None is returned and the attribute is therefore not plotted. - - Parameters - ---------- - name - The name of the attribute which is to be extracted. - value - The `value` of the attribute, which is used when criteria 2 above is met. - - Returns - ------- - The collection of attributes that can be plotted by a `Plotter` object. - """ - - if include_name is None: - include_name = name - - if getattr(self.visuals, name) is not None: - return getattr(self.visuals, name) - elif getattr(self.include, include_name): - return value diff --git a/autoarray/plot/get_visuals/one_d.py b/autoarray/plot/get_visuals/one_d.py deleted file mode 100644 index 449120c2c..000000000 --- a/autoarray/plot/get_visuals/one_d.py +++ /dev/null @@ -1,54 +0,0 @@ -from autoarray.plot.get_visuals.abstract import AbstractGetVisuals -from autoarray.plot.include.one_d import Include1D -from autoarray.plot.visuals.one_d import Visuals1D -from autoarray.structures.arrays.uniform_1d import Array1D - - -class GetVisuals1D(AbstractGetVisuals): - def __init__(self, include: Include1D, visuals: Visuals1D): - """ - Class which gets 1D attributes and adds them to a `Visuals1D` objects, such that they are plotted on 1D figures. - - For a visual to be extracted and added for plotting, it must have a `True` value in its corresponding entry in - the `Include1D` object. If this entry is `False`, the `GetVisuals1D.get` method returns a None and the attribute - is omitted from the plot. - - The `GetVisuals1D` class adds new visuals to a pre-existing `Visuals1D` object that is passed to its `__init__` - method. This only adds a new entry if the visual are not already in this object. - - Parameters - ---------- - include - Sets which 1D visuals are included on the figure that is to be plotted (only entries which are `True` - are extracted via the `GetVisuals1D` object). - visuals - The pre-existing visuals of the plotter which new visuals are added too via the `GetVisuals1D` class. - """ - super().__init__(include=include, visuals=visuals) - - def via_array_1d_from(self, array_1d: Array1D) -> Visuals1D: - """ - From an `Array1D` get its attributes that can be plotted and return them in a `Visuals1D` object. - - Only attributes not already in `self.visuals` and with `True` entries in the `Include1D` object are extracted - for plotting. - - From an `Array1D` the following attributes can be extracted for plotting: - - - origin: the (y,x) origin of the 1D array's coordinate system. - - mask: the mask of the 1D array. - - Parameters - ---------- - array - The 1D array whose attributes are extracted for plotting. - - Returns - ------- - Visuals1D - The collection of attributes that are plotted by a `Plotter` object. - """ - return self.visuals + self.visuals.__class__( - origin=self.get("origin", array_1d.origin), - mask=self.get("mask", array_1d.mask), - ) diff --git a/autoarray/plot/get_visuals/two_d.py b/autoarray/plot/get_visuals/two_d.py deleted file mode 100644 index c2b99a173..000000000 --- a/autoarray/plot/get_visuals/two_d.py +++ /dev/null @@ -1,231 +0,0 @@ -from typing import Union - -from autoarray.fit.fit_imaging import FitImaging -from autoarray.inversion.pixelization.mappers.rectangular import ( - MapperRectangular, -) -from autoarray.mask.mask_2d import Mask2D -from autoarray.plot.get_visuals.abstract import AbstractGetVisuals -from autoarray.plot.include.two_d import Include2D -from autoarray.plot.visuals.two_d import Visuals2D -from autoarray.structures.grids.uniform_2d import Grid2D -from autoarray.structures.grids.irregular_2d import Grid2DIrregular - -from autoarray.type import Grid2DLike - - -class GetVisuals2D(AbstractGetVisuals): - def __init__(self, include: Include2D, visuals: Visuals2D): - """ - Class which gets 2D attributes and adds them to a `Visuals2D` objects, such that they are plotted on 2D figures. - - For a visual to be extracted and added for plotting, it must have a `True` value in its corresponding entry in - the `Include2D` object. If this entry is `False`, the `GetVisuals2D.get` method returns a None and the - attribute is omitted from the plot. - - The `GetVisuals2D` class adds new visuals to a pre-existing `Visuals2D` object that is passed to - its `__init__` method. This only adds a new entry if the visual are not already in this object. - - Parameters - ---------- - include - Sets which 2D visuals are included on the figure that is to be plotted (only entries which are `True` - are extracted via the `GetVisuals2D` object). - visuals - The pre-existing visuals of the plotter which new visuals are added too via the `GetVisuals2D` class. - """ - super().__init__(include=include, visuals=visuals) - - def origin_via_mask_from(self, mask: Mask2D) -> Grid2DIrregular: - """ - From a `Mask2D` get its origin for plotter, which is only extracted if an origin is not already - in `self.visuals` and with `True` entries in the `Include2D` object are extracted for plotting. - - Parameters - ---------- - mask - The 2D mask whose origin is extracted for plotting. - - Returns - ------- - Visuals2D - The collection of attributes that are plotted by a `Plotter` object, which include the origin if it is - extracted. - """ - return self.get("origin", Grid2DIrregular(values=[mask.origin])) - - def via_mask_from(self, mask: Mask2D) -> Visuals2D: - """ - From a `Mask2D` get its attributes that can be plotted and return them in a `Visuals2D` object. - - Only attributes not already in `self.visuals` and with `True` entries in the `Include2D` object are extracted - for plotting. - - From a `Mask2D` the following attributes can be extracted for plotting: - - - origin: the (y,x) origin of the 2D coordinate system. - - mask: the 2D mask. - - border: the border of the 2D mask, which are all of the mask's exterior edge pixels. - - Parameters - ---------- - mask - The 2D mask whose attributes are extracted for plotting. - - Returns - ------- - Visuals2D - The collection of attributes that are plotted by a `Plotter` object. - """ - origin = self.origin_via_mask_from(mask=mask) - mask_visuals = self.get("mask", mask) - border = self.get("border", mask.derive_grid.border) - - return self.visuals + self.visuals.__class__( - origin=origin, mask=mask_visuals, border=border - ) - - def via_grid_from(self, grid: Grid2DLike) -> Visuals2D: - """ - From a `Grid2D` get its attributes that can be plotted and return them in a `Visuals2D` object. - - Only attributes not already in `self.visuals` and with `True` entries in the `Include2D` object are extracted - for plotting. - - From a `Grid2D` the following attributes can be extracted for plotting: - - - origin: the (y,x) origin of the grid's coordinate system. - - Parameters - ---------- - grid : Grid2D - The grid whose attributes are extracted for plotting. - - Returns - ------- - Visuals2D - The collection of attributes that can be plotted by a `Plotter` object. - """ - if not isinstance(grid, Grid2D): - return self.visuals - - origin = self.origin_via_mask_from(mask=grid.mask) - - return self.visuals + self.visuals.__class__(origin=origin) - - def via_mapper_for_data_from(self, mapper: MapperRectangular) -> Visuals2D: - """ - From a `Mapper` get its attributes that can be plotted in the mapper's data-plane (e.g. the reconstructed - data) and return them in a `Visuals2D` object. - - Only attributes not already in `self.visuals` and with `True` entries in the `Include2D` object are extracted - for plotting. - - From a `Mapper` the following attributes can be extracted for plotting in the data-plane: - - - origin: the (y,x) origin of the `Array2D`'s coordinate system in the data plane. - - mask : the `Mask2D` defined in the data-plane containing the data that is used by the `Mapper`. - - mapper_image_plane_mesh_grid: the `Mapper`'s pixelization's mesh in the data-plane. - - mapper_border_grid: the border of the `Mapper`'s full grid in the data-plane. - - Parameters - ---------- - mapper - The mapper whose data-plane attributes are extracted for plotting. - - Returns - ------- - Visuals2D - The collection of attributes that can be plotted by a `Plotter` object. - """ - - visuals_via_mask = self.via_mask_from(mask=mapper.mapper_grids.mask) - - mesh_grid = self.get( - "mesh_grid", mapper.image_plane_mesh_grid, "mapper_image_plane_mesh_grid" - ) - - return ( - self.visuals - + visuals_via_mask - + self.visuals.__class__(mesh_grid=mesh_grid) - ) - - def via_mapper_for_source_from(self, mapper: MapperRectangular) -> Visuals2D: - """ - From a `Mapper` get its attributes that can be plotted in the mapper's source-plane (e.g. the reconstruction) - and return them in a `Visuals2D` object. - - Only attributes not already in `self.visuals` and with `True` entries in the `Include2D` object are extracted - for plotting. - - From a `Mapper` the following attributes can be extracted for plotting in the source-plane: - - - origin: the (y,x) origin of the coordinate system in the source plane. - - mapper_source_plane_data_grid: the (y,x) grid of coordinates in the mapper's source-plane which are paired with - the mapper's pixelization's mesh pixels. - - mapper_source_plane_mesh_grid: the `Mapper`'s pixelization's mesh grid in the source-plane. - - mapper_border_grid: the border of the `Mapper`'s full grid in the data-plane. - - Parameters - ---------- - mapper - The mapper whose source-plane attributes are extracted for plotting. - - Returns - ------- - Visuals2D - The collection of attributes that can be plotted by a `Plotter2D` object. - """ - - origin = self.get( - "origin", Grid2DIrregular(values=[mapper.source_plane_mesh_grid.origin]) - ) - - grid = self.get( - "grid", - mapper.source_plane_data_grid.over_sampled, - "mapper_source_plane_data_grid", - ) - - try: - border_grid = mapper.mapper_grids.source_plane_data_grid.over_sampled[ - mapper.border_relocator.sub_border_slim - ] - border = self.get("border", border_grid) - - except AttributeError: - border = None - - mesh_grid = self.get( - "mesh_grid", mapper.source_plane_mesh_grid, "mapper_source_plane_mesh_grid" - ) - - return self.visuals + self.visuals.__class__( - origin=origin, grid=grid, border=border, mesh_grid=mesh_grid - ) - - def via_fit_imaging_from(self, fit: FitImaging) -> Visuals2D: - """ - From a `FitImaging` get its attributes that can be plotted and return them in a `Visuals2D` object. - - Only attributes not already in `self.visuals` and with `True` entries in the `Include2D` object are extracted - for plotting. - - From a `FitImaging` the following attributes can be extracted for plotting: - - - origin: the (y,x) origin of the 2D coordinate system. - - mask: the 2D mask. - - border: the border of the 2D mask, which are all of the mask's exterior edge pixels. - - Parameters - ---------- - fit - The fit imaging object whose attributes are extracted for plotting. - - Returns - ------- - Visuals2D - The collection of attributes that are plotted by a `Plotter` object. - """ - return self.via_mask_from(mask=fit.mask) diff --git a/autoarray/plot/include/__init__.py b/autoarray/plot/include/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/autoarray/plot/include/abstract.py b/autoarray/plot/include/abstract.py deleted file mode 100644 index baaaa586b..000000000 --- a/autoarray/plot/include/abstract.py +++ /dev/null @@ -1,48 +0,0 @@ -from typing import Optional - -from autoconf import conf - - -class AbstractInclude: - def __init__(self, origin: Optional[bool] = None, mask: Optional[bool] = None): - """ - Sets which `Visuals` are included on a figure that is plotted using a `Plotter`. - - The `Include` object is used to extract the visuals of the plotted data structure (e.g. `Array2D`, `Grid2D`) so - they can be used in plot functions. Only visuals with a `True` entry in the `Include` object are extracted and t - plotted. - - If an entry is not input into the class (e.g. it retains its default entry of `None`) then the bool is - loaded from the `config/visualize/include.ini` config file. This means the default visuals of a project - can be specified in a config file. - - Parameters - ---------- - origin - If `True`, the `origin` of the plotted data structure (e.g. `Array2D`, `Grid2D`) is included on the figure. - mask - if `True`, the `mask` of the plotted data structure (e.g. `Array2D`, `Grid2D`) is included on the figure. - """ - - self._origin = origin - self._mask = mask - - def load(self, value, name): - if value is True: - return True - elif value is False: - return False - elif value is None: - return conf.instance["visualize"]["include"][self.section][name] - - @property - def section(self): - raise NotImplementedError - - @property - def origin(self): - return self.load(value=self._origin, name="origin") - - @property - def mask(self): - return self.load(value=self._mask, name="mask") diff --git a/autoarray/plot/include/one_d.py b/autoarray/plot/include/one_d.py deleted file mode 100644 index 593471a74..000000000 --- a/autoarray/plot/include/one_d.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Optional - -from autoarray.plot.include.abstract import AbstractInclude - - -class Include1D(AbstractInclude): - def __init__(self, origin: Optional[bool] = None, mask: Optional[bool] = None): - """ - Sets which `Visuals1D` are included on a figure plotting 1D data that is plotted using a `Plotter1D`. - - The `Include` object is used to extract the visuals of the plotted 1D data structures so they can be used in - plot functions. Only visuals with a `True` entry in the `Include` object are extracted and plotted. - - If an entry is not input into the class (e.g. it retains its default entry of `None`) then the bool is - loaded from the `config/visualize/include.ini` config file. This means the default visuals of a project - can be specified in a config file. - - Parameters - ---------- - origin - If `True`, the `origin` of the plotted data structure (e.g. `Line`) is included on the figure. - mask - if `True`, the `mask` of the plotted data structure (e.g. `Line`) is included on the figure. - """ - super().__init__(origin=origin, mask=mask) - - @property - def section(self): - return "include_1d" diff --git a/autoarray/plot/include/two_d.py b/autoarray/plot/include/two_d.py deleted file mode 100644 index ebf284e29..000000000 --- a/autoarray/plot/include/two_d.py +++ /dev/null @@ -1,104 +0,0 @@ -from typing import Optional - -from autoarray.plot.include.abstract import AbstractInclude - - -class Include2D(AbstractInclude): - def __init__( - self, - origin: Optional[bool] = None, - mask: Optional[bool] = None, - border: Optional[bool] = None, - grid: Optional[bool] = None, - mapper_image_plane_mesh_grid: Optional[bool] = None, - mapper_source_plane_mesh_grid: Optional[bool] = None, - mapper_source_plane_data_grid: Optional[bool] = None, - parallel_overscan: Optional[bool] = None, - serial_prescan: Optional[bool] = None, - serial_overscan: Optional[bool] = None, - ): - """ - Sets which `Visuals2D` are included on a figure plotting 2D data that is plotted using a `Plotter`. - - The `Include` object is used to extract the visuals of the plotted 2D data structures so they can be used in - plot functions. Only visuals with a `True` entry in the `Include` object are extracted and plotted. - - If an entry is not input into the class (e.g. it retains its default entry of `None`) then the bool is - loaded from the `config/visualize/include.ini` config file. This means the default visuals of a project - can be specified in a config file. - - Parameters - ---------- - origin - If `True`, the `origin` of the plotted data structure (e.g. `Array2D`, `Grid2D`) is included on the figure. - mask - if `True`, the `mask` of the plotted data structure (e.g. `Array2D`, `Grid2D`) is included on the figure. - border - If `True`, the `border` of the plotted data structure (e.g. `Array2D`, `Grid2D`) is included on the figure. - mapper_image_plane_mesh_grid - If `True`, the pixelization grid in the data plane of a plotted `Mapper` is included on the figure. - mapper_source_plane_mesh_grid - If `True`, the pixelization grid in the source plane of a plotted `Mapper` is included on the figure. - parallel_overscan - If `True`, the parallel overscan of a plotted `Frame2D` is included on the figure. - serial_prescan - If `True`, the serial prescan of a plotted `Frame2D` is included on the figure. - serial_overscan - If `True`, the serial overscan of a plotted `Frame2D` is included on the figure. - """ - - super().__init__(origin=origin, mask=mask) - - self._border = border - self._grid = grid - self._mapper_image_plane_mesh_grid = mapper_image_plane_mesh_grid - self._mapper_source_plane_mesh_grid = mapper_source_plane_mesh_grid - self._mapper_source_plane_data_grid = mapper_source_plane_data_grid - self._parallel_overscan = parallel_overscan - self._serial_prescan = serial_prescan - self._serial_overscan = serial_overscan - - @property - def section(self): - return "include_2d" - - @property - def border(self): - return self.load(value=self._border, name="border") - - @property - def grid(self): - return self.load(value=self._grid, name="grid") - - @property - def mapper_image_plane_mesh_grid(self): - return self.load( - value=self._mapper_image_plane_mesh_grid, - name="mapper_image_plane_mesh_grid", - ) - - @property - def mapper_source_plane_mesh_grid(self): - return self.load( - value=self._mapper_source_plane_mesh_grid, - name="mapper_source_plane_mesh_grid", - ) - - @property - def mapper_source_plane_data_grid(self): - return self.load( - value=self._mapper_source_plane_data_grid, - name="mapper_source_plane_data_grid", - ) - - @property - def parallel_overscan(self): - return self.load(value=self._parallel_overscan, name="parallel_overscan") - - @property - def serial_prescan(self): - return self.load(value=self._serial_prescan, name="serial_prescan") - - @property - def serial_overscan(self): - return self.load(value=self._serial_overscan, name="serial_overscan") diff --git a/autoarray/plot/mat_plot/two_d.py b/autoarray/plot/mat_plot/two_d.py index d585cbbcd..5e38bd896 100644 --- a/autoarray/plot/mat_plot/two_d.py +++ b/autoarray/plot/mat_plot/two_d.py @@ -1,6 +1,6 @@ import matplotlib.pyplot as plt import numpy as np -from typing import Optional, List, Tuple, Union +from typing import Optional, List, Union from autoconf import conf @@ -43,6 +43,7 @@ def __init__( legend: Optional[wb.Legend] = None, output: Optional[wb.Output] = None, array_overlay: Optional[w2d.ArrayOverlay] = None, + fill: Optional[w2d.Fill] = None, contour: Optional[w2d.Contour] = None, grid_scatter: Optional[w2d.GridScatter] = None, grid_plot: Optional[w2d.GridPlot] = None, @@ -63,6 +64,7 @@ def __init__( serial_prescan_plot: Optional[w2d.SerialPrescanPlot] = None, serial_overscan_plot: Optional[w2d.SerialOverscanPlot] = None, use_log10: bool = False, + plot_mask: bool = True, ): """ Visualizes 2D data structures (e.g an `Array2D`, `Grid2D`, `VectorField`, etc.) using Matplotlib. @@ -121,6 +123,8 @@ def __init__( Sets if the figure is displayed on the user's screen or output to `.png` using `plt.show` and `plt.savefig` array_overlay Overlays an input `Array2D` over the figure using `plt.imshow`. + fill + Sets the fill of the figure using `plt.fill` and customizes its appearance, such as the color and alpha. contour Overlays contours of an input `Array2D` over the figure using `plt.contour`. grid_scatter @@ -179,6 +183,7 @@ def __init__( ) self.array_overlay = array_overlay or w2d.ArrayOverlay(is_default=True) + self.fill = fill or w2d.Fill(is_default=True) self.contour = contour or w2d.Contour(is_default=True) @@ -219,6 +224,7 @@ def __init__( ) self.use_log10 = use_log10 + self.plot_mask = plot_mask self.is_for_subplot = False @@ -364,9 +370,13 @@ def plot_array( except ValueError: pass - visuals_2d.plot_via_plotter( - plotter=self, grid_indexes=grid_indexes, geometry=array.geometry - ) + if self.plot_mask and visuals_2d.mask is None: + + if not array.mask.is_all_false: + + self.mask_scatter.scatter_grid(grid=array.mask.derive_grid.edge.array) + + visuals_2d.plot_via_plotter(plotter=self, grid_indexes=grid_indexes) if not self.is_for_subplot and not bypass: self.output.to_figure(structure=array, auto_filename=auto_labels.filename) @@ -476,9 +486,7 @@ def plot_grid( if self.contour is not False: self.contour.set(array=color_array, extent=extent, use_log10=self.use_log10) - visuals_2d.plot_via_plotter( - plotter=self, grid_indexes=grid.array, geometry=grid.geometry - ) + visuals_2d.plot_via_plotter(plotter=self, grid_indexes=grid.array) if not self.is_for_subplot: self.output.to_figure(structure=grid, auto_filename=auto_labels.filename) @@ -590,10 +598,7 @@ def _plot_rectangular_mapper( self.xlabel.set() visuals_2d.plot_via_plotter( - plotter=self, - grid_indexes=mapper.source_plane_data_grid.over_sampled, - mapper=mapper, - geometry=mapper.mapper_grids.mask.geometry, + plotter=self, grid_indexes=mapper.source_plane_data_grid.over_sampled ) if not self.is_for_subplot: @@ -674,10 +679,7 @@ def _plot_delaunay_mapper( self.xlabel.set() visuals_2d.plot_via_plotter( - plotter=self, - grid_indexes=mapper.source_plane_data_grid.over_sampled, - mapper=mapper, - geometry=mapper.mapper_grids.mask.geometry, + plotter=self, grid_indexes=mapper.source_plane_data_grid.over_sampled ) if not self.is_for_subplot: @@ -757,10 +759,7 @@ def _plot_voronoi_mapper( self.xlabel.set() visuals_2d.plot_via_plotter( - plotter=self, - grid_indexes=mapper.source_plane_data_grid.over_sampled, - mapper=mapper, - geometry=mapper.mapper_grids.mask.geometry, + plotter=self, grid_indexes=mapper.source_plane_data_grid.over_sampled ) if pixel_values is not None: diff --git a/autoarray/plot/visuals/abstract.py b/autoarray/plot/visuals/abstract.py index b9d4b0e26..35583e985 100644 --- a/autoarray/plot/visuals/abstract.py +++ b/autoarray/plot/visuals/abstract.py @@ -12,25 +12,19 @@ def __add__(self, other): mask = Mask2D.circular(shape_native=(100, 100), pixel_scales=0.1, radius=3.0) array = Array2D.ones(shape_native=(100, 100), pixel_scales=0.1) masked_array = al.Array2D(values=array, mask=mask) - include_2d = Include2D(mask=True) - array_plotter = aplt.Array2DPlotter(array=masked_array, include_2d=include_2d) + array_plotter = aplt.Array2DPlotter(array=masked_array) array_plotter.figure() - Because `mask=True` in `Include2D` the function `figure` extracts the `Mask2D` from the `masked_array` - and plots it. It does this by creating a new `Visuals2D` object. - If the user did not manually input a `Visuals2D` object, the one created in `function_array` is the one used to plot the image However, if the user specifies their own `Visuals2D` object and passed it to the plotter, e.g.: visuals_2d = Visuals2D(origin=(0.0, 0.0)) - include_2d = Include2D(mask=True) - array_plotter = aplt.Array2DPlotter(array=masked_array, include_2d=include_2d) + array_plotter = aplt.Array2DPlotter(array=masked_array) - We now wish for the `Plotter` to plot the `origin` in the user's input `Visuals2D` object and the `Mask2d` - extracted via the `Include2D`. To achieve this, two `Visuals2D` objects are created: (i) the user's input - instance (with an origin) and; (ii) the one created by the `Include2D` object (with a mask). + We now wish for the `Plotter` to plot the `origin` in the user's input `Visuals2D` object. To achieve this, + one `Visuals2D` object is created: (i) the user's input instance (with an origin). This `__add__` override means we can add the two together to make the final `Visuals2D` object that is plotted on the figure containing both the `origin` and `Mask2D`.: diff --git a/autoarray/plot/visuals/one_d.py b/autoarray/plot/visuals/one_d.py index 8e3e33584..b84a832b3 100644 --- a/autoarray/plot/visuals/one_d.py +++ b/autoarray/plot/visuals/one_d.py @@ -2,7 +2,6 @@ from typing import List, Optional, Union from autoarray.mask.mask_1d import Mask1D -from autoarray.plot.include.one_d import Include1D from autoarray.plot.visuals.abstract import AbstractVisuals from autoarray.structures.arrays.uniform_1d import Array1D from autoarray.structures.grids.uniform_1d import Grid1D @@ -23,10 +22,6 @@ def __init__( self.vertical_line = vertical_line self.shaded_region = shaded_region - @property - def include(self): - return Include1D() - def plot_via_plotter(self, plotter): if self.points is not None: plotter.yx_scatter.scatter_yx(y=self.points, x=np.arange(len(self.points))) diff --git a/autoarray/plot/visuals/two_d.py b/autoarray/plot/visuals/two_d.py index a573e2b33..863e54eb9 100644 --- a/autoarray/plot/visuals/two_d.py +++ b/autoarray/plot/visuals/two_d.py @@ -1,4 +1,3 @@ -import numpy as np from matplotlib import patches as ptch from typing import List, Optional, Union @@ -23,13 +22,12 @@ def __init__( mesh_grid: Optional[Grid2D] = None, vectors: Optional[VectorYX2DIrregular] = None, patches: Optional[List[ptch.Patch]] = None, + fill_region: Optional[List] = None, array_overlay: Optional[Array2D] = None, parallel_overscan=None, serial_prescan=None, serial_overscan=None, indexes=None, - pix_indexes=None, - indexes_via_scatter=False, ): self.origin = origin self.mask = mask @@ -40,28 +38,34 @@ def __init__( self.mesh_grid = mesh_grid self.vectors = vectors self.patches = patches + self.fill_region = fill_region self.array_overlay = array_overlay self.parallel_overscan = parallel_overscan self.serial_prescan = serial_prescan self.serial_overscan = serial_overscan self.indexes = indexes - self.pix_indexes = pix_indexes - self.indexes_via_scatter = indexes_via_scatter - def plot_via_plotter(self, plotter, grid_indexes=None, mapper=None, geometry=None): + def plot_via_plotter(self, plotter, grid_indexes=None): + + if self.mask is not None: + plotter.mask_scatter.scatter_grid(grid=self.mask.derive_grid.edge.array) + if self.origin is not None: plotter.origin_scatter.scatter_grid( grid=Grid2DIrregular(values=self.origin).array ) - if self.mask is not None: - plotter.mask_scatter.scatter_grid(grid=self.mask.derive_grid.edge.array) - if self.border is not None: - plotter.border_scatter.scatter_grid(grid=self.border.array) + try: + plotter.border_scatter.scatter_grid(grid=self.border.array) + except AttributeError: + plotter.border_scatter.scatter_grid(grid=self.border) if self.grid is not None: - plotter.grid_scatter.scatter_grid(grid=self.grid.array) + try: + plotter.grid_scatter.scatter_grid(grid=self.grid.array) + except AttributeError: + plotter.grid_scatter.scatter_grid(grid=self.grid) if self.mesh_grid is not None: plotter.mesh_grid_scatter.scatter_grid(grid=self.mesh_grid.array) @@ -75,32 +79,15 @@ def plot_via_plotter(self, plotter, grid_indexes=None, mapper=None, geometry=Non if self.patches is not None: plotter.patch_overlay.overlay_patches(patches=self.patches) + if self.fill_region is not None: + plotter.fill.plot_fill(fill_region=self.fill_region) + if self.lines is not None: plotter.grid_plot.plot_grid(grid=self.lines) if self.indexes is not None and grid_indexes is not None: - if not self.indexes_via_scatter: - plotter.index_plot.plot_grid_indexes_multi( - grid=grid_indexes, indexes=self.indexes, geometry=geometry - ) - - else: - plotter.index_scatter.scatter_grid_indexes( - grid=grid_indexes, - indexes=self.indexes, - ) - - if self.pix_indexes is not None and mapper is not None: - indexes = mapper.pix_indexes_for_slim_indexes(pix_indexes=self.pix_indexes) - - if not self.indexes_via_scatter: - plotter.index_plot.plot_grid_indexes_x1( - grid=grid_indexes, - indexes=indexes, - ) - - else: - plotter.index_scatter.scatter_grid_indexes( - grid=mapper.source_plane_data_grid.over_sampled, - indexes=indexes, - ) + + plotter.index_scatter.scatter_grid_indexes( + grid=grid_indexes, + indexes=self.indexes, + ) diff --git a/autoarray/plot/wrap/two_d/__init__.py b/autoarray/plot/wrap/two_d/__init__.py index 5eb85eeab..5b438f4f8 100644 --- a/autoarray/plot/wrap/two_d/__init__.py +++ b/autoarray/plot/wrap/two_d/__init__.py @@ -1,5 +1,6 @@ from .array_overlay import ArrayOverlay from .contour import Contour +from .fill import Fill from .grid_scatter import GridScatter from .grid_plot import GridPlot from .grid_errorbar import GridErrorbar diff --git a/autoarray/plot/wrap/two_d/fill.py b/autoarray/plot/wrap/two_d/fill.py new file mode 100644 index 000000000..f580dde54 --- /dev/null +++ b/autoarray/plot/wrap/two_d/fill.py @@ -0,0 +1,38 @@ +import logging + +import matplotlib.pyplot as plt + +from autoarray.plot.wrap.two_d.abstract import AbstractMatWrap2D + + +logger = logging.getLogger(__name__) + + +class Fill(AbstractMatWrap2D): + def __init__(self, **kwargs): + """ + The settings used to customize plots using fill on a figure + + This object wraps the following Matplotlib methods: + + - plt.fill https://matplotlib.org/3.3.2/api/_as_gen/matplotlib.pyplot.fill.html + + Parameters + ---------- + symmetric + If True, the colormap normalization (e.g. `vmin` and `vmax`) span the same absolute values producing a + symmetric color bar. + """ + + super().__init__(**kwargs) + + def plot_fill(self, fill_region): + + try: + y_fill = fill_region[:, 0] + x_fill = fill_region[:, 1] + except TypeError: + y_fill = fill_region[0] + x_fill = fill_region[1] + + plt.fill(x_fill, y_fill, **self.config_dict) diff --git a/autoarray/plot/wrap/two_d/grid_plot.py b/autoarray/plot/wrap/two_d/grid_plot.py index 1f77fcf75..cce6ea336 100644 --- a/autoarray/plot/wrap/two_d/grid_plot.py +++ b/autoarray/plot/wrap/two_d/grid_plot.py @@ -106,67 +106,11 @@ def plot_grid_list(self, grid_list: Union[List[Grid2D], List[Grid2DIrregular]]): try: for grid in grid_list: - plt.plot(grid[:, 1], grid[:, 0], c=next(color), **config_dict) + try: + plt.plot(grid[:, 1], grid[:, 0], c=next(color), **config_dict) + except ValueError: + plt.plot( + grid.array[:, 1], grid.array[:, 0], c=next(color), **config_dict + ) except IndexError: pass - - def plot_grid_indexes_x1( - self, - grid: Union[np.ndarray, Grid2D, Grid2DIrregular], - indexes: np.ndarray, - ): - - import matplotlib.pyplot as plt - - color = itertools.cycle(self.config_dict["c"]) - config_dict = self.config_dict - config_dict.pop("c") - - if isinstance(indexes[0], int): - indexes = [indexes] - - for index_list in indexes: - grid_contour = Grid2DContour( - grid=grid[index_list, :], - pixel_scales=None, - shape_native=None, - ) - - grid_hull = grid_contour.hull - - if grid_hull is not None: - plt.plot( - grid_hull[:, 1], grid_hull[:, 0], color=next(color), **config_dict - ) - - def plot_grid_indexes_multi( - self, - grid: Union[np.ndarray, Grid2D, Grid2DIrregular], - indexes: np.ndarray, - geometry: Geometry2D, - ): - import matplotlib.pyplot as plt - - color = itertools.cycle(self.config_dict["c"]) - config_dict = self.config_dict - config_dict.pop("c") - - if isinstance(indexes[0], int): - indexes = [indexes] - - for index_list in indexes: - grid_in = grid[index_list, :] - - if isinstance(index_list[0], tuple): - grid_in = grid_in[0] - - grid_contour = Grid2DContour( - grid=grid_in, - pixel_scales=geometry.pixel_scales, - shape_native=geometry.shape_native, - ) - - color_plot = next(color) - - for contour in grid_contour.contour_list: - plt.plot(contour[:, 1], contour[:, 0], color=color_plot, **config_dict) diff --git a/autoarray/structures/plot/structure_plotters.py b/autoarray/structures/plot/structure_plotters.py index 0596ebcd4..7e7cf655e 100644 --- a/autoarray/structures/plot/structure_plotters.py +++ b/autoarray/structures/plot/structure_plotters.py @@ -1,11 +1,9 @@ import numpy as np from typing import List, Optional, Union -from autoarray.plot.abstract_plotters import Plotter +from autoarray.plot.abstract_plotters import AbstractPlotter from autoarray.plot.visuals.one_d import Visuals1D from autoarray.plot.visuals.two_d import Visuals2D -from autoarray.plot.include.one_d import Include1D -from autoarray.plot.include.two_d import Include2D from autoarray.plot.mat_plot.one_d import MatPlot1D from autoarray.plot.mat_plot.two_d import MatPlot2D from autoarray.plot.auto_labels import AutoLabels @@ -15,13 +13,12 @@ from autoarray.structures.grids.uniform_2d import Grid2D -class Array2DPlotter(Plotter): +class Array2DPlotter(AbstractPlotter): def __init__( self, array: Array2D, mat_plot_2d: MatPlot2D = None, visuals_2d: Visuals2D = None, - include_2d: Include2D = None, ): """ Plots `Array2D` objects using the matplotlib method `imshow()` and many other matplotlib functions which @@ -32,8 +29,7 @@ def __init__( but a user can manually input values into `MatPlot2d` to customize the figure's appearance. Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `Array2D` and plotted via the visuals object, if the corresponding entry is `True` in the `Include2D` - object or the `config/visualize/include.ini` file. + the `Array2D` and plotted via the visuals object. Parameters ---------- @@ -43,36 +39,28 @@ def __init__( Contains objects which wrap the matplotlib function calls that make 2D plots. visuals_2d Contains 2D visuals that can be overlaid on 2D plots. - include_2d - Specifies which attributes of the `Array2D` are extracted and plotted as visuals for 2D plots. """ - super().__init__( - visuals_2d=visuals_2d, include_2d=include_2d, mat_plot_2d=mat_plot_2d - ) + super().__init__(visuals_2d=visuals_2d, mat_plot_2d=mat_plot_2d) self.array = array - def get_visuals_2d(self) -> Visuals2D: - return self.get_2d.via_mask_from(mask=self.array.mask) - def figure_2d(self): """ Plots the plotter's `Array2D` object in 2D. """ self.mat_plot_2d.plot_array( array=self.array, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels(title="Array2D", filename="array"), ) -class Grid2DPlotter(Plotter): +class Grid2DPlotter(AbstractPlotter): def __init__( self, grid: Grid2D, mat_plot_2d: MatPlot2D = None, visuals_2d: Visuals2D = None, - include_2d: Include2D = None, ): """ Plots `Grid2D` objects using the matplotlib method `scatter()` and many other matplotlib functions which @@ -83,8 +71,7 @@ def __init__( but a user can manually input values into `MatPlot2d` to customize the figure's appearance. Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `Grid2D` and plotted via the visuals object, if the corresponding entry is `True` in the `Include2D` - object or the `config/visualize/include.ini` file. + the `Grid2D` and plotted via the visuals object. Parameters ---------- @@ -94,18 +81,11 @@ def __init__( Contains objects which wrap the matplotlib function calls that make 2D plots. visuals_2d Contains 2D visuals that can be overlaid on 2D plots. - include_2d - Specifies which attributes of the `Grid2D` are extracted and plotted as visuals for 2D plots. """ - super().__init__( - visuals_2d=visuals_2d, include_2d=include_2d, mat_plot_2d=mat_plot_2d - ) + super().__init__(visuals_2d=visuals_2d, mat_plot_2d=mat_plot_2d) self.grid = grid - def get_visuals_2d(self) -> Visuals2D: - return self.get_2d.via_grid_from(grid=self.grid) - def figure_2d( self, color_array: np.ndarray = None, @@ -128,7 +108,7 @@ def figure_2d( """ self.mat_plot_2d.plot_grid( grid=self.grid, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels(title="Grid2D", filename="grid"), color_array=color_array, plot_grid_lines=plot_grid_lines, @@ -136,14 +116,13 @@ def figure_2d( ) -class YX1DPlotter(Plotter): +class YX1DPlotter(AbstractPlotter): def __init__( self, y: Union[Array1D, List], x: Optional[Union[Array1D, Grid1D, List]] = None, mat_plot_1d: MatPlot1D = None, visuals_1d: Visuals1D = None, - include_1d: Include1D = None, should_plot_grid: bool = False, should_plot_zero: bool = False, plot_axis_type: Optional[str] = None, @@ -159,8 +138,7 @@ def __init__( but a user can manually input values into `MatPlot1d` to customize the figure's appearance. Overlaid on the figure are visuals, contained in the `Visuals1D` object. Attributes may be extracted from - the `Array1D` and plotted via the visuals object, if the corresponding entry is `True` in the `Include1D` - object or the `config/visualize/include.ini` file. + the `Array1D` and plotted via the visuals object. Parameters ---------- @@ -172,8 +150,6 @@ def __init__( Contains objects which wrap the matplotlib function calls that make 1D plots. visuals_1d Contains 1D visuals that can be overlaid on 1D plots. - include_1d - Specifies which attributes of the `Array1D` are extracted and plotted as visuals for 1D plots. """ if isinstance(y, list): @@ -182,9 +158,7 @@ def __init__( if isinstance(x, list): x = Array1D.no_mask(values=x, pixel_scales=1.0) - super().__init__( - visuals_1d=visuals_1d, include_1d=include_1d, mat_plot_1d=mat_plot_1d - ) + super().__init__(visuals_1d=visuals_1d, mat_plot_1d=mat_plot_1d) self.y = y self.x = y.grid_radial if x is None else x @@ -194,9 +168,6 @@ def __init__( self.plot_yx_dict = plot_yx_dict or {} self.auto_labels = auto_labels - def get_visuals_1d(self) -> Visuals1D: - return self.get_1d.via_array_1d_from(array_1d=self.x) - def figure_1d(self): """ Plots the plotter's y and x values in 1D. @@ -205,7 +176,7 @@ def figure_1d(self): self.mat_plot_1d.plot_yx( y=self.y, x=self.x, - visuals_1d=self.get_visuals_1d(), + visuals_1d=self.visuals_1d, auto_labels=self.auto_labels, should_plot_grid=self.should_plot_grid, should_plot_zero=self.should_plot_zero, diff --git a/test_autoarray/config/visualize.yaml b/test_autoarray/config/visualize.yaml index 8934bb465..d631ae7e9 100644 --- a/test_autoarray/config/visualize.yaml +++ b/test_autoarray/config/visualize.yaml @@ -4,16 +4,6 @@ general: imshow_origin: upper zoom_around_mask: true disable_zoom_for_fits: true # If True, the zoom-in around the masked region is disabled when outputting .fits files, which is useful to retain the same dimensions as the input data. - include_2d: - border: false - mapper_image_plane_mesh_grid: false - mapper_source_plane_data_grid: false - mapper_source_plane_mesh_grid: false - mask: true - origin: true - parallel_overscan: true - serial_overscan: true - serial_prescan: true subplot_shape: 1: (1, 1) # The shape of subplots for a figure with 1 subplot. 2: (2, 2) # The shape of subplots for a figure with 2 subplots. @@ -28,21 +18,6 @@ general: 64: (8, 8) # The shape of subplots for a figure with 64 (or less than the above value) of subplots. 81: (9, 9) # The shape of subplots for a figure with 81 (or less than the above value) of subplots. 100: (10, 10) # The shape of subplots for a figure with 100 (or less than the above value) of subplots. -include: - include_1d: - mask: false - origin: false - include_2d: - border: false - mapper_image_plane_mesh_grid: false - mapper_source_plane_data_grid: false - mapper_source_plane_mesh_grid: false - mask: true - origin: true - parallel_overscan: true - positions: true - serial_overscan: false - serial_prescan: true mat_wrap: Axis: figure: diff --git a/test_autoarray/fit/plot/test_fit_imaging_plotters.py b/test_autoarray/fit/plot/test_fit_imaging_plotters.py index 31c288a4c..22223ff61 100644 --- a/test_autoarray/fit/plot/test_fit_imaging_plotters.py +++ b/test_autoarray/fit/plot/test_fit_imaging_plotters.py @@ -19,7 +19,6 @@ def make_plot_path_setup(): def test__fit_quantities_are_output(fit_imaging_7x7, plot_path, plot_patch): fit_plotter = aplt.FitImagingPlotter( fit=fit_imaging_7x7, - include_2d=aplt.Include2D(origin=True, mask=True, border=True), mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), ) @@ -63,7 +62,6 @@ def test__fit_quantities_are_output(fit_imaging_7x7, plot_path, plot_patch): def test__fit_sub_plot(fit_imaging_7x7, plot_path, plot_patch): fit_plotter = aplt.FitImagingPlotter( fit=fit_imaging_7x7, - include_2d=aplt.Include2D(origin=True, mask=True, border=True), mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), ) @@ -77,7 +75,6 @@ def test__output_as_fits__correct_output_format( ): fit_plotter = aplt.FitImagingPlotter( fit=fit_imaging_7x7, - include_2d=aplt.Include2D(origin=True, mask=True, border=True), mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="fits")), ) diff --git a/test_autoarray/inversion/pixelization/mappers/test_abstract.py b/test_autoarray/inversion/pixelization/mappers/test_abstract.py index 925ec7360..271e03772 100644 --- a/test_autoarray/inversion/pixelization/mappers/test_abstract.py +++ b/test_autoarray/inversion/pixelization/mappers/test_abstract.py @@ -16,7 +16,7 @@ def test__pix_indexes_for_slim_indexes__different_types_of_lists_input(): parameters=9, ) - pixe_indexes_for_slim_indexes = mapper.pix_indexes_for_slim_indexes( + pixe_indexes_for_slim_indexes = mapper.slim_indexes_for_pix_indexes( pix_indexes=[0, 1] ) @@ -31,7 +31,7 @@ def test__pix_indexes_for_slim_indexes__different_types_of_lists_input(): parameters=9, ) - pixe_indexes_for_slim_indexes = mapper.pix_indexes_for_slim_indexes( + pixe_indexes_for_slim_indexes = mapper.slim_indexes_for_pix_indexes( pix_indexes=[[0], [4]] ) diff --git a/test_autoarray/inversion/plot/test_inversion_plotters.py b/test_autoarray/inversion/plot/test_inversion_plotters.py index 735365ecd..62737ec87 100644 --- a/test_autoarray/inversion/plot/test_inversion_plotters.py +++ b/test_autoarray/inversion/plot/test_inversion_plotters.py @@ -25,7 +25,7 @@ def test__individual_attributes_are_output_for_all_mappers( ): inversion_plotter = aplt.InversionPlotter( inversion=rectangular_inversion_7x7_3x3, - visuals_2d=aplt.Visuals2D(indexes=[0], pix_indexes=[1]), + visuals_2d=aplt.Visuals2D(indexes=[0]), mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), ) @@ -55,7 +55,7 @@ def test__individual_attributes_are_output_for_all_mappers( inversion_plotter = aplt.InversionPlotter( inversion=voronoi_inversion_9_3x3, - visuals_2d=aplt.Visuals2D(indexes=[0], pix_indexes=[1]), + visuals_2d=aplt.Visuals2D(indexes=[0]), mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), ) @@ -101,7 +101,7 @@ def test__inversion_subplot_of_mapper__is_output_for_all_inversions( ): inversion_plotter = aplt.InversionPlotter( inversion=rectangular_inversion_7x7_3x3, - visuals_2d=aplt.Visuals2D(indexes=[0], pix_indexes=[1]), + visuals_2d=aplt.Visuals2D(indexes=[0]), mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), ) diff --git a/test_autoarray/inversion/plot/test_mapper_plotters.py b/test_autoarray/inversion/plot/test_mapper_plotters.py index 3dabe5514..1b91c3ad3 100644 --- a/test_autoarray/inversion/plot/test_mapper_plotters.py +++ b/test_autoarray/inversion/plot/test_mapper_plotters.py @@ -14,91 +14,6 @@ def make_plot_path_setup(): ) -def test__get_2d__via_mapper_for_data_from(rectangular_mapper_7x7_3x3): - include = aplt.Include2D( - origin=True, mask=True, mapper_image_plane_mesh_grid=True, border=True - ) - - mapper_plotter = aplt.MapperPlotter( - mapper=rectangular_mapper_7x7_3x3, include_2d=include - ) - - get_2d = mapper_plotter.get_2d.via_mapper_for_data_from( - mapper=rectangular_mapper_7x7_3x3 - ) - - assert get_2d.origin.in_list == [(0.0, 0.0)] - assert (get_2d.mask == rectangular_mapper_7x7_3x3.mapper_grids.mask).all() - assert get_2d.grid == None - - include = aplt.Include2D( - origin=False, mask=False, mapper_image_plane_mesh_grid=False, border=False - ) - - mapper_plotter = aplt.MapperPlotter( - mapper=rectangular_mapper_7x7_3x3, include_2d=include - ) - - get_2d = mapper_plotter.get_2d.via_mapper_for_data_from( - mapper=rectangular_mapper_7x7_3x3 - ) - - assert get_2d.origin == None - assert get_2d.mask == None - assert get_2d.grid == None - assert get_2d.border == None - - -def test__get_2d__via_mapper_for_source_from(rectangular_mapper_7x7_3x3): - include = aplt.Include2D( - origin=True, - mapper_source_plane_data_grid=True, - mapper_source_plane_mesh_grid=True, - border=True, - ) - - mapper_plotter = aplt.MapperPlotter( - mapper=rectangular_mapper_7x7_3x3, include_2d=include - ) - - get_2d = mapper_plotter.get_2d.via_mapper_for_source_from( - mapper=rectangular_mapper_7x7_3x3 - ) - - assert mapper_plotter.visuals_2d.origin == None - assert get_2d.origin.in_list == [(0.0, 0.0)] - assert ( - get_2d.grid == rectangular_mapper_7x7_3x3.source_plane_data_grid.over_sampled - ).all() - assert (get_2d.mesh_grid == rectangular_mapper_7x7_3x3.source_plane_mesh_grid).all() - border_grid = ( - rectangular_mapper_7x7_3x3.mapper_grids.source_plane_data_grid.over_sampled[ - rectangular_mapper_7x7_3x3.border_relocator.sub_border_slim - ] - ) - assert (get_2d.border == border_grid).all() - - include = aplt.Include2D( - origin=False, - border=False, - mapper_source_plane_data_grid=False, - mapper_source_plane_mesh_grid=False, - ) - - mapper_plotter = aplt.MapperPlotter( - mapper=rectangular_mapper_7x7_3x3, include_2d=include - ) - - get_2d = mapper_plotter.get_2d.via_mapper_for_source_from( - mapper=rectangular_mapper_7x7_3x3 - ) - - assert get_2d.origin == None - assert get_2d.grid == None - assert get_2d.mesh_grid == None - assert get_2d.border == None - - def test__figure_2d( rectangular_mapper_7x7_3x3, delaunay_mapper_9_3x3, @@ -107,20 +22,17 @@ def test__figure_2d( plot_patch, ): visuals_2d = aplt.Visuals2D( - indexes=[[(0, 0), (0, 1)], [(1, 2)]], pix_indexes=[[0, 1], [2]] + indexes=[[(0, 0), (0, 1)], [(1, 2)]], ) mat_plot_2d = aplt.MatPlot2D( output=aplt.Output(path=plot_path, filename="mapper1", format="png") ) - include_2d = aplt.Include2D(origin=True, mapper_source_plane_mesh_grid=True) - mapper_plotter = aplt.MapperPlotter( mapper=rectangular_mapper_7x7_3x3, visuals_2d=visuals_2d, mat_plot_2d=mat_plot_2d, - include_2d=include_2d, ) mapper_plotter.figure_2d() @@ -133,10 +45,9 @@ def test__figure_2d( mapper=delaunay_mapper_9_3x3, visuals_2d=visuals_2d, mat_plot_2d=mat_plot_2d, - include_2d=include_2d, ) - mapper_plotter.figure_2d() + mapper_plotter.figure_2d(interpolate_to_uniform=True) assert path.join(plot_path, "mapper1.png") in plot_patch.paths @@ -151,10 +62,9 @@ def test__figure_2d( mapper=voronoi_mapper_9_3x3, visuals_2d=visuals_2d, mat_plot_2d=mat_plot_2d, - include_2d=include_2d, ) - mapper_plotter.figure_2d() + mapper_plotter.figure_2d(interpolate_to_uniform=True) assert path.join(plot_path, "mapper1.png") in plot_patch.paths @@ -167,16 +77,17 @@ def test__subplot_image_and_mapper( plot_path, plot_patch, ): - visuals_2d = aplt.Visuals2D(indexes=[0, 1, 2], pix_indexes=[[0, 1], [2]]) + visuals_2d = aplt.Visuals2D(indexes=[0, 1, 2]) mapper_plotter = aplt.MapperPlotter( mapper=rectangular_mapper_7x7_3x3, visuals_2d=visuals_2d, mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), - include_2d=aplt.Include2D(mapper_source_plane_mesh_grid=True), ) - mapper_plotter.subplot_image_and_mapper(image=imaging_7x7.data) + mapper_plotter.subplot_image_and_mapper( + image=imaging_7x7.data, interpolate_to_uniform=True + ) assert path.join(plot_path, "subplot_image_and_mapper.png") in plot_patch.paths plot_patch.paths = [] @@ -185,10 +96,11 @@ def test__subplot_image_and_mapper( mapper=delaunay_mapper_9_3x3, visuals_2d=visuals_2d, mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), - include_2d=aplt.Include2D(mapper_source_plane_mesh_grid=True), ) - mapper_plotter.subplot_image_and_mapper(image=imaging_7x7.data) + mapper_plotter.subplot_image_and_mapper( + image=imaging_7x7.data, interpolate_to_uniform=True + ) assert path.join(plot_path, "subplot_image_and_mapper.png") in plot_patch.paths pytest.importorskip( @@ -202,8 +114,9 @@ def test__subplot_image_and_mapper( mapper=voronoi_mapper_9_3x3, visuals_2d=visuals_2d, mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), - include_2d=aplt.Include2D(mapper_source_plane_mesh_grid=True), ) - mapper_plotter.subplot_image_and_mapper(image=imaging_7x7.data) + mapper_plotter.subplot_image_and_mapper( + image=imaging_7x7.data, interpolate_to_uniform=True + ) assert path.join(plot_path, "subplot_image_and_mapper.png") in plot_patch.paths diff --git a/test_autoarray/plot/get_visuals/__init__.py b/test_autoarray/plot/get_visuals/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/test_autoarray/plot/get_visuals/test_one_d.py b/test_autoarray/plot/get_visuals/test_one_d.py deleted file mode 100644 index 73f05bf43..000000000 --- a/test_autoarray/plot/get_visuals/test_one_d.py +++ /dev/null @@ -1,32 +0,0 @@ -from os import path -import pytest - -import autoarray.plot as aplt - - -@pytest.fixture(name="plot_path") -def make_plot_path_setup(): - return path.join( - "{}".format(path.dirname(path.realpath(__file__))), "files", "plots", "imaging" - ) - - -def test__via_array_1d_from(array_1d_7): - visuals_1d = aplt.Visuals1D(origin=(1.0, 1.0)) - include_1d = aplt.Include1D(origin=True, mask=True) - - get_visuals = aplt.GetVisuals1D(include=include_1d, visuals=visuals_1d) - - visuals_1d_via = get_visuals.via_array_1d_from(array_1d=array_1d_7) - - assert visuals_1d_via.origin == (1.0, 1.0) - assert (visuals_1d_via.mask == array_1d_7.mask).all() - - include_1d = aplt.Include1D(origin=False, mask=False) - - get_visuals = aplt.GetVisuals1D(include=include_1d, visuals=visuals_1d) - - visuals_1d_via = get_visuals.via_array_1d_from(array_1d=array_1d_7) - - assert visuals_1d_via.origin == (1.0, 1.0) - assert visuals_1d_via.mask == None diff --git a/test_autoarray/plot/get_visuals/test_two_d.py b/test_autoarray/plot/get_visuals/test_two_d.py deleted file mode 100644 index 5494bcf22..000000000 --- a/test_autoarray/plot/get_visuals/test_two_d.py +++ /dev/null @@ -1,164 +0,0 @@ -from os import path -import pytest - -import autoarray.plot as aplt - - -@pytest.fixture(name="plot_path") -def make_plot_path_setup(): - return path.join( - "{}".format(path.dirname(path.realpath(__file__))), "files", "plots", "imaging" - ) - - -def test__via_mask_from(mask_2d_7x7): - visuals_2d = aplt.Visuals2D(origin=(1.0, 1.0), vectors=2) - include_2d = aplt.Include2D(origin=True, mask=True, border=True) - - get_visuals = aplt.GetVisuals2D(include=include_2d, visuals=visuals_2d) - - visuals_2d_via = get_visuals.via_mask_from(mask=mask_2d_7x7) - - assert visuals_2d_via.origin == (1.0, 1.0) - assert (visuals_2d_via.mask == mask_2d_7x7).all() - assert (visuals_2d_via.border == mask_2d_7x7.derive_grid.border).all() - assert visuals_2d_via.vectors == 2 - - include_2d = aplt.Include2D(origin=False, mask=False, border=False) - - get_visuals = aplt.GetVisuals2D(include=include_2d, visuals=visuals_2d) - - visuals_2d_via = get_visuals.via_mask_from(mask=mask_2d_7x7) - - assert visuals_2d_via.origin == (1.0, 1.0) - assert visuals_2d_via.mask == None - assert visuals_2d_via.border == None - assert visuals_2d_via.vectors == 2 - - -def test__via_grid_from(grid_2d_7x7): - visuals_2d = aplt.Visuals2D() - include_2d = aplt.Include2D(origin=True) - - get_visuals = aplt.GetVisuals2D(include=include_2d, visuals=visuals_2d) - - visuals_2d_via = get_visuals.via_grid_from(grid=grid_2d_7x7) - - assert (visuals_2d_via.origin == grid_2d_7x7.origin).all() - - include_2d = aplt.Include2D(origin=False) - - get_visuals = aplt.GetVisuals2D(include=include_2d, visuals=visuals_2d) - - visuals_2d_via = get_visuals.via_grid_from(grid=grid_2d_7x7) - - assert visuals_2d_via.origin == None - - -def test__via_mapper_for_data_from(voronoi_mapper_9_3x3): - visuals_2d = aplt.Visuals2D(origin=(1.0, 1.0)) - include_2d = aplt.Include2D( - origin=True, mask=True, border=True, mapper_image_plane_mesh_grid=True - ) - - get_visuals = aplt.GetVisuals2D(include=include_2d, visuals=visuals_2d) - - visuals_2d_via = get_visuals.via_mapper_for_data_from(mapper=voronoi_mapper_9_3x3) - - assert visuals_2d.origin == (1.0, 1.0) - assert (visuals_2d_via.mask == voronoi_mapper_9_3x3.mapper_grids.mask).all() - assert ( - visuals_2d_via.border - == voronoi_mapper_9_3x3.mapper_grids.mask.derive_grid.border - ).all() - - assert ( - visuals_2d_via.mesh_grid == voronoi_mapper_9_3x3.image_plane_mesh_grid - ).all() - - include_2d = aplt.Include2D( - origin=False, mask=False, border=False, mapper_image_plane_mesh_grid=False - ) - - get_visuals = aplt.GetVisuals2D(include=include_2d, visuals=visuals_2d) - - visuals_2d_via = get_visuals.via_mapper_for_data_from(mapper=voronoi_mapper_9_3x3) - - assert visuals_2d.origin == (1.0, 1.0) - assert visuals_2d_via.mask == None - assert visuals_2d_via.border == None - assert visuals_2d_via.mesh_grid == None - - -def test__via_mapper_for_source_from(rectangular_mapper_7x7_3x3): - visuals_2d = aplt.Visuals2D(origin=(1.0, 1.0)) - include_2d = aplt.Include2D( - origin=True, - border=True, - mapper_source_plane_data_grid=True, - mapper_source_plane_mesh_grid=True, - ) - - get_visuals = aplt.GetVisuals2D(include=include_2d, visuals=visuals_2d) - - visuals_2d_via = get_visuals.via_mapper_for_source_from( - mapper=rectangular_mapper_7x7_3x3 - ) - - assert visuals_2d.origin == (1.0, 1.0) - assert ( - visuals_2d_via.grid - == rectangular_mapper_7x7_3x3.source_plane_data_grid.over_sampled - ).all() - border_grid = ( - rectangular_mapper_7x7_3x3.mapper_grids.source_plane_data_grid.over_sampled[ - rectangular_mapper_7x7_3x3.border_relocator.sub_border_slim - ] - ) - assert (visuals_2d_via.border == border_grid).all() - assert ( - visuals_2d_via.mesh_grid == rectangular_mapper_7x7_3x3.source_plane_mesh_grid - ).all() - - include_2d = aplt.Include2D( - origin=False, - border=False, - mapper_source_plane_data_grid=False, - mapper_source_plane_mesh_grid=False, - ) - - get_visuals = aplt.GetVisuals2D(include=include_2d, visuals=visuals_2d) - - visuals_2d_via = get_visuals.via_mapper_for_source_from( - mapper=rectangular_mapper_7x7_3x3 - ) - - assert visuals_2d.origin == (1.0, 1.0) - assert visuals_2d_via.grid == None - assert visuals_2d_via.border == None - assert visuals_2d_via.mesh_grid == None - - -def test__via_fit_imaging_from(fit_imaging_7x7): - visuals_2d = aplt.Visuals2D(origin=(1.0, 1.0), vectors=2) - include_2d = aplt.Include2D(origin=True, mask=True, border=True) - - get_visuals = aplt.GetVisuals2D(include=include_2d, visuals=visuals_2d) - - visuals_2d_via = get_visuals.via_fit_imaging_from(fit=fit_imaging_7x7) - - assert visuals_2d_via.origin == (1.0, 1.0) - assert (visuals_2d_via.mask == fit_imaging_7x7.mask).all() - assert (visuals_2d_via.border == fit_imaging_7x7.mask.derive_grid.border).all() - assert visuals_2d_via.vectors == 2 - - include_2d = aplt.Include2D(origin=False, mask=False, border=False) - - get_visuals = aplt.GetVisuals2D(include=include_2d, visuals=visuals_2d) - - visuals_2d_via = get_visuals.via_fit_imaging_from(fit=fit_imaging_7x7) - - assert visuals_2d_via.origin == (1.0, 1.0) - assert visuals_2d_via.mask == None - assert visuals_2d_via.border == None - assert visuals_2d_via.vectors == 2 diff --git a/test_autoarray/plot/include/__init__.py b/test_autoarray/plot/include/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/test_autoarray/plot/include/test_include.py b/test_autoarray/plot/include/test_include.py deleted file mode 100644 index b32616e9d..000000000 --- a/test_autoarray/plot/include/test_include.py +++ /dev/null @@ -1,21 +0,0 @@ -import autoarray.plot as aplt - - -def test__loads_default_values_from_config_if_not_input(): - include = aplt.Include2D() - - assert include.origin is True - assert include.mask == True - assert include.border is False - assert include.parallel_overscan is True - assert include.serial_prescan is True - assert include.serial_overscan is False - - include = aplt.Include2D(origin=False, border=False, serial_overscan=True) - - assert include.origin is False - assert include.mask == True - assert include.border is False - assert include.parallel_overscan is True - assert include.serial_prescan is True - assert include.serial_overscan is True diff --git a/test_autoarray/plot/test_abstract_plotters.py b/test_autoarray/plot/test_abstract_plotters.py index 8d44b1320..05c905f7c 100644 --- a/test_autoarray/plot/test_abstract_plotters.py +++ b/test_autoarray/plot/test_abstract_plotters.py @@ -105,33 +105,3 @@ def test__uses_figure_or_subplot_configs_correctly(): assert plotter.mat_plot_2d.figure.config_dict["aspect"] == "square" assert plotter.mat_plot_2d.cmap.config_dict["cmap"] == "default" assert plotter.mat_plot_2d.cmap.config_dict["norm"] == "linear" - - -def test__get__visuals(): - visuals_2d = aplt.Visuals2D() - include_2d = aplt.Include2D(origin=False) - - plotter = abstract_plotters.Plotter(visuals_2d=visuals_2d, include_2d=include_2d) - attr = plotter.get_2d.get(name="origin", value=1) - - assert attr == None - - include_2d = aplt.Include2D(origin=True) - plotter = abstract_plotters.Plotter(visuals_2d=visuals_2d, include_2d=include_2d) - attr = plotter.get_2d.get(name="origin", value=1) - - assert attr == 1 - - visuals_2d = aplt.Visuals2D(origin=10) - - include_2d = aplt.Include2D(origin=False) - plotter = abstract_plotters.Plotter(visuals_2d=visuals_2d, include_2d=include_2d) - attr = plotter.get_2d.get(name="origin", value=2) - - assert attr == 10 - - include_2d = aplt.Include2D(origin=True) - plotter = abstract_plotters.Plotter(visuals_2d=visuals_2d, include_2d=include_2d) - attr = plotter.get_2d.get(name="origin", value=2) - - assert attr == 10 diff --git a/test_autoarray/plot/test_multi_plotters.py b/test_autoarray/plot/test_multi_plotters.py index 7485a3d6a..9c2048ac3 100644 --- a/test_autoarray/plot/test_multi_plotters.py +++ b/test_autoarray/plot/test_multi_plotters.py @@ -47,14 +47,12 @@ def __init__( x, mat_plot_1d: aplt.MatPlot1D = None, visuals_1d: aplt.Visuals1D = None, - include_1d: aplt.Include1D = None, ): super().__init__( y=y, x=x, mat_plot_1d=mat_plot_1d, visuals_1d=visuals_1d, - include_1d=include_1d, ) def figures_1d(self, figure_name=False): diff --git a/test_autoarray/structures/plot/test_structure_plotters.py b/test_autoarray/structures/plot/test_structure_plotters.py index 45fff5bf1..53b796798 100644 --- a/test_autoarray/structures/plot/test_structure_plotters.py +++ b/test_autoarray/structures/plot/test_structure_plotters.py @@ -57,7 +57,6 @@ def test__array( array_plotter = aplt.Array2DPlotter( array=array_2d_7x7, - include_2d=aplt.Include2D(origin=True, mask=True, border=True), mat_plot_2d=aplt.MatPlot2D( output=aplt.Output(path=plot_path, filename="array2", format="png") ), @@ -138,7 +137,6 @@ def test__grid( mat_plot_2d=aplt.MatPlot2D( output=aplt.Output(path=plot_path, filename="grid2", format="png") ), - include_2d=aplt.Include2D(origin=True, mask=True, border=True), ) grid_2d_plotter.figure_2d(color_array=color_array)