Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions docs/user-guide/tof/dream.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -743,8 +743,8 @@
"metadata": {},
"outputs": [],
"source": [
"table = lut_wf.compute(TofLookupTable).array\n",
"table.plot() / (sc.stddevs(table) / sc.values(table)).plot(norm=\"log\")"
"table = lut_wf.compute(TofLookupTable)\n",
"table.plot() / (sc.stddevs(table.array) / sc.values(table.array)).plot(norm=\"log\")"
]
},
{
Expand All @@ -767,10 +767,12 @@
"metadata": {},
"outputs": [],
"source": [
"lut_wf[LookupTableRelativeErrorThreshold] = 0.01\n",
"wf[TofLookupTable] = table\n",
"\n",
"table = lut_wf.compute(TofLookupTable)\n",
"table.plot()"
"wf[LookupTableRelativeErrorThreshold] = 8.0e-3\n",
"\n",
"masked_table = wf.compute(ErrorLimitedTofLookupTable)\n",
"masked_table.plot()"
]
},
{
Expand All @@ -797,8 +799,6 @@
"wf[RawDetector[SampleRun]] = ess_beamline.get_monitor(\"detector\")[0]\n",
"wf[DetectorLtotal[SampleRun]] = Ltotal\n",
"\n",
"wf[TofLookupTable] = table\n",
"\n",
"# Compute time-of-flight\n",
"tofs = wf.compute(TofDetector[SampleRun])\n",
"# Compute wavelength\n",
Expand Down Expand Up @@ -833,7 +833,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
"version": "3.12.12"
}
},
"nbformat": 4,
Expand Down
1 change: 0 additions & 1 deletion docs/user-guide/tof/wfm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,6 @@
"lut_wf[DiskChoppers[AnyRun]] = disk_choppers\n",
"lut_wf[SourcePosition] = source_position\n",
"lut_wf[LtotalRange] = Ltotal, Ltotal\n",
"lut_wf[LookupTableRelativeErrorThreshold] = 0.1\n",
"lut_wf.visualize(TofLookupTable)"
]
},
Expand Down
4 changes: 3 additions & 1 deletion src/ess/reduce/time_of_flight/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from .eto_to_tof import providers
from .lut import (
DistanceResolution,
LookupTableRelativeErrorThreshold,
LtotalRange,
NumberOfSimulatedNeutrons,
PulsePeriod,
Expand All @@ -24,6 +23,8 @@
)
from .types import (
DetectorLtotal,
ErrorLimitedTofLookupTable,
LookupTableRelativeErrorThreshold,
MonitorLtotal,
PulseStrideOffset,
TimeOfFlightLookupTable,
Expand All @@ -42,6 +43,7 @@
"DetectorLtotal",
"DiskChoppers",
"DistanceResolution",
"ErrorLimitedTofLookupTable",
"GenericTofWorkflow",
"LookupTableRelativeErrorThreshold",
"LtotalRange",
Expand Down
45 changes: 38 additions & 7 deletions src/ess/reduce/time_of_flight/eto_to_tof.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

from collections.abc import Callable
from dataclasses import asdict

import numpy as np
import scipp as sc
Expand All @@ -34,6 +35,8 @@
from .resample import rebin_strictly_increasing
from .types import (
DetectorLtotal,
ErrorLimitedTofLookupTable,
LookupTableRelativeErrorThreshold,
MonitorLtotal,
PulseStrideOffset,
ToaDetector,
Expand Down Expand Up @@ -99,7 +102,7 @@ def __call__(


def _time_of_flight_data_histogram(
da: sc.DataArray, lookup: TofLookupTable, ltotal: sc.Variable
da: sc.DataArray, lookup: ErrorLimitedTofLookupTable, ltotal: sc.Variable
) -> sc.DataArray:
# In NeXus, 'time_of_flight' is the canonical name in NXmonitor, but in some files,
# it may be called 'tof' or 'frame_time'.
Expand Down Expand Up @@ -204,7 +207,7 @@ def _guess_pulse_stride_offset(

def _prepare_tof_interpolation_inputs(
da: sc.DataArray,
lookup: TofLookupTable,
lookup: ErrorLimitedTofLookupTable,
ltotal: sc.Variable,
pulse_stride_offset: int | None,
) -> dict:
Expand Down Expand Up @@ -298,7 +301,7 @@ def _prepare_tof_interpolation_inputs(

def _time_of_flight_data_events(
da: sc.DataArray,
lookup: TofLookupTable,
lookup: ErrorLimitedTofLookupTable,
ltotal: sc.Variable,
pulse_stride_offset: int | None,
) -> sc.DataArray:
Expand Down Expand Up @@ -396,9 +399,36 @@ def monitor_ltotal_from_straight_line_approximation(
)


def mask_large_uncertainty_in_lut(
table: TofLookupTable, error_threshold: LookupTableRelativeErrorThreshold
) -> ErrorLimitedTofLookupTable:
"""
Mask regions in the time-of-flight lookup table with large uncertainty using NaNs.

Parameters
----------
table:
Lookup table with time-of-flight as a function of distance and time-of-arrival.
error_threshold:
Threshold for the relative standard deviation (coefficient of variation) of the
projected time-of-flight above which values are masked.
"""
# TODO: The error threshold could be made dependent on the time-of-flight or
# distance, instead of being a single value for the whole table.
da = table.array
relative_error = sc.stddevs(da.data) / sc.values(da.data)
mask = relative_error > sc.scalar(error_threshold)
return ErrorLimitedTofLookupTable(
**{
**asdict(table),
"array": sc.where(mask, sc.scalar(np.nan, unit=da.unit), da),
}
)


def _compute_tof_data(
da: sc.DataArray,
lookup: TofLookupTable,
lookup: ErrorLimitedTofLookupTable,
ltotal: sc.Variable,
pulse_stride_offset: int,
) -> sc.DataArray:
Expand All @@ -417,7 +447,7 @@ def _compute_tof_data(

def detector_time_of_flight_data(
detector_data: RawDetector[RunType],
lookup: TofLookupTable,
lookup: ErrorLimitedTofLookupTable,
ltotal: DetectorLtotal[RunType],
pulse_stride_offset: PulseStrideOffset,
) -> TofDetector[RunType]:
Expand Down Expand Up @@ -452,7 +482,7 @@ def detector_time_of_flight_data(

def monitor_time_of_flight_data(
monitor_data: RawMonitor[RunType, MonitorType],
lookup: TofLookupTable,
lookup: ErrorLimitedTofLookupTable,
ltotal: MonitorLtotal[RunType, MonitorType],
pulse_stride_offset: PulseStrideOffset,
) -> TofMonitor[RunType, MonitorType]:
Expand Down Expand Up @@ -487,7 +517,7 @@ def monitor_time_of_flight_data(

def detector_time_of_arrival_data(
detector_data: RawDetector[RunType],
lookup: TofLookupTable,
lookup: ErrorLimitedTofLookupTable,
ltotal: DetectorLtotal[RunType],
pulse_stride_offset: PulseStrideOffset,
) -> ToaDetector[RunType]:
Expand Down Expand Up @@ -585,4 +615,5 @@ def providers() -> tuple[Callable]:
detector_time_of_arrival_data,
detector_wavelength_data,
monitor_wavelength_data,
mask_large_uncertainty_in_lut,
)
37 changes: 0 additions & 37 deletions src/ess/reduce/time_of_flight/lut.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from dataclasses import dataclass
from typing import NewType

import numpy as np
import sciline as sl
import scipp as sc

Expand Down Expand Up @@ -112,13 +111,6 @@ class SimulationResults:
smaller if the pulse period is not an integer multiple of the time resolution.
"""


LookupTableRelativeErrorThreshold = NewType("LookupTableRelativeErrorThreshold", float)
"""
Threshold for the relative standard deviation (coefficient of variation) of the
projected time-of-flight above which values are masked.
"""

PulsePeriod = NewType("PulsePeriod", sc.Variable)
"""
Period of the source pulses, i.e., time between consecutive pulse starts.
Expand All @@ -144,26 +136,6 @@ class SimulationResults:
"""


def _mask_large_uncertainty(table: sc.DataArray, error_threshold: float):
"""
Mask regions with large uncertainty with NaNs.
The values are modified in place in the input table.

Parameters
----------
table:
Lookup table with time-of-flight as a function of distance and time-of-arrival.
error_threshold:
Threshold for the relative standard deviation (coefficient of variation) of the
projected time-of-flight above which values are masked.
"""
# Finally, mask regions with large uncertainty with NaNs.
relative_error = sc.stddevs(table.data) / sc.values(table.data)
mask = relative_error > sc.scalar(error_threshold)
# Use numpy for indexing as table is 2D
table.values[mask.values] = np.nan


def _compute_mean_tof(
simulation: BeamlineComponentReading,
distance: sc.Variable,
Expand Down Expand Up @@ -235,7 +207,6 @@ def make_tof_lookup_table(
time_resolution: TimeResolution,
pulse_period: PulsePeriod,
pulse_stride: PulseStride,
error_threshold: LookupTableRelativeErrorThreshold,
) -> TofLookupTable:
"""
Compute a lookup table for time-of-flight as a function of distance and
Expand All @@ -258,9 +229,6 @@ def make_tof_lookup_table(
pulse_stride:
Stride of used pulses. Usually 1, but may be a small integer when
pulse-skipping.
error_threshold:
Threshold for the relative standard deviation (coefficient of variation) of the
projected time-of-flight above which values are masked.

Notes
-----
Expand Down Expand Up @@ -387,17 +355,13 @@ def make_tof_lookup_table(
},
)

# In-place masking for better performance
_mask_large_uncertainty(table, error_threshold)

return TofLookupTable(
array=table,
pulse_period=pulse_period,
pulse_stride=pulse_stride,
distance_resolution=table.coords["distance"][1] - table.coords["distance"][0],
time_resolution=table.coords["event_time_offset"][1]
- table.coords["event_time_offset"][0],
error_threshold=error_threshold,
choppers=sc.DataGroup(
{k: sc.DataGroup(ch.as_dict()) for k, ch in simulation.choppers.items()}
)
Expand Down Expand Up @@ -490,7 +454,6 @@ def TofLookupTableWorkflow():
PulseStride: 1,
DistanceResolution: sc.scalar(0.1, unit="m"),
TimeResolution: sc.scalar(250.0, unit='us'),
LookupTableRelativeErrorThreshold: 0.1,
NumberOfSimulatedNeutrons: 1_000_000,
SimulationSeed: None,
SimulationFacility: 'ess',
Expand Down
15 changes: 12 additions & 3 deletions src/ess/reduce/time_of_flight/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ class TofLookupTable:
"""Resolution of the distance coordinate in the lookup table."""
time_resolution: sc.Variable
"""Resolution of the time_of_arrival coordinate in the lookup table."""
error_threshold: float
"""The table is masked with NaNs in regions where the standard deviation of the
time-of-flight is above this threshold."""
choppers: sc.DataGroup | None = None
"""Chopper parameters used when generating the lookup table, if any. This is made
optional so we can still support old lookup tables without chopper info."""
Expand All @@ -54,12 +51,24 @@ def plot(self, *args, **kwargs) -> Any:
"""Lookup table giving time-of-flight as a function of distance and time of arrival
(alias)."""


class ErrorLimitedTofLookupTable(TofLookupTable):
"""Lookup table that is masked with NaNs in regions where the standard deviation of
the time-of-flight is above a certain threshold."""


PulseStrideOffset = NewType("PulseStrideOffset", int | None)
"""
When pulse-skipping, the offset of the first pulse in the stride. This is typically
zero but can be a small integer < pulse_stride. If None, a guess is made.
"""

LookupTableRelativeErrorThreshold = NewType("LookupTableRelativeErrorThreshold", float)
"""
Threshold for the relative standard deviation (coefficient of variation) of the
projected time-of-flight above which values are masked.
"""


class DetectorLtotal(sl.Scope[RunType, sc.Variable], sc.Variable):
"""Total path length of neutrons from source to detector (L1 + L2)."""
Expand Down
30 changes: 16 additions & 14 deletions src/ess/reduce/time_of_flight/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,34 @@

from ..nexus import GenericNeXusWorkflow
from . import eto_to_tof
from .types import PulseStrideOffset, TofLookupTable, TofLookupTableFilename
from .types import (
LookupTableRelativeErrorThreshold,
PulseStrideOffset,
TofLookupTable,
TofLookupTableFilename,
)


def load_tof_lookup_table(
filename: TofLookupTableFilename,
) -> TofLookupTable:
def load_tof_lookup_table(filename: TofLookupTableFilename) -> TofLookupTable:
"""Load a time-of-flight lookup table from an HDF5 file."""
table = sc.io.load_hdf5(filename)

# Support old format where the metadata were stored as coordinates of the DataArray.
# Note that no chopper info was saved in the old format.
if isinstance(table, sc.DataArray):
to_be_dropped = {
"pulse_period",
"pulse_stride",
"distance_resolution",
"time_resolution",
"error_threshold",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the table may or may not have a error_threshold coord. If it does not, trying to remove it using drop_coords would raise. So we filter it out here.

} & set(table.coords)
table = {
"array": table.drop_coords(
[
"pulse_period",
"pulse_stride",
"distance_resolution",
"time_resolution",
"error_threshold",
]
),
"array": table.drop_coords(to_be_dropped),
"pulse_period": table.coords["pulse_period"],
"pulse_stride": table.coords["pulse_stride"].value,
"distance_resolution": table.coords["distance_resolution"],
"time_resolution": table.coords["time_resolution"],
"error_threshold": table.coords["error_threshold"].value,
}

return TofLookupTable(**table)
Expand Down Expand Up @@ -87,5 +88,6 @@ def GenericTofWorkflow(

# Default parameters
wf[PulseStrideOffset] = None
wf[LookupTableRelativeErrorThreshold] = 1.0
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the default be something like np.inf?


return wf
2 changes: 0 additions & 2 deletions tests/time_of_flight/lut_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ def test_lut_workflow_computes_table_with_choppers():
)
wf[time_of_flight.DistanceResolution] = sc.scalar(0.1, unit='m')
wf[time_of_flight.TimeResolution] = sc.scalar(250.0, unit='us')
wf[time_of_flight.LookupTableRelativeErrorThreshold] = 2e3

table = wf.compute(time_of_flight.TofLookupTable)

Expand Down Expand Up @@ -194,7 +193,6 @@ def test_lut_workflow_computes_table_with_choppers_full_beamline_range():
)
wf[time_of_flight.DistanceResolution] = sc.scalar(0.1, unit='m')
wf[time_of_flight.TimeResolution] = sc.scalar(250.0, unit='us')
wf[time_of_flight.LookupTableRelativeErrorThreshold] = 2e3

table = wf.compute(time_of_flight.TofLookupTable)

Expand Down
Loading
Loading