diff --git a/conftest.py b/conftest.py index ed6de7e..b67681a 100644 --- a/conftest.py +++ b/conftest.py @@ -1,3 +1,4 @@ import jax +jax.config.update("jax_enable_x64", True) # noqa: E702 jax.config.update("jax_platform_name", "cpu") diff --git a/src/temgym_core/__init__.py b/src/temgym_core/__init__.py index 44ad727..6f1c9ce 100755 --- a/src/temgym_core/__init__.py +++ b/src/temgym_core/__init__.py @@ -1,5 +1,8 @@ from typing_extensions import TypeAlias -from typing import NamedTuple +from typing import NamedTuple, Union + +import jax; jax.config.update("jax_enable_x64", True) # noqa: E702 +import jax.numpy as jnp import numpy as np from numpy.typing import NDArray @@ -46,6 +49,26 @@ class ScaleYX(NamedTuple): x: float +class CoordXY(NamedTuple): + """Continuous coordinates in the optical frame. + + Parameters + ---------- + x : float + X position, metres. + y : float + Y position, metres. + """ + x: float + y: float + + def to_coords(self) -> 'CoordsXY': + return CoordsXY( + x=jnp.array((self.x,)), + y=jnp.array((self.y,)) + ) + + class CoordsXY(NamedTuple): """Continuous coordinates in the optical frame. @@ -60,22 +83,46 @@ class CoordsXY(NamedTuple): y: NDArray[np.floating] +class PixelYX(NamedTuple): + """Pixel coordinates for images. + + Parameters + ---------- + y : Union[int, float] + Pixel row indices + x : Union[int, float] + Pixel column indices + + Notes + ----- + Pixel indices are 0-based. + """ + y: Union[int, float] + x: Union[int, float] + + def to_pixels(self) -> 'PixelsYX': + return PixelsYX( + x=jnp.array((self.x,)), + y=jnp.array((self.y,)) + ) + + class PixelsYX(NamedTuple): - """Discrete pixel coordinates for images. + """Pixel coordinates for images. Parameters ---------- y : numpy.ndarray - Pixel row indices. Integer dtype. + Pixel row indices. Integer or floating dtype. x : numpy.ndarray - Pixel column indices. Integer dtype. + Pixel column indices. Integer or floating dtype. Notes ----- Pixel indices are 0-based. """ - y: NDArray[np.integer] - x: NDArray[np.integer] + y: NDArray[Union[np.integer, np.floating]] + x: NDArray[Union[np.integer, np.floating]] # Convenience re-exports diff --git a/src/temgym_core/components.py b/src/temgym_core/components.py index 6f8fafa..0396cca 100755 --- a/src/temgym_core/components.py +++ b/src/temgym_core/components.py @@ -1,4 +1,5 @@ from typing import NamedTuple, Dict +import jax; jax.config.update("jax_enable_x64", True) # noqa: E702 import jax_dataclasses as jdc import jax.numpy as jnp diff --git a/src/temgym_core/coordinate_transforms.py b/src/temgym_core/coordinate_transforms.py index a7b18f5..16708d8 100755 --- a/src/temgym_core/coordinate_transforms.py +++ b/src/temgym_core/coordinate_transforms.py @@ -1,3 +1,4 @@ +import jax; jax.config.update("jax_enable_x64", True) # noqa: E702 import jax.numpy as jnp import jax.lax as lax from . import Degrees, Radians, ShapeYX, CoordsXY, ScaleYX diff --git a/src/temgym_core/gaussian.py b/src/temgym_core/gaussian.py index 4593f26..10168e9 100644 --- a/src/temgym_core/gaussian.py +++ b/src/temgym_core/gaussian.py @@ -1,12 +1,13 @@ +import jax; jax.config.update("jax_enable_x64", True) # noqa: E702 import jax.numpy as jnp -import jax +import jax_dataclasses as jdc +from jax._src.lax.control_flow.loops import _batch_and_remainder +from jax import lax + from .grid import Grid from .run import run_to_end from .utils import custom_jacobian_matrix from .ray import Ray -import jax_dataclasses as jdc -from jax._src.lax.control_flow.loops import _batch_and_remainder -from jax import lax def w_z(w0, z, z_r): diff --git a/src/temgym_core/grid.py b/src/temgym_core/grid.py index 5f77eb7..6e959fa 100755 --- a/src/temgym_core/grid.py +++ b/src/temgym_core/grid.py @@ -1,8 +1,9 @@ from typing import Union import numpy as np +import jax; jax.config.update("jax_enable_x64", True) # noqa: E702 import jax.numpy as jnp -from . import Degrees, ShapeYX, CoordsXY, ScaleYX, PixelsYX +from . import Degrees, ShapeYX, CoordXY, CoordsXY, ScaleYX, PixelsYX from .ray import Ray from .utils import inplace_sum, try_ravel, try_reshape from .coordinate_transforms import pixels_to_metres_transform, apply_transformation @@ -15,7 +16,7 @@ class Grid: ---------- z : float Axial position in metres. - centre : CoordsXY + centre : CoordXY Grid centre in metres (x, y). shape : ShapeYX Grid shape (y, x) in pixels. @@ -27,7 +28,7 @@ class Grid: If True, apply an additional vertical flip. """ z: float - centre: CoordsXY + centre: CoordXY shape: ShapeYX pixel_size: ScaleYX rotation: Degrees @@ -146,7 +147,10 @@ def metres_to_pixels(self, coords: CoordsXY, cast: bool = True) -> PixelsYX: if cast: pixels_y = jnp.round(pixels_y).astype(jnp.int32) pixels_x = jnp.round(pixels_x).astype(jnp.int32) - return try_reshape(pixels_y, coords_y), try_reshape(pixels_x, coords_x) + return PixelsYX( + y=try_reshape(pixels_y, coords_y), + x=try_reshape(pixels_x, coords_x) + ) def pixels_to_metres(self, pixels: PixelsYX) -> CoordsXY: """Convert pixel indices to metric coordinates. @@ -172,7 +176,10 @@ def pixels_to_metres(self, pixels: PixelsYX) -> CoordsXY: metres_y, metres_x = apply_transformation( try_ravel(pixels_y), try_ravel(pixels_x), pixels_to_metres_mat ) - return try_reshape(metres_x, pixels_x), try_reshape(metres_y, pixels_y) + return CoordsXY( + x=try_reshape(metres_x, pixels_x), + y=try_reshape(metres_y, pixels_y) + ) def ray_at_grid( self, px_y: float, px_x: float, dx: float = 0., dy: float = 0., z: float | None = None diff --git a/src/temgym_core/ray.py b/src/temgym_core/ray.py index f19fd4c..24dce53 100755 --- a/src/temgym_core/ray.py +++ b/src/temgym_core/ray.py @@ -1,4 +1,5 @@ import dataclasses +import jax; jax.config.update("jax_enable_x64", True) # noqa: E702 import jax_dataclasses as jdc import jax.numpy as jnp from .tree_utils import HasParamsMixin diff --git a/src/temgym_core/run.py b/src/temgym_core/run.py index 92489e2..899c6e2 100755 --- a/src/temgym_core/run.py +++ b/src/temgym_core/run.py @@ -2,7 +2,7 @@ import dataclasses from typing import TYPE_CHECKING, Sequence, Union, Any, Callable, Generator -import jax +import jax; jax.config.update("jax_enable_x64", True) # noqa: E702 import jax.numpy as jnp from .utils import custom_jacobian_matrix from .propagator import FreeSpaceParaxial, BasePropagator, Propagator diff --git a/src/temgym_core/source.py b/src/temgym_core/source.py index 670c9d2..873df0a 100755 --- a/src/temgym_core/source.py +++ b/src/temgym_core/source.py @@ -1,4 +1,5 @@ import numpy as np +import jax; jax.config.update("jax_enable_x64", True) # noqa: E702 import jax_dataclasses as jdc from .tree_utils import HasParamsMixin @@ -98,7 +99,7 @@ class PointSource(Source): """ z: float semi_conv: float - offset_xy: CoordsXY = (0.0, 0.0) + offset_xy: CoordsXY = CoordsXY(x=0.0, y=0.0) def generate_array(self, num: int, random: bool = False) -> np.ndarray: """Generate rays with varying slopes within a cone of semi-convergence. diff --git a/src/temgym_core/transfer.py b/src/temgym_core/transfer.py index 027230e..0e6435f 100755 --- a/src/temgym_core/transfer.py +++ b/src/temgym_core/transfer.py @@ -1,3 +1,4 @@ +import jax; jax.config.update("jax_enable_x64", True) # noqa: E702 import jax.numpy as jnp import numpy as np diff --git a/src/temgym_core/tree_utils.py b/src/temgym_core/tree_utils.py index fbbb4f8..ecf5608 100644 --- a/src/temgym_core/tree_utils.py +++ b/src/temgym_core/tree_utils.py @@ -2,7 +2,7 @@ from typing_extensions import get_type_hints import dataclasses -import jax +import jax; jax.config.update("jax_enable_x64", True) # noqa: E702 from jax_dataclasses._dataclasses import ( FieldInfo, JDC_STATIC_MARKER, diff --git a/src/temgym_core/utils.py b/src/temgym_core/utils.py index f7cbb0a..508609c 100755 --- a/src/temgym_core/utils.py +++ b/src/temgym_core/utils.py @@ -1,3 +1,4 @@ +import jax; jax.config.update("jax_enable_x64", True) # noqa: E702 import jax.numpy as jnp import numpy as np from numba import njit diff --git a/tests/test_component.py b/tests/test_component.py index 6644468..4a74fa2 100755 --- a/tests/test_component.py +++ b/tests/test_component.py @@ -1,12 +1,16 @@ import pytest +from numpy.testing import assert_allclose + import numpy as np -import jax +import jax; jax.config.update("jax_enable_x64", True) # noqa: E702 from jax import jacobian import jax.numpy as jnp import jax_dataclasses as jdc +from temgym_core.components import ( + ScanGrid, Detector, Descanner, Scanner, DescanError, Component, Plane, Biprism, Lens +) from temgym_core.source import ParallelBeam -from temgym_core.components import ScanGrid, Detector, Descanner, DescanError, Component, Biprism, Lens from temgym_core.ray import Ray from temgym_core.utils import custom_jacobian_matrix from temgym_core.run import run_to_end @@ -91,8 +95,8 @@ def test_scan_grid_metres_to_pixels(xy, rotation, expected_pixel_coords): shape=(11, 11), ) pixel_coords_y, pixel_coords_x = scan_grid.metres_to_pixels(xy) - np.testing.assert_allclose(pixel_coords_y, expected_pixel_coords[0], atol=1e-6) - np.testing.assert_allclose(pixel_coords_x, expected_pixel_coords[1], atol=1e-6) + assert_allclose(pixel_coords_y, expected_pixel_coords[0], atol=1e-6) + assert_allclose(pixel_coords_x, expected_pixel_coords[1], atol=1e-6) @pytest.mark.parametrize( @@ -120,8 +124,8 @@ def test_scan_grid_pixels_to_metres(pixel_coords, rotation, expected_xy): shape=(11, 11), ) metres_coords_x, metres_coords_y = scan_grid.pixels_to_metres(pixel_coords) - np.testing.assert_allclose(metres_coords_x, expected_xy[0], atol=1e-6) - np.testing.assert_allclose(metres_coords_y, expected_xy[1], atol=1e-6) + assert_allclose(metres_coords_x, expected_xy[0], atol=1e-6) + assert_allclose(metres_coords_y, expected_xy[1], atol=1e-6) @pytest.mark.parametrize( @@ -142,8 +146,8 @@ def test_detector_metres_to_pixels(xy, expected_pixel_coords): flip_y=False, ) pixel_coords_y, pixel_coords_x = detector.metres_to_pixels(xy) - np.testing.assert_allclose(pixel_coords_y, expected_pixel_coords[0], atol=1e-6) - np.testing.assert_allclose(pixel_coords_x, expected_pixel_coords[1], atol=1e-6) + assert_allclose(pixel_coords_y, expected_pixel_coords[0], atol=1e-6) + assert_allclose(pixel_coords_x, expected_pixel_coords[1], atol=1e-6) # Test cases for Detector: @@ -166,46 +170,93 @@ def test_detector_pixels_to_metres(pixel_coords, expected_xy): flip_y=False, ) metres_coords_x, metres_coords_y = detector.pixels_to_metres(pixel_coords) - np.testing.assert_allclose(metres_coords_x, expected_xy[0], atol=1e-6) - np.testing.assert_allclose(metres_coords_y, expected_xy[1], atol=1e-6) + assert_allclose(metres_coords_x, expected_xy[0], atol=1e-6) + assert_allclose(metres_coords_y, expected_xy[1], atol=1e-6) + + +def test_plane(): + x, y, dx, dy, z, pathlength = np.random.uniform(-5.0, 5.0, size=6) + ray = Ray(x=x, y=y, dx=dx, dy=dy, _one=1.0, z=0.0, pathlength=0.0) + comp = Plane(z=23) + out = comp(ray) + for attr in ('x', 'y', 'dx', 'dy', '_one', 'z', 'pathlength'): + assert getattr(ray, attr) == getattr(out, attr) + + +def test_scanner_random(): + # Randomly chosen scan position and ray parameters + sp_x, sp_y = np.random.uniform(-5.0, 5.0), np.random.uniform(-5.0, 5.0) + st_x, st_y = np.random.uniform(-0.5, 0.5), np.random.uniform(0.5, 0.5) + x, y, dx, dy, z, pathlength = np.random.uniform(-5.0, 5.0, size=6) + + sc = Scanner( + z=23, + scan_pos_x=sp_x, scan_pos_y=sp_y, + scan_tilt_y=st_y, scan_tilt_x=st_x, + ) + ray = Ray(x=x, y=y, dx=dx, dy=dy, _one=1.0, z=z, pathlength=pathlength) + out = sc(ray) + + # Expected values computed using the same formula as in the implementation + exp_x = x + sp_x + exp_y = y + sp_y + exp_dx = dx + st_x + exp_dy = dy + st_y + + assert_allclose(out.x, exp_x, atol=1e-6) + assert_allclose(out.y, exp_y, atol=1e-6) + assert_allclose(out.dx, exp_dx, atol=1e-6) + assert_allclose(out.dy, exp_dy, atol=1e-6) + for attr in ('_one', 'z', 'pathlength'): + assert getattr(ray, attr) == getattr(out, attr) def test_descanner_random_descan_error(): # Randomly chosen scan position and ray parameters sp_x, sp_y = np.random.uniform(-5.0, 5.0), np.random.uniform(-5.0, 5.0) - x, y, dx, dy = np.random.uniform(-5.0, 5.0, size=4) + st_x, st_y = np.random.uniform(-0.5, 0.5), np.random.uniform(0.5, 0.5) + x, y, dx, dy, z, pathlength = np.random.uniform(-5.0, 5.0, size=6) # Randomly chosen non-zero descan error (length 12) - err = np.random.rand(12) + (pxo_pxi, pxo_pyi, pyo_pxi, pyo_pyi, + sxo_pxi, sxo_pyi, syo_pxi, syo_pyi, + offpxi, offpyi, offsxi, offsyi) = np.random.rand(12) err = DescanError( - pxo_pxi=err[0], - pxo_pyi=err[1], - pyo_pxi=err[2], - pyo_pyi=err[3], - sxo_pxi=err[4], - sxo_pyi=err[5], - syo_pxi=err[6], - syo_pyi=err[7], - offpxi=err[8], - offpyi=err[9], - offsxi=err[10], - offsyi=err[11], + pxo_pxi=pxo_pxi, + pxo_pyi=pxo_pyi, + pyo_pxi=pyo_pxi, + pyo_pyi=pyo_pyi, + sxo_pxi=sxo_pxi, + sxo_pyi=sxo_pyi, + syo_pxi=syo_pxi, + syo_pyi=syo_pyi, + offpxi=offpxi, + offpyi=offpyi, + offsxi=offsxi, + offsyi=offsyi, ) - desc = Descanner(z=0.0, scan_pos_x=sp_x, scan_pos_y=sp_y, descan_error=err) - ray = Ray(x=x, y=y, dx=dx, dy=dy, _one=1.0, z=0.0, pathlength=0.0) + desc = Descanner( + z=23, + scan_pos_x=sp_x, scan_pos_y=sp_y, + scan_tilt_y=st_y, scan_tilt_x=st_x, + descan_error=err + ) + ray = Ray(x=x, y=y, dx=dx, dy=dy, _one=1.0, z=z, pathlength=pathlength) out = desc(ray) # Expected values computed using the same formula as in the implementation - exp_x = x + sp_x * err[0] + sp_y * err[1] + err[8] - sp_x - exp_y = y + sp_x * err[2] + sp_y * err[3] + err[9] - sp_y - exp_dx = dx + sp_x * err[4] + sp_y * err[5] + err[10] - exp_dy = dy + sp_x * err[6] + sp_y * err[7] + err[11] + exp_x = x + sp_x * pxo_pxi + sp_y * pxo_pyi + offpxi - sp_x + exp_y = y + sp_x * pyo_pxi + sp_y * pyo_pyi + offpyi - sp_y + exp_dx = dx + sp_x * sxo_pxi + sp_y * sxo_pyi + offsxi - st_x + exp_dy = dy + sp_x * syo_pxi + sp_y * syo_pyi + offsyi - st_y - np.testing.assert_allclose(out.x, exp_x, atol=1e-6) - np.testing.assert_allclose(out.y, exp_y, atol=1e-6) - np.testing.assert_allclose(out.dx, exp_dx, atol=1e-6) - np.testing.assert_allclose(out.dy, exp_dy, atol=1e-6) + assert_allclose(out.x, exp_x, atol=1e-6) + assert_allclose(out.y, exp_y, atol=1e-6) + assert_allclose(out.dx, exp_dx, atol=1e-6) + assert_allclose(out.dy, exp_dy, atol=1e-6) + for attr in ('_one', 'z', 'pathlength'): + assert getattr(ray, attr) == getattr(out, attr) def test_descanner_offset_consistency(): @@ -228,7 +279,7 @@ def test_descanner_offset_consistency(): offsyi=err[11], ) desc = Descanner( - z=0.0, scan_pos_x=scan_pos_x, scan_pos_y=scan_pos_y, descan_error=err + z=11, scan_pos_x=scan_pos_x, scan_pos_y=scan_pos_y, descan_error=err ) # generate a batch of random rays @@ -256,7 +307,7 @@ def test_descanner_offset_consistency(): # assert that all rays have received the same offset first = offsets[0] for off in offsets: - np.testing.assert_allclose(off, first, atol=1e-6) + assert_allclose(off, first, atol=1e-6) def test_descanner_jacobian_matrix(): @@ -299,7 +350,7 @@ def test_descanner_jacobian_matrix(): [0.0, 0.0, 0.0, 0.0, 1.0], ] ) - np.testing.assert_allclose(J, T, atol=1e-6) + assert_allclose(J, T, atol=1e-6) @pytest.mark.parametrize("repeat", tuple(range(5))) @@ -324,7 +375,7 @@ def test_scan_grid_rotation_random(repeat): # expected rotated step vector = R(scan_rot) @ [step_x, 0] theta = np.deg2rad(scan_rot) exp_scan = np.array([np.cos(theta) * step[0], -np.sin(theta) * step[0]]) - np.testing.assert_allclose(vec_scan, exp_scan, atol=1e-6) + assert_allclose(vec_scan, exp_scan, atol=1e-6) def test_singular_component_jacobian(): diff --git a/tests/test_coordinates.py b/tests/test_coordinates.py new file mode 100644 index 0000000..40640d8 --- /dev/null +++ b/tests/test_coordinates.py @@ -0,0 +1,40 @@ +import jax; jax.config.update("jax_enable_x64", True) # noqa: E702 +import jax.numpy as jnp + +from temgym_core import PixelsYX, PixelYX, CoordsXY, CoordXY + + +def test_to_pixels_int(): + px = PixelYX(x=17, y=23) + pxs = px.to_pixels() + assert isinstance(pxs, PixelsYX) + assert len(pxs.x) == 1 + assert len(pxs.y) == 1 + assert jnp.all(px.x == pxs.x) + assert jnp.all(px.y == pxs.y) + assert pxs.x.dtype.kind == 'i' + assert pxs.y.dtype.kind == 'i' + + +def test_to_pixels_float(): + px = PixelYX(x=17., y=23.) + pxs = px.to_pixels() + assert isinstance(pxs, PixelsYX) + assert len(pxs.x) == 1 + assert len(pxs.y) == 1 + assert jnp.all(px.x == pxs.x) + assert jnp.all(px.y == pxs.y) + assert pxs.x.dtype.kind == 'f' + assert pxs.y.dtype.kind == 'f' + + +def test_to_coords(): + coord = CoordXY(x=17., y=23.) + coords = coord.to_coords() + assert isinstance(coords, CoordsXY) + assert len(coords.x) == 1 + assert len(coords.y) == 1 + assert jnp.all(coord.x == coords.x) + assert jnp.all(coord.y == coords.y) + assert coords.x.dtype.kind == 'f' + assert coords.y.dtype.kind == 'f' diff --git a/tests/test_gaussians.py b/tests/test_gaussians.py index 3cbc8c8..568b0db 100644 --- a/tests/test_gaussians.py +++ b/tests/test_gaussians.py @@ -19,10 +19,12 @@ calculate_z1_and_z2_from_M_and_f, ) from skimage.restoration import unwrap_phase -from temgym_core.utils import make_aperture, zero_phase, FresnelPropagator, fresnel_lens_imaging_solution +from temgym_core.utils import ( + make_aperture, zero_phase, FresnelPropagator, fresnel_lens_imaging_solution +) import numpy as np +import jax; jax.config.update("jax_enable_x64", True) # noqa: E702 import jax.numpy as jnp -import jax import matplotlib.pyplot as plt jax.config.update("jax_enable_x64", True) @@ -327,7 +329,9 @@ def test_gaussian_free_space_vs_fresnel(): gauss_input.shape[1] // 2, ) - fresnel_gauss_image = FresnelPropagator(gauss_input, det_edge_x, wavelength, propagation_distance) + fresnel_gauss_image = FresnelPropagator( + gauss_input, det_edge_x, wavelength, propagation_distance + ) # Normalize amplitude so the maximum magnitude is 1 analytic_gauss_image /= np.max(np.abs(analytic_gauss_image)) @@ -436,8 +440,10 @@ def test_gaussian_lens_vs_fresnel(): gauss_input.shape[1] // 2, ) - fresnel_gauss_image = fresnel_lens_imaging_solution(gauss_input, Y, X, pixel_size[0], wavelength, - defocus+np.abs(z1), f, z2) + fresnel_gauss_image = fresnel_lens_imaging_solution( + gauss_input, Y, X, pixel_size[0], wavelength, + defocus+np.abs(z1), f, z2 + ) fresnel_gauss_image = zero_phase( fresnel_gauss_image, @@ -597,11 +603,17 @@ def test_gaussian_two_beam_interference_vs_fresnel(): tilted_shifted_plane_wave1 = np.exp(1j * k * dot[0]) tilted_shifted_plane_wave2 = np.exp(1j * k * dot[1]) - gaussian_misaligned1 = (gaussian_shifted_1 * tilted_shifted_plane_wave1).reshape(shape[0], shape[1]) - gaussian_misaligned2 = (gaussian_shifted_2 * tilted_shifted_plane_wave2).reshape(shape[0], shape[1]) + gaussian_misaligned1 = ( + gaussian_shifted_1 * tilted_shifted_plane_wave1 + ).reshape(shape[0], shape[1]) + gaussian_misaligned2 = ( + gaussian_shifted_2 * tilted_shifted_plane_wave2 + ).reshape(shape[0], shape[1]) gaussian_misaligned = gaussian_misaligned1 + gaussian_misaligned2 - fresnel_gauss_image = fresnel_lens_imaging_solution(gaussian_misaligned, Y, X, pixel_size[0], wavelength, 0.0, f, z2) + fresnel_gauss_image = fresnel_lens_imaging_solution( + gaussian_misaligned, Y, X, pixel_size[0], wavelength, 0.0, f, z2 + ) fresnel_gauss_image = zero_phase(fresnel_gauss_image, shape[0]//2, shape[1]//2) # Normalize amplitude so the maximum magnitude is 1 diff --git a/tests/test_rays.py b/tests/test_rays.py index 3fe5d06..e21e66f 100644 --- a/tests/test_rays.py +++ b/tests/test_rays.py @@ -1,6 +1,7 @@ import pytest import numpy as np import math +import jax; jax.config.update("jax_enable_x64", True) # noqa: E702 from jax import jacobian import jax.numpy as jnp from jax import value_and_grad diff --git a/tests/test_run.py b/tests/test_run.py new file mode 100644 index 0000000..7165b2c --- /dev/null +++ b/tests/test_run.py @@ -0,0 +1,92 @@ +from numpy.testing import assert_allclose + +from temgym_core.components import Scanner, Plane, Descanner +from temgym_core.source import PointSource +from temgym_core.ray import Ray +from temgym_core.run import run_iter +from temgym_core.propagator import Propagator, FreeSpaceParaxial + + +def test_run_iter(): + # These components shouldn't change the ray as it passes through + components = ( + PointSource(z=0., semi_conv=0.023), + Scanner(z=1.2, scan_pos_x=0., scan_pos_y=0.), + Plane(z=1.2), + Descanner(z=1.2, scan_pos_x=0., scan_pos_y=0.), + Plane(z=3.1) + ) + ray = Ray( + x=0.12, + y=0.23, + dx=0.34, + dy=0.45, + z=3.14, + pathlength=0.34 + ) + res = list(run_iter(ray=ray, components=components)) + + prev_ray = ray + + for i, component in enumerate(components): + prop_index = 2*i + comp_index = 2*i + 1 + prop, prop_r = res[prop_index] + comp, comp_r = res[comp_index] + assert isinstance(prop, Propagator) + assert isinstance(prop.propagator, FreeSpaceParaxial) + assert prop.distance == component.z - prev_ray.z + assert_allclose(prop_r.z, comp.z) + assert_allclose(comp_r.z, comp.z) + assert prev_ray.dx == prop_r.dx + assert prev_ray.dy == prop_r.dy + assert_allclose(prop_r.x, prev_ray.x + prev_ray.dx*prop.distance) + assert_allclose(prop_r.y, prev_ray.y + prev_ray.dy*prop.distance) + # FIXME add test for correct path length + + prev_ray = comp_r + + +def test_run_iter_noprop(): + # everything at the same z level + z = 1.2 + # These components do change the ray, i.e. we + # test that run_iter() actually passes the ray through the components + components = ( + PointSource(z=z, semi_conv=0.023), + Scanner(z=z, scan_pos_x=23., scan_pos_y=42.), + Plane(z=z), + Descanner(z=z, scan_pos_x=13., scan_pos_y=11.), + Plane(z=z) + ) + ray = Ray( + x=0.12, + y=0.23, + dx=0.34, + dy=0.45, + z=z, + pathlength=0.34 + ) + res = list(run_iter(ray=ray, components=components)) + + # Reference result: Compose the components without propagation + res_ref = ray + for comp in components: + res_ref = comp(res_ref) + + prev_ray = ray + for i, component in enumerate(components): + prop_index = 2*i + comp_index = 2*i + 1 + prop, prop_r = res[prop_index] + comp, comp_r = res[comp_index] + assert isinstance(prop, Propagator) + assert isinstance(prop.propagator, FreeSpaceParaxial) + assert prop.distance == component.z - prev_ray.z + assert_allclose(prop_r.z, comp.z) + assert_allclose(comp_r.z, comp.z) + prev_ray = comp_r + + final_ray = res[-1][1] + for attr in ('x', 'y', 'dx', 'dy', '_one', 'z', 'pathlength'): + assert_allclose(getattr(final_ray, attr), getattr(res_ref, attr)) diff --git a/tests/test_transfer.py b/tests/test_transfer.py index 145fc2e..011851f 100755 --- a/tests/test_transfer.py +++ b/tests/test_transfer.py @@ -1,3 +1,4 @@ +import jax; jax.config.update("jax_enable_x64", True) # noqa: E702 import jax.numpy as jnp import numpy as np from temgym_core.transfer import transfer_rays, transfer_rays_pt_src, accumulate_matrices diff --git a/tests/transfer_matrices.py b/tests/transfer_matrices.py index cc234cc..56aaeee 100644 --- a/tests/transfer_matrices.py +++ b/tests/transfer_matrices.py @@ -1,5 +1,6 @@ import sympy as sp import numpy as np +import jax; jax.config.update("jax_enable_x64", True) # noqa: E702 import jax.numpy as jnp