Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
fd0bc7d
initial integration
DavidLanders95 Sep 2, 2025
d4894cf
Support
DavidLanders95 Sep 2, 2025
4d6a6c8
Added gaussian ray object.
DavidLanders95 Sep 2, 2025
19f6f03
Fixed gaussian ray that had wrong sign
DavidLanders95 Sep 3, 2025
777145e
Working for n rays
DavidLanders95 Sep 4, 2025
7c594c3
Added utility functions
DavidLanders95 Sep 4, 2025
cb42870
Added ugly checks to gaussians so we don't get divide by zero errors.…
DavidLanders95 Sep 4, 2025
4c14fa0
Added 1D coords option to Grid
DavidLanders95 Sep 4, 2025
9a732b3
Added roughly toleranced tests against fresnel solution which pass
DavidLanders95 Sep 4, 2025
a604e7c
Added transfer matrices py file for tests
DavidLanders95 Sep 4, 2025
b0c387c
Make sure rays was composed of rays going into the gaussian get_image…
DavidLanders95 Sep 4, 2025
4800c7c
Gaussian example with lots of rays runs
DavidLanders95 Sep 4, 2025
fd48a05
map reduce added to make it go brrrrrrr
DavidLanders95 Sep 4, 2025
2214055
Consistent stacking
DavidLanders95 Sep 4, 2025
d9ea9c9
Moved gaussian ray and and added to_vector in Ray
DavidLanders95 Sep 8, 2025
1d5ada2
Added test requirments
DavidLanders95 Sep 8, 2025
6867c56
Added to_vector test to test_rays
DavidLanders95 Sep 8, 2025
4bdf624
Added more gaussian tests for how their input tilts should correspond…
DavidLanders95 Sep 8, 2025
361724c
Trying to test small defocus values in image plane - 1e-6 works with …
DavidLanders95 Sep 10, 2025
a19fb80
Biprism with full magnification and small defocus really doesnt work,…
DavidLanders95 Sep 10, 2025
660eeff
Added extra phase factor to make biprism simulation work and fixed ra…
DavidLanders95 Sep 11, 2025
a78378c
Added test to check free space paraxial had not dependency on ray._one
DavidLanders95 Sep 12, 2025
acd530f
Added ray._one to deflector
DavidLanders95 Sep 12, 2025
d1e0f14
Added fit image to test fitting gaussians, but theta fitting struggles.
DavidLanders95 Sep 12, 2025
ad2335a
Image of smiley composed of gaussians works
DavidLanders95 Sep 12, 2025
98b8a94
Removing old notebooks
DavidLanders95 Sep 12, 2025
3a481fb
Removed unneccessary cells from biprism.ipynb
DavidLanders95 Sep 12, 2025
6607c47
Tidied up up aperture image notebook
DavidLanders95 Sep 12, 2025
59bff92
Added notebook to fit gaussians
DavidLanders95 Sep 12, 2025
92d71a6
added 5x5 matrices to transfer_matrix.py
DavidLanders95 Sep 12, 2025
04512c4
Added ray.derive to gaussian ray that uses jdc.replace (Maybe use thi…
DavidLanders95 Sep 12, 2025
a9b335a
Two beam interference notebook
DavidLanders95 Sep 12, 2025
c8da280
Fixed biprism tests
DavidLanders95 Sep 12, 2025
a5d2b43
Fixed gaussian tests
DavidLanders95 Sep 12, 2025
c710c0c
Simplified Aperture image notebook
DavidLanders95 Sep 12, 2025
6421349
Upscaled Smiley
DavidLanders95 Sep 12, 2025
61e57f4
Added aberratedKrivanekLens component and aberrations.py file with ex…
DavidLanders95 Sep 15, 2025
bb38e95
Energy is conserved in aperture diffraction solution, but need to thi…
DavidLanders95 Sep 15, 2025
bb0a435
Added random start to gaussian fitting
DavidLanders95 Sep 15, 2025
0ba5a9c
Trying to verify aberrated solution
DavidLanders95 Sep 15, 2025
f474224
Renamed aberrated image to aberrated probe, and fixed a test.
DavidLanders95 Sep 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
592 changes: 296 additions & 296 deletions README.md

Large diffs are not rendered by default.

598 changes: 598 additions & 0 deletions examples/aberrated_probe.ipynb

Large diffs are not rendered by default.

452 changes: 452 additions & 0 deletions examples/aperture_diffraction.ipynb

Large diffs are not rendered by default.

399 changes: 399 additions & 0 deletions examples/aperture_image.ipynb

Large diffs are not rendered by default.

603 changes: 603 additions & 0 deletions examples/biprism.ipynb

Large diffs are not rendered by default.

314 changes: 314 additions & 0 deletions examples/decompose_image_into_gaussians.ipynb

Large diffs are not rendered by default.

389 changes: 389 additions & 0 deletions examples/fit_three_gaussians.ipynb

Large diffs are not rendered by default.

250 changes: 250 additions & 0 deletions examples/single_gaussian_example.ipynb

Large diffs are not rendered by default.

507 changes: 507 additions & 0 deletions examples/two_beam_interference.ipynb

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions src/temgym_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,11 @@ class PixelsYX(NamedTuple):
"""
y: NDArray[np.integer]
x: NDArray[np.integer]


# Convenience re-exports
try:
from .plotting import plot_model, PlotParams # noqa: F401
except Exception:
# Plotting has optional dependencies (matplotlib); ignore import errors at package import time
pass
108 changes: 108 additions & 0 deletions src/temgym_core/aberrations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import jax.numpy as jnp
from dataclasses import dataclass


@dataclass
class KrivanekCoeffs:
C10: float = 0.0
C12: float = 0.0
phi12: float = 0.0
C21: float = 0.0
phi21: float = 0.0
C23: float = 0.0
phi23: float = 0.0
C30: float = 0.0
C32: float = 0.0
phi32: float = 0.0
C34: float = 0.0
phi34: float = 0.0
C41: float = 0.0
phi41: float = 0.0
C43: float = 0.0
phi43: float = 0.0
C45: float = 0.0
phi45: float = 0.0
C50: float = 0.0
C52: float = 0.0
phi52: float = 0.0
C54: float = 0.0
phi54: float = 0.0
C56: float = 0.0
phi56: float = 0.0


def _cos_kriv(m, ph, ph0):
return jnp.cos(m * (ph - ph0))


def _sin_kriv(m, ph, ph0):
return jnp.sin(m * (ph - ph0))


def krivanek_coeff_brackets(phi, p: KrivanekCoeffs):
B2 = p.C10 + p.C12 * _cos_kriv(2, phi, p.phi12)
B3 = p.C21 * _cos_kriv(1, phi, p.phi21) + p.C23 * _cos_kriv(3, phi, p.phi23)
B4 = p.C30 + p.C32 * _cos_kriv(2, phi, p.phi32) + p.C34 * _cos_kriv(4, phi, p.phi34)
B5 = p.C41 * _cos_kriv(1, phi, p.phi41) + p.C43 * _cos_kriv(3, phi, p.phi43) + p.C45 * _cos_kriv(5, phi, p.phi45) # noqa: E501
B6 = p.C50 + p.C52 * _cos_kriv(2, phi, p.phi52) + p.C54 * _cos_kriv(4, phi, p.phi54) + p.C56 * _cos_kriv(6, phi, p.phi56) # noqa: E501
return B2, B3, B4, B5, B6


def W_krivanek(alpha, phi, p):

B2, B3, B4, B5, B6 = krivanek_coeff_brackets(phi, p)
a = alpha
a2 = a * a
a3 = a2 * a
a4 = a2 * a2
a6 = a3 * a3

return 0.5 * a2 * B2 + (a3 / 3.0) * B3 + 0.25 * a4 * B4 + 0.2 * a4 * a * B5 + (a6 / 6.0) * B6


def grad_W_krivanek(alpha_x, alpha_y, p):
ax, ay = alpha_x, alpha_y
alpha = jnp.hypot(ax, ay)
phi = jnp.arctan2(ay, ax)

B2, B3, B4, B5, B6 = krivanek_coeff_brackets(phi, p)

a = alpha
a2 = a * a
a3 = a2 * a
a4 = a2 * a2
a5 = a4 * a
a6 = a3 * a3

dW_dalpha = a * B2 + a2 * B3 + a3 * B4 + a4 * B5 + a5 * B6

dW_dphi = (0.5 * a2) * (
-2.0 * p.C12 * _sin_kriv(2, phi, p.phi12)
)
dW_dphi += (a3 / 3.0) * (
-1.0 * p.C21 * _sin_kriv(1, phi, p.phi21)
- 3.0 * p.C23 * _sin_kriv(3, phi, p.phi23)
)
dW_dphi += (0.25 * a4) * (
-2.0 * p.C32 * _sin_kriv(2, phi, p.phi32)
- 4.0 * p.C34 * _sin_kriv(4, phi, p.phi34)
)
dW_dphi += (0.2 * a4 * a) * (
-1.0 * p.C41 * _sin_kriv(1, phi, p.phi41)
- 3.0 * p.C43 * _sin_kriv(3, phi, p.phi43)
- 5.0 * p.C45 * _sin_kriv(5, phi, p.phi45)
)
dW_dphi += (a6 / 6.0) * (
-2.0 * p.C52 * _sin_kriv(2, phi, p.phi52)
- 4.0 * p.C54 * _sin_kriv(4, phi, p.phi54)
- 6.0 * p.C56 * _sin_kriv(6, phi, p.phi56)
)

eps = 1e-30
a_safe = jnp.where(a == 0, eps, a)
inv_a = 1.0 / a_safe
inv_a2 = inv_a * inv_a

dWx = dW_dalpha * (ax * inv_a) + dW_dphi * (-ay * inv_a2)
dWy = dW_dalpha * (ay * inv_a) + dW_dphi * (ax * inv_a2)
return dWx, dWy
105 changes: 53 additions & 52 deletions src/temgym_core/components.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import NamedTuple
from typing import NamedTuple, Dict
import jax_dataclasses as jdc
import jax.numpy as jnp

from .ray import Ray
from .grid import Grid
from . import Degrees, CoordsXY, ScaleYX, ShapeYX
from .tree_utils import HasParamsMixin

from .aberrations import grad_W_krivanek, W_krivanek

class Component(HasParamsMixin):
"""Base component that transforms a ray without side effects.
Expand Down Expand Up @@ -173,6 +173,47 @@ def __call__(self, ray: Ray):
)


@jdc.pytree_dataclass
class AberratedLensKrivanek(Lens):
"""Thin lens with Krivanek aberrations.

Parameters
----------
z : float
Axial position in metres.
focal_length : float
Focal length in metres.
aber_coeffs : jnp.ndarray

"""
coeffs: Dict

def __call__(self, ray: Ray):
f = self.focal_length
x, y, dx, dy = ray.x, ray.y, ray.dx, ray.dy
coeffs = self.coeffs

# Paraxial thin lens
ideal_dx = -x / f + dx
ideal_dy = -y / f + dy

alpha = jnp.hypot(ideal_dx, ideal_dy) # radians
phi = jnp.arctan2(ideal_dy, ideal_dx) # radians

dWx, dWy = grad_W_krivanek(ideal_dx, ideal_dy, coeffs)
dux, duy = -dWx / f, -dWy / f

aber_dx = ideal_dx + dux
aber_dy = ideal_dy + duy

pathlength = ray.pathlength - (x**2 + y**2) / (2 * f) + W_krivanek(alpha, phi, coeffs) / f
one = ray._one * 1.0

return Ray(
x=x, y=y, dx=aber_dx, dy=aber_dy, _one=one, pathlength=pathlength, z=ray.z
)


@jdc.pytree_dataclass
class ScanGrid(Component, Grid):
"""Scanning grid defining pixel-to-metre mapping at plane z.
Expand Down Expand Up @@ -434,8 +475,8 @@ class Deflector(Component):
def __call__(self, ray: Ray):
x, y, dx, dy = ray.x, ray.y, ray.dx, ray.dy
return ray.derive(
dx=dx + self.def_x,
dy=dy + self.def_y,
dx=dx + self.def_x * ray._one,
dy=dy + self.def_y * ray._one,
pathlength=ray.pathlength + dx * x + dy * y,
)

Expand Down Expand Up @@ -505,53 +546,13 @@ class Biprism(Component):
z: float
offset: float = 0.0
rotation: Degrees = 0.0
deflection: float = 0.0

def __call__(
self,
ray: Ray,
) -> Ray:
pos_x, pos_y, dx, dy = ray.x, ray.y, ray.dx, ray.dy

deflection = self.deflection
offset = self.offset
rot = jnp.deg2rad(self.rotation)

rays_v = jnp.array([pos_x, pos_y]).T

biprism_loc_v = jnp.array([offset * jnp.cos(rot), offset * jnp.sin(rot)])

biprism_v = jnp.array([-jnp.sin(rot), jnp.cos(rot)])
biprism_v /= jnp.linalg.norm(biprism_v)

rays_v_centred = rays_v - biprism_loc_v

dot_product = jnp.dot(rays_v_centred, biprism_v) / jnp.dot(biprism_v, biprism_v)
projection = jnp.outer(dot_product, biprism_v)

rejection = rays_v_centred - projection
rejection = rejection / jnp.linalg.norm(rejection, axis=1, keepdims=True)

# If the ray position is located at [zero, zero], rejection_norm returns a nan,
# so we convert it to a zero, zero.
rejection = jnp.nan_to_num(rejection)

xdeflection_mag = rejection[:, 0]
ydeflection_mag = rejection[:, 1]

new_dx = (dx + xdeflection_mag * deflection).squeeze()
new_dy = (dy + ydeflection_mag * deflection).squeeze()
def_x: float = 0.0
side: int = 1

pathlength = ray.pathlength + (
xdeflection_mag * deflection * pos_x + ydeflection_mag * deflection * pos_y
)

return Ray(
x=pos_x.squeeze(),
y=pos_y.squeeze(),
dx=new_dx,
dy=new_dy,
_one=ray._one,
pathlength=pathlength,
z=ray.z,
def __call__(self, ray: Ray):
x, y, dx, dy = ray.x, ray.y, ray.dx, ray.dy
return ray.derive(
dx=dx + self.def_x * ray._one * jnp.sign(ray.x),
dy=dy,
pathlength=ray.pathlength + dx * x + dy * y,
)
Loading
Loading