From 09784fa26167049fbb0f520446fc16cc537a7687 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 17 Jun 2025 18:18:20 +0100 Subject: [PATCH 01/13] removed numpy based Trianghle and fixed first test --- .../structures/triangles/array/__init__.py | 7 +- autoarray/structures/triangles/array/array.py | 123 ------------ .../triangles/coordinate_array/__init__.py | 9 +- .../coordinate_array/coordinate_array.py | 188 ------------------ .../structures/triangles/conftest.py | 2 +- .../triangles/coordinate/conftest.py | 2 +- .../test_coordinate_implementation.py | 10 +- .../structures/triangles/test_area.py | 2 +- .../triangles/test_array_representation.py | 2 +- .../triangles/test_extended_source.py | 2 +- .../structures/triangles/test_jax.py | 2 +- .../triangles/test_nan_triangles.py | 2 +- 12 files changed, 13 insertions(+), 338 deletions(-) delete mode 100644 autoarray/structures/triangles/array/array.py delete mode 100644 autoarray/structures/triangles/coordinate_array/coordinate_array.py diff --git a/autoarray/structures/triangles/array/__init__.py b/autoarray/structures/triangles/array/__init__.py index 0fade4b81..e1cbd9336 100644 --- a/autoarray/structures/triangles/array/__init__.py +++ b/autoarray/structures/triangles/array/__init__.py @@ -1,6 +1 @@ -from .array import ArrayTriangles - -try: - from .jax_array import ArrayTriangles as JAXArrayTriangles -except ImportError: - pass +from .jax_array import ArrayTriangles as JAXArrayTriangles \ No newline at end of file diff --git a/autoarray/structures/triangles/array/array.py b/autoarray/structures/triangles/array/array.py deleted file mode 100644 index 06bb5dc89..000000000 --- a/autoarray/structures/triangles/array/array.py +++ /dev/null @@ -1,123 +0,0 @@ -from abc import ABC - -import numpy as np - -from autoarray.structures.triangles.array.abstract_array import AbstractArrayTriangles -from autoarray.structures.triangles.shape import Shape - - -class ArrayTriangles(AbstractArrayTriangles, ABC): - @property - def triangles(self): - return self.vertices[self.indices] - - @property - def numpy(self): - return np - - @property - def means(self): - return np.mean(self.triangles, axis=1) - - def containing_indices(self, shape: Shape) -> np.ndarray: - """ - Find the triangles that insect with a given shape. - - Parameters - ---------- - shape - The shape - - Returns - ------- - The triangles that intersect the shape. - """ - inside = shape.mask(self.triangles) - - return np.where(inside)[0] - - def for_indexes(self, indexes: np.ndarray) -> "ArrayTriangles": - """ - Create a new ArrayTriangles containing indices and vertices corresponding to the given indexes - but without duplicate vertices. - - Parameters - ---------- - indexes - The indexes of the triangles to include in the new ArrayTriangles. - - Returns - ------- - The new ArrayTriangles instance. - """ - selected_indices = self.indices[indexes] - - flat_indices = selected_indices.flatten() - unique_vertices, inverse_indices = np.unique( - self.vertices[flat_indices], axis=0, return_inverse=True - ) - - new_indices = inverse_indices.reshape(selected_indices.shape) - - return ArrayTriangles(indices=new_indices, vertices=unique_vertices) - - def up_sample(self) -> "ArrayTriangles": - """ - Up-sample the triangles by adding a new vertex at the midpoint of each edge. - - This means each triangle becomes four smaller triangles. - """ - unique_vertices, inverse_indices = np.unique( - self._up_sample_triangle().reshape(-1, 2), axis=0, return_inverse=True - ) - new_indices = inverse_indices.reshape(-1, 3) - - return ArrayTriangles( - indices=new_indices, - vertices=unique_vertices, - ) - - def neighborhood(self) -> "ArrayTriangles": - """ - Create a new set of triangles that are the neighborhood of the current triangles. - - Includes the current triangles and the triangles that share an edge with the current triangles. - """ - unique_vertices, inverse_indices = np.unique( - self._neighborhood_triangles().reshape(-1, 2), - axis=0, - return_inverse=True, - ) - new_indices = inverse_indices.reshape(-1, 3) - - new_indices_sorted = np.sort(new_indices, axis=1) - - unique_triangles_indices, unique_index_positions = np.unique( - new_indices_sorted, axis=0, return_index=True - ) - - return ArrayTriangles( - indices=unique_triangles_indices, - vertices=unique_vertices, - ) - - def with_vertices(self, vertices: np.ndarray) -> "ArrayTriangles": - """ - Create a new set of triangles with the vertices replaced. - - Parameters - ---------- - vertices - The new vertices to use. - - Returns - ------- - The new set of triangles with the new vertices. - """ - return ArrayTriangles( - indices=self.indices, - vertices=vertices, - ) - - def __iter__(self): - return iter(self.triangles) diff --git a/autoarray/structures/triangles/coordinate_array/__init__.py b/autoarray/structures/triangles/coordinate_array/__init__.py index f70bc8a9a..b4c84484c 100644 --- a/autoarray/structures/triangles/coordinate_array/__init__.py +++ b/autoarray/structures/triangles/coordinate_array/__init__.py @@ -1,8 +1 @@ -from .coordinate_array import CoordinateArrayTriangles - -try: - from .jax_coordinate_array import ( - CoordinateArrayTriangles as JAXCoordinateArrayTriangles, - ) -except ImportError: - pass +from .jax_coordinate_array import CoordinateArrayTriangles as JAXCoordinateArrayTriangles \ No newline at end of file diff --git a/autoarray/structures/triangles/coordinate_array/coordinate_array.py b/autoarray/structures/triangles/coordinate_array/coordinate_array.py deleted file mode 100644 index 997c8ab7f..000000000 --- a/autoarray/structures/triangles/coordinate_array/coordinate_array.py +++ /dev/null @@ -1,188 +0,0 @@ -import numpy as np - -from autoarray.structures.triangles.abstract import HEIGHT_FACTOR -from autoarray.structures.triangles.coordinate_array.abstract_coordinate_array import ( - AbstractCoordinateArray, -) -from autoarray.structures.triangles.array import ArrayTriangles -from autoarray.structures.triangles.shape import Shape -from autoconf import cached_property - - -class CoordinateArrayTriangles(AbstractCoordinateArray): - @cached_property - def flip_array(self) -> np.ndarray: - """ - An array of 1s and -1s to flip the triangles. - """ - array = np.ones( - self.coordinates.shape[0], - dtype=np.int32, - ) - array[self.flip_mask] = -1 - - return array[:, np.newaxis] - - @property - def numpy(self): - return np - - @classmethod - def for_limits_and_scale( - cls, - x_min: float, - x_max: float, - y_min: float, - y_max: float, - scale: float = 1.0, - **_, - ): - x_shift = int(2 * x_min / scale) - y_shift = int(y_min / (HEIGHT_FACTOR * scale)) - - coordinates = [] - - for x in range(x_shift, int(2 * x_max / scale) + 1): - for y in range(y_shift - 1, int(y_max / (HEIGHT_FACTOR * scale)) + 2): - coordinates.append([x, y]) - - return cls( - coordinates=np.array(coordinates, dtype=np.int32), - side_length=scale, - ) - - def up_sample(self) -> "CoordinateArrayTriangles": - """ - Up-sample the triangles by adding a new vertex at the midpoint of each edge. - """ - new_coordinates = np.zeros( - (4 * self.coordinates.shape[0], 2), - dtype=np.int32, - ) - n_normal = 4 * np.sum(~self.flip_mask) - - new_coordinates[:n_normal] = np.vstack( - ( - 2 * self.coordinates[~self.flip_mask], - 2 * self.coordinates[~self.flip_mask] + np.array([1, 0]), - 2 * self.coordinates[~self.flip_mask] + np.array([-1, 0]), - 2 * self.coordinates[~self.flip_mask] + np.array([0, 1]), - ) - ) - new_coordinates[n_normal:] = np.vstack( - ( - 2 * self.coordinates[self.flip_mask], - 2 * self.coordinates[self.flip_mask] + np.array([1, 1]), - 2 * self.coordinates[self.flip_mask] + np.array([-1, 1]), - 2 * self.coordinates[self.flip_mask] + np.array([0, 1]), - ) - ) - - return CoordinateArrayTriangles( - coordinates=new_coordinates, - side_length=self.side_length / 2, - y_offset=self.y_offset + -0.25 * HEIGHT_FACTOR * self.side_length, - x_offset=self.x_offset, - flipped=True, - ) - - def neighborhood(self) -> "CoordinateArrayTriangles": - """ - Create a new set of triangles that are the neighborhood of the current triangles. - - Ensures that the new triangles are unique. - """ - new_coordinates = np.zeros( - (4 * self.coordinates.shape[0], 2), - dtype=np.int32, - ) - n_normal = 4 * np.sum(~self.flip_mask) - - new_coordinates[:n_normal] = np.vstack( - ( - self.coordinates[~self.flip_mask], - self.coordinates[~self.flip_mask] + np.array([1, 0]), - self.coordinates[~self.flip_mask] + np.array([-1, 0]), - self.coordinates[~self.flip_mask] + np.array([0, -1]), - ) - ) - new_coordinates[n_normal:] = np.vstack( - ( - self.coordinates[self.flip_mask], - self.coordinates[self.flip_mask] + np.array([1, 0]), - self.coordinates[self.flip_mask] + np.array([-1, 0]), - self.coordinates[self.flip_mask] + np.array([0, 1]), - ) - ) - return CoordinateArrayTriangles( - coordinates=np.unique(new_coordinates, axis=0), - side_length=self.side_length, - y_offset=self.y_offset, - x_offset=self.x_offset, - flipped=self.flipped, - ) - - @cached_property - def _vertices_and_indices(self): - flat_triangles = self.triangles.reshape(-1, 2) - vertices, inverse_indices = np.unique( - flat_triangles, - axis=0, - return_inverse=True, - ) - indices = inverse_indices.reshape(-1, 3) - return vertices, indices - - def with_vertices(self, vertices: np.ndarray) -> ArrayTriangles: - """ - Create a new set of triangles with the vertices replaced. - - Parameters - ---------- - vertices - The new vertices to use. - - Returns - ------- - The new set of triangles with the new vertices. - """ - return ArrayTriangles( - indices=self.indices, - vertices=vertices, - ) - - def for_indexes(self, indexes: np.ndarray) -> "CoordinateArrayTriangles": - """ - Create a new CoordinateArrayTriangles containing triangles corresponding to the given indexes - - Parameters - ---------- - indexes - The indexes of the triangles to include in the new CoordinateArrayTriangles. - - Returns - ------- - The new CoordinateArrayTriangles instance. - """ - return CoordinateArrayTriangles( - coordinates=self.coordinates[indexes], - side_length=self.side_length, - y_offset=self.y_offset, - x_offset=self.x_offset, - flipped=self.flipped, - ) - - def containing_indices(self, shape: Shape) -> np.ndarray: - """ - Find the triangles that insect with a given shape. - - Parameters - ---------- - shape - The shape - - Returns - ------- - The indices of triangles that intersect the shape. - """ - return self.with_vertices(self.vertices).containing_indices(shape) diff --git a/test_autoarray/structures/triangles/conftest.py b/test_autoarray/structures/triangles/conftest.py index a8d8580a3..bf35f5643 100644 --- a/test_autoarray/structures/triangles/conftest.py +++ b/test_autoarray/structures/triangles/conftest.py @@ -1,5 +1,5 @@ from autoarray.numpy_wrapper import np -from autoarray.structures.triangles.array import ArrayTriangles +from autoarray.structures.triangles.array import JAXArrayTriangles as ArrayTriangles from matplotlib import pyplot as plt diff --git a/test_autoarray/structures/triangles/coordinate/conftest.py b/test_autoarray/structures/triangles/coordinate/conftest.py index 302b565f7..0d53f32a2 100644 --- a/test_autoarray/structures/triangles/coordinate/conftest.py +++ b/test_autoarray/structures/triangles/coordinate/conftest.py @@ -2,7 +2,7 @@ import numpy as np -from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles +from autoarray.structures.triangles.coordinate_array import JAXCoordinateArrayTriangles as CoordinateArrayTriangles @pytest.fixture diff --git a/test_autoarray/structures/triangles/coordinate/test_coordinate_implementation.py b/test_autoarray/structures/triangles/coordinate/test_coordinate_implementation.py index 545d16da0..1599ce0e9 100644 --- a/test_autoarray/structures/triangles/coordinate/test_coordinate_implementation.py +++ b/test_autoarray/structures/triangles/coordinate/test_coordinate_implementation.py @@ -3,15 +3,14 @@ import numpy as np from autoarray.structures.triangles.abstract import HEIGHT_FACTOR -from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles +from autoarray.structures.triangles.coordinate_array import JAXCoordinateArrayTriangles as CoordinateArrayTriangles from autoarray.structures.triangles.shape import Point def test_two(two_triangles): + assert np.all(two_triangles.centres == np.array([[0, 0], [0.5, 0]])) - assert np.all( - two_triangles.triangles - == [ + assert two_triangles.triangles == pytest.approx(np.array([ [ [0.0, HEIGHT_FACTOR / 2], [0.5, -HEIGHT_FACTOR / 2], @@ -22,8 +21,7 @@ def test_two(two_triangles): [0.0, HEIGHT_FACTOR / 2], [1.0, HEIGHT_FACTOR / 2], ], - ] - ) + ]), 1.0e-4) def test_trivial_triangles(one_triangle): diff --git a/test_autoarray/structures/triangles/test_area.py b/test_autoarray/structures/triangles/test_area.py index c7a1b6ccc..95a19da4a 100644 --- a/test_autoarray/structures/triangles/test_area.py +++ b/test_autoarray/structures/triangles/test_area.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from autoarray.structures.triangles.array import ArrayTriangles +from autoarray.structures.triangles.array import JAXArrayTriangles as ArrayTriangles from autoarray.structures.triangles.shape import Triangle, Circle, Square, Polygon diff --git a/test_autoarray/structures/triangles/test_array_representation.py b/test_autoarray/structures/triangles/test_array_representation.py index 832c0793f..2496cf747 100644 --- a/test_autoarray/structures/triangles/test_array_representation.py +++ b/test_autoarray/structures/triangles/test_array_representation.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from autoarray.structures.triangles.array import ArrayTriangles +from autoarray.structures.triangles.array import JAXArrayTriangles as ArrayTriangles from autoarray.structures.triangles.shape import Point diff --git a/test_autoarray/structures/triangles/test_extended_source.py b/test_autoarray/structures/triangles/test_extended_source.py index 4491bd834..e4b50adcc 100644 --- a/test_autoarray/structures/triangles/test_extended_source.py +++ b/test_autoarray/structures/triangles/test_extended_source.py @@ -1,7 +1,7 @@ import pytest import numpy as np -from autoarray.structures.triangles.array import ArrayTriangles +from autoarray.structures.triangles.array import JAXArrayTriangles as ArrayTriangles from autoarray.structures.triangles.shape import Circle diff --git a/test_autoarray/structures/triangles/test_jax.py b/test_autoarray/structures/triangles/test_jax.py index def239849..f613d0fe4 100644 --- a/test_autoarray/structures/triangles/test_jax.py +++ b/test_autoarray/structures/triangles/test_jax.py @@ -8,7 +8,7 @@ from autoarray.structures.triangles.array.jax_array import ArrayTriangles except ImportError: import numpy as np - from autoarray.structures.triangles.array import ArrayTriangles + from autoarray.structures.triangles.array import JAXArrayTriangles as ArrayTriangles import pytest diff --git a/test_autoarray/structures/triangles/test_nan_triangles.py b/test_autoarray/structures/triangles/test_nan_triangles.py index 725cf5257..2e14080c4 100644 --- a/test_autoarray/structures/triangles/test_nan_triangles.py +++ b/test_autoarray/structures/triangles/test_nan_triangles.py @@ -5,7 +5,7 @@ from autoarray.structures.triangles.array.jax_array import ArrayTriangles except ImportError: import numpy as np - from autoarray.structures.triangles.array import ArrayTriangles + from autoarray.structures.triangles.array import JAXArrayTriangles as ArrayTriangles pytest.importorskip("jax") From 3f24384fff2026e750deaa85b61ed118cc0bc5e8 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 17 Jun 2025 18:19:50 +0100 Subject: [PATCH 02/13] fix one_triangle.triangles --- .../coordinate/test_coordinate_implementation.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/test_autoarray/structures/triangles/coordinate/test_coordinate_implementation.py b/test_autoarray/structures/triangles/coordinate/test_coordinate_implementation.py index 1599ce0e9..12b5b1ddf 100644 --- a/test_autoarray/structures/triangles/coordinate/test_coordinate_implementation.py +++ b/test_autoarray/structures/triangles/coordinate/test_coordinate_implementation.py @@ -7,7 +7,7 @@ from autoarray.structures.triangles.shape import Point -def test_two(two_triangles): +def test__two(two_triangles): assert np.all(two_triangles.centres == np.array([[0, 0], [0.5, 0]])) assert two_triangles.triangles == pytest.approx(np.array([ @@ -24,19 +24,17 @@ def test_two(two_triangles): ]), 1.0e-4) -def test_trivial_triangles(one_triangle): +def test__trivial_triangles(one_triangle): assert one_triangle.flip_array == np.array([1]) assert np.all(one_triangle.centres == np.array([[0, 0]])) - assert np.all( - one_triangle.triangles - == [ + assert one_triangle.triangles == pytest.approx(np.array([ [ [0.0, HEIGHT_FACTOR / 2], [0.5, -HEIGHT_FACTOR / 2], [-0.5, -HEIGHT_FACTOR / 2], ], ] - ) + ), 1.0e-4) def test_above(): From 2db16d94299c9862b783720a6211bc4a331391f0 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 17 Jun 2025 18:20:38 +0100 Subject: [PATCH 03/13] test__above --- .../coordinate/test_coordinate_implementation.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/test_autoarray/structures/triangles/coordinate/test_coordinate_implementation.py b/test_autoarray/structures/triangles/coordinate/test_coordinate_implementation.py index 12b5b1ddf..405cb7ef9 100644 --- a/test_autoarray/structures/triangles/coordinate/test_coordinate_implementation.py +++ b/test_autoarray/structures/triangles/coordinate/test_coordinate_implementation.py @@ -37,14 +37,12 @@ def test__trivial_triangles(one_triangle): ), 1.0e-4) -def test_above(): +def test__above(): triangles = CoordinateArrayTriangles( coordinates=np.array([[0, 1]]), side_length=1.0, ) - assert np.all( - triangles.up_sample().triangles - == [ + assert triangles.up_sample().triangles == pytest.approx(np.array([ [ [0.0, 0.43301270189221935], [-0.25, 0.8660254037844386], @@ -66,7 +64,7 @@ def test_above(): [-0.25, 0.8660254037844388], ], ] - ) + ), 1.0e-4) @pytest.fixture From 81d1c7044455906f0ec692324bf586f9fa9de39a Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 17 Jun 2025 18:24:46 +0100 Subject: [PATCH 04/13] multiple fixes in test_coordinate_implementation.py --- .../test_coordinate_implementation.py | 52 +++++++------------ 1 file changed, 18 insertions(+), 34 deletions(-) diff --git a/test_autoarray/structures/triangles/coordinate/test_coordinate_implementation.py b/test_autoarray/structures/triangles/coordinate/test_coordinate_implementation.py index 405cb7ef9..098a9c4c1 100644 --- a/test_autoarray/structures/triangles/coordinate/test_coordinate_implementation.py +++ b/test_autoarray/structures/triangles/coordinate/test_coordinate_implementation.py @@ -77,44 +77,38 @@ def upside_down(): def test_upside_down(upside_down): assert np.all(upside_down.centres == np.array([[0.5, 0]])) - assert np.all( - upside_down.triangles - == [ + assert upside_down.triangles == pytest.approx(np.array([ [ [0.5, -HEIGHT_FACTOR / 2], [0.0, HEIGHT_FACTOR / 2], [1.0, HEIGHT_FACTOR / 2], ], ] - ) + ), 1.0e-4) def test_up_sample(one_triangle): up_sampled = one_triangle.up_sample() assert up_sampled.side_length == 0.5 - assert np.all( - up_sampled.triangles - == [ + assert up_sampled.triangles == pytest.approx(np.array([ [[0.0, -0.4330127018922193], [-0.25, 0.0], [0.25, 0.0]], [[0.25, 0.0], [0.5, -0.4330127018922193], [0.0, -0.4330127018922193]], [[-0.25, 0.0], [0.0, -0.4330127018922193], [-0.5, -0.4330127018922193]], [[0.0, 0.4330127018922193], [0.25, 0.0], [-0.25, 0.0]], ] - ) + ), 1.0e-4) def test_up_sample_upside_down(upside_down): up_sampled = upside_down.up_sample() assert up_sampled.side_length == 0.5 - assert np.all( - up_sampled.triangles - == [ + assert up_sampled.triangles == pytest.approx(np.array([ [[0.5, -0.4330127018922193], [0.25, 0.0], [0.75, 0.0]], [[0.75, 0.0], [0.5, 0.4330127018922193], [1.0, 0.4330127018922193]], [[0.25, 0.0], [0.0, 0.4330127018922193], [0.5, 0.4330127018922193]], [[0.5, 0.4330127018922193], [0.75, 0.0], [0.25, 0.0]], ] - ) + ), 1.0e-4) def _test_up_sample_twice(one_triangle, plot): @@ -128,9 +122,7 @@ def _test_up_sample_twice(one_triangle, plot): def test_neighborhood(one_triangle): - assert np.all( - one_triangle.neighborhood().triangles - == [ + assert one_triangle.neighborhood().triangles == pytest.approx(np.array([ [ [-0.5, -0.4330127018922193], [-1.0, 0.4330127018922193], @@ -152,13 +144,11 @@ def test_neighborhood(one_triangle): [1.0, 0.4330127018922193], ], ] - ) + ), 1.0e-4) def test_upside_down_neighborhood(upside_down): - assert np.all( - upside_down.neighborhood().triangles - == [ + assert upside_down.neighborhood().triangles == pytest.approx(np.array([ [ [0.0, 0.4330127018922193], [0.5, -0.4330127018922193], @@ -180,7 +170,7 @@ def test_upside_down_neighborhood(upside_down): [0.5, -0.4330127018922193], ], ] - ) + ), 1.0e-4) def _test_complicated(plot, one_triangle): @@ -189,20 +179,16 @@ def _test_complicated(plot, one_triangle): def test_vertices(one_triangle): - assert np.all( - one_triangle.vertices - == [ + assert one_triangle.vertices == pytest.approx(np.array([ [-0.5, -0.4330127018922193], [0.0, 0.4330127018922193], [0.5, -0.4330127018922193], ] - ) + ), 1.0e-4) def test_up_sampled_vertices(one_triangle): - assert np.all( - one_triangle.up_sample().vertices - == [ + assert one_triangle.up_sample().vertices == pytest.approx(np.array([ [-0.5, -0.4330127018922193], [-0.25, 0.0], [0.0, -0.4330127018922193], @@ -210,12 +196,12 @@ def test_up_sampled_vertices(one_triangle): [0.25, 0.0], [0.5, -0.4330127018922193], ] - ) + ), 1.0e-4) def test_with_vertices(one_triangle): triangle = one_triangle.with_vertices(np.array([[0, 0], [1, 0], [0.5, 1]])) - assert np.all(triangle.triangles == [[[1.0, 0.0], [0.5, 1.0], [0.0, 0.0]]]) + assert triangle.triangles == pytest.approx(np.array([[[1.0, 0.0], [0.5, 1.0], [0.0, 0.0]]]), 1.0e-4) def _test_multiple_with_vertices(one_triangle, plot): @@ -224,20 +210,18 @@ def _test_multiple_with_vertices(one_triangle, plot): def test_for_indexes(two_triangles): - assert np.all( - two_triangles.for_indexes(np.array([0])).triangles - == [ + assert two_triangles.for_indexes(np.array([0])).triangles == pytest.approx(np.array([ [ [0.0, 0.4330127018922193], [0.5, -0.4330127018922193], [-0.5, -0.4330127018922193], ] ] - ) + ), 1.0e-4) def test_means(one_triangle): - assert np.all(one_triangle.means == [[0.0, -0.14433756729740643]]) + assert one_triangle.means == pytest.approx(np.array([[0.0, -0.14433756729740643]]), 1.0e-4) @pytest.mark.parametrize( From 6c35ed2f2686e0e66c2d35438111975183ff1bb0 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 17 Jun 2025 18:29:05 +0100 Subject: [PATCH 05/13] remove containment test --- .../coordinate/test_coordinate_implementation.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/test_autoarray/structures/triangles/coordinate/test_coordinate_implementation.py b/test_autoarray/structures/triangles/coordinate/test_coordinate_implementation.py index 098a9c4c1..13ed6d28c 100644 --- a/test_autoarray/structures/triangles/coordinate/test_coordinate_implementation.py +++ b/test_autoarray/structures/triangles/coordinate/test_coordinate_implementation.py @@ -188,7 +188,7 @@ def test_vertices(one_triangle): def test_up_sampled_vertices(one_triangle): - assert one_triangle.up_sample().vertices == pytest.approx(np.array([ + assert one_triangle.up_sample().vertices[0:6, :] == pytest.approx(np.array([ [-0.5, -0.4330127018922193], [-0.25, 0.0], [0.0, -0.4330127018922193], @@ -224,19 +224,6 @@ def test_means(one_triangle): assert one_triangle.means == pytest.approx(np.array([[0.0, -0.14433756729740643]]), 1.0e-4) -@pytest.mark.parametrize( - "x, y", - [ - (0.0, 0.0), - (-0.5, -HEIGHT_FACTOR / 2), - (0.5, -HEIGHT_FACTOR / 2), - (0.0, HEIGHT_FACTOR / 2), - ], -) -def test_containment(one_triangle, x, y): - assert one_triangle.containing_indices(Point(x, y)) == [0] - - def test_triangles_touch(): triangles = CoordinateArrayTriangles( np.array([[0, 0], [2, 0]]), From 075e8e700c696bdf6c5ac9ff00b9259b3afc7c4d Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 17 Jun 2025 18:35:34 +0100 Subject: [PATCH 06/13] test_array_representation deleted --- .../triangles/test_array_representation.py | 215 ------------------ 1 file changed, 215 deletions(-) delete mode 100644 test_autoarray/structures/triangles/test_array_representation.py diff --git a/test_autoarray/structures/triangles/test_array_representation.py b/test_autoarray/structures/triangles/test_array_representation.py deleted file mode 100644 index 2496cf747..000000000 --- a/test_autoarray/structures/triangles/test_array_representation.py +++ /dev/null @@ -1,215 +0,0 @@ -import numpy as np -import pytest - -from autoarray.structures.triangles.array import JAXArrayTriangles as ArrayTriangles -from autoarray.structures.triangles.shape import Point - - -@pytest.mark.parametrize( - "point, indices", - [ - ( - Point(0.1, 0.1), - np.array([0]), - ), - ( - Point(0.6, 0.6), - np.array([1]), - ), - ( - Point(0.5, 0.5), - np.array([0, 1]), - ), - ], -) -def test_contains_vertices( - triangles, - point, - indices, -): - containing_indices = triangles.containing_indices(point) - - assert (containing_indices == indices).all() - - -@pytest.mark.parametrize( - "indexes, vertices, indices", - [ - ( - np.array([0]), - np.array( - [ - [0.0, 0.0], - [0.0, 1.0], - [1.0, 0.0], - ] - ), - np.array( - [ - [0, 2, 1], - ] - ), - ), - ( - np.array([1]), - np.array( - [ - [0.0, 1.0], - [1.0, 0.0], - [1.0, 1.0], - ] - ), - np.array( - [ - [1, 0, 2], - ] - ), - ), - ( - np.array([0, 1]), - np.array( - [ - [0.0, 0.0], - [0.0, 1.0], - [1.0, 0.0], - [1.0, 1.0], - ] - ), - np.array( - [ - [0, 2, 1], - [2, 1, 3], - ] - ), - ), - ], -) -def test_for_indexes( - triangles, - indexes, - vertices, - indices, -): - containing = triangles.for_indexes(indexes) - - assert (containing.indices == indices).all() - assert (containing.vertices == vertices).all() - - -def test_up_sample(triangles): - up_sampled = triangles.up_sample() - - assert ( - up_sampled.vertices - == np.array( - [ - [0.0, 0.0], - [0.0, 0.5], - [0.0, 1.0], - [0.5, 0.0], - [0.5, 0.5], - [0.5, 1.0], - [1.0, 0.0], - [1.0, 0.5], - [1.0, 1.0], - ] - ) - ).all() - - assert ( - up_sampled.indices - == np.array( - [ - [6, 4, 3], - [2, 5, 4], - [2, 1, 4], - [8, 7, 5], - [3, 4, 1], - [4, 5, 7], - [0, 3, 1], - [6, 4, 7], - ] - ) - ).all() - - -@pytest.mark.parametrize( - "offset", - [-1, 0, 1], -) -def test_simple_neighborhood(offset): - triangles = ArrayTriangles( - indices=np.array( - [ - [0, 1, 2], - ] - ), - vertices=np.array( - [ - [0.0, 0.0], - [1.0, 0.0], - [0.0, 1.0], - ] - ) - + offset, - ) - assert ( - triangles.neighborhood().triangles - == ( - np.array( - [ - [[-1.0, 1.0], [0.0, 0.0], [0.0, 1.0]], - [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]], - [[0.0, 0.0], [1.0, -1.0], [1.0, 0.0]], - [[0.0, 1.0], [1.0, 0.0], [1.0, 1.0]], - ] - ) - + offset - ) - ).all() - - -def test_neighborhood(triangles): - neighborhood = triangles.neighborhood() - - assert ( - neighborhood.vertices - == np.array( - [ - [-1.0, 1.0], - [0.0, 0.0], - [0.0, 1.0], - [0.0, 2.0], - [1.0, -1.0], - [1.0, 0.0], - [1.0, 1.0], - [2.0, 0.0], - ] - ) - ).all() - - assert ( - neighborhood.indices - == np.array( - [ - [0, 1, 2], - [1, 2, 5], - [1, 4, 5], - [2, 3, 6], - [2, 5, 6], - [5, 6, 7], - ] - ) - ).all() - - -def test_means(triangles): - means = triangles.means - assert means == pytest.approx( - np.array( - [ - [0.33333333, 0.33333333], - [0.66666667, 0.66666667], - ] - ) - ) From 3933b5534ef949aa86c473a211803b5720a060dc Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 17 Jun 2025 18:37:54 +0100 Subject: [PATCH 07/13] fix test_extended_source --- test_autoarray/structures/triangles/test_extended_source.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test_autoarray/structures/triangles/test_extended_source.py b/test_autoarray/structures/triangles/test_extended_source.py index e4b50adcc..0cdaca2b3 100644 --- a/test_autoarray/structures/triangles/test_extended_source.py +++ b/test_autoarray/structures/triangles/test_extended_source.py @@ -49,7 +49,7 @@ def test_small_point(triangles, point, indices): radius=0.001, ) ) - assert containing_triangles.tolist() == indices + assert [i for i in containing_triangles.tolist() if i != -1] == indices @pytest.mark.parametrize( @@ -72,4 +72,4 @@ def test_large_circle( radius=radius, ) ) - assert containing_triangles.tolist() == indices + assert [i for i in containing_triangles.tolist() if i != -1] == indices From 20ceacaf66cac10aa5fa683e5d38d0386ace50fa Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 17 Jun 2025 18:43:49 +0100 Subject: [PATCH 08/13] JAx CoordinateArrayTriangles has explicit JAX use now --- autoarray/structures/triangles/abstract.py | 15 ---- autoarray/structures/triangles/array.py | 0 .../coordinate_array/jax_coordinate_array.py | 73 ++++++++----------- 3 files changed, 32 insertions(+), 56 deletions(-) create mode 100644 autoarray/structures/triangles/array.py diff --git a/autoarray/structures/triangles/abstract.py b/autoarray/structures/triangles/abstract.py index 880eea2f7..fe2b754fb 100644 --- a/autoarray/structures/triangles/abstract.py +++ b/autoarray/structures/triangles/abstract.py @@ -122,21 +122,6 @@ def for_indexes(self, indexes: np.ndarray) -> "AbstractTriangles": The new ArrayTriangles instance. """ - @abstractmethod - def containing_indices(self, shape: Shape) -> np.ndarray: - """ - Find the triangles that insect with a given shape. - - Parameters - ---------- - shape - The shape - - Returns - ------- - The indices of triangles that intersect the shape. - """ - @abstractmethod def neighborhood(self) -> "AbstractTriangles": """ diff --git a/autoarray/structures/triangles/array.py b/autoarray/structures/triangles/array.py new file mode 100644 index 000000000..e69de29bb diff --git a/autoarray/structures/triangles/coordinate_array/jax_coordinate_array.py b/autoarray/structures/triangles/coordinate_array/jax_coordinate_array.py index e80facd1f..31661dc47 100644 --- a/autoarray/structures/triangles/coordinate_array/jax_coordinate_array.py +++ b/autoarray/structures/triangles/coordinate_array/jax_coordinate_array.py @@ -1,10 +1,7 @@ -from jax import numpy as np +import jax.numpy as jnp import jax from autoarray.structures.triangles.abstract import HEIGHT_FACTOR -from autoarray.structures.triangles.coordinate_array.abstract_coordinate_array import ( - AbstractCoordinateArray, -) from autoarray.structures.triangles.array.jax_array import ArrayTriangles from autoarray.numpy_wrapper import register_pytree_node_class from autoconf import cached_property @@ -13,10 +10,7 @@ @register_pytree_node_class -class CoordinateArrayTriangles(AbstractCoordinateArray): - @property - def numpy(self): - return jax.numpy +class CoordinateArrayTriangles: @classmethod def for_limits_and_scale( @@ -38,7 +32,7 @@ def for_limits_and_scale( coordinates.append([x, y]) return cls( - coordinates=np.array(coordinates), + coordinates=jnp.array(coordinates), side_length=scale, ) @@ -70,17 +64,17 @@ def tree_unflatten(cls, aux_data, children): return cls(*children, flipped=aux_data[0]) @property - def centres(self) -> np.ndarray: + def centres(self) -> jnp.ndarray: """ The centres of the triangles. """ - centres = self.scaling_factors * self.coordinates + np.array( + centres = self.scaling_factors * self.coordinates + jnp.array( [self.x_offset, self.y_offset] ) return centres @cached_property - def flip_mask(self) -> np.ndarray: + def flip_mask(self) -> jnp.ndarray: """ A mask for the triangles that are flipped. @@ -92,11 +86,11 @@ def flip_mask(self) -> np.ndarray: return mask @cached_property - def flip_array(self) -> np.ndarray: + def flip_array(self) -> jnp.ndarray: """ An array of 1s and -1s to flip the triangles. """ - array = np.where(self.flip_mask, -1, 1) + array = jnp.where(self.flip_mask, -1, 1) return array[:, None] def __iter__(self): @@ -113,11 +107,11 @@ def up_sample(self) -> "CoordinateArrayTriangles": n = coordinates.shape[0] - shift0 = np.zeros((n, 2)) - shift3 = np.tile(np.array([0, 1]), (n, 1)) - shift1 = np.stack([np.ones(n), np.where(flip_mask, 1, 0)], axis=1) - shift2 = np.stack([-np.ones(n), np.where(flip_mask, 1, 0)], axis=1) - shifts = np.stack([shift0, shift1, shift2, shift3], axis=1) + shift0 = jnp.zeros((n, 2)) + shift3 = jnp.tile(jnp.array([0, 1]), (n, 1)) + shift1 = jnp.stack([jnp.ones(n), jnp.where(flip_mask, 1, 0)], axis=1) + shift2 = jnp.stack([-jnp.ones(n), jnp.where(flip_mask, 1, 0)], axis=1) + shifts = jnp.stack([shift0, shift1, shift2, shift3], axis=1) coordinates_expanded = coordinates[:, None, :] new_coordinates = coordinates_expanded + shifts @@ -140,27 +134,27 @@ def neighborhood(self) -> "CoordinateArrayTriangles": coordinates = self.coordinates flip_mask = self.flip_mask - shift0 = np.zeros((coordinates.shape[0], 2)) - shift1 = np.tile(np.array([1, 0]), (coordinates.shape[0], 1)) - shift2 = np.tile(np.array([-1, 0]), (coordinates.shape[0], 1)) - shift3 = np.where( + shift0 = jnp.zeros((coordinates.shape[0], 2)) + shift1 = jnp.tile(jnp.array([1, 0]), (coordinates.shape[0], 1)) + shift2 = jnp.tile(jnp.array([-1, 0]), (coordinates.shape[0], 1)) + shift3 = jnp.where( flip_mask[:, None], - np.tile(np.array([0, 1]), (coordinates.shape[0], 1)), - np.tile(np.array([0, -1]), (coordinates.shape[0], 1)), + jnp.tile(jnp.array([0, 1]), (coordinates.shape[0], 1)), + jnp.tile(jnp.array([0, -1]), (coordinates.shape[0], 1)), ) - shifts = np.stack([shift0, shift1, shift2, shift3], axis=1) + shifts = jnp.stack([shift0, shift1, shift2, shift3], axis=1) coordinates_expanded = coordinates[:, None, :] new_coordinates = coordinates_expanded + shifts new_coordinates = new_coordinates.reshape(-1, 2) expected_size = 4 * coordinates.shape[0] - unique_coords, indices = np.unique( + unique_coords, indices = jnp.unique( new_coordinates, axis=0, size=expected_size, - fill_value=np.nan, + fill_value=jnp.nan, return_index=True, ) @@ -175,22 +169,22 @@ def neighborhood(self) -> "CoordinateArrayTriangles": @cached_property def _vertices_and_indices(self): flat_triangles = self.triangles.reshape(-1, 2) - vertices, inverse_indices = np.unique( + vertices, inverse_indices = jnp.unique( flat_triangles, axis=0, return_inverse=True, size=3 * self.coordinates.shape[0], equal_nan=True, - fill_value=np.nan, + fill_value=jnp.nan, ) - nan_mask = np.isnan(vertices).any(axis=1) - inverse_indices = np.where(nan_mask[inverse_indices], -1, inverse_indices) + nan_mask = jnp.isnan(vertices).any(axis=1) + inverse_indices = jnp.where(nan_mask[inverse_indices], -1, inverse_indices) indices = inverse_indices.reshape(-1, 3) return vertices, indices - def with_vertices(self, vertices: np.ndarray) -> ArrayTriangles: + def with_vertices(self, vertices: jnp.ndarray) -> ArrayTriangles: """ Create a new set of triangles with the vertices replaced. @@ -208,7 +202,7 @@ def with_vertices(self, vertices: np.ndarray) -> ArrayTriangles: vertices=vertices, ) - def for_indexes(self, indexes: np.ndarray) -> "CoordinateArrayTriangles": + def for_indexes(self, indexes: jnp.ndarray) -> "CoordinateArrayTriangles": """ Create a new CoordinateArrayTriangles containing triangles corresponding to the given indexes @@ -222,9 +216,9 @@ def for_indexes(self, indexes: np.ndarray) -> "CoordinateArrayTriangles": The new CoordinateArrayTriangles instance. """ mask = indexes == -1 - safe_indexes = np.where(mask, 0, indexes) - coordinates = np.take(self.coordinates, safe_indexes, axis=0) - coordinates = np.where(mask[:, None], np.nan, coordinates) + safe_indexes = jnp.where(mask, 0, indexes) + coordinates = jnp.take(self.coordinates, safe_indexes, axis=0) + coordinates = jnp.where(mask[:, None], jnp.nan, coordinates) return CoordinateArrayTriangles( coordinates=coordinates, @@ -232,7 +226,4 @@ def for_indexes(self, indexes: np.ndarray) -> "CoordinateArrayTriangles": y_offset=self.y_offset, x_offset=self.x_offset, flipped=self.flipped, - ) - - def containing_indices(self, shape: np.ndarray) -> np.ndarray: - raise NotImplementedError("JAX ArrayTriangles are used for this method.") + ) \ No newline at end of file From d83f4ba0aa2880c33cdd91452bd8a6504d20b1c8 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 17 Jun 2025 18:50:13 +0100 Subject: [PATCH 09/13] simplfiied coordinate_array and removed support for numpy --- ...oordinate_array.py => coordinate_array.py} | 108 +++++++++++- .../triangles/coordinate_array/__init__.py | 1 - .../abstract_coordinate_array.py | 162 ------------------ .../triangles/coordinate/conftest.py | 2 +- .../test_coordinate_implementation.py | 2 +- 5 files changed, 106 insertions(+), 169 deletions(-) rename autoarray/structures/triangles/{coordinate_array/jax_coordinate_array.py => coordinate_array.py} (70%) delete mode 100644 autoarray/structures/triangles/coordinate_array/__init__.py delete mode 100644 autoarray/structures/triangles/coordinate_array/abstract_coordinate_array.py diff --git a/autoarray/structures/triangles/coordinate_array/jax_coordinate_array.py b/autoarray/structures/triangles/coordinate_array.py similarity index 70% rename from autoarray/structures/triangles/coordinate_array/jax_coordinate_array.py rename to autoarray/structures/triangles/coordinate_array.py index 31661dc47..3827622af 100644 --- a/autoarray/structures/triangles/coordinate_array/jax_coordinate_array.py +++ b/autoarray/structures/triangles/coordinate_array.py @@ -1,3 +1,4 @@ +import numpy as np import jax.numpy as jnp import jax @@ -12,6 +13,38 @@ @register_pytree_node_class class CoordinateArrayTriangles: + def __init__( + self, + coordinates: np.ndarray, + side_length: float = 1.0, + x_offset: float = 0.0, + y_offset: float = 0.0, + flipped: bool = False, + ): + """ + Represents a set of triangles by integer coordinates. + + Parameters + ---------- + coordinates + Integer x y coordinates for each triangle. + side_length + The side length of the triangles. + flipped + Whether the triangles are flipped upside down. + y_offset + An y_offset to apply to the y coordinates so that up-sampled triangles align. + """ + self.coordinates = coordinates + self.side_length = side_length + self.flipped = flipped + + self.scaling_factors = jnp.array( + [0.5 * side_length, HEIGHT_FACTOR * side_length] + ) + self.x_offset = x_offset + self.y_offset = y_offset + @classmethod def for_limits_and_scale( cls, @@ -63,6 +96,12 @@ def tree_unflatten(cls, aux_data, children): """ return cls(*children, flipped=aux_data[0]) + def __len__(self): + return jnp.count_nonzero(~jnp.isnan(self.coordinates).any(axis=1)) + + def __iter__(self): + return iter(self.triangles) + @property def centres(self) -> jnp.ndarray: """ @@ -73,6 +112,48 @@ def centres(self) -> jnp.ndarray: ) return centres + @cached_property + def vertex_coordinates(self) -> np.ndarray: + """ + The vertices of the triangles as an Nx3x2 array. + """ + coordinates = self.coordinates + return jnp.concatenate( + [ + coordinates + self.flip_array * np.array([0, 1], dtype=np.int32), + coordinates + self.flip_array * np.array([1, -1], dtype=np.int32), + coordinates + self.flip_array * np.array([-1, -1], dtype=np.int32), + ], + dtype=np.int32, + ) + + @cached_property + def triangles(self) -> np.ndarray: + """ + The vertices of the triangles as an Nx3x2 array. + """ + centres = self.centres + return jnp.stack( + ( + centres + + self.flip_array + * jnp.array( + [0.0, 0.5 * self.side_length * HEIGHT_FACTOR], + ), + centres + + self.flip_array + * jnp.array( + [0.5 * self.side_length, -0.5 * self.side_length * HEIGHT_FACTOR] + ), + centres + + self.flip_array + * jnp.array( + [-0.5 * self.side_length, -0.5 * self.side_length * HEIGHT_FACTOR] + ), + ), + axis=1, + ) + @cached_property def flip_mask(self) -> jnp.ndarray: """ @@ -93,9 +174,6 @@ def flip_array(self) -> jnp.ndarray: array = jnp.where(self.flip_mask, -1, 1) return array[:, None] - def __iter__(self): - return iter(self.triangles) - def up_sample(self) -> "CoordinateArrayTriangles": """ Up-sample the triangles by adding a new vertex at the midpoint of each edge. @@ -226,4 +304,26 @@ def for_indexes(self, indexes: jnp.ndarray) -> "CoordinateArrayTriangles": y_offset=self.y_offset, x_offset=self.x_offset, flipped=self.flipped, - ) \ No newline at end of file + ) + + @property + def vertices(self) -> np.ndarray: + """ + The unique vertices of the triangles. + """ + return self._vertices_and_indices[0] + + @property + def indices(self) -> np.ndarray: + """ + The indices of the vertices of the triangles. + """ + return self._vertices_and_indices[1] + + @property + def means(self): + return jnp.mean(self.triangles, axis=1) + + @property + def area(self): + return (3**0.5 / 4 * self.side_length**2) * len(self) \ No newline at end of file diff --git a/autoarray/structures/triangles/coordinate_array/__init__.py b/autoarray/structures/triangles/coordinate_array/__init__.py deleted file mode 100644 index b4c84484c..000000000 --- a/autoarray/structures/triangles/coordinate_array/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .jax_coordinate_array import CoordinateArrayTriangles as JAXCoordinateArrayTriangles \ No newline at end of file diff --git a/autoarray/structures/triangles/coordinate_array/abstract_coordinate_array.py b/autoarray/structures/triangles/coordinate_array/abstract_coordinate_array.py deleted file mode 100644 index 5c0fe799c..000000000 --- a/autoarray/structures/triangles/coordinate_array/abstract_coordinate_array.py +++ /dev/null @@ -1,162 +0,0 @@ -from abc import abstractmethod, ABC - -import numpy as np - -from autoarray.structures.triangles.abstract import HEIGHT_FACTOR, AbstractTriangles -from autoconf import cached_property - - -class AbstractCoordinateArray(AbstractTriangles, ABC): - def __init__( - self, - coordinates: np.ndarray, - side_length: float = 1.0, - x_offset: float = 0.0, - y_offset: float = 0.0, - flipped: bool = False, - ): - """ - Represents a set of triangles by integer coordinates. - - Parameters - ---------- - coordinates - Integer x y coordinates for each triangle. - side_length - The side length of the triangles. - flipped - Whether the triangles are flipped upside down. - y_offset - An y_offset to apply to the y coordinates so that up-sampled triangles align. - """ - self.coordinates = coordinates - self.side_length = side_length - self.flipped = flipped - - self.scaling_factors = self.numpy.array( - [0.5 * side_length, HEIGHT_FACTOR * side_length] - ) - self.x_offset = x_offset - self.y_offset = y_offset - - @property - @abstractmethod - def numpy(self): - pass - - @cached_property - def vertex_coordinates(self) -> np.ndarray: - """ - The vertices of the triangles as an Nx3x2 array. - """ - coordinates = self.coordinates - return self.numpy.concatenate( - [ - coordinates + self.flip_array * np.array([0, 1], dtype=np.int32), - coordinates + self.flip_array * np.array([1, -1], dtype=np.int32), - coordinates + self.flip_array * np.array([-1, -1], dtype=np.int32), - ], - dtype=np.int32, - ) - - @cached_property - def triangles(self) -> np.ndarray: - """ - The vertices of the triangles as an Nx3x2 array. - """ - centres = self.centres - return self.numpy.stack( - ( - centres - + self.flip_array - * self.numpy.array( - [0.0, 0.5 * self.side_length * HEIGHT_FACTOR], - ), - centres - + self.flip_array - * self.numpy.array( - [0.5 * self.side_length, -0.5 * self.side_length * HEIGHT_FACTOR] - ), - centres - + self.flip_array - * self.numpy.array( - [-0.5 * self.side_length, -0.5 * self.side_length * HEIGHT_FACTOR] - ), - ), - axis=1, - ) - - @property - def centres(self) -> np.ndarray: - """ - The centres of the triangles. - """ - return self.scaling_factors * self.coordinates + self.numpy.array( - [self.x_offset, self.y_offset] - ) - - @cached_property - def flip_mask(self) -> np.ndarray: - """ - A mask for the triangles that are flipped. - - Every other triangle is flipped so that they tessellate. - """ - mask = (self.coordinates[:, 0] + self.coordinates[:, 1]) % 2 != 0 - if self.flipped: - mask = ~mask - return mask - - @cached_property - @abstractmethod - def flip_array(self) -> np.ndarray: - """ - An array of 1s and -1s to flip the triangles. - """ - - def __iter__(self): - return iter(self.triangles) - - @cached_property - @abstractmethod - def _vertices_and_indices(self): - pass - - @property - def vertices(self) -> np.ndarray: - """ - The unique vertices of the triangles. - """ - return self._vertices_and_indices[0] - - @property - def indices(self) -> np.ndarray: - """ - The indices of the vertices of the triangles. - """ - return self._vertices_and_indices[1] - - def with_vertices(self, vertices: np.ndarray) -> AbstractTriangles: - """ - Create a new set of triangles with the vertices replaced. - - Parameters - ---------- - vertices - The new vertices to use. - - Returns - ------- - The new set of triangles with the new vertices. - """ - - @property - def means(self): - return self.numpy.mean(self.triangles, axis=1) - - @property - def area(self): - return (3**0.5 / 4 * self.side_length**2) * len(self) - - def __len__(self): - return self.numpy.count_nonzero(~self.numpy.isnan(self.coordinates).any(axis=1)) diff --git a/test_autoarray/structures/triangles/coordinate/conftest.py b/test_autoarray/structures/triangles/coordinate/conftest.py index 0d53f32a2..302b565f7 100644 --- a/test_autoarray/structures/triangles/coordinate/conftest.py +++ b/test_autoarray/structures/triangles/coordinate/conftest.py @@ -2,7 +2,7 @@ import numpy as np -from autoarray.structures.triangles.coordinate_array import JAXCoordinateArrayTriangles as CoordinateArrayTriangles +from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles @pytest.fixture diff --git a/test_autoarray/structures/triangles/coordinate/test_coordinate_implementation.py b/test_autoarray/structures/triangles/coordinate/test_coordinate_implementation.py index 13ed6d28c..b0e91467e 100644 --- a/test_autoarray/structures/triangles/coordinate/test_coordinate_implementation.py +++ b/test_autoarray/structures/triangles/coordinate/test_coordinate_implementation.py @@ -3,7 +3,7 @@ import numpy as np from autoarray.structures.triangles.abstract import HEIGHT_FACTOR -from autoarray.structures.triangles.coordinate_array import JAXCoordinateArrayTriangles as CoordinateArrayTriangles +from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles from autoarray.structures.triangles.shape import Point From 1094154c5ef73a53140baace7f209c388b3ff222 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 17 Jun 2025 18:58:51 +0100 Subject: [PATCH 10/13] refactor arrayu to only use JAX --- autoarray/structures/triangles/array.py | 454 ++++++++++++++++++ .../structures/triangles/array/__init__.py | 1 - .../triangles/array/abstract_array.py | 211 -------- .../structures/triangles/array/jax_array.py | 289 ----------- .../structures/triangles/coordinate_array.py | 2 +- .../structures/triangles/conftest.py | 2 +- .../structures/triangles/test_area.py | 2 +- .../triangles/test_extended_source.py | 2 +- .../structures/triangles/test_jax.py | 4 +- .../triangles/test_nan_triangles.py | 4 +- 10 files changed, 462 insertions(+), 509 deletions(-) delete mode 100644 autoarray/structures/triangles/array/__init__.py delete mode 100644 autoarray/structures/triangles/array/abstract_array.py delete mode 100644 autoarray/structures/triangles/array/jax_array.py diff --git a/autoarray/structures/triangles/array.py b/autoarray/structures/triangles/array.py index e69de29bb..7e4526300 100644 --- a/autoarray/structures/triangles/array.py +++ b/autoarray/structures/triangles/array.py @@ -0,0 +1,454 @@ +import numpy as np +import jax.numpy as jnp +from jax.tree_util import register_pytree_node_class + +from autoarray.structures.triangles.abstract import HEIGHT_FACTOR + +from autoarray.structures.grids.uniform_2d import Grid2D +from autoarray.structures.triangles.abstract import AbstractTriangles +from autoarray.structures.triangles.shape import Shape + +MAX_CONTAINING_SIZE = 15 + + +@register_pytree_node_class +class ArrayTriangles: + def __init__( + self, + indices, + vertices, + max_containing_size=MAX_CONTAINING_SIZE, + **kwargs, + ): + """ + Represents a set of triangles in efficient NumPy arrays. + + Parameters + ---------- + indices + The indices of the vertices of the triangles. This is a 2D array where each row is a triangle + with the three indices of the vertices. + vertices + The vertices of the triangles. + """ + self._indices = indices + self._vertices = vertices + self.max_containing_size = max_containing_size + + def __len__(self): + return len(self.triangles) + + def __iter__(self): + return iter(self.triangles) + + def __str__(self): + return f"{self.__class__.__name__} with {len(self.indices)} triangles" + + def __repr__(self): + return str(self) + + @classmethod + def for_limits_and_scale( + cls, + y_min: float, + y_max: float, + x_min: float, + x_max: float, + scale: float, + max_containing_size=MAX_CONTAINING_SIZE, + ) -> "AbstractTriangles": + height = scale * HEIGHT_FACTOR + + vertices = [] + indices = [] + vertex_dict = {} + + def add_vertex(v): + if v not in vertex_dict: + vertex_dict[v] = len(vertices) + vertices.append(v) + return vertex_dict[v] + + rows = [] + for row_y in np.arange(y_min, y_max + height, height): + row = [] + offset = (len(rows) % 2) * scale / 2 + for col_x in np.arange(x_min - offset, x_max + scale, scale): + row.append((row_y, col_x)) + rows.append(row) + + for i in range(len(rows) - 1): + row = rows[i] + next_row = rows[i + 1] + for j in range(len(row)): + if i % 2 == 0 and j < len(next_row) - 1: + t1 = [ + add_vertex(row[j]), + add_vertex(next_row[j]), + add_vertex(next_row[j + 1]), + ] + if j < len(row) - 1: + t2 = [ + add_vertex(row[j]), + add_vertex(row[j + 1]), + add_vertex(next_row[j + 1]), + ] + indices.append(t2) + elif i % 2 == 1 and j < len(next_row) - 1: + t1 = [ + add_vertex(row[j]), + add_vertex(next_row[j]), + add_vertex(row[j + 1]), + ] + indices.append(t1) + if j < len(next_row) - 1: + t2 = [ + add_vertex(next_row[j]), + add_vertex(next_row[j + 1]), + add_vertex(row[j + 1]), + ] + indices.append(t2) + else: + continue + indices.append(t1) + + return cls( + indices=jnp.array(indices), + vertices=jnp.array(vertices), + max_containing_size=max_containing_size, + ) + + @classmethod + def for_grid( + cls, + grid: Grid2D, + **kwargs, + ) -> "AbstractTriangles": + """ + Create a grid of equilateral triangles from a regular grid. + + Parameters + ---------- + grid + The regular grid to convert to a grid of triangles. + + Returns + ------- + The grid of triangles. + """ + + scale = grid.pixel_scale + + y = grid[:, 0] + x = grid[:, 1] + + y_min = y.min() + y_max = y.max() + x_min = x.min() + x_max = x.max() + + return cls.for_limits_and_scale( + y_min, + y_max, + x_min, + x_max, + scale, + **kwargs, + ) + + @property + def indices(self): + return self._indices + + @property + def vertices(self): + return self._vertices + + @property + def triangles(self) -> jnp.ndarray: + """ + The triangles as a 3x2 array of vertices. + """ + + invalid_mask = jnp.any(self.indices == -1, axis=1) + nan_array = jnp.full( + (self.indices.shape[0], 3, 2), + jnp.nan, + dtype=jnp.float32, + ) + safe_indices = jnp.where(self.indices == -1, 0, self.indices) + triangle_vertices = self.vertices[safe_indices] + return jnp.where(invalid_mask[:, None, None], nan_array, triangle_vertices) + + @property + def means(self) -> jnp.ndarray: + """ + The mean of each triangle. + """ + return jnp.mean(self.triangles, axis=1) + + def containing_indices(self, shape: Shape) -> jnp.ndarray: + """ + Find the triangles that insect with a given shape. + + Parameters + ---------- + shape + The shape + + Returns + ------- + The triangles that intersect the shape. + """ + inside = shape.mask(self.triangles) + + return jnp.where( + inside, + size=self.max_containing_size, + fill_value=-1, + )[0] + + def for_indexes(self, indexes: jnp.ndarray) -> "ArrayTriangles": + """ + Create a new ArrayTriangles containing indices and vertices corresponding to the given indexes + but without duplicate vertices. + + Parameters + ---------- + indexes + The indexes of the triangles to include in the new ArrayTriangles. + + Returns + ------- + The new ArrayTriangles instance. + """ + selected_indices = select_and_handle_invalid( + data=self.indices, + indices=indexes, + invalid_value=-1, + invalid_replacement=jnp.array([-1, -1, -1], dtype=jnp.int32), + ) + + flat_indices = selected_indices.flatten() + + selected_vertices = select_and_handle_invalid( + data=self.vertices, + indices=flat_indices, + invalid_value=-1, + invalid_replacement=jnp.array([jnp.nan, jnp.nan], dtype=jnp.float32), + ) + + unique_vertices, inv_indices = jnp.unique( + selected_vertices, + axis=0, + return_inverse=True, + equal_nan=True, + size=selected_indices.shape[0] * 3, + fill_value=jnp.nan, + ) + + nan_mask = jnp.isnan(unique_vertices).any(axis=1) + inv_indices = jnp.where(nan_mask[inv_indices], -1, inv_indices) + + new_indices = inv_indices.reshape(selected_indices.shape) + + new_indices_sorted = jnp.sort(new_indices, axis=1) + + unique_triangles_indices = jnp.unique( + new_indices_sorted, + axis=0, + size=new_indices_sorted.shape[0], + fill_value=-1, + ) + + return ArrayTriangles( + indices=unique_triangles_indices, + vertices=unique_vertices, + max_containing_size=self.max_containing_size, + ) + + def _up_sample_triangle(self): + triangles = self.triangles + + m01 = (triangles[:, 0] + triangles[:, 1]) / 2 + m12 = (triangles[:, 1] + triangles[:, 2]) / 2 + m20 = (triangles[:, 2] + triangles[:, 0]) / 2 + + return jnp.concatenate( + [ + jnp.stack([triangles[:, 1], m12, m01], axis=1), + jnp.stack([triangles[:, 2], m20, m12], axis=1), + jnp.stack([m01, m12, m20], axis=1), + jnp.stack([triangles[:, 0], m01, m20], axis=1), + ], + axis=0, + ) + + def up_sample(self) -> "ArrayTriangles": + """ + Up-sample the triangles by adding a new vertex at the midpoint of each edge. + + This means each triangle becomes four smaller triangles. + """ + new_indices, unique_vertices = remove_duplicates(self._up_sample_triangle()) + + return ArrayTriangles( + indices=new_indices, + vertices=unique_vertices, + max_containing_size=self.max_containing_size, + ) + + def _neighborhood_triangles(self): + triangles = self.triangles + + new_v0 = triangles[:, 1] + triangles[:, 2] - triangles[:, 0] + new_v1 = triangles[:, 0] + triangles[:, 2] - triangles[:, 1] + new_v2 = triangles[:, 0] + triangles[:, 1] - triangles[:, 2] + + return jnp.concatenate( + [ + jnp.stack([new_v0, triangles[:, 1], triangles[:, 2]], axis=1), + jnp.stack([triangles[:, 0], new_v1, triangles[:, 2]], axis=1), + jnp.stack([triangles[:, 0], triangles[:, 1], new_v2], axis=1), + triangles, + ], + axis=0, + ) + + def neighborhood(self) -> "ArrayTriangles": + """ + Create a new set of triangles that are the neighborhood of the current triangles. + + Includes the current triangles and the triangles that share an edge with the current triangles. + """ + new_indices, unique_vertices = remove_duplicates(self._neighborhood_triangles()) + + return ArrayTriangles( + indices=new_indices, + vertices=unique_vertices, + max_containing_size=self.max_containing_size, + ) + + def with_vertices(self, vertices: jnp.ndarray) -> "ArrayTriangles": + """ + Create a new set of triangles with the vertices replaced. + + Parameters + ---------- + vertices + The new vertices to use. + + Returns + ------- + The new set of triangles with the new vertices. + """ + return ArrayTriangles( + indices=self.indices, + vertices=vertices, + max_containing_size=self.max_containing_size, + ) + + @property + def area(self) -> float: + """ + The total area covered by the triangles. + """ + triangles = self.triangles + return ( + 0.5 + * np.abs( + (triangles[:, 0, 0] * (triangles[:, 1, 1] - triangles[:, 2, 1])) + + (triangles[:, 1, 0] * (triangles[:, 2, 1] - triangles[:, 0, 1])) + + (triangles[:, 2, 0] * (triangles[:, 0, 1] - triangles[:, 1, 1])) + ).sum() + ) + + def tree_flatten(self): + """ + Flatten this model as a PyTree. + """ + return ( + self.indices, + self.vertices, + ), (self.max_containing_size,) + + @classmethod + def tree_unflatten(cls, aux_data, children): + """ + Unflatten a PyTree into a model. + """ + return cls( + indices=children[0], + vertices=children[1], + max_containing_size=aux_data[0], + ) + + + +def select_and_handle_invalid( + data: jnp.ndarray, + indices: jnp.ndarray, + invalid_value, + invalid_replacement, +): + """ + Select data based on indices, handling invalid indices by replacing them with a specified value. + + Parameters + ---------- + data + The array from which to select data. + indices + The indices used to select data from the array. + invalid_value + The value representing invalid indices. + invalid_replacement + The value to use for invalid entries in the result. + + Returns + ------- + An array with selected data, where invalid indices are replaced with `invalid_replacement`. + """ + invalid_mask = indices == invalid_value + safe_indices = jnp.where(invalid_mask, 0, indices) + selected_data = data[safe_indices] + selected_data = jnp.where( + invalid_mask[..., None], + invalid_replacement, + selected_data, + ) + + return selected_data + + +def remove_duplicates(new_triangles): + unique_vertices, inverse_indices = jnp.unique( + new_triangles.reshape(-1, 2), + axis=0, + return_inverse=True, + size=2 * new_triangles.shape[0], + fill_value=jnp.nan, + equal_nan=True, + ) + + inverse_indices_flat = inverse_indices.reshape(-1) + selected_vertices = unique_vertices[inverse_indices_flat] + mask = jnp.any(jnp.isnan(selected_vertices), axis=1) + inverse_indices_flat = jnp.where(mask, -1, inverse_indices_flat) + inverse_indices = inverse_indices_flat.reshape(inverse_indices.shape) + + new_indices = inverse_indices.reshape(-1, 3) + + new_indices_sorted = jnp.sort(new_indices, axis=1) + + unique_triangles_indices = jnp.unique( + new_indices_sorted, + axis=0, + size=new_indices_sorted.shape[0], + fill_value=jnp.array( + [-1, -1, -1], + dtype=jnp.int32, + ), + ) + + return unique_triangles_indices, unique_vertices diff --git a/autoarray/structures/triangles/array/__init__.py b/autoarray/structures/triangles/array/__init__.py deleted file mode 100644 index e1cbd9336..000000000 --- a/autoarray/structures/triangles/array/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .jax_array import ArrayTriangles as JAXArrayTriangles \ No newline at end of file diff --git a/autoarray/structures/triangles/array/abstract_array.py b/autoarray/structures/triangles/array/abstract_array.py deleted file mode 100644 index d0f8620ee..000000000 --- a/autoarray/structures/triangles/array/abstract_array.py +++ /dev/null @@ -1,211 +0,0 @@ -from abc import abstractmethod - -import numpy as np - -from autoarray import Grid2D, AbstractTriangles -from autoarray.structures.triangles.abstract import HEIGHT_FACTOR - - -class AbstractArrayTriangles(AbstractTriangles): - def __init__( - self, - indices, - vertices, - **kwargs, - ): - """ - Represents a set of triangles in efficient NumPy arrays. - - Parameters - ---------- - indices - The indices of the vertices of the triangles. This is a 2D array where each row is a triangle - with the three indices of the vertices. - vertices - The vertices of the triangles. - """ - self._indices = indices - self._vertices = vertices - - @property - def indices(self): - return self._indices - - @property - def vertices(self): - return self._vertices - - def __len__(self): - return len(self.triangles) - - @property - def area(self) -> float: - """ - The total area covered by the triangles. - """ - triangles = self.triangles - return ( - 0.5 - * np.abs( - (triangles[:, 0, 0] * (triangles[:, 1, 1] - triangles[:, 2, 1])) - + (triangles[:, 1, 0] * (triangles[:, 2, 1] - triangles[:, 0, 1])) - + (triangles[:, 2, 0] * (triangles[:, 0, 1] - triangles[:, 1, 1])) - ).sum() - ) - - @property - @abstractmethod - def numpy(self): - pass - - def _up_sample_triangle(self): - triangles = self.triangles - - m01 = (triangles[:, 0] + triangles[:, 1]) / 2 - m12 = (triangles[:, 1] + triangles[:, 2]) / 2 - m20 = (triangles[:, 2] + triangles[:, 0]) / 2 - - return self.numpy.concatenate( - [ - self.numpy.stack([triangles[:, 1], m12, m01], axis=1), - self.numpy.stack([triangles[:, 2], m20, m12], axis=1), - self.numpy.stack([m01, m12, m20], axis=1), - self.numpy.stack([triangles[:, 0], m01, m20], axis=1), - ], - axis=0, - ) - - def _neighborhood_triangles(self): - triangles = self.triangles - - new_v0 = triangles[:, 1] + triangles[:, 2] - triangles[:, 0] - new_v1 = triangles[:, 0] + triangles[:, 2] - triangles[:, 1] - new_v2 = triangles[:, 0] + triangles[:, 1] - triangles[:, 2] - - return self.numpy.concatenate( - [ - self.numpy.stack([new_v0, triangles[:, 1], triangles[:, 2]], axis=1), - self.numpy.stack([triangles[:, 0], new_v1, triangles[:, 2]], axis=1), - self.numpy.stack([triangles[:, 0], triangles[:, 1], new_v2], axis=1), - triangles, - ], - axis=0, - ) - - def __str__(self): - return f"{self.__class__.__name__} with {len(self.indices)} triangles" - - def __repr__(self): - return str(self) - - @classmethod - def for_limits_and_scale( - cls, - y_min: float, - y_max: float, - x_min: float, - x_max: float, - scale: float, - **kwargs, - ) -> "AbstractTriangles": - height = scale * HEIGHT_FACTOR - - vertices = [] - indices = [] - vertex_dict = {} - - def add_vertex(v): - if v not in vertex_dict: - vertex_dict[v] = len(vertices) - vertices.append(v) - return vertex_dict[v] - - rows = [] - for row_y in np.arange(y_min, y_max + height, height): - row = [] - offset = (len(rows) % 2) * scale / 2 - for col_x in np.arange(x_min - offset, x_max + scale, scale): - row.append((row_y, col_x)) - rows.append(row) - - for i in range(len(rows) - 1): - row = rows[i] - next_row = rows[i + 1] - for j in range(len(row)): - if i % 2 == 0 and j < len(next_row) - 1: - t1 = [ - add_vertex(row[j]), - add_vertex(next_row[j]), - add_vertex(next_row[j + 1]), - ] - if j < len(row) - 1: - t2 = [ - add_vertex(row[j]), - add_vertex(row[j + 1]), - add_vertex(next_row[j + 1]), - ] - indices.append(t2) - elif i % 2 == 1 and j < len(next_row) - 1: - t1 = [ - add_vertex(row[j]), - add_vertex(next_row[j]), - add_vertex(row[j + 1]), - ] - indices.append(t1) - if j < len(next_row) - 1: - t2 = [ - add_vertex(next_row[j]), - add_vertex(next_row[j + 1]), - add_vertex(row[j + 1]), - ] - indices.append(t2) - else: - continue - indices.append(t1) - - vertices = np.array(vertices) - indices = np.array(indices) - - return cls( - indices=indices, - vertices=vertices, - **kwargs, - ) - - @classmethod - def for_grid( - cls, - grid: Grid2D, - **kwargs, - ) -> "AbstractTriangles": - """ - Create a grid of equilateral triangles from a regular grid. - - Parameters - ---------- - grid - The regular grid to convert to a grid of triangles. - - Returns - ------- - The grid of triangles. - """ - - scale = grid.pixel_scale - - y = grid[:, 0] - x = grid[:, 1] - - y_min = y.min() - y_max = y.max() - x_min = x.min() - x_max = x.max() - - return cls.for_limits_and_scale( - y_min, - y_max, - x_min, - x_max, - scale, - **kwargs, - ) diff --git a/autoarray/structures/triangles/array/jax_array.py b/autoarray/structures/triangles/array/jax_array.py deleted file mode 100644 index 23b9ad3b5..000000000 --- a/autoarray/structures/triangles/array/jax_array.py +++ /dev/null @@ -1,289 +0,0 @@ -from jax import numpy as np -from jax.tree_util import register_pytree_node_class - -from autoarray.structures.triangles.abstract import AbstractTriangles -from autoarray.structures.triangles.array.abstract_array import AbstractArrayTriangles -from autoarray.structures.triangles.shape import Shape - -MAX_CONTAINING_SIZE = 15 - - -@register_pytree_node_class -class ArrayTriangles(AbstractArrayTriangles): - def __init__( - self, - indices, - vertices, - max_containing_size=MAX_CONTAINING_SIZE, - ): - super().__init__(indices, vertices) - self.max_containing_size = max_containing_size - - @property - def numpy(self): - return np - - @property - def triangles(self) -> np.ndarray: - """ - The triangles as a 3x2 array of vertices. - """ - - invalid_mask = np.any(self.indices == -1, axis=1) - nan_array = np.full( - (self.indices.shape[0], 3, 2), - np.nan, - dtype=np.float32, - ) - safe_indices = np.where(self.indices == -1, 0, self.indices) - triangle_vertices = self.vertices[safe_indices] - return np.where(invalid_mask[:, None, None], nan_array, triangle_vertices) - - @property - def means(self) -> np.ndarray: - """ - The mean of each triangle. - """ - return np.mean(self.triangles, axis=1) - - def containing_indices(self, shape: Shape) -> np.ndarray: - """ - Find the triangles that insect with a given shape. - - Parameters - ---------- - shape - The shape - - Returns - ------- - The triangles that intersect the shape. - """ - inside = shape.mask(self.triangles) - - return np.where( - inside, - size=self.max_containing_size, - fill_value=-1, - )[0] - - def for_indexes(self, indexes: np.ndarray) -> "ArrayTriangles": - """ - Create a new ArrayTriangles containing indices and vertices corresponding to the given indexes - but without duplicate vertices. - - Parameters - ---------- - indexes - The indexes of the triangles to include in the new ArrayTriangles. - - Returns - ------- - The new ArrayTriangles instance. - """ - selected_indices = select_and_handle_invalid( - data=self.indices, - indices=indexes, - invalid_value=-1, - invalid_replacement=np.array([-1, -1, -1], dtype=np.int32), - ) - - flat_indices = selected_indices.flatten() - - selected_vertices = select_and_handle_invalid( - data=self.vertices, - indices=flat_indices, - invalid_value=-1, - invalid_replacement=np.array([np.nan, np.nan], dtype=np.float32), - ) - - unique_vertices, inv_indices = np.unique( - selected_vertices, - axis=0, - return_inverse=True, - equal_nan=True, - size=selected_indices.shape[0] * 3, - fill_value=np.nan, - ) - - nan_mask = np.isnan(unique_vertices).any(axis=1) - inv_indices = np.where(nan_mask[inv_indices], -1, inv_indices) - - new_indices = inv_indices.reshape(selected_indices.shape) - - new_indices_sorted = np.sort(new_indices, axis=1) - - unique_triangles_indices = np.unique( - new_indices_sorted, - axis=0, - size=new_indices_sorted.shape[0], - fill_value=-1, - ) - - return ArrayTriangles( - indices=unique_triangles_indices, - vertices=unique_vertices, - max_containing_size=self.max_containing_size, - ) - - def up_sample(self) -> "ArrayTriangles": - """ - Up-sample the triangles by adding a new vertex at the midpoint of each edge. - - This means each triangle becomes four smaller triangles. - """ - new_indices, unique_vertices = remove_duplicates(self._up_sample_triangle()) - - return ArrayTriangles( - indices=new_indices, - vertices=unique_vertices, - max_containing_size=self.max_containing_size, - ) - - def neighborhood(self) -> "ArrayTriangles": - """ - Create a new set of triangles that are the neighborhood of the current triangles. - - Includes the current triangles and the triangles that share an edge with the current triangles. - """ - new_indices, unique_vertices = remove_duplicates(self._neighborhood_triangles()) - - return ArrayTriangles( - indices=new_indices, - vertices=unique_vertices, - max_containing_size=self.max_containing_size, - ) - - def with_vertices(self, vertices: np.ndarray) -> "ArrayTriangles": - """ - Create a new set of triangles with the vertices replaced. - - Parameters - ---------- - vertices - The new vertices to use. - - Returns - ------- - The new set of triangles with the new vertices. - """ - return ArrayTriangles( - indices=self.indices, - vertices=vertices, - max_containing_size=self.max_containing_size, - ) - - def __iter__(self): - return iter(self.triangles) - - def tree_flatten(self): - """ - Flatten this model as a PyTree. - """ - return ( - self.indices, - self.vertices, - ), (self.max_containing_size,) - - @classmethod - def tree_unflatten(cls, aux_data, children): - """ - Unflatten a PyTree into a model. - """ - return cls( - indices=children[0], - vertices=children[1], - max_containing_size=aux_data[0], - ) - - @classmethod - def for_limits_and_scale( - cls, - y_min: float, - y_max: float, - x_min: float, - x_max: float, - scale: float, - max_containing_size=MAX_CONTAINING_SIZE, - ) -> "AbstractTriangles": - triangles = super().for_limits_and_scale( - y_min, - y_max, - x_min, - x_max, - scale, - ) - return cls( - indices=np.array(triangles.indices), - vertices=np.array(triangles.vertices), - max_containing_size=max_containing_size, - ) - - -def select_and_handle_invalid( - data: np.ndarray, - indices: np.ndarray, - invalid_value, - invalid_replacement, -): - """ - Select data based on indices, handling invalid indices by replacing them with a specified value. - - Parameters - ---------- - data - The array from which to select data. - indices - The indices used to select data from the array. - invalid_value - The value representing invalid indices. - invalid_replacement - The value to use for invalid entries in the result. - - Returns - ------- - An array with selected data, where invalid indices are replaced with `invalid_replacement`. - """ - invalid_mask = indices == invalid_value - safe_indices = np.where(invalid_mask, 0, indices) - selected_data = data[safe_indices] - selected_data = np.where( - invalid_mask[..., None], - invalid_replacement, - selected_data, - ) - - return selected_data - - -def remove_duplicates(new_triangles): - unique_vertices, inverse_indices = np.unique( - new_triangles.reshape(-1, 2), - axis=0, - return_inverse=True, - size=2 * new_triangles.shape[0], - fill_value=np.nan, - equal_nan=True, - ) - - inverse_indices_flat = inverse_indices.reshape(-1) - selected_vertices = unique_vertices[inverse_indices_flat] - mask = np.any(np.isnan(selected_vertices), axis=1) - inverse_indices_flat = np.where(mask, -1, inverse_indices_flat) - inverse_indices = inverse_indices_flat.reshape(inverse_indices.shape) - - new_indices = inverse_indices.reshape(-1, 3) - - new_indices_sorted = np.sort(new_indices, axis=1) - - unique_triangles_indices = np.unique( - new_indices_sorted, - axis=0, - size=new_indices_sorted.shape[0], - fill_value=np.array( - [-1, -1, -1], - dtype=np.int32, - ), - ) - - return unique_triangles_indices, unique_vertices diff --git a/autoarray/structures/triangles/coordinate_array.py b/autoarray/structures/triangles/coordinate_array.py index 3827622af..8d950dd20 100644 --- a/autoarray/structures/triangles/coordinate_array.py +++ b/autoarray/structures/triangles/coordinate_array.py @@ -3,7 +3,7 @@ import jax from autoarray.structures.triangles.abstract import HEIGHT_FACTOR -from autoarray.structures.triangles.array.jax_array import ArrayTriangles +from autoarray.structures.triangles.array import ArrayTriangles from autoarray.numpy_wrapper import register_pytree_node_class from autoconf import cached_property diff --git a/test_autoarray/structures/triangles/conftest.py b/test_autoarray/structures/triangles/conftest.py index bf35f5643..a8d8580a3 100644 --- a/test_autoarray/structures/triangles/conftest.py +++ b/test_autoarray/structures/triangles/conftest.py @@ -1,5 +1,5 @@ from autoarray.numpy_wrapper import np -from autoarray.structures.triangles.array import JAXArrayTriangles as ArrayTriangles +from autoarray.structures.triangles.array import ArrayTriangles from matplotlib import pyplot as plt diff --git a/test_autoarray/structures/triangles/test_area.py b/test_autoarray/structures/triangles/test_area.py index 95a19da4a..c7a1b6ccc 100644 --- a/test_autoarray/structures/triangles/test_area.py +++ b/test_autoarray/structures/triangles/test_area.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from autoarray.structures.triangles.array import JAXArrayTriangles as ArrayTriangles +from autoarray.structures.triangles.array import ArrayTriangles from autoarray.structures.triangles.shape import Triangle, Circle, Square, Polygon diff --git a/test_autoarray/structures/triangles/test_extended_source.py b/test_autoarray/structures/triangles/test_extended_source.py index 0cdaca2b3..4ea2482af 100644 --- a/test_autoarray/structures/triangles/test_extended_source.py +++ b/test_autoarray/structures/triangles/test_extended_source.py @@ -1,7 +1,7 @@ import pytest import numpy as np -from autoarray.structures.triangles.array import JAXArrayTriangles as ArrayTriangles +from autoarray.structures.triangles.array import ArrayTriangles from autoarray.structures.triangles.shape import Circle diff --git a/test_autoarray/structures/triangles/test_jax.py b/test_autoarray/structures/triangles/test_jax.py index f613d0fe4..f82691169 100644 --- a/test_autoarray/structures/triangles/test_jax.py +++ b/test_autoarray/structures/triangles/test_jax.py @@ -5,10 +5,10 @@ import jax jax.config.update("jax_log_compiles", True) - from autoarray.structures.triangles.array.jax_array import ArrayTriangles + from autoarray.structures.triangles.array import ArrayTriangles except ImportError: import numpy as np - from autoarray.structures.triangles.array import JAXArrayTriangles as ArrayTriangles + from autoarray.structures.triangles.array import ArrayTriangles import pytest diff --git a/test_autoarray/structures/triangles/test_nan_triangles.py b/test_autoarray/structures/triangles/test_nan_triangles.py index 2e14080c4..f583d541c 100644 --- a/test_autoarray/structures/triangles/test_nan_triangles.py +++ b/test_autoarray/structures/triangles/test_nan_triangles.py @@ -2,10 +2,10 @@ try: from jax import numpy as np - from autoarray.structures.triangles.array.jax_array import ArrayTriangles + from autoarray.structures.triangles.array import ArrayTriangles except ImportError: import numpy as np - from autoarray.structures.triangles.array import JAXArrayTriangles as ArrayTriangles + from autoarray.structures.triangles.array import ArrayTriangles pytest.importorskip("jax") From b4f1853dbfe5a4fe7dcf96c20c881f9efa3edf13 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 17 Jun 2025 19:02:49 +0100 Subject: [PATCH 11/13] simplify unit tests --- .../structures/triangles/conftest.py | 16 +++ .../triangles/coordinate/__init__.py | 0 .../triangles/coordinate/conftest.py | 21 --- .../coordinate/test_coordinate_jax.py | 127 ------------------ ...e_implementation.py => test_coordinate.py} | 124 ++++++++++++++++- .../structures/triangles/test_jax.py | 16 +-- .../test_vertex_coordinates.py | 0 7 files changed, 142 insertions(+), 162 deletions(-) delete mode 100644 test_autoarray/structures/triangles/coordinate/__init__.py delete mode 100644 test_autoarray/structures/triangles/coordinate/conftest.py delete mode 100644 test_autoarray/structures/triangles/coordinate/test_coordinate_jax.py rename test_autoarray/structures/triangles/{coordinate/test_coordinate_implementation.py => test_coordinate.py} (69%) rename test_autoarray/structures/triangles/{coordinate => }/test_vertex_coordinates.py (100%) diff --git a/test_autoarray/structures/triangles/conftest.py b/test_autoarray/structures/triangles/conftest.py index a8d8580a3..203f31f88 100644 --- a/test_autoarray/structures/triangles/conftest.py +++ b/test_autoarray/structures/triangles/conftest.py @@ -1,5 +1,6 @@ from autoarray.numpy_wrapper import np from autoarray.structures.triangles.array import ArrayTriangles +from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles from matplotlib import pyplot as plt @@ -54,3 +55,18 @@ def triangles(): ] ), ) + +@pytest.fixture +def one_triangle(): + return CoordinateArrayTriangles( + coordinates=np.array([[0, 0]]), + side_length=1.0, + ) + + +@pytest.fixture +def two_triangles(): + return CoordinateArrayTriangles( + coordinates=np.array([[0, 0], [1, 0]]), + side_length=1.0, + ) diff --git a/test_autoarray/structures/triangles/coordinate/__init__.py b/test_autoarray/structures/triangles/coordinate/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/test_autoarray/structures/triangles/coordinate/conftest.py b/test_autoarray/structures/triangles/coordinate/conftest.py deleted file mode 100644 index 302b565f7..000000000 --- a/test_autoarray/structures/triangles/coordinate/conftest.py +++ /dev/null @@ -1,21 +0,0 @@ -import pytest - -import numpy as np - -from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles - - -@pytest.fixture -def one_triangle(): - return CoordinateArrayTriangles( - coordinates=np.array([[0, 0]]), - side_length=1.0, - ) - - -@pytest.fixture -def two_triangles(): - return CoordinateArrayTriangles( - coordinates=np.array([[0, 0], [1, 0]]), - side_length=1.0, - ) diff --git a/test_autoarray/structures/triangles/coordinate/test_coordinate_jax.py b/test_autoarray/structures/triangles/coordinate/test_coordinate_jax.py deleted file mode 100644 index 1f37a1c90..000000000 --- a/test_autoarray/structures/triangles/coordinate/test_coordinate_jax.py +++ /dev/null @@ -1,127 +0,0 @@ -from autoarray.numpy_wrapper import jit -import pytest - -from autoarray.structures.triangles.abstract import HEIGHT_FACTOR -from autoarray.structures.triangles.shape import Point - -try: - from jax import numpy as np - import jax - - jax.config.update("jax_log_compiles", True) - from autoarray.structures.triangles.coordinate_array.jax_coordinate_array import ( - CoordinateArrayTriangles, - ) -except ImportError: - import numpy as np - from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles - - -@pytest.fixture -def one_triangle(): - return CoordinateArrayTriangles( - coordinates=np.array([[0, 0]]), - side_length=1.0, - ) - - -@jit -def full_routine(triangles): - neighborhood = triangles.neighborhood() - up_sampled = neighborhood.up_sample() - with_vertices = up_sampled.with_vertices(up_sampled.vertices) - indexes = with_vertices.containing_indices(Point(0.1, 0.1)) - return up_sampled.for_indexes(indexes) - - -# def test_full_routine(one_triangle, compare_with_nans): -# result = full_routine(one_triangle) -# -# assert compare_with_nans( -# result.triangles, -# np.array( -# [ -# [ -# [0.0, 0.4330126941204071], -# [0.25, 0.0], -# [-0.25, 0.0], -# ] -# ] -# ), -# ) - - -def test_neighborhood(one_triangle): - assert np.allclose( - np.array(jit(one_triangle.neighborhood)().triangles), - np.array( - [ - [ - [-0.5, -0.4330126941204071], - [-1.0, 0.4330126941204071], - [0.0, 0.4330126941204071], - ], - [ - [0.0, -1.299038052558899], - [-0.5, -0.4330126941204071], - [0.5, -0.4330126941204071], - ], - [ - [0.0, 0.4330126941204071], - [0.5, -0.4330126941204071], - [-0.5, -0.4330126941204071], - ], - [ - [0.5, -0.4330126941204071], - [0.0, 0.4330126941204071], - [1.0, 0.4330126941204071], - ], - ] - ), - ) - - -def test_up_sample(one_triangle): - up_sampled = jit(one_triangle.up_sample)() - assert np.allclose( - np.array(up_sampled.triangles), - np.array( - [ - [ - [[0.0, -0.4330126941204071], [-0.25, 0.0], [0.25, 0.0]], - [ - [0.25, 0.0], - [0.5, -0.4330126941204071], - [0.0, -0.4330126941204071], - ], - [ - [-0.25, 0.0], - [0.0, -0.4330126941204071], - [-0.5, -0.4330126941204071], - ], - [[0.0, 0.4330126941204071], [0.25, 0.0], [-0.25, 0.0]], - ] - ] - ), - ) - - -def test_means(one_triangle): - assert len(one_triangle.means) == 1 - - up_sampled = one_triangle.up_sample() - neighborhood = up_sampled.neighborhood() - assert np.count_nonzero(~np.isnan(neighborhood.means).any(axis=1)) == 10 - - -ONE_TRIANGLE_AREA = HEIGHT_FACTOR * 0.5 - - -def test_area(one_triangle): - assert one_triangle.area == ONE_TRIANGLE_AREA - assert one_triangle.up_sample().area == ONE_TRIANGLE_AREA - - neighborhood = one_triangle.neighborhood() - assert neighborhood.area == 4 * ONE_TRIANGLE_AREA - assert neighborhood.up_sample().area == 4 * ONE_TRIANGLE_AREA - assert neighborhood.neighborhood().area == 10 * ONE_TRIANGLE_AREA diff --git a/test_autoarray/structures/triangles/coordinate/test_coordinate_implementation.py b/test_autoarray/structures/triangles/test_coordinate.py similarity index 69% rename from test_autoarray/structures/triangles/coordinate/test_coordinate_implementation.py rename to test_autoarray/structures/triangles/test_coordinate.py index b0e91467e..4331cbf9a 100644 --- a/test_autoarray/structures/triangles/coordinate/test_coordinate_implementation.py +++ b/test_autoarray/structures/triangles/test_coordinate.py @@ -1,11 +1,19 @@ -import pytest - +from jax import numpy as np +import jax import numpy as np +jax.config.update("jax_log_compiles", True) + +import pytest + from autoarray.structures.triangles.abstract import HEIGHT_FACTOR -from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles from autoarray.structures.triangles.shape import Point +from autoarray.structures.triangles.coordinate_array import ( + CoordinateArrayTriangles, +) + + def test__two(two_triangles): @@ -253,3 +261,113 @@ def test_from_grid_regression(): y = triangles.vertices[:, 1] assert min(y) <= -4.75 assert max(y) >= 4.75 + + +@pytest.fixture +def one_triangle(): + return CoordinateArrayTriangles( + coordinates=np.array([[0, 0]]), + side_length=1.0, + ) + + +@jax.jit +def full_routine(triangles): + neighborhood = triangles.neighborhood() + up_sampled = neighborhood.up_sample() + with_vertices = up_sampled.with_vertices(up_sampled.vertices) + indexes = with_vertices.containing_indices(Point(0.1, 0.1)) + return up_sampled.for_indexes(indexes) + + +# def test_full_routine(one_triangle, compare_with_nans): +# result = full_routine(one_triangle) +# +# assert compare_with_nans( +# result.triangles, +# np.array( +# [ +# [ +# [0.0, 0.4330126941204071], +# [0.25, 0.0], +# [-0.25, 0.0], +# ] +# ] +# ), +# ) + + +def test_neighborhood(one_triangle): + assert np.allclose( + np.array(jax.jit(one_triangle.neighborhood)().triangles), + np.array( + [ + [ + [-0.5, -0.4330126941204071], + [-1.0, 0.4330126941204071], + [0.0, 0.4330126941204071], + ], + [ + [0.0, -1.299038052558899], + [-0.5, -0.4330126941204071], + [0.5, -0.4330126941204071], + ], + [ + [0.0, 0.4330126941204071], + [0.5, -0.4330126941204071], + [-0.5, -0.4330126941204071], + ], + [ + [0.5, -0.4330126941204071], + [0.0, 0.4330126941204071], + [1.0, 0.4330126941204071], + ], + ] + ), + ) + + +def test_up_sample(one_triangle): + up_sampled = jax.jit(one_triangle.up_sample)() + assert np.allclose( + np.array(up_sampled.triangles), + np.array( + [ + [ + [[0.0, -0.4330126941204071], [-0.25, 0.0], [0.25, 0.0]], + [ + [0.25, 0.0], + [0.5, -0.4330126941204071], + [0.0, -0.4330126941204071], + ], + [ + [-0.25, 0.0], + [0.0, -0.4330126941204071], + [-0.5, -0.4330126941204071], + ], + [[0.0, 0.4330126941204071], [0.25, 0.0], [-0.25, 0.0]], + ] + ] + ), + ) + + +def test_means(one_triangle): + assert len(one_triangle.means) == 1 + + up_sampled = one_triangle.up_sample() + neighborhood = up_sampled.neighborhood() + assert np.count_nonzero(~np.isnan(neighborhood.means).any(axis=1)) == 10 + + +ONE_TRIANGLE_AREA = HEIGHT_FACTOR * 0.5 + + +def test_area(one_triangle): + assert one_triangle.area == ONE_TRIANGLE_AREA + assert one_triangle.up_sample().area == ONE_TRIANGLE_AREA + + neighborhood = one_triangle.neighborhood() + assert neighborhood.area == 4 * ONE_TRIANGLE_AREA + assert neighborhood.up_sample().area == 4 * ONE_TRIANGLE_AREA + assert neighborhood.neighborhood().area == 10 * ONE_TRIANGLE_AREA diff --git a/test_autoarray/structures/triangles/test_jax.py b/test_autoarray/structures/triangles/test_jax.py index f82691169..63e1b1293 100644 --- a/test_autoarray/structures/triangles/test_jax.py +++ b/test_autoarray/structures/triangles/test_jax.py @@ -1,19 +1,13 @@ -from autoarray.structures.triangles.shape import Point - -try: - from jax import numpy as np - import jax +from jax import numpy as np +import jax - jax.config.update("jax_log_compiles", True) - from autoarray.structures.triangles.array import ArrayTriangles -except ImportError: - import numpy as np - from autoarray.structures.triangles.array import ArrayTriangles +jax.config.update("jax_log_compiles", True) import pytest -pytest.importorskip("jax") +from autoarray.structures.triangles.shape import Point +from autoarray.structures.triangles.array import ArrayTriangles @pytest.fixture diff --git a/test_autoarray/structures/triangles/coordinate/test_vertex_coordinates.py b/test_autoarray/structures/triangles/test_vertex_coordinates.py similarity index 100% rename from test_autoarray/structures/triangles/coordinate/test_vertex_coordinates.py rename to test_autoarray/structures/triangles/test_vertex_coordinates.py From 33eb08754fb457cc6906dfaee897c96402d673cb Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 17 Jun 2025 19:04:38 +0100 Subject: [PATCH 12/13] black --- autoarray/structures/triangles/array.py | 21 +- .../structures/triangles/coordinate_array.py | 2 +- .../structures/triangles/conftest.py | 1 + .../structures/triangles/test_coordinate.py | 307 ++++++++++-------- .../triangles/test_nan_triangles.py | 11 +- 5 files changed, 191 insertions(+), 151 deletions(-) diff --git a/autoarray/structures/triangles/array.py b/autoarray/structures/triangles/array.py index 7e4526300..cb65356a3 100644 --- a/autoarray/structures/triangles/array.py +++ b/autoarray/structures/triangles/array.py @@ -49,13 +49,13 @@ def __repr__(self): @classmethod def for_limits_and_scale( - cls, - y_min: float, - y_max: float, - x_min: float, - x_max: float, - scale: float, - max_containing_size=MAX_CONTAINING_SIZE, + cls, + y_min: float, + y_max: float, + x_min: float, + x_max: float, + scale: float, + max_containing_size=MAX_CONTAINING_SIZE, ) -> "AbstractTriangles": height = scale * HEIGHT_FACTOR @@ -120,9 +120,9 @@ def add_vertex(v): @classmethod def for_grid( - cls, - grid: Grid2D, - **kwargs, + cls, + grid: Grid2D, + **kwargs, ) -> "AbstractTriangles": """ Create a grid of equilateral triangles from a regular grid. @@ -384,7 +384,6 @@ def tree_unflatten(cls, aux_data, children): ) - def select_and_handle_invalid( data: jnp.ndarray, indices: jnp.ndarray, diff --git a/autoarray/structures/triangles/coordinate_array.py b/autoarray/structures/triangles/coordinate_array.py index 8d950dd20..7a8fd125a 100644 --- a/autoarray/structures/triangles/coordinate_array.py +++ b/autoarray/structures/triangles/coordinate_array.py @@ -326,4 +326,4 @@ def means(self): @property def area(self): - return (3**0.5 / 4 * self.side_length**2) * len(self) \ No newline at end of file + return (3**0.5 / 4 * self.side_length**2) * len(self) diff --git a/test_autoarray/structures/triangles/conftest.py b/test_autoarray/structures/triangles/conftest.py index 203f31f88..9b943c224 100644 --- a/test_autoarray/structures/triangles/conftest.py +++ b/test_autoarray/structures/triangles/conftest.py @@ -56,6 +56,7 @@ def triangles(): ), ) + @pytest.fixture def one_triangle(): return CoordinateArrayTriangles( diff --git a/test_autoarray/structures/triangles/test_coordinate.py b/test_autoarray/structures/triangles/test_coordinate.py index 4331cbf9a..2f37bf506 100644 --- a/test_autoarray/structures/triangles/test_coordinate.py +++ b/test_autoarray/structures/triangles/test_coordinate.py @@ -3,7 +3,6 @@ import numpy as np jax.config.update("jax_log_compiles", True) - import pytest from autoarray.structures.triangles.abstract import HEIGHT_FACTOR @@ -14,35 +13,43 @@ ) - def test__two(two_triangles): assert np.all(two_triangles.centres == np.array([[0, 0], [0.5, 0]])) - assert two_triangles.triangles == pytest.approx(np.array([ - [ - [0.0, HEIGHT_FACTOR / 2], - [0.5, -HEIGHT_FACTOR / 2], - [-0.5, -HEIGHT_FACTOR / 2], - ], + assert two_triangles.triangles == pytest.approx( + np.array( [ - [0.5, -HEIGHT_FACTOR / 2], - [0.0, HEIGHT_FACTOR / 2], - [1.0, HEIGHT_FACTOR / 2], - ], - ]), 1.0e-4) + [ + [0.0, HEIGHT_FACTOR / 2], + [0.5, -HEIGHT_FACTOR / 2], + [-0.5, -HEIGHT_FACTOR / 2], + ], + [ + [0.5, -HEIGHT_FACTOR / 2], + [0.0, HEIGHT_FACTOR / 2], + [1.0, HEIGHT_FACTOR / 2], + ], + ] + ), + 1.0e-4, + ) def test__trivial_triangles(one_triangle): assert one_triangle.flip_array == np.array([1]) assert np.all(one_triangle.centres == np.array([[0, 0]])) - assert one_triangle.triangles == pytest.approx(np.array([ + assert one_triangle.triangles == pytest.approx( + np.array( [ - [0.0, HEIGHT_FACTOR / 2], - [0.5, -HEIGHT_FACTOR / 2], - [-0.5, -HEIGHT_FACTOR / 2], - ], - ] - ), 1.0e-4) + [ + [0.0, HEIGHT_FACTOR / 2], + [0.5, -HEIGHT_FACTOR / 2], + [-0.5, -HEIGHT_FACTOR / 2], + ], + ] + ), + 1.0e-4, + ) def test__above(): @@ -50,29 +57,33 @@ def test__above(): coordinates=np.array([[0, 1]]), side_length=1.0, ) - assert triangles.up_sample().triangles == pytest.approx(np.array([ - [ - [0.0, 0.43301270189221935], - [-0.25, 0.8660254037844386], - [0.25, 0.8660254037844386], - ], - [ - [0.25, 0.8660254037844388], - [0.0, 1.299038105676658], - [0.5, 1.299038105676658], - ], - [ - [-0.25, 0.8660254037844388], - [-0.5, 1.299038105676658], - [0.0, 1.299038105676658], - ], + assert triangles.up_sample().triangles == pytest.approx( + np.array( [ - [0.0, 1.299038105676658], - [0.25, 0.8660254037844388], - [-0.25, 0.8660254037844388], - ], - ] - ), 1.0e-4) + [ + [0.0, 0.43301270189221935], + [-0.25, 0.8660254037844386], + [0.25, 0.8660254037844386], + ], + [ + [0.25, 0.8660254037844388], + [0.0, 1.299038105676658], + [0.5, 1.299038105676658], + ], + [ + [-0.25, 0.8660254037844388], + [-0.5, 1.299038105676658], + [0.0, 1.299038105676658], + ], + [ + [0.0, 1.299038105676658], + [0.25, 0.8660254037844388], + [-0.25, 0.8660254037844388], + ], + ] + ), + 1.0e-4, + ) @pytest.fixture @@ -85,38 +96,50 @@ def upside_down(): def test_upside_down(upside_down): assert np.all(upside_down.centres == np.array([[0.5, 0]])) - assert upside_down.triangles == pytest.approx(np.array([ + assert upside_down.triangles == pytest.approx( + np.array( [ - [0.5, -HEIGHT_FACTOR / 2], - [0.0, HEIGHT_FACTOR / 2], - [1.0, HEIGHT_FACTOR / 2], - ], - ] - ), 1.0e-4) + [ + [0.5, -HEIGHT_FACTOR / 2], + [0.0, HEIGHT_FACTOR / 2], + [1.0, HEIGHT_FACTOR / 2], + ], + ] + ), + 1.0e-4, + ) def test_up_sample(one_triangle): up_sampled = one_triangle.up_sample() assert up_sampled.side_length == 0.5 - assert up_sampled.triangles == pytest.approx(np.array([ - [[0.0, -0.4330127018922193], [-0.25, 0.0], [0.25, 0.0]], - [[0.25, 0.0], [0.5, -0.4330127018922193], [0.0, -0.4330127018922193]], - [[-0.25, 0.0], [0.0, -0.4330127018922193], [-0.5, -0.4330127018922193]], - [[0.0, 0.4330127018922193], [0.25, 0.0], [-0.25, 0.0]], - ] - ), 1.0e-4) + assert up_sampled.triangles == pytest.approx( + np.array( + [ + [[0.0, -0.4330127018922193], [-0.25, 0.0], [0.25, 0.0]], + [[0.25, 0.0], [0.5, -0.4330127018922193], [0.0, -0.4330127018922193]], + [[-0.25, 0.0], [0.0, -0.4330127018922193], [-0.5, -0.4330127018922193]], + [[0.0, 0.4330127018922193], [0.25, 0.0], [-0.25, 0.0]], + ] + ), + 1.0e-4, + ) def test_up_sample_upside_down(upside_down): up_sampled = upside_down.up_sample() assert up_sampled.side_length == 0.5 - assert up_sampled.triangles == pytest.approx(np.array([ - [[0.5, -0.4330127018922193], [0.25, 0.0], [0.75, 0.0]], - [[0.75, 0.0], [0.5, 0.4330127018922193], [1.0, 0.4330127018922193]], - [[0.25, 0.0], [0.0, 0.4330127018922193], [0.5, 0.4330127018922193]], - [[0.5, 0.4330127018922193], [0.75, 0.0], [0.25, 0.0]], - ] - ), 1.0e-4) + assert up_sampled.triangles == pytest.approx( + np.array( + [ + [[0.5, -0.4330127018922193], [0.25, 0.0], [0.75, 0.0]], + [[0.75, 0.0], [0.5, 0.4330127018922193], [1.0, 0.4330127018922193]], + [[0.25, 0.0], [0.0, 0.4330127018922193], [0.5, 0.4330127018922193]], + [[0.5, 0.4330127018922193], [0.75, 0.0], [0.25, 0.0]], + ] + ), + 1.0e-4, + ) def _test_up_sample_twice(one_triangle, plot): @@ -130,55 +153,63 @@ def _test_up_sample_twice(one_triangle, plot): def test_neighborhood(one_triangle): - assert one_triangle.neighborhood().triangles == pytest.approx(np.array([ - [ - [-0.5, -0.4330127018922193], - [-1.0, 0.4330127018922193], - [0.0, 0.4330127018922193], - ], - [ - [0.0, -1.299038105676658], - [-0.5, -0.4330127018922193], - [0.5, -0.4330127018922193], - ], - [ - [0.0, 0.4330127018922193], - [0.5, -0.4330127018922193], - [-0.5, -0.4330127018922193], - ], + assert one_triangle.neighborhood().triangles == pytest.approx( + np.array( [ - [0.5, -0.4330127018922193], - [0.0, 0.4330127018922193], - [1.0, 0.4330127018922193], - ], - ] - ), 1.0e-4) + [ + [-0.5, -0.4330127018922193], + [-1.0, 0.4330127018922193], + [0.0, 0.4330127018922193], + ], + [ + [0.0, -1.299038105676658], + [-0.5, -0.4330127018922193], + [0.5, -0.4330127018922193], + ], + [ + [0.0, 0.4330127018922193], + [0.5, -0.4330127018922193], + [-0.5, -0.4330127018922193], + ], + [ + [0.5, -0.4330127018922193], + [0.0, 0.4330127018922193], + [1.0, 0.4330127018922193], + ], + ] + ), + 1.0e-4, + ) def test_upside_down_neighborhood(upside_down): - assert upside_down.neighborhood().triangles == pytest.approx(np.array([ - [ - [0.0, 0.4330127018922193], - [0.5, -0.4330127018922193], - [-0.5, -0.4330127018922193], - ], - [ - [0.5, -0.4330127018922193], - [0.0, 0.4330127018922193], - [1.0, 0.4330127018922193], - ], - [ - [0.5, 1.299038105676658], - [1.0, 0.4330127018922193], - [0.0, 0.4330127018922193], - ], + assert upside_down.neighborhood().triangles == pytest.approx( + np.array( [ - [1.0, 0.4330127018922193], - [1.5, -0.4330127018922193], - [0.5, -0.4330127018922193], - ], - ] - ), 1.0e-4) + [ + [0.0, 0.4330127018922193], + [0.5, -0.4330127018922193], + [-0.5, -0.4330127018922193], + ], + [ + [0.5, -0.4330127018922193], + [0.0, 0.4330127018922193], + [1.0, 0.4330127018922193], + ], + [ + [0.5, 1.299038105676658], + [1.0, 0.4330127018922193], + [0.0, 0.4330127018922193], + ], + [ + [1.0, 0.4330127018922193], + [1.5, -0.4330127018922193], + [0.5, -0.4330127018922193], + ], + ] + ), + 1.0e-4, + ) def _test_complicated(plot, one_triangle): @@ -187,29 +218,39 @@ def _test_complicated(plot, one_triangle): def test_vertices(one_triangle): - assert one_triangle.vertices == pytest.approx(np.array([ - [-0.5, -0.4330127018922193], - [0.0, 0.4330127018922193], - [0.5, -0.4330127018922193], - ] - ), 1.0e-4) + assert one_triangle.vertices == pytest.approx( + np.array( + [ + [-0.5, -0.4330127018922193], + [0.0, 0.4330127018922193], + [0.5, -0.4330127018922193], + ] + ), + 1.0e-4, + ) def test_up_sampled_vertices(one_triangle): - assert one_triangle.up_sample().vertices[0:6, :] == pytest.approx(np.array([ - [-0.5, -0.4330127018922193], - [-0.25, 0.0], - [0.0, -0.4330127018922193], - [0.0, 0.4330127018922193], - [0.25, 0.0], - [0.5, -0.4330127018922193], - ] - ), 1.0e-4) + assert one_triangle.up_sample().vertices[0:6, :] == pytest.approx( + np.array( + [ + [-0.5, -0.4330127018922193], + [-0.25, 0.0], + [0.0, -0.4330127018922193], + [0.0, 0.4330127018922193], + [0.25, 0.0], + [0.5, -0.4330127018922193], + ] + ), + 1.0e-4, + ) def test_with_vertices(one_triangle): triangle = one_triangle.with_vertices(np.array([[0, 0], [1, 0], [0.5, 1]])) - assert triangle.triangles == pytest.approx(np.array([[[1.0, 0.0], [0.5, 1.0], [0.0, 0.0]]]), 1.0e-4) + assert triangle.triangles == pytest.approx( + np.array([[[1.0, 0.0], [0.5, 1.0], [0.0, 0.0]]]), 1.0e-4 + ) def _test_multiple_with_vertices(one_triangle, plot): @@ -218,18 +259,24 @@ def _test_multiple_with_vertices(one_triangle, plot): def test_for_indexes(two_triangles): - assert two_triangles.for_indexes(np.array([0])).triangles == pytest.approx(np.array([ + assert two_triangles.for_indexes(np.array([0])).triangles == pytest.approx( + np.array( [ - [0.0, 0.4330127018922193], - [0.5, -0.4330127018922193], - [-0.5, -0.4330127018922193], + [ + [0.0, 0.4330127018922193], + [0.5, -0.4330127018922193], + [-0.5, -0.4330127018922193], + ] ] - ] - ), 1.0e-4) + ), + 1.0e-4, + ) def test_means(one_triangle): - assert one_triangle.means == pytest.approx(np.array([[0.0, -0.14433756729740643]]), 1.0e-4) + assert one_triangle.means == pytest.approx( + np.array([[0.0, -0.14433756729740643]]), 1.0e-4 + ) def test_triangles_touch(): diff --git a/test_autoarray/structures/triangles/test_nan_triangles.py b/test_autoarray/structures/triangles/test_nan_triangles.py index f583d541c..6dd420ad5 100644 --- a/test_autoarray/structures/triangles/test_nan_triangles.py +++ b/test_autoarray/structures/triangles/test_nan_triangles.py @@ -1,14 +1,7 @@ +from jax import numpy as np import pytest -try: - from jax import numpy as np - from autoarray.structures.triangles.array import ArrayTriangles -except ImportError: - import numpy as np - from autoarray.structures.triangles.array import ArrayTriangles - - -pytest.importorskip("jax") +from autoarray.structures.triangles.array import ArrayTriangles @pytest.fixture From ad4211e59fbc8661343004e4bf540042e674165a Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 17 Jun 2025 19:41:24 +0100 Subject: [PATCH 13/13] black --- autoarray/structures/triangles/abstract.py | 1 - autoarray/structures/triangles/array.py | 40 +------------------ .../structures/triangles/coordinate_array.py | 5 ++- 3 files changed, 5 insertions(+), 41 deletions(-) diff --git a/autoarray/structures/triangles/abstract.py b/autoarray/structures/triangles/abstract.py index fe2b754fb..3ae5e4718 100644 --- a/autoarray/structures/triangles/abstract.py +++ b/autoarray/structures/triangles/abstract.py @@ -3,7 +3,6 @@ import numpy as np from autoarray import Grid2D -from autoarray.structures.triangles.shape import Shape HEIGHT_FACTOR = 3**0.5 / 2 diff --git a/autoarray/structures/triangles/array.py b/autoarray/structures/triangles/array.py index cb65356a3..353163a00 100644 --- a/autoarray/structures/triangles/array.py +++ b/autoarray/structures/triangles/array.py @@ -12,7 +12,7 @@ @register_pytree_node_class -class ArrayTriangles: +class ArrayTriangles(AbstractTriangles): def __init__( self, indices, @@ -118,44 +118,6 @@ def add_vertex(v): max_containing_size=max_containing_size, ) - @classmethod - def for_grid( - cls, - grid: Grid2D, - **kwargs, - ) -> "AbstractTriangles": - """ - Create a grid of equilateral triangles from a regular grid. - - Parameters - ---------- - grid - The regular grid to convert to a grid of triangles. - - Returns - ------- - The grid of triangles. - """ - - scale = grid.pixel_scale - - y = grid[:, 0] - x = grid[:, 1] - - y_min = y.min() - y_max = y.max() - x_min = x.min() - x_max = x.max() - - return cls.for_limits_and_scale( - y_min, - y_max, - x_min, - x_max, - scale, - **kwargs, - ) - @property def indices(self): return self._indices diff --git a/autoarray/structures/triangles/coordinate_array.py b/autoarray/structures/triangles/coordinate_array.py index 7a8fd125a..c919ffc86 100644 --- a/autoarray/structures/triangles/coordinate_array.py +++ b/autoarray/structures/triangles/coordinate_array.py @@ -1,8 +1,11 @@ +from abc import ABC + import numpy as np import jax.numpy as jnp import jax from autoarray.structures.triangles.abstract import HEIGHT_FACTOR +from autoarray.structures.triangles.abstract import AbstractTriangles from autoarray.structures.triangles.array import ArrayTriangles from autoarray.numpy_wrapper import register_pytree_node_class from autoconf import cached_property @@ -11,7 +14,7 @@ @register_pytree_node_class -class CoordinateArrayTriangles: +class CoordinateArrayTriangles(AbstractTriangles, ABC): def __init__( self,