From 3e43e8c9720f462d829e12994b92d1c049d6d245 Mon Sep 17 00:00:00 2001 From: Dieter Weber Date: Tue, 2 Sep 2025 17:17:37 +0200 Subject: [PATCH 1/4] Default with correct type --- src/temgym_core/source.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/temgym_core/source.py b/src/temgym_core/source.py index 670c9d2..a0a6879 100755 --- a/src/temgym_core/source.py +++ b/src/temgym_core/source.py @@ -98,7 +98,7 @@ class PointSource(Source): """ z: float semi_conv: float - offset_xy: CoordsXY = (0.0, 0.0) + offset_xy: CoordsXY = CoordsXY(x=0.0, y=0.0) def generate_array(self, num: int, random: bool = False) -> np.ndarray: """Generate rays with varying slopes within a cone of semi-convergence. From f5c0f6dfe894b8a4c1e70112da3d3610f8fefa36 Mon Sep 17 00:00:00 2001 From: Dieter Weber Date: Mon, 8 Sep 2025 14:00:19 +0200 Subject: [PATCH 2/4] Add more tests ...in particular for `run_iter()` and more components. Test descan error implementation for correctly handling the names, not only indices. Drive-by flake8 fixes --- tests/test_component.py | 125 ++++++++++++++++++++++++++++------------ tests/test_gaussians.py | 26 ++++++--- tests/test_run.py | 92 +++++++++++++++++++++++++++++ 3 files changed, 199 insertions(+), 44 deletions(-) create mode 100644 tests/test_run.py diff --git a/tests/test_component.py b/tests/test_component.py index c0e4b1a..942b349 100755 --- a/tests/test_component.py +++ b/tests/test_component.py @@ -1,10 +1,14 @@ import pytest +from numpy.testing import assert_allclose + import numpy as np from jax import jacobian import jax.numpy as jnp import jax_dataclasses as jdc -from temgym_core.components import ScanGrid, Detector, Descanner, DescanError, Component +from temgym_core.components import ( + ScanGrid, Detector, Descanner, Scanner, DescanError, Component, Plane +) from temgym_core.ray import Ray from temgym_core.utils import custom_jacobian_matrix @@ -86,8 +90,8 @@ def test_scan_grid_metres_to_pixels(xy, rotation, expected_pixel_coords): shape=(11, 11), ) pixel_coords_y, pixel_coords_x = scan_grid.metres_to_pixels(xy) - np.testing.assert_allclose(pixel_coords_y, expected_pixel_coords[0], atol=1e-6) - np.testing.assert_allclose(pixel_coords_x, expected_pixel_coords[1], atol=1e-6) + assert_allclose(pixel_coords_y, expected_pixel_coords[0], atol=1e-6) + assert_allclose(pixel_coords_x, expected_pixel_coords[1], atol=1e-6) @pytest.mark.parametrize( @@ -115,8 +119,8 @@ def test_scan_grid_pixels_to_metres(pixel_coords, rotation, expected_xy): shape=(11, 11), ) metres_coords_x, metres_coords_y = scan_grid.pixels_to_metres(pixel_coords) - np.testing.assert_allclose(metres_coords_x, expected_xy[0], atol=1e-6) - np.testing.assert_allclose(metres_coords_y, expected_xy[1], atol=1e-6) + assert_allclose(metres_coords_x, expected_xy[0], atol=1e-6) + assert_allclose(metres_coords_y, expected_xy[1], atol=1e-6) @pytest.mark.parametrize( @@ -137,8 +141,8 @@ def test_detector_metres_to_pixels(xy, expected_pixel_coords): flip_y=False, ) pixel_coords_y, pixel_coords_x = detector.metres_to_pixels(xy) - np.testing.assert_allclose(pixel_coords_y, expected_pixel_coords[0], atol=1e-6) - np.testing.assert_allclose(pixel_coords_x, expected_pixel_coords[1], atol=1e-6) + assert_allclose(pixel_coords_y, expected_pixel_coords[0], atol=1e-6) + assert_allclose(pixel_coords_x, expected_pixel_coords[1], atol=1e-6) # Test cases for Detector: @@ -161,46 +165,93 @@ def test_detector_pixels_to_metres(pixel_coords, expected_xy): flip_y=False, ) metres_coords_x, metres_coords_y = detector.pixels_to_metres(pixel_coords) - np.testing.assert_allclose(metres_coords_x, expected_xy[0], atol=1e-6) - np.testing.assert_allclose(metres_coords_y, expected_xy[1], atol=1e-6) + assert_allclose(metres_coords_x, expected_xy[0], atol=1e-6) + assert_allclose(metres_coords_y, expected_xy[1], atol=1e-6) + + +def test_plane(): + x, y, dx, dy, z, pathlength = np.random.uniform(-5.0, 5.0, size=6) + ray = Ray(x=x, y=y, dx=dx, dy=dy, _one=1.0, z=0.0, pathlength=0.0) + comp = Plane(z=23) + out = comp(ray) + for attr in ('x', 'y', 'dx', 'dy', '_one', 'z', 'pathlength'): + assert getattr(ray, attr) == getattr(out, attr) + + +def test_scanner_random(): + # Randomly chosen scan position and ray parameters + sp_x, sp_y = np.random.uniform(-5.0, 5.0), np.random.uniform(-5.0, 5.0) + st_x, st_y = np.random.uniform(-0.5, 0.5), np.random.uniform(0.5, 0.5) + x, y, dx, dy, z, pathlength = np.random.uniform(-5.0, 5.0, size=6) + + sc = Scanner( + z=23, + scan_pos_x=sp_x, scan_pos_y=sp_y, + scan_tilt_y=st_y, scan_tilt_x=st_x, + ) + ray = Ray(x=x, y=y, dx=dx, dy=dy, _one=1.0, z=z, pathlength=pathlength) + out = sc(ray) + + # Expected values computed using the same formula as in the implementation + exp_x = x + sp_x + exp_y = y + sp_y + exp_dx = dx + st_x + exp_dy = dy + st_y + + assert_allclose(out.x, exp_x, atol=1e-6) + assert_allclose(out.y, exp_y, atol=1e-6) + assert_allclose(out.dx, exp_dx, atol=1e-6) + assert_allclose(out.dy, exp_dy, atol=1e-6) + for attr in ('_one', 'z', 'pathlength'): + assert getattr(ray, attr) == getattr(out, attr) def test_descanner_random_descan_error(): # Randomly chosen scan position and ray parameters sp_x, sp_y = np.random.uniform(-5.0, 5.0), np.random.uniform(-5.0, 5.0) - x, y, dx, dy = np.random.uniform(-5.0, 5.0, size=4) + st_x, st_y = np.random.uniform(-0.5, 0.5), np.random.uniform(0.5, 0.5) + x, y, dx, dy, z, pathlength = np.random.uniform(-5.0, 5.0, size=6) # Randomly chosen non-zero descan error (length 12) - err = np.random.rand(12) + (pxo_pxi, pxo_pyi, pyo_pxi, pyo_pyi, + sxo_pxi, sxo_pyi, syo_pxi, syo_pyi, + offpxi, offpyi, offsxi, offsyi) = np.random.rand(12) err = DescanError( - pxo_pxi=err[0], - pxo_pyi=err[1], - pyo_pxi=err[2], - pyo_pyi=err[3], - sxo_pxi=err[4], - sxo_pyi=err[5], - syo_pxi=err[6], - syo_pyi=err[7], - offpxi=err[8], - offpyi=err[9], - offsxi=err[10], - offsyi=err[11], + pxo_pxi=pxo_pxi, + pxo_pyi=pxo_pyi, + pyo_pxi=pyo_pxi, + pyo_pyi=pyo_pyi, + sxo_pxi=sxo_pxi, + sxo_pyi=sxo_pyi, + syo_pxi=syo_pxi, + syo_pyi=syo_pyi, + offpxi=offpxi, + offpyi=offpyi, + offsxi=offsxi, + offsyi=offsyi, ) - desc = Descanner(z=0.0, scan_pos_x=sp_x, scan_pos_y=sp_y, descan_error=err) - ray = Ray(x=x, y=y, dx=dx, dy=dy, _one=1.0, z=0.0, pathlength=0.0) + desc = Descanner( + z=23, + scan_pos_x=sp_x, scan_pos_y=sp_y, + scan_tilt_y=st_y, scan_tilt_x=st_x, + descan_error=err + ) + ray = Ray(x=x, y=y, dx=dx, dy=dy, _one=1.0, z=z, pathlength=pathlength) out = desc(ray) # Expected values computed using the same formula as in the implementation - exp_x = x + sp_x * err[0] + sp_y * err[1] + err[8] - sp_x - exp_y = y + sp_x * err[2] + sp_y * err[3] + err[9] - sp_y - exp_dx = dx + sp_x * err[4] + sp_y * err[5] + err[10] - exp_dy = dy + sp_x * err[6] + sp_y * err[7] + err[11] + exp_x = x + sp_x * pxo_pxi + sp_y * pxo_pyi + offpxi - sp_x + exp_y = y + sp_x * pyo_pxi + sp_y * pyo_pyi + offpyi - sp_y + exp_dx = dx + sp_x * sxo_pxi + sp_y * sxo_pyi + offsxi - st_x + exp_dy = dy + sp_x * syo_pxi + sp_y * syo_pyi + offsyi - st_y - np.testing.assert_allclose(out.x, exp_x, atol=1e-6) - np.testing.assert_allclose(out.y, exp_y, atol=1e-6) - np.testing.assert_allclose(out.dx, exp_dx, atol=1e-6) - np.testing.assert_allclose(out.dy, exp_dy, atol=1e-6) + assert_allclose(out.x, exp_x, atol=1e-6) + assert_allclose(out.y, exp_y, atol=1e-6) + assert_allclose(out.dx, exp_dx, atol=1e-6) + assert_allclose(out.dy, exp_dy, atol=1e-6) + for attr in ('_one', 'z', 'pathlength'): + assert getattr(ray, attr) == getattr(out, attr) def test_descanner_offset_consistency(): @@ -223,7 +274,7 @@ def test_descanner_offset_consistency(): offsyi=err[11], ) desc = Descanner( - z=0.0, scan_pos_x=scan_pos_x, scan_pos_y=scan_pos_y, descan_error=err + z=11, scan_pos_x=scan_pos_x, scan_pos_y=scan_pos_y, descan_error=err ) # generate a batch of random rays @@ -251,7 +302,7 @@ def test_descanner_offset_consistency(): # assert that all rays have received the same offset first = offsets[0] for off in offsets: - np.testing.assert_allclose(off, first, atol=1e-6) + assert_allclose(off, first, atol=1e-6) def test_descanner_jacobian_matrix(): @@ -294,7 +345,7 @@ def test_descanner_jacobian_matrix(): [0.0, 0.0, 0.0, 0.0, 1.0], ] ) - np.testing.assert_allclose(J, T, atol=1e-6) + assert_allclose(J, T, atol=1e-6) @pytest.mark.parametrize("repeat", tuple(range(5))) @@ -319,7 +370,7 @@ def test_scan_grid_rotation_random(repeat): # expected rotated step vector = R(scan_rot) @ [step_x, 0] theta = np.deg2rad(scan_rot) exp_scan = np.array([np.cos(theta) * step[0], -np.sin(theta) * step[0]]) - np.testing.assert_allclose(vec_scan, exp_scan, atol=1e-6) + assert_allclose(vec_scan, exp_scan, atol=1e-6) def test_singular_component_jacobian(): diff --git a/tests/test_gaussians.py b/tests/test_gaussians.py index 0ad4094..e65d599 100644 --- a/tests/test_gaussians.py +++ b/tests/test_gaussians.py @@ -17,7 +17,9 @@ calculate_z1_and_z2_from_M_and_f, ) from skimage.restoration import unwrap_phase -from temgym_core.utils import make_aperture, zero_phase, FresnelPropagator, fresnel_lens_imaging_solution +from temgym_core.utils import ( + make_aperture, zero_phase, FresnelPropagator, fresnel_lens_imaging_solution +) import numpy as np import jax.numpy as jnp import jax @@ -293,7 +295,9 @@ def test_gaussian_free_space_vs_fresnel(): gauss_input.shape[1] // 2, ) - fresnel_gauss_image = FresnelPropagator(gauss_input, det_edge_x, wavelength, propagation_distance) + fresnel_gauss_image = FresnelPropagator( + gauss_input, det_edge_x, wavelength, propagation_distance + ) # Normalize amplitude so the maximum magnitude is 1 analytic_gauss_image /= np.max(np.abs(analytic_gauss_image)) @@ -402,8 +406,10 @@ def test_gaussian_lens_vs_fresnel(): gauss_input.shape[1] // 2, ) - fresnel_gauss_image = fresnel_lens_imaging_solution(gauss_input, Y, X, pixel_size[0], wavelength, - defocus+np.abs(z1), f, z2) + fresnel_gauss_image = fresnel_lens_imaging_solution( + gauss_input, Y, X, pixel_size[0], wavelength, + defocus+np.abs(z1), f, z2 + ) fresnel_gauss_image = zero_phase( fresnel_gauss_image, @@ -563,11 +569,17 @@ def test_gaussian_two_beam_interference_vs_fresnel(): tilted_shifted_plane_wave1 = np.exp(1j * k * dot[0]) tilted_shifted_plane_wave2 = np.exp(1j * k * dot[1]) - gaussian_misaligned1 = (gaussian_shifted_1 * tilted_shifted_plane_wave1).reshape(shape[0], shape[1]) - gaussian_misaligned2 = (gaussian_shifted_2 * tilted_shifted_plane_wave2).reshape(shape[0], shape[1]) + gaussian_misaligned1 = ( + gaussian_shifted_1 * tilted_shifted_plane_wave1 + ).reshape(shape[0], shape[1]) + gaussian_misaligned2 = ( + gaussian_shifted_2 * tilted_shifted_plane_wave2 + ).reshape(shape[0], shape[1]) gaussian_misaligned = gaussian_misaligned1 + gaussian_misaligned2 - fresnel_gauss_image = fresnel_lens_imaging_solution(gaussian_misaligned, Y, X, pixel_size[0], wavelength, 0.0, f, z2) + fresnel_gauss_image = fresnel_lens_imaging_solution( + gaussian_misaligned, Y, X, pixel_size[0], wavelength, 0.0, f, z2 + ) fresnel_gauss_image = zero_phase(fresnel_gauss_image, shape[0]//2, shape[1]//2) # Normalize amplitude so the maximum magnitude is 1 diff --git a/tests/test_run.py b/tests/test_run.py new file mode 100644 index 0000000..7165b2c --- /dev/null +++ b/tests/test_run.py @@ -0,0 +1,92 @@ +from numpy.testing import assert_allclose + +from temgym_core.components import Scanner, Plane, Descanner +from temgym_core.source import PointSource +from temgym_core.ray import Ray +from temgym_core.run import run_iter +from temgym_core.propagator import Propagator, FreeSpaceParaxial + + +def test_run_iter(): + # These components shouldn't change the ray as it passes through + components = ( + PointSource(z=0., semi_conv=0.023), + Scanner(z=1.2, scan_pos_x=0., scan_pos_y=0.), + Plane(z=1.2), + Descanner(z=1.2, scan_pos_x=0., scan_pos_y=0.), + Plane(z=3.1) + ) + ray = Ray( + x=0.12, + y=0.23, + dx=0.34, + dy=0.45, + z=3.14, + pathlength=0.34 + ) + res = list(run_iter(ray=ray, components=components)) + + prev_ray = ray + + for i, component in enumerate(components): + prop_index = 2*i + comp_index = 2*i + 1 + prop, prop_r = res[prop_index] + comp, comp_r = res[comp_index] + assert isinstance(prop, Propagator) + assert isinstance(prop.propagator, FreeSpaceParaxial) + assert prop.distance == component.z - prev_ray.z + assert_allclose(prop_r.z, comp.z) + assert_allclose(comp_r.z, comp.z) + assert prev_ray.dx == prop_r.dx + assert prev_ray.dy == prop_r.dy + assert_allclose(prop_r.x, prev_ray.x + prev_ray.dx*prop.distance) + assert_allclose(prop_r.y, prev_ray.y + prev_ray.dy*prop.distance) + # FIXME add test for correct path length + + prev_ray = comp_r + + +def test_run_iter_noprop(): + # everything at the same z level + z = 1.2 + # These components do change the ray, i.e. we + # test that run_iter() actually passes the ray through the components + components = ( + PointSource(z=z, semi_conv=0.023), + Scanner(z=z, scan_pos_x=23., scan_pos_y=42.), + Plane(z=z), + Descanner(z=z, scan_pos_x=13., scan_pos_y=11.), + Plane(z=z) + ) + ray = Ray( + x=0.12, + y=0.23, + dx=0.34, + dy=0.45, + z=z, + pathlength=0.34 + ) + res = list(run_iter(ray=ray, components=components)) + + # Reference result: Compose the components without propagation + res_ref = ray + for comp in components: + res_ref = comp(res_ref) + + prev_ray = ray + for i, component in enumerate(components): + prop_index = 2*i + comp_index = 2*i + 1 + prop, prop_r = res[prop_index] + comp, comp_r = res[comp_index] + assert isinstance(prop, Propagator) + assert isinstance(prop.propagator, FreeSpaceParaxial) + assert prop.distance == component.z - prev_ray.z + assert_allclose(prop_r.z, comp.z) + assert_allclose(comp_r.z, comp.z) + prev_ray = comp_r + + final_ray = res[-1][1] + for attr in ('x', 'y', 'dx', 'dy', '_one', 'z', 'pathlength'): + assert_allclose(getattr(final_ray, attr), getattr(res_ref, attr)) From aec96027bcb0f90e9837319bcd320c7f87ee0070 Mon Sep 17 00:00:00 2001 From: Dieter Weber Date: Tue, 9 Sep 2025 12:08:51 +0200 Subject: [PATCH 3/4] Introduce type for single coordinates or pixels; allow float pixels The types work well to make clear what x and y are, and to distinguish pixel vs physical coordinates in code. * Introduce types for single values as opposed to arrays, which correspond well to single rays etc * Make Grid center a single type * Allow floats for pixel coordinates since conversion to discrete values should happen as late as possible to avoid rounding issues * Test for conversion helpers that reduce boilerplate --- src/temgym_core/__init__.py | 58 +++++++++++++++++++++++++++++++++---- src/temgym_core/grid.py | 16 ++++++---- tests/test_coordinates.py | 39 +++++++++++++++++++++++++ 3 files changed, 102 insertions(+), 11 deletions(-) create mode 100644 tests/test_coordinates.py diff --git a/src/temgym_core/__init__.py b/src/temgym_core/__init__.py index b772618..8cf5803 100755 --- a/src/temgym_core/__init__.py +++ b/src/temgym_core/__init__.py @@ -1,5 +1,7 @@ from typing_extensions import TypeAlias -from typing import NamedTuple +from typing import NamedTuple, Union + +import jax.numpy as jnp import numpy as np from numpy.typing import NDArray @@ -46,6 +48,26 @@ class ScaleYX(NamedTuple): x: float +class CoordXY(NamedTuple): + """Continuous coordinates in the optical frame. + + Parameters + ---------- + x : float + X position, metres. + y : float + Y position, metres. + """ + x: float + y: float + + def to_coords(self) -> 'CoordsXY': + return CoordsXY( + x=jnp.array((self.x,)), + y=jnp.array((self.y,)) + ) + + class CoordsXY(NamedTuple): """Continuous coordinates in the optical frame. @@ -60,19 +82,43 @@ class CoordsXY(NamedTuple): y: NDArray[np.floating] +class PixelYX(NamedTuple): + """Pixel coordinates for images. + + Parameters + ---------- + y : Union[int, float] + Pixel row indices + x : Union[int, float] + Pixel column indices + + Notes + ----- + Pixel indices are 0-based. + """ + y: Union[int, float] + x: Union[int, float] + + def to_pixels(self) -> 'PixelsYX': + return PixelsYX( + x=jnp.array((self.x,)), + y=jnp.array((self.y,)) + ) + + class PixelsYX(NamedTuple): - """Discrete pixel coordinates for images. + """Pixel coordinates for images. Parameters ---------- y : numpy.ndarray - Pixel row indices. Integer dtype. + Pixel row indices. Integer or floating dtype. x : numpy.ndarray - Pixel column indices. Integer dtype. + Pixel column indices. Integer or floating dtype. Notes ----- Pixel indices are 0-based. """ - y: NDArray[np.integer] - x: NDArray[np.integer] + y: NDArray[Union[np.integer, np.floating]] + x: NDArray[Union[np.integer, np.floating]] diff --git a/src/temgym_core/grid.py b/src/temgym_core/grid.py index 5f77eb7..830ac7a 100755 --- a/src/temgym_core/grid.py +++ b/src/temgym_core/grid.py @@ -2,7 +2,7 @@ import numpy as np import jax.numpy as jnp -from . import Degrees, ShapeYX, CoordsXY, ScaleYX, PixelsYX +from . import Degrees, ShapeYX, CoordXY, CoordsXY, ScaleYX, PixelsYX from .ray import Ray from .utils import inplace_sum, try_ravel, try_reshape from .coordinate_transforms import pixels_to_metres_transform, apply_transformation @@ -15,7 +15,7 @@ class Grid: ---------- z : float Axial position in metres. - centre : CoordsXY + centre : CoordXY Grid centre in metres (x, y). shape : ShapeYX Grid shape (y, x) in pixels. @@ -27,7 +27,7 @@ class Grid: If True, apply an additional vertical flip. """ z: float - centre: CoordsXY + centre: CoordXY shape: ShapeYX pixel_size: ScaleYX rotation: Degrees @@ -146,7 +146,10 @@ def metres_to_pixels(self, coords: CoordsXY, cast: bool = True) -> PixelsYX: if cast: pixels_y = jnp.round(pixels_y).astype(jnp.int32) pixels_x = jnp.round(pixels_x).astype(jnp.int32) - return try_reshape(pixels_y, coords_y), try_reshape(pixels_x, coords_x) + return PixelsYX( + y=try_reshape(pixels_y, coords_y), + x=try_reshape(pixels_x, coords_x) + ) def pixels_to_metres(self, pixels: PixelsYX) -> CoordsXY: """Convert pixel indices to metric coordinates. @@ -172,7 +175,10 @@ def pixels_to_metres(self, pixels: PixelsYX) -> CoordsXY: metres_y, metres_x = apply_transformation( try_ravel(pixels_y), try_ravel(pixels_x), pixels_to_metres_mat ) - return try_reshape(metres_x, pixels_x), try_reshape(metres_y, pixels_y) + return CoordsXY( + x=try_reshape(metres_x, pixels_x), + y=try_reshape(metres_y, pixels_y) + ) def ray_at_grid( self, px_y: float, px_x: float, dx: float = 0., dy: float = 0., z: float | None = None diff --git a/tests/test_coordinates.py b/tests/test_coordinates.py new file mode 100644 index 0000000..c37948d --- /dev/null +++ b/tests/test_coordinates.py @@ -0,0 +1,39 @@ +import jax.numpy as jnp + +from temgym_core import PixelsYX, PixelYX, CoordsXY, CoordXY + + +def test_to_pixels_int(): + px = PixelYX(x=17, y=23) + pxs = px.to_pixels() + assert isinstance(pxs, PixelsYX) + assert len(pxs.x) == 1 + assert len(pxs.y) == 1 + assert jnp.all(px.x == pxs.x) + assert jnp.all(px.y == pxs.y) + assert pxs.x.dtype.kind == 'i' + assert pxs.y.dtype.kind == 'i' + + +def test_to_pixels_float(): + px = PixelYX(x=17., y=23.) + pxs = px.to_pixels() + assert isinstance(pxs, PixelsYX) + assert len(pxs.x) == 1 + assert len(pxs.y) == 1 + assert jnp.all(px.x == pxs.x) + assert jnp.all(px.y == pxs.y) + assert pxs.x.dtype.kind == 'f' + assert pxs.y.dtype.kind == 'f' + + +def test_to_coords(): + coord = CoordXY(x=17., y=23.) + coords = coord.to_coords() + assert isinstance(coords, CoordsXY) + assert len(coords.x) == 1 + assert len(coords.y) == 1 + assert jnp.all(coord.x == coords.x) + assert jnp.all(coord.y == coords.y) + assert coords.x.dtype.kind == 'f' + assert coords.y.dtype.kind == 'f' From 13d3c965050f097222b416012a00e41eb77b25d8 Mon Sep 17 00:00:00 2001 From: Dieter Weber Date: Thu, 11 Sep 2025 11:55:22 +0200 Subject: [PATCH 4/4] Set float64 support everywhere --- conftest.py | 1 + src/temgym_core/__init__.py | 1 + src/temgym_core/components.py | 2 ++ src/temgym_core/coordinate_transforms.py | 1 + src/temgym_core/gaussian.py | 9 +++++---- src/temgym_core/grid.py | 1 + src/temgym_core/ray.py | 1 + src/temgym_core/run.py | 2 +- src/temgym_core/source.py | 1 + src/temgym_core/transfer.py | 1 + src/temgym_core/tree_utils.py | 2 +- src/temgym_core/utils.py | 1 + tests/test_component.py | 1 + tests/test_coordinates.py | 1 + tests/test_gaussians.py | 3 +-- tests/test_rays.py | 1 + tests/test_transfer.py | 1 + tests/transfer_matrices.py | 1 + 18 files changed, 23 insertions(+), 8 deletions(-) diff --git a/conftest.py b/conftest.py index ed6de7e..b67681a 100644 --- a/conftest.py +++ b/conftest.py @@ -1,3 +1,4 @@ import jax +jax.config.update("jax_enable_x64", True) # noqa: E702 jax.config.update("jax_platform_name", "cpu") diff --git a/src/temgym_core/__init__.py b/src/temgym_core/__init__.py index 8cf5803..a2ac04a 100755 --- a/src/temgym_core/__init__.py +++ b/src/temgym_core/__init__.py @@ -1,6 +1,7 @@ from typing_extensions import TypeAlias from typing import NamedTuple, Union +import jax; jax.config.update("jax_enable_x64", True) # noqa: E702 import jax.numpy as jnp import numpy as np from numpy.typing import NDArray diff --git a/src/temgym_core/components.py b/src/temgym_core/components.py index f615e76..aaed567 100755 --- a/src/temgym_core/components.py +++ b/src/temgym_core/components.py @@ -1,4 +1,6 @@ from typing import NamedTuple + +import jax; jax.config.update("jax_enable_x64", True) # noqa: E702 import jax_dataclasses as jdc import jax.numpy as jnp diff --git a/src/temgym_core/coordinate_transforms.py b/src/temgym_core/coordinate_transforms.py index a7b18f5..16708d8 100755 --- a/src/temgym_core/coordinate_transforms.py +++ b/src/temgym_core/coordinate_transforms.py @@ -1,3 +1,4 @@ +import jax; jax.config.update("jax_enable_x64", True) # noqa: E702 import jax.numpy as jnp import jax.lax as lax from . import Degrees, Radians, ShapeYX, CoordsXY, ScaleYX diff --git a/src/temgym_core/gaussian.py b/src/temgym_core/gaussian.py index 8a73937..ad6cfc3 100644 --- a/src/temgym_core/gaussian.py +++ b/src/temgym_core/gaussian.py @@ -1,12 +1,13 @@ +import jax; jax.config.update("jax_enable_x64", True) # noqa: E702 import jax.numpy as jnp -import jax +import jax_dataclasses as jdc +from jax._src.lax.control_flow.loops import _batch_and_remainder +from jax import lax + from .grid import Grid from .run import run_to_end from .utils import custom_jacobian_matrix from .ray import Ray -import jax_dataclasses as jdc -from jax._src.lax.control_flow.loops import _batch_and_remainder -from jax import lax def w_z(w0, z, z_r): diff --git a/src/temgym_core/grid.py b/src/temgym_core/grid.py index 830ac7a..6e959fa 100755 --- a/src/temgym_core/grid.py +++ b/src/temgym_core/grid.py @@ -1,5 +1,6 @@ from typing import Union import numpy as np +import jax; jax.config.update("jax_enable_x64", True) # noqa: E702 import jax.numpy as jnp from . import Degrees, ShapeYX, CoordXY, CoordsXY, ScaleYX, PixelsYX diff --git a/src/temgym_core/ray.py b/src/temgym_core/ray.py index f19fd4c..24dce53 100755 --- a/src/temgym_core/ray.py +++ b/src/temgym_core/ray.py @@ -1,4 +1,5 @@ import dataclasses +import jax; jax.config.update("jax_enable_x64", True) # noqa: E702 import jax_dataclasses as jdc import jax.numpy as jnp from .tree_utils import HasParamsMixin diff --git a/src/temgym_core/run.py b/src/temgym_core/run.py index 92489e2..899c6e2 100755 --- a/src/temgym_core/run.py +++ b/src/temgym_core/run.py @@ -2,7 +2,7 @@ import dataclasses from typing import TYPE_CHECKING, Sequence, Union, Any, Callable, Generator -import jax +import jax; jax.config.update("jax_enable_x64", True) # noqa: E702 import jax.numpy as jnp from .utils import custom_jacobian_matrix from .propagator import FreeSpaceParaxial, BasePropagator, Propagator diff --git a/src/temgym_core/source.py b/src/temgym_core/source.py index a0a6879..873df0a 100755 --- a/src/temgym_core/source.py +++ b/src/temgym_core/source.py @@ -1,4 +1,5 @@ import numpy as np +import jax; jax.config.update("jax_enable_x64", True) # noqa: E702 import jax_dataclasses as jdc from .tree_utils import HasParamsMixin diff --git a/src/temgym_core/transfer.py b/src/temgym_core/transfer.py index 027230e..0e6435f 100755 --- a/src/temgym_core/transfer.py +++ b/src/temgym_core/transfer.py @@ -1,3 +1,4 @@ +import jax; jax.config.update("jax_enable_x64", True) # noqa: E702 import jax.numpy as jnp import numpy as np diff --git a/src/temgym_core/tree_utils.py b/src/temgym_core/tree_utils.py index fbbb4f8..ecf5608 100644 --- a/src/temgym_core/tree_utils.py +++ b/src/temgym_core/tree_utils.py @@ -2,7 +2,7 @@ from typing_extensions import get_type_hints import dataclasses -import jax +import jax; jax.config.update("jax_enable_x64", True) # noqa: E702 from jax_dataclasses._dataclasses import ( FieldInfo, JDC_STATIC_MARKER, diff --git a/src/temgym_core/utils.py b/src/temgym_core/utils.py index 0a451af..0c44530 100755 --- a/src/temgym_core/utils.py +++ b/src/temgym_core/utils.py @@ -1,3 +1,4 @@ +import jax; jax.config.update("jax_enable_x64", True) # noqa: E702 import jax.numpy as jnp import numpy as np from numba import njit diff --git a/tests/test_component.py b/tests/test_component.py index 942b349..f56ad77 100755 --- a/tests/test_component.py +++ b/tests/test_component.py @@ -2,6 +2,7 @@ from numpy.testing import assert_allclose import numpy as np +import jax; jax.config.update("jax_enable_x64", True) # noqa: E702 from jax import jacobian import jax.numpy as jnp import jax_dataclasses as jdc diff --git a/tests/test_coordinates.py b/tests/test_coordinates.py index c37948d..40640d8 100644 --- a/tests/test_coordinates.py +++ b/tests/test_coordinates.py @@ -1,3 +1,4 @@ +import jax; jax.config.update("jax_enable_x64", True) # noqa: E702 import jax.numpy as jnp from temgym_core import PixelsYX, PixelYX, CoordsXY, CoordXY diff --git a/tests/test_gaussians.py b/tests/test_gaussians.py index e65d599..65d5af3 100644 --- a/tests/test_gaussians.py +++ b/tests/test_gaussians.py @@ -21,10 +21,9 @@ make_aperture, zero_phase, FresnelPropagator, fresnel_lens_imaging_solution ) import numpy as np +import jax; jax.config.update("jax_enable_x64", True) # noqa: E702 import jax.numpy as jnp -import jax import matplotlib.pyplot as plt -jax.config.update("jax_enable_x64", True) def plot_cross_sections( diff --git a/tests/test_rays.py b/tests/test_rays.py index 847b9ec..8851a98 100644 --- a/tests/test_rays.py +++ b/tests/test_rays.py @@ -1,6 +1,7 @@ import pytest import numpy as np import math +import jax; jax.config.update("jax_enable_x64", True) # noqa: E702 from jax import jacobian import jax.numpy as jnp diff --git a/tests/test_transfer.py b/tests/test_transfer.py index 145fc2e..011851f 100755 --- a/tests/test_transfer.py +++ b/tests/test_transfer.py @@ -1,3 +1,4 @@ +import jax; jax.config.update("jax_enable_x64", True) # noqa: E702 import jax.numpy as jnp import numpy as np from temgym_core.transfer import transfer_rays, transfer_rays_pt_src, accumulate_matrices diff --git a/tests/transfer_matrices.py b/tests/transfer_matrices.py index 90c7e36..f32ae70 100644 --- a/tests/transfer_matrices.py +++ b/tests/transfer_matrices.py @@ -1,5 +1,6 @@ import sympy as sp import numpy as np +import jax; jax.config.update("jax_enable_x64", True) # noqa: E702 import jax.numpy as jnp