Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import jax

jax.config.update("jax_enable_x64", True) # noqa: E702
jax.config.update("jax_platform_name", "cpu")
59 changes: 53 additions & 6 deletions src/temgym_core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing_extensions import TypeAlias
from typing import NamedTuple
from typing import NamedTuple, Union

import jax; jax.config.update("jax_enable_x64", True) # noqa: E702
import jax.numpy as jnp
import numpy as np
from numpy.typing import NDArray

Expand Down Expand Up @@ -46,6 +49,26 @@ class ScaleYX(NamedTuple):
x: float


class CoordXY(NamedTuple):
"""Continuous coordinates in the optical frame.

Parameters
----------
x : float
X position, metres.
y : float
Y position, metres.
"""
x: float
y: float

def to_coords(self) -> 'CoordsXY':
return CoordsXY(
x=jnp.array((self.x,)),
y=jnp.array((self.y,))
)


class CoordsXY(NamedTuple):
"""Continuous coordinates in the optical frame.

Expand All @@ -60,22 +83,46 @@ class CoordsXY(NamedTuple):
y: NDArray[np.floating]


class PixelYX(NamedTuple):
"""Pixel coordinates for images.

Parameters
----------
y : Union[int, float]
Pixel row indices
x : Union[int, float]
Pixel column indices

Notes
-----
Pixel indices are 0-based.
"""
y: Union[int, float]
x: Union[int, float]

def to_pixels(self) -> 'PixelsYX':
return PixelsYX(
x=jnp.array((self.x,)),
y=jnp.array((self.y,))
)


class PixelsYX(NamedTuple):
"""Discrete pixel coordinates for images.
"""Pixel coordinates for images.

Parameters
----------
y : numpy.ndarray
Pixel row indices. Integer dtype.
Pixel row indices. Integer or floating dtype.
x : numpy.ndarray
Pixel column indices. Integer dtype.
Pixel column indices. Integer or floating dtype.

Notes
-----
Pixel indices are 0-based.
"""
y: NDArray[np.integer]
x: NDArray[np.integer]
y: NDArray[Union[np.integer, np.floating]]
x: NDArray[Union[np.integer, np.floating]]


# Convenience re-exports
Expand Down
1 change: 1 addition & 0 deletions src/temgym_core/components.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import NamedTuple, Dict
import jax; jax.config.update("jax_enable_x64", True) # noqa: E702
import jax_dataclasses as jdc
import jax.numpy as jnp

Expand Down
1 change: 1 addition & 0 deletions src/temgym_core/coordinate_transforms.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import jax; jax.config.update("jax_enable_x64", True) # noqa: E702
import jax.numpy as jnp
import jax.lax as lax
from . import Degrees, Radians, ShapeYX, CoordsXY, ScaleYX
Expand Down
9 changes: 5 additions & 4 deletions src/temgym_core/gaussian.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import jax; jax.config.update("jax_enable_x64", True) # noqa: E702
import jax.numpy as jnp
import jax
import jax_dataclasses as jdc
from jax._src.lax.control_flow.loops import _batch_and_remainder
from jax import lax

from .grid import Grid
from .run import run_to_end
from .utils import custom_jacobian_matrix
from .ray import Ray
import jax_dataclasses as jdc
from jax._src.lax.control_flow.loops import _batch_and_remainder
from jax import lax


def w_z(w0, z, z_r):
Expand Down
17 changes: 12 additions & 5 deletions src/temgym_core/grid.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Union
import numpy as np
import jax; jax.config.update("jax_enable_x64", True) # noqa: E702
import jax.numpy as jnp

from . import Degrees, ShapeYX, CoordsXY, ScaleYX, PixelsYX
from . import Degrees, ShapeYX, CoordXY, CoordsXY, ScaleYX, PixelsYX
from .ray import Ray
from .utils import inplace_sum, try_ravel, try_reshape
from .coordinate_transforms import pixels_to_metres_transform, apply_transformation
Expand All @@ -15,7 +16,7 @@ class Grid:
----------
z : float
Axial position in metres.
centre : CoordsXY
centre : CoordXY
Grid centre in metres (x, y).
shape : ShapeYX
Grid shape (y, x) in pixels.
Expand All @@ -27,7 +28,7 @@ class Grid:
If True, apply an additional vertical flip.
"""
z: float
centre: CoordsXY
centre: CoordXY
shape: ShapeYX
pixel_size: ScaleYX
rotation: Degrees
Expand Down Expand Up @@ -146,7 +147,10 @@ def metres_to_pixels(self, coords: CoordsXY, cast: bool = True) -> PixelsYX:
if cast:
pixels_y = jnp.round(pixels_y).astype(jnp.int32)
pixels_x = jnp.round(pixels_x).astype(jnp.int32)
return try_reshape(pixels_y, coords_y), try_reshape(pixels_x, coords_x)
return PixelsYX(
y=try_reshape(pixels_y, coords_y),
x=try_reshape(pixels_x, coords_x)
)

def pixels_to_metres(self, pixels: PixelsYX) -> CoordsXY:
"""Convert pixel indices to metric coordinates.
Expand All @@ -172,7 +176,10 @@ def pixels_to_metres(self, pixels: PixelsYX) -> CoordsXY:
metres_y, metres_x = apply_transformation(
try_ravel(pixels_y), try_ravel(pixels_x), pixels_to_metres_mat
)
return try_reshape(metres_x, pixels_x), try_reshape(metres_y, pixels_y)
return CoordsXY(
x=try_reshape(metres_x, pixels_x),
y=try_reshape(metres_y, pixels_y)
)

def ray_at_grid(
self, px_y: float, px_x: float, dx: float = 0., dy: float = 0., z: float | None = None
Expand Down
1 change: 1 addition & 0 deletions src/temgym_core/ray.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dataclasses
import jax; jax.config.update("jax_enable_x64", True) # noqa: E702
import jax_dataclasses as jdc
import jax.numpy as jnp
from .tree_utils import HasParamsMixin
Expand Down
2 changes: 1 addition & 1 deletion src/temgym_core/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import dataclasses
from typing import TYPE_CHECKING, Sequence, Union, Any, Callable, Generator

import jax
import jax; jax.config.update("jax_enable_x64", True) # noqa: E702
import jax.numpy as jnp
from .utils import custom_jacobian_matrix
from .propagator import FreeSpaceParaxial, BasePropagator, Propagator
Expand Down
3 changes: 2 additions & 1 deletion src/temgym_core/source.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import jax; jax.config.update("jax_enable_x64", True) # noqa: E702
import jax_dataclasses as jdc

from .tree_utils import HasParamsMixin
Expand Down Expand Up @@ -98,7 +99,7 @@ class PointSource(Source):
"""
z: float
semi_conv: float
offset_xy: CoordsXY = (0.0, 0.0)
offset_xy: CoordsXY = CoordsXY(x=0.0, y=0.0)

def generate_array(self, num: int, random: bool = False) -> np.ndarray:
"""Generate rays with varying slopes within a cone of semi-convergence.
Expand Down
1 change: 1 addition & 0 deletions src/temgym_core/transfer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import jax; jax.config.update("jax_enable_x64", True) # noqa: E702
import jax.numpy as jnp
import numpy as np

Expand Down
2 changes: 1 addition & 1 deletion src/temgym_core/tree_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing_extensions import get_type_hints
import dataclasses

import jax
import jax; jax.config.update("jax_enable_x64", True) # noqa: E702
from jax_dataclasses._dataclasses import (
FieldInfo,
JDC_STATIC_MARKER,
Expand Down
1 change: 1 addition & 0 deletions src/temgym_core/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import jax; jax.config.update("jax_enable_x64", True) # noqa: E702
import jax.numpy as jnp
import numpy as np
from numba import njit
Expand Down
Loading
Loading