diff --git a/autoarray/__init__.py b/autoarray/__init__.py index 947377b59..2dd361dd8 100644 --- a/autoarray/__init__.py +++ b/autoarray/__init__.py @@ -1,3 +1,4 @@ +from autoconf import jax_wrapper from autoconf.dictable import register_parser from autoconf import conf diff --git a/autoarray/abstract_ndarray.py b/autoarray/abstract_ndarray.py index 6d71f0983..6c8b1d9e8 100644 --- a/autoarray/abstract_ndarray.py +++ b/autoarray/abstract_ndarray.py @@ -4,8 +4,6 @@ from abc import ABC from abc import abstractmethod -import jax.numpy as jnp -from jax._src.tree_util import register_pytree_node import numpy as np @@ -75,20 +73,20 @@ def __init__(self, array, xp=np): while isinstance(array, AbstractNDArray): array = array.array self._array = array - try: - register_pytree_node( - type(self), - self.instance_flatten, - self.instance_unflatten, - ) - except ValueError: - pass + # try: + # register_pytree_node( + # type(self), + # self.instance_flatten, + # self.instance_unflatten, + # ) + # except ValueError: + # pass self._xp = xp def invert(self): new = self.copy() - new._array = jnp.invert(new._array) + new._array = self._xp.invert(new._array) return new @classmethod @@ -117,7 +115,7 @@ def instance_unflatten(cls, aux_data, children): setattr(instance, key, value) return instance - def with_new_array(self, array: jnp.ndarray) -> "AbstractNDArray": + def with_new_array(self, array: np.ndarray) -> "AbstractNDArray": """ Copy this object but give it a new array. @@ -137,10 +135,9 @@ def with_new_array(self, array: jnp.ndarray) -> "AbstractNDArray": new_array._array = array return new_array - @staticmethod - def flip_hdu_for_ds9(values): + def flip_hdu_for_ds9(self, values): if conf.instance["general"]["fits"]["flip_for_ds9"]: - return jnp.flipud(values) + return self._xp.flipud(values) return values def copy(self): @@ -170,7 +167,7 @@ def __iter__(self): @to_new_array def sqrt(self): - return jnp.sqrt(self._array) + return self._xp.sqrt(self._array) @property def array(self): @@ -333,20 +330,28 @@ def __getattr__(self, item): ) def __getitem__(self, item): + result = self._array[item] + if isinstance(item, slice): result = self.with_new_array(result) - if isinstance(result, jnp.ndarray): - result = self.with_new_array(result) + + try: + import jax.numpy as jnp + if isinstance(result, jnp.ndarray): + result = self.with_new_array(result) + except ImportError: + pass + return result def __setitem__(self, key, value): - from jax import Array - if isinstance(key, (jnp.ndarray, AbstractNDArray, Array)): - self._array = jnp.where(key, value, self._array) - else: + if isinstance(self._array, np.ndarray): self._array[key] = value + else: + import jax.numpy as jnp + self._array = jnp.where(key, value, self._array) def __repr__(self): return repr(self._array).replace( diff --git a/autoarray/config/general.yaml b/autoarray/config/general.yaml index 224001ba2..a80402109 100644 --- a/autoarray/config/general.yaml +++ b/autoarray/config/general.yaml @@ -1,5 +1,3 @@ -jax: - use_jax: true # If True, uses JAX internally, whereas False uses normal Numpy. fits: flip_for_ds9: false # If True, the image is flipped before output to a .fits file, which is useful for viewing in DS9. psf: diff --git a/autoarray/fit/fit_dataset.py b/autoarray/fit/fit_dataset.py index 745ae5a45..02089fcc6 100644 --- a/autoarray/fit/fit_dataset.py +++ b/autoarray/fit/fit_dataset.py @@ -163,7 +163,7 @@ def subtracted_from(grid, offset): if grid is None: return None - return grid.subtracted_from(offset=offset) + return grid.subtracted_from(offset=offset, xp=self._xp) lp = subtracted_from( grid=self.dataset.grids.lp, offset=self.dataset_model.grid_offset diff --git a/autoarray/inversion/inversion/interferometer/abstract.py b/autoarray/inversion/inversion/interferometer/abstract.py index dd37952d9..cbe508c15 100644 --- a/autoarray/inversion/inversion/interferometer/abstract.py +++ b/autoarray/inversion/inversion/interferometer/abstract.py @@ -65,7 +65,7 @@ def operated_mapping_matrix_list(self) -> List[np.ndarray]: """ return [ self.transformer.transform_mapping_matrix( - mapping_matrix=linear_obj.mapping_matrix + mapping_matrix=linear_obj.mapping_matrix, xp=self._xp ) for linear_obj in self.linear_obj_list ] diff --git a/autoarray/mask/derive/indexes_2d.py b/autoarray/mask/derive/indexes_2d.py index 13bdec3b9..b9f996164 100644 --- a/autoarray/mask/derive/indexes_2d.py +++ b/autoarray/mask/derive/indexes_2d.py @@ -2,7 +2,6 @@ import logging import numpy as np -from jax._src.tree_util import register_pytree_node_class from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -14,7 +13,6 @@ logger = logging.getLogger(__name__) -@register_pytree_node_class class DeriveIndexes2D: def __init__(self, mask: Mask2D, xp=np): diff --git a/autoarray/operators/over_sampling/over_sampler.py b/autoarray/operators/over_sampling/over_sampler.py index 62a3b8ba6..beb75bbbe 100644 --- a/autoarray/operators/over_sampling/over_sampler.py +++ b/autoarray/operators/over_sampling/over_sampler.py @@ -1,6 +1,5 @@ import numpy as np -from jax._src.tree_util import register_pytree_node_class from typing import Union from autoconf import conf @@ -11,7 +10,6 @@ from autoarray.operators.over_sampling import over_sample_util -@register_pytree_node_class class OverSampler: def __init__(self, mask: Mask2D, sub_size: Union[int, Array2D]): """ @@ -229,6 +227,7 @@ def binned_array_2d_from(self, array: Array2D, xp=np) -> "Array2D": Sub-pixels that are part of the same mask array pixel are indexed next to one another, such that the second sub-pixel in the first pixel has index 1, its next sub-pixel has index 2, and so forth. """ + if conf.instance["general"]["structures"]["native_binned_only"]: return self @@ -245,16 +244,28 @@ def binned_array_2d_from(self, array: Array2D, xp=np) -> "Array2D": else: - import jax + if xp.__name__.startswith("jax"): - # Compute the group means + import jax + + sums = jax.ops.segment_sum( + array, self.segment_ids, self.mask.pixels_in_mask + ) + counts = jax.ops.segment_sum( + xp.ones_like(array), self.segment_ids, self.mask.pixels_in_mask + ) + + else: + + # Sum values per segment + sums = np.bincount(self.segment_ids, weights=array, minlength=self.mask.pixels_in_mask) + + # Count number of items per segment + counts = np.bincount(self.segment_ids, minlength=self.mask.pixels_in_mask) + + # Avoid division by zero + counts[counts == 0] = 1 - sums = jax.ops.segment_sum( - array, self.segment_ids, self.mask.pixels_in_mask - ) - counts = jax.ops.segment_sum( - xp.ones_like(array), self.segment_ids, self.mask.pixels_in_mask - ) binned_array_2d = sums / counts return Array2D( diff --git a/autoarray/operators/transformer.py b/autoarray/operators/transformer.py index b2d7fc916..72065dedd 100644 --- a/autoarray/operators/transformer.py +++ b/autoarray/operators/transformer.py @@ -39,7 +39,6 @@ def __init__( uv_wavelengths: np.ndarray, real_space_mask: Mask2D, preload_transform: bool = True, - xp=np, ): """ A direct Fourier transform (DFT) operator for radio interferometric imaging. @@ -112,9 +111,7 @@ def __init__( 2.0 * self.grid.shape_native[1] ) - self._xp = xp - - def visibilities_from(self, image: Array2D) -> Visibilities: + def visibilities_from(self, image: Array2D, xp=np) -> Visibilities: """ Computes the visibilities from a real-space image using the direct Fourier transform (DFT). @@ -138,19 +135,20 @@ def visibilities_from(self, image: Array2D) -> Visibilities: image_1d=image.array, preloaded_reals=self.preload_real_transforms, preloaded_imags=self.preload_imag_transforms, - xp=self._xp, + xp=xp, ) else: visibilities = transformer_util.visibilities_from( image_1d=image.slim.array, grid_radians=self.grid.array, uv_wavelengths=self.uv_wavelengths, + xp=xp ) - return Visibilities(visibilities=self._xp.array(visibilities)) + return Visibilities(visibilities=xp.array(visibilities)) def image_from( - self, visibilities: Visibilities, use_adjoint_scaling: bool = False + self, visibilities: Visibilities, use_adjoint_scaling: bool = False, xp=np ) -> Array2D: """ Computes the real-space image from a set of visibilities using the adjoint of the DFT. @@ -178,12 +176,12 @@ def image_from( ) image_native = array_2d_util.array_2d_native_from( - array_2d_slim=image_slim, mask_2d=self.real_space_mask, xp=self._xp + array_2d_slim=image_slim, mask_2d=self.real_space_mask, xp=xp ) return Array2D(values=image_native, mask=self.real_space_mask) - def transform_mapping_matrix(self, mapping_matrix: np.ndarray) -> np.ndarray: + def transform_mapping_matrix(self, mapping_matrix: np.ndarray, xp=np) -> np.ndarray: """ Applies the DFT to a mapping matrix that maps source pixels to image pixels. @@ -310,8 +308,6 @@ def __init__( 2.0 * self.grid.shape_native[1] ) - self._xp = xp - def initialize_plan(self, ratio: int = 2, interp_kernel: Tuple[int, int] = (6, 6)): """ Initializes the PyNUFFT plan for performing the NUFFT operation. @@ -394,7 +390,7 @@ def visibilities_from(self, image: Array2D) -> Visibilities: ) def image_from( - self, visibilities: Visibilities, use_adjoint_scaling: bool = False + self, visibilities: Visibilities, use_adjoint_scaling: bool = False, xp=np ) -> Array2D: """ Reconstructs a real-space image from visibilities using the NUFFT adjoint transform. @@ -425,24 +421,24 @@ def image_from( return Array2D(values=image, mask=self.real_space_mask) - def transform_mapping_matrix(self, mapping_matrix: np.ndarray) -> np.ndarray: + def transform_mapping_matrix(self, mapping_matrix: np.ndarray, xp=np) -> np.ndarray: """ - Applies the NUFFT forward transform to each column of a mapping matrix, producing transformed visibilities. + Applies the NUFFT forward transform to each column of a mapping matrix, producing transformed visibilities. - Parameters - ---------- - mapping_matrix - A 2D array where each column corresponds to a source-plane pixel intensity distribution flattened into image space. + Parameters + ---------- + mapping_matrix + A 2D array where each column corresponds to a source-plane pixel intensity distribution flattened into image space. - Returns + Returns ------- - A complex-valued 2D array where each column contains the visibilities corresponding to the respective column - in the input mapping matrix. + A complex-valued 2D array where each column contains the visibilities corresponding to the respective column + in the input mapping matrix. - Notes - ----- - - Each column of the input mapping matrix is reshaped into the native 2D image grid before transformation. - - This method repeatedly calls `visibilities_from` for each column, which may be computationally intensive. + Notes + ----- + - Each column of the input mapping matrix is reshaped into the native 2D image grid before transformation. + - This method repeatedly calls `visibilities_from` for each column, which may be computationally intensive. """ transformed_mapping_matrix = 0 + 0j * np.zeros( (self.uv_wavelengths.shape[0], mapping_matrix.shape[1]) @@ -452,7 +448,7 @@ def transform_mapping_matrix(self, mapping_matrix: np.ndarray) -> np.ndarray: image_2d = array_2d_util.array_2d_native_from( array_2d_slim=mapping_matrix[:, source_pixel_1d_index], mask_2d=self.grid.mask, - xp=self._xp, + xp=xp, ) image = Array2D(values=image_2d, mask=self.grid.mask) diff --git a/autoarray/operators/transformer_util.py b/autoarray/operators/transformer_util.py index 3ff0cf868..2beb7e145 100644 --- a/autoarray/operators/transformer_util.py +++ b/autoarray/operators/transformer_util.py @@ -120,7 +120,7 @@ def visibilities_via_preload_from( def visibilities_from( - image_1d: np.ndarray, grid_radians: np.ndarray, uv_wavelengths: np.ndarray + image_1d: np.ndarray, grid_radians: np.ndarray, uv_wavelengths: np.ndarray, xp=np ) -> np.ndarray: """ Compute complex visibilities from an input sky image using the Fourier transform, @@ -150,19 +150,19 @@ def visibilities_from( # Compute the dot product for each pixel-uv pair phase = ( -2.0 - * np.pi + * xp.pi * ( - np.outer(grid_radians[:, 1], uv_wavelengths[:, 0]) - + np.outer(grid_radians[:, 0], uv_wavelengths[:, 1]) + xp.outer(grid_radians[:, 1], uv_wavelengths[:, 0]) + + xp.outer(grid_radians[:, 0], uv_wavelengths[:, 1]) ) ) # shape (n_pixels, n_vis) # Multiply image values with phase terms - vis_real = image_1d[:, None] * np.cos(phase) - vis_imag = image_1d[:, None] * np.sin(phase) + vis_real = image_1d[:, None] * xp.cos(phase) + vis_imag = image_1d[:, None] * xp.sin(phase) # Sum over all pixels for each visibility - visibilities = np.sum(vis_real + 1j * vis_imag, axis=0) + visibilities = xp.sum(vis_real + 1j * vis_imag, axis=0) return visibilities @@ -247,7 +247,7 @@ def transformed_mapping_matrix_via_preload_from( def transformed_mapping_matrix_from( - mapping_matrix: np.ndarray, grid_radians: np.ndarray, uv_wavelengths: np.ndarray + mapping_matrix: np.ndarray, grid_radians: np.ndarray, uv_wavelengths: np.ndarray, xp=np ) -> np.ndarray: """ Computes the Fourier-transformed mapping matrix used in radio interferometric imaging. @@ -273,16 +273,16 @@ def transformed_mapping_matrix_from( # Compute phase term: (n_image_pixels, n_visibilities) phase = ( -2.0 - * np.pi + * xp.pi * ( - np.outer(grid_radians[:, 1], uv_wavelengths[:, 0]) # y * u - + np.outer(grid_radians[:, 0], uv_wavelengths[:, 1]) # x * v + xp.outer(grid_radians[:, 1], uv_wavelengths[:, 0]) # y * u + + xp.outer(grid_radians[:, 0], uv_wavelengths[:, 1]) # x * v ) ) # Compute real and imaginary Fourier matrices - fourier_real = np.cos(phase) - fourier_imag = np.sin(phase) + fourier_real = xp.cos(phase) + fourier_imag = xp.sin(phase) # Only compute contributions from non-zero mapping entries # This matrix multiplication is: (n_visibilities x n_image_pixels) dot (n_image_pixels x n_source_pixels) diff --git a/autoarray/preloads.py b/autoarray/preloads.py index 0cc076eba..5b258ffc0 100644 --- a/autoarray/preloads.py +++ b/autoarray/preloads.py @@ -9,9 +9,7 @@ def mapper_indices_from(total_linear_light_profiles, total_mapper_pixels): - import jax.numpy as jnp - - return jnp.arange( + return np.arange( total_linear_light_profiles, total_linear_light_profiles + total_mapper_pixels, dtype=int, @@ -54,8 +52,6 @@ def __init__( is fixed to the maximum likelihood solution, allowing the blurred mapping matrix to be preloaded, but the intensity values will still be solved for during the inversion. """ - import jax.numpy as jnp - self.mapper_indices = None self.source_pixel_zeroed_indices = None self.source_pixel_zeroed_indices_to_keep = None @@ -63,22 +59,21 @@ def __init__( if mapper_indices is not None: - self.mapper_indices = jnp.array(mapper_indices) + self.mapper_indices = np.array(mapper_indices) if source_pixel_zeroed_indices is not None: - self.source_pixel_zeroed_indices = jnp.array(source_pixel_zeroed_indices) + self.source_pixel_zeroed_indices = np.array(source_pixel_zeroed_indices) - ids_zeros = jnp.array(source_pixel_zeroed_indices, dtype=int) + ids_zeros = np.array(source_pixel_zeroed_indices, dtype=int) - values_to_solve = jnp.ones(np.max(mapper_indices), dtype=bool) - values_to_solve = values_to_solve.at[ids_zeros].set(False) + values_to_solve = np.ones(np.max(mapper_indices)+1, dtype=bool) + values_to_solve[ids_zeros] = False - # Get the indices where values_to_solve is True - self.source_pixel_zeroed_indices_to_keep = jnp.where(values_to_solve)[0] + self.source_pixel_zeroed_indices_to_keep = np.where(values_to_solve)[0] if linear_light_profile_blurred_mapping_matrix is not None: - self.linear_light_profile_blurred_mapping_matrix = jnp.array( + self.linear_light_profile_blurred_mapping_matrix = np.array( linear_light_profile_blurred_mapping_matrix ) diff --git a/autoarray/structures/triangles/abstract.py b/autoarray/structures/triangles/abstract.py index 3ae5e4718..adaae7fa8 100644 --- a/autoarray/structures/triangles/abstract.py +++ b/autoarray/structures/triangles/abstract.py @@ -12,21 +12,27 @@ def __len__(self): return len(self.triangles) @property - @abstractmethod def area(self) -> float: """ The total area covered by the triangles. """ + triangles = self.triangles + return ( + 0.5 + * np.abs( + (triangles[:, 0, 0] * (triangles[:, 1, 1] - triangles[:, 2, 1])) + + (triangles[:, 1, 0] * (triangles[:, 2, 1] - triangles[:, 0, 1])) + + (triangles[:, 2, 0] * (triangles[:, 0, 1] - triangles[:, 1, 1])) + ).sum() + ) @property - @abstractmethod def indices(self): - pass + return self._indices @property - @abstractmethod def vertices(self): - pass + return self._vertices def __str__(self): return f"{self.__class__.__name__} with {len(self.indices)} triangles" diff --git a/autoarray/structures/triangles/array.py b/autoarray/structures/triangles/array.py index 3f7d50049..f644728e8 100644 --- a/autoarray/structures/triangles/array.py +++ b/autoarray/structures/triangles/array.py @@ -1,7 +1,5 @@ import numpy as np -from jax.tree_util import register_pytree_node_class - from autoarray.structures.triangles.abstract import HEIGHT_FACTOR from autoarray.structures.triangles.abstract import AbstractTriangles @@ -10,7 +8,6 @@ MAX_CONTAINING_SIZE = 15 -@register_pytree_node_class class ArrayTriangles(AbstractTriangles): def __init__( self, @@ -120,14 +117,6 @@ def add_vertex(v): max_containing_size=max_containing_size, ) - @property - def indices(self): - return self._indices - - @property - def vertices(self): - return self._vertices - @property def triangles(self) -> np.ndarray: """ @@ -324,21 +313,6 @@ def with_vertices(self, vertices: np.ndarray) -> "ArrayTriangles": max_containing_size=self.max_containing_size, ) - @property - def area(self) -> float: - """ - The total area covered by the triangles. - """ - triangles = self.triangles - return ( - 0.5 - * np.abs( - (triangles[:, 0, 0] * (triangles[:, 1, 1] - triangles[:, 2, 1])) - + (triangles[:, 1, 0] * (triangles[:, 2, 1] - triangles[:, 0, 1])) - + (triangles[:, 2, 0] * (triangles[:, 0, 1] - triangles[:, 1, 1])) - ).sum() - ) - def tree_flatten(self): """ Flatten this model as a PyTree. diff --git a/autoarray/structures/triangles/array_np.py b/autoarray/structures/triangles/array_np.py new file mode 100644 index 000000000..f0840c587 --- /dev/null +++ b/autoarray/structures/triangles/array_np.py @@ -0,0 +1,251 @@ +from abc import ABC + +import numpy as np + +from autoarray.structures.triangles.abstract import AbstractTriangles +from autoarray.structures.triangles.shape import Shape +from autoarray.structures.triangles.abstract import HEIGHT_FACTOR + + +class ArrayTrianglesNp(AbstractTriangles, ABC): + + def __init__( + self, + indices, + vertices, + **kwargs, + ): + """ + Represents a set of triangles in efficient NumPy arrays. + + Parameters + ---------- + indices + The indices of the vertices of the triangles. This is a 2D array where each row is a triangle + with the three indices of the vertices. + vertices + The vertices of the triangles. + """ + + self._indices = indices + self._vertices = vertices + + @classmethod + def for_limits_and_scale( + cls, + y_min: float, + y_max: float, + x_min: float, + x_max: float, + scale: float, + **kwargs, + ) -> "AbstractTriangles": + height = scale * HEIGHT_FACTOR + + vertices = [] + indices = [] + vertex_dict = {} + + def add_vertex(v): + if v not in vertex_dict: + vertex_dict[v] = len(vertices) + vertices.append(v) + return vertex_dict[v] + + rows = [] + for row_y in np.arange(y_min, y_max + height, height): + row = [] + offset = (len(rows) % 2) * scale / 2 + for col_x in np.arange(x_min - offset, x_max + scale, scale): + row.append((row_y, col_x)) + rows.append(row) + + for i in range(len(rows) - 1): + row = rows[i] + next_row = rows[i + 1] + for j in range(len(row)): + if i % 2 == 0 and j < len(next_row) - 1: + t1 = [ + add_vertex(row[j]), + add_vertex(next_row[j]), + add_vertex(next_row[j + 1]), + ] + if j < len(row) - 1: + t2 = [ + add_vertex(row[j]), + add_vertex(row[j + 1]), + add_vertex(next_row[j + 1]), + ] + indices.append(t2) + elif i % 2 == 1 and j < len(next_row) - 1: + t1 = [ + add_vertex(row[j]), + add_vertex(next_row[j]), + add_vertex(row[j + 1]), + ] + indices.append(t1) + if j < len(next_row) - 1: + t2 = [ + add_vertex(next_row[j]), + add_vertex(next_row[j + 1]), + add_vertex(row[j + 1]), + ] + indices.append(t2) + else: + continue + indices.append(t1) + + vertices = np.array(vertices) + indices = np.array(indices) + + return ArrayTrianglesNp( + indices=indices, + vertices=vertices, + **kwargs, + ) + + @property + def triangles(self): + return self.vertices[self.indices] + + @property + def means(self): + return np.mean(self.triangles, axis=1) + + def containing_indices(self, shape: Shape) -> np.ndarray: + """ + Find the triangles that insect with a given shape. + + Parameters + ---------- + shape + The shape + + Returns + ------- + The triangles that intersect the shape. + """ + inside = shape.mask(self.triangles) + + return np.where(inside)[0] + + def for_indexes(self, indexes: np.ndarray) -> "ArrayTrianglesNp": + """ + Create a new ArrayTrianglesNp containing indices and vertices corresponding to the given indexes + but without duplicate vertices. + + Parameters + ---------- + indexes + The indexes of the triangles to include in the new ArrayTrianglesNp. + + Returns + ------- + The new ArrayTrianglesNp instance. + """ + selected_indices = self.indices[indexes] + + flat_indices = selected_indices.flatten() + unique_vertices, inverse_indices = np.unique( + self.vertices[flat_indices], axis=0, return_inverse=True + ) + + new_indices = inverse_indices.reshape(selected_indices.shape) + + return ArrayTrianglesNp(indices=new_indices, vertices=unique_vertices) + + def up_sample(self) -> "ArrayTrianglesNp": + """ + Up-sample the triangles by adding a new vertex at the midpoint of each edge. + + This means each triangle becomes four smaller triangles. + """ + unique_vertices, inverse_indices = np.unique( + self._up_sample_triangle().reshape(-1, 2), axis=0, return_inverse=True + ) + new_indices = inverse_indices.reshape(-1, 3) + + return ArrayTrianglesNp( + indices=new_indices, + vertices=unique_vertices, + ) + + def neighborhood(self) -> "ArrayTrianglesNp": + """ + Create a new set of triangles that are the neighborhood of the current triangles. + + Includes the current triangles and the triangles that share an edge with the current triangles. + """ + unique_vertices, inverse_indices = np.unique( + self._neighborhood_triangles().reshape(-1, 2), + axis=0, + return_inverse=True, + ) + new_indices = inverse_indices.reshape(-1, 3) + + new_indices_sorted = np.sort(new_indices, axis=1) + + unique_triangles_indices, unique_index_positions = np.unique( + new_indices_sorted, axis=0, return_index=True + ) + + return ArrayTrianglesNp( + indices=unique_triangles_indices, + vertices=unique_vertices, + ) + + def with_vertices(self, vertices: np.ndarray) -> "ArrayTrianglesNp": + """ + Create a new set of triangles with the vertices replaced. + + Parameters + ---------- + vertices + The new vertices to use. + + Returns + ------- + The new set of triangles with the new vertices. + """ + bbbb + return ArrayTrianglesNp( + indices=self.indices, + vertices=vertices, + ) + + def _up_sample_triangle(self): + triangles = self.triangles + + m01 = (triangles[:, 0] + triangles[:, 1]) / 2 + m12 = (triangles[:, 1] + triangles[:, 2]) / 2 + m20 = (triangles[:, 2] + triangles[:, 0]) / 2 + + return np.concatenate( + [ + np.stack([triangles[:, 1], m12, m01], axis=1), + np.stack([triangles[:, 2], m20, m12], axis=1), + np.stack([m01, m12, m20], axis=1), + np.stack([triangles[:, 0], m01, m20], axis=1), + ], + axis=0, + ) + + def _neighborhood_triangles(self): + triangles = self.triangles + + new_v0 = triangles[:, 1] + triangles[:, 2] - triangles[:, 0] + new_v1 = triangles[:, 0] + triangles[:, 2] - triangles[:, 1] + new_v2 = triangles[:, 0] + triangles[:, 1] - triangles[:, 2] + + return np.concatenate( + [ + np.stack([new_v0, triangles[:, 1], triangles[:, 2]], axis=1), + np.stack([triangles[:, 0], new_v1, triangles[:, 2]], axis=1), + np.stack([triangles[:, 0], triangles[:, 1], new_v2], axis=1), + triangles, + ], + axis=0, + ) + + def __iter__(self): + return iter(self.triangles) diff --git a/autoarray/structures/triangles/coordinate_array.py b/autoarray/structures/triangles/coordinate_array.py index 64c3717e2..2d674c58e 100644 --- a/autoarray/structures/triangles/coordinate_array.py +++ b/autoarray/structures/triangles/coordinate_array.py @@ -2,14 +2,11 @@ import numpy as np -from jax._src.tree_util import register_pytree_node_class - from autoarray.structures.triangles.abstract import HEIGHT_FACTOR from autoarray.structures.triangles.abstract import AbstractTriangles from autoarray.structures.triangles.array import ArrayTriangles -@register_pytree_node_class class CoordinateArrayTriangles(AbstractTriangles, ABC): def __init__( diff --git a/autoarray/structures/triangles/coordinate_array_np.py b/autoarray/structures/triangles/coordinate_array_np.py new file mode 100644 index 000000000..f37c66f17 --- /dev/null +++ b/autoarray/structures/triangles/coordinate_array_np.py @@ -0,0 +1,281 @@ +import numpy as np + +from autoarray.structures.triangles.abstract import HEIGHT_FACTOR +from autoarray.structures.triangles.abstract import AbstractTriangles +from autoarray.structures.triangles.array_np import ArrayTrianglesNp +from autoarray.structures.triangles.shape import Shape +from autoconf import cached_property + + +class CoordinateArrayTrianglesNp(AbstractTriangles): + + def __init__( + self, + coordinates: np.ndarray, + side_length: float = 1.0, + x_offset: float = 0.0, + y_offset: float = 0.0, + flipped: bool = False, + ): + """ + Represents a set of triangles by integer coordinates. + + Parameters + ---------- + coordinates + Integer x y coordinates for each triangle. + side_length + The side length of the triangles. + flipped + Whether the triangles are flipped upside down. + y_offset + An y_offset to apply to the y coordinates so that up-sampled triangles align. + """ + self.coordinates = coordinates + self.side_length = side_length + self.flipped = flipped + + self.scaling_factors = np.array( + [0.5 * side_length, HEIGHT_FACTOR * side_length] + ) + self.x_offset = x_offset + self.y_offset = y_offset + + @property + def vertices(self) -> np.ndarray: + """ + The unique vertices of the triangles. + """ + return self._vertices_and_indices[0] + + @property + def indices(self) -> np.ndarray: + """ + The indices of the vertices of the triangles. + """ + return self._vertices_and_indices[1] + + @property + def centres(self) -> np.ndarray: + """ + The centres of the triangles. + """ + return self.scaling_factors * self.coordinates + np.array( + [self.x_offset, self.y_offset] + ) + + @cached_property + def flip_mask(self) -> np.ndarray: + """ + A mask for the triangles that are flipped. + + Every other triangle is flipped so that they tessellate. + """ + mask = (self.coordinates[:, 0] + self.coordinates[:, 1]) % 2 != 0 + if self.flipped: + mask = ~mask + return mask + + @cached_property + def flip_array(self) -> np.ndarray: + """ + An array of 1s and -1s to flip the triangles. + """ + array = np.ones( + self.coordinates.shape[0], + dtype=np.int32, + ) + array[self.flip_mask] = -1 + + return array[:, np.newaxis] + + @classmethod + def for_limits_and_scale( + cls, + x_min: float, + x_max: float, + y_min: float, + y_max: float, + scale: float = 1.0, + **_, + ): + x_shift = int(2 * x_min / scale) + y_shift = int(y_min / (HEIGHT_FACTOR * scale)) + + coordinates = [] + + for x in range(x_shift, int(2 * x_max / scale) + 1): + for y in range(y_shift - 1, int(y_max / (HEIGHT_FACTOR * scale)) + 2): + coordinates.append([x, y]) + + return CoordinateArrayTrianglesNp( + coordinates=np.array(coordinates, dtype=np.int32), + side_length=scale, + ) + + @cached_property + def triangles(self) -> np.ndarray: + """ + The vertices of the triangles as an Nx3x2 array. + """ + centres = self.centres + return np.stack( + ( + centres + + self.flip_array + * np.array( + [0.0, 0.5 * self.side_length * HEIGHT_FACTOR], + ), + centres + + self.flip_array + * np.array( + [0.5 * self.side_length, -0.5 * self.side_length * HEIGHT_FACTOR] + ), + centres + + self.flip_array + * np.array( + [-0.5 * self.side_length, -0.5 * self.side_length * HEIGHT_FACTOR] + ), + ), + axis=1, + ) + + def up_sample(self) -> "CoordinateArrayTrianglesNp": + """ + Up-sample the triangles by adding a new vertex at the midpoint of each edge. + """ + new_coordinates = np.zeros( + (4 * self.coordinates.shape[0], 2), + dtype=np.int32, + ) + n_normal = 4 * np.sum(~self.flip_mask) + + new_coordinates[:n_normal] = np.vstack( + ( + 2 * self.coordinates[~self.flip_mask], + 2 * self.coordinates[~self.flip_mask] + np.array([1, 0]), + 2 * self.coordinates[~self.flip_mask] + np.array([-1, 0]), + 2 * self.coordinates[~self.flip_mask] + np.array([0, 1]), + ) + ) + new_coordinates[n_normal:] = np.vstack( + ( + 2 * self.coordinates[self.flip_mask], + 2 * self.coordinates[self.flip_mask] + np.array([1, 1]), + 2 * self.coordinates[self.flip_mask] + np.array([-1, 1]), + 2 * self.coordinates[self.flip_mask] + np.array([0, 1]), + ) + ) + + return CoordinateArrayTrianglesNp( + coordinates=new_coordinates, + side_length=self.side_length / 2, + y_offset=self.y_offset + -0.25 * HEIGHT_FACTOR * self.side_length, + x_offset=self.x_offset, + flipped=True, + ) + + def neighborhood(self) -> "CoordinateArrayTrianglesNp": + """ + Create a new set of triangles that are the neighborhood of the current triangles. + + Ensures that the new triangles are unique. + """ + new_coordinates = np.zeros( + (4 * self.coordinates.shape[0], 2), + dtype=np.int32, + ) + n_normal = 4 * np.sum(~self.flip_mask) + + new_coordinates[:n_normal] = np.vstack( + ( + self.coordinates[~self.flip_mask], + self.coordinates[~self.flip_mask] + np.array([1, 0]), + self.coordinates[~self.flip_mask] + np.array([-1, 0]), + self.coordinates[~self.flip_mask] + np.array([0, -1]), + ) + ) + new_coordinates[n_normal:] = np.vstack( + ( + self.coordinates[self.flip_mask], + self.coordinates[self.flip_mask] + np.array([1, 0]), + self.coordinates[self.flip_mask] + np.array([-1, 0]), + self.coordinates[self.flip_mask] + np.array([0, 1]), + ) + ) + return CoordinateArrayTrianglesNp( + coordinates=np.unique(new_coordinates, axis=0), + side_length=self.side_length, + y_offset=self.y_offset, + x_offset=self.x_offset, + flipped=self.flipped, + ) + + @cached_property + def _vertices_and_indices(self): + flat_triangles = self.triangles.reshape(-1, 2) + vertices, inverse_indices = np.unique( + flat_triangles, + axis=0, + return_inverse=True, + ) + indices = inverse_indices.reshape(-1, 3) + return vertices, indices + + def with_vertices(self, vertices: np.ndarray) -> ArrayTrianglesNp: + """ + Create a new set of triangles with the vertices replaced. + + Parameters + ---------- + vertices + The new vertices to use. + + Returns + ------- + The new set of triangles with the new vertices. + """ + return ArrayTrianglesNp( + indices=self.indices, + vertices=vertices, + ) + + def for_indexes(self, indexes: np.ndarray) -> "CoordinateArrayTrianglesNp": + """ + Create a new CoordinateArrayTrianglesNp containing triangles corresponding to the given indexes + + Parameters + ---------- + indexes + The indexes of the triangles to include in the new CoordinateArrayTrianglesNp. + + Returns + ------- + The new CoordinateArrayTrianglesNp instance. + """ + return CoordinateArrayTrianglesNp( + coordinates=self.coordinates[indexes], + side_length=self.side_length, + y_offset=self.y_offset, + x_offset=self.x_offset, + flipped=self.flipped, + ) + + def containing_indices(self, shape: Shape) -> np.ndarray: + """ + Find the triangles that insect with a given shape. + + Parameters + ---------- + shape + The shape + + Returns + ------- + The indices of triangles that intersect the shape. + """ + return self.with_vertices(self.vertices).containing_indices(shape) + + @property + def means(self): + return np.mean(self.triangles, axis=1) diff --git a/autoarray/structures/triangles/shape.py b/autoarray/structures/triangles/shape.py index 71370dae1..3da2fe94a 100644 --- a/autoarray/structures/triangles/shape.py +++ b/autoarray/structures/triangles/shape.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from jax._src.tree_util import register_pytree_node_class from typing import List, Tuple import numpy as np @@ -34,7 +33,6 @@ def mask(self, triangles: np.ndarray) -> np.ndarray: """ -@register_pytree_node_class class Point(Shape): def __init__(self, x: float, y: float): """ @@ -107,7 +105,6 @@ def centroid(triangles: np.ndarray): return (x1 + x2 + x3) / 3, (y1 + y2 + y3) / 3 -@register_pytree_node_class class Circle(Point): def __init__( self, diff --git a/test_autoarray/structures/triangles/conftest.py b/test_autoarray/structures/triangles/conftest.py index 9677ce6c7..aa9339c99 100644 --- a/test_autoarray/structures/triangles/conftest.py +++ b/test_autoarray/structures/triangles/conftest.py @@ -1,12 +1,10 @@ -from autoconf.jax_wrapper import np -from autoarray.structures.triangles.array import ArrayTriangles -from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles - +import jax.numpy as jnp from matplotlib import pyplot as plt - - import pytest +from autoarray.structures.triangles.array import ArrayTriangles +from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles + @pytest.fixture def plot(): @@ -14,8 +12,8 @@ def plot(): def plot(triangles, color="black"): for triangle in triangles: - triangle = np.array(triangle) - triangle = np.append(triangle, np.array([triangle[0]]), axis=0) + triangle = jnp.array(triangle) + triangle = jnp.append(triangle, jnp.array([triangle[0]]), axis=0) plt.plot(triangle[:, 0], triangle[:, 1], "o-", color=color) yield plot @@ -26,13 +24,13 @@ def plot(triangles, color="black"): @pytest.fixture def compare_with_nans(): def compare_with_nans_(arr1, arr2): - nan_mask1 = np.isnan(arr1) - nan_mask2 = np.isnan(arr2) + nan_mask1 = jnp.isnan(arr1) + nan_mask2 = jnp.isnan(arr2) arr1 = arr1[~nan_mask1] arr2 = arr2[~nan_mask2] - return np.all(arr1 == arr2) + return jnp.all(arr1 == arr2) return compare_with_nans_ @@ -40,13 +38,13 @@ def compare_with_nans_(arr1, arr2): @pytest.fixture def triangles(): return ArrayTriangles( - indices=np.array( + indices=jnp.array( [ [0, 1, 2], [1, 2, 3], ] ), - vertices=np.array( + vertices=jnp.array( [ [0.0, 0.0], [1.0, 0.0], @@ -60,7 +58,7 @@ def triangles(): @pytest.fixture def one_triangle(): return CoordinateArrayTriangles( - coordinates=np.array([[0, 0]]), + coordinates=jnp.array([[0, 0]]), side_length=1.0, ) @@ -68,6 +66,6 @@ def one_triangle(): @pytest.fixture def two_triangles(): return CoordinateArrayTriangles( - coordinates=np.array([[0, 0], [1, 0]]), + coordinates=jnp.array([[0, 0], [1, 0]]), side_length=1.0, ) diff --git a/test_autoarray/structures/triangles/test_coordinate.py b/test_autoarray/structures/triangles/test_coordinate.py index bfd677874..52a95f863 100644 --- a/test_autoarray/structures/triangles/test_coordinate.py +++ b/test_autoarray/structures/triangles/test_coordinate.py @@ -1,14 +1,16 @@ import numpy as np +from jax.tree_util import register_pytree_node_class import pytest from autoarray.structures.triangles.abstract import HEIGHT_FACTOR -from autoarray.structures.triangles.shape import Point from autoarray.structures.triangles.coordinate_array import ( CoordinateArrayTriangles, ) +CoordinateArrayTriangles = register_pytree_node_class(CoordinateArrayTriangles) + def test__two(two_triangles): diff --git a/test_autoarray/structures/triangles/test_jax.py b/test_autoarray/structures/triangles/test_jax.py index 63e1b1293..e62e8d295 100644 --- a/test_autoarray/structures/triangles/test_jax.py +++ b/test_autoarray/structures/triangles/test_jax.py @@ -1,25 +1,26 @@ -from jax import numpy as np import jax - -jax.config.update("jax_log_compiles", True) +import jax.numpy as jnp +from jax.tree_util import register_pytree_node_class import pytest - from autoarray.structures.triangles.shape import Point from autoarray.structures.triangles.array import ArrayTriangles +ArrayTriangles = register_pytree_node_class(ArrayTriangles) +Point = register_pytree_node_class(Point) + @pytest.fixture def triangles(): return ArrayTriangles( - indices=np.array( + indices=jnp.array( [ [0, 1, 2], [1, 2, 3], ] ), - vertices=np.array( + vertices=jnp.array( [ [0.0, 0.0], [1.0, 0.0], @@ -36,29 +37,29 @@ def triangles(): [ ( Point(0.1, 0.1), - np.array( + jnp.array( [ [0.0, 0.0], [0.0, 1.0], [1.0, 0.0], ] ), - np.array([0, -1, -1, -1, -1]), + jnp.array([0, -1, -1, -1, -1]), ), ( Point(0.6, 0.6), - np.array( + jnp.array( [ [0.0, 1.0], [1.0, 0.0], [1.0, 1.0], ] ), - np.array([1, -1, -1, -1, -1]), + jnp.array([1, -1, -1, -1, -1]), ), ( Point(0.5, 0.5), - np.array( + jnp.array( [ [0.0, 0.0], [0.0, 1.0], @@ -66,7 +67,7 @@ def triangles(): [1.0, 1.0], ] ), - np.array([0, 1, -1, -1, -1]), + jnp.array([0, 1, -1, -1, -1]), ), ], ) @@ -85,38 +86,38 @@ def test_contains_vertices( "indexes, vertices, indices", [ ( - np.array([0]), - np.array( + jnp.array([0]), + jnp.array( [ [0.0, 0.0], [0.0, 1.0], [1.0, 0.0], ] ), - np.array( + jnp.array( [ [0, 1, 2], ] ), ), ( - np.array([1]), - np.array( + jnp.array([1]), + jnp.array( [ [0.0, 1.0], [1.0, 0.0], [1.0, 1.0], ] ), - np.array( + jnp.array( [ [0, 1, 2], ] ), ), ( - np.array([0, 1]), - np.array( + jnp.array([0, 1]), + jnp.array( [ [0.0, 0.0], [0.0, 1.0], @@ -124,7 +125,7 @@ def test_contains_vertices( [1.0, 1.0], ], ), - np.array( + jnp.array( [ [0, 1, 2], [1, 2, 3], @@ -153,13 +154,13 @@ def test_negative_index( triangles, compare_with_nans, ): - indexes = np.array([0, -1]) + indexes = jnp.array([0, -1]) containing = jax.jit(triangles.for_indexes)(indexes) assert ( containing.indices - == np.array( + == jnp.array( [ [-1, -1, -1], [0, 1, 2], @@ -168,7 +169,7 @@ def test_negative_index( ).all() assert compare_with_nans( containing.vertices, - np.array( + jnp.array( [ [0.0, 0.0], [0.0, 1.0], @@ -186,7 +187,7 @@ def test_up_sample( assert compare_with_nans( up_sampled.vertices, - np.array( + jnp.array( [ [0.0, 0.0], [0.0, 0.5], @@ -203,7 +204,7 @@ def test_up_sample( assert ( up_sampled.indices - == np.array( + == jnp.array( [ [0, 1, 3], [1, 2, 4], @@ -224,12 +225,12 @@ def test_up_sample( ) def test_simple_neighborhood(offset, compare_with_nans): triangles = ArrayTriangles( - indices=np.array( + indices=jnp.array( [ [0, 1, 2], ] ), - vertices=np.array( + vertices=jnp.array( [ [0.0, 0.0], [1.0, 0.0], @@ -242,7 +243,7 @@ def test_simple_neighborhood(offset, compare_with_nans): assert compare_with_nans( jax.jit(triangles.neighborhood)().triangles, ( - np.array( + jnp.array( [ [[-1.0, 1.0], [0.0, 0.0], [0.0, 1.0]], [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]], @@ -260,7 +261,7 @@ def test_neighborhood(triangles, compare_with_nans): assert compare_with_nans( neighborhood.vertices, - np.array( + jnp.array( [ [-1.0, 1.0], [0.0, 0.0], @@ -276,7 +277,7 @@ def test_neighborhood(triangles, compare_with_nans): assert ( neighborhood.indices - == np.array( + == jnp.array( [ [0, 1, 2], [1, 2, 5], @@ -294,7 +295,7 @@ def test_neighborhood(triangles, compare_with_nans): def test_means(triangles): means = triangles.means assert means == pytest.approx( - np.array( + jnp.array( [ [0.33333333, 0.33333333], [0.66666667, 0.66666667], diff --git a/test_autoarray/structures/triangles/test_np.py b/test_autoarray/structures/triangles/test_np.py new file mode 100644 index 000000000..fac76f350 --- /dev/null +++ b/test_autoarray/structures/triangles/test_np.py @@ -0,0 +1,111 @@ +import numpy as np +import pytest + +from autoarray.structures.triangles.shape import Point +from autoarray.structures.triangles.array_np import ArrayTrianglesNp + + +@pytest.fixture +def triangles(): + return ArrayTrianglesNp( + indices=np.array( + [ + [0, 1, 2], + [1, 2, 3], + ] + ), + vertices=np.array( + [ + [0.0, 0.0], + [1.0, 0.0], + [0.0, 1.0], + [1.0, 1.0], + ] + ), + max_containing_size=5, + ) + + + + +@pytest.mark.parametrize( + "offset", + [-1, 0, 1], +) +def test_simple_neighborhood(offset, compare_with_nans): + triangles = ArrayTrianglesNp( + indices=np.array( + [ + [0, 1, 2], + ] + ), + vertices=np.array( + [ + [0.0, 0.0], + [1.0, 0.0], + [0.0, 1.0], + ] + ) + + offset, + ) + + assert compare_with_nans( + triangles.neighborhood().triangles, + ( + np.array( + [ + [[-1.0, 1.0], [0.0, 0.0], [0.0, 1.0]], + [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]], + [[0.0, 0.0], [1.0, -1.0], [1.0, 0.0]], + [[0.0, 1.0], [1.0, 0.0], [1.0, 1.0]], + ] + ) + + offset + ), + ) + + +def test_neighborhood(triangles, compare_with_nans): + neighborhood = triangles.neighborhood() + + assert compare_with_nans( + neighborhood.vertices, + np.array( + [ + [-1.0, 1.0], + [0.0, 0.0], + [0.0, 1.0], + [0.0, 2.0], + [1.0, -1.0], + [1.0, 0.0], + [1.0, 1.0], + [2.0, 0.0], + ] + ), + ) + + assert ( + neighborhood.indices + == np.array( + [ + [0, 1, 2], + [1, 2, 5], + [1, 4, 5], + [2, 3, 6], + [2, 5, 6], + [5, 6, 7], + ] + ) + ).all() + + +def test_means(triangles): + means = triangles.means + assert means == pytest.approx( + np.array( + [ + [0.33333333, 0.33333333], + [0.66666667, 0.66666667], + ] + ) + )