diff --git a/docs/user-guide/tof/dream.ipynb b/docs/user-guide/tof/dream.ipynb index 174d8ecc..cde14768 100644 --- a/docs/user-guide/tof/dream.ipynb +++ b/docs/user-guide/tof/dream.ipynb @@ -743,8 +743,8 @@ "metadata": {}, "outputs": [], "source": [ - "table = lut_wf.compute(TofLookupTable).array\n", - "table.plot() / (sc.stddevs(table) / sc.values(table)).plot(norm=\"log\")" + "table = lut_wf.compute(TofLookupTable)\n", + "table.plot() / (sc.stddevs(table.array) / sc.values(table.array)).plot(norm=\"log\")" ] }, { @@ -767,10 +767,12 @@ "metadata": {}, "outputs": [], "source": [ - "lut_wf[LookupTableRelativeErrorThreshold] = 0.01\n", + "wf[TofLookupTable] = table\n", "\n", - "table = lut_wf.compute(TofLookupTable)\n", - "table.plot()" + "wf[LookupTableRelativeErrorThreshold] = 8.0e-3\n", + "\n", + "masked_table = wf.compute(ErrorLimitedTofLookupTable)\n", + "masked_table.plot()" ] }, { @@ -797,8 +799,6 @@ "wf[RawDetector[SampleRun]] = ess_beamline.get_monitor(\"detector\")[0]\n", "wf[DetectorLtotal[SampleRun]] = Ltotal\n", "\n", - "wf[TofLookupTable] = table\n", - "\n", "# Compute time-of-flight\n", "tofs = wf.compute(TofDetector[SampleRun])\n", "# Compute wavelength\n", @@ -833,7 +833,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.7" + "version": "3.12.12" } }, "nbformat": 4, diff --git a/docs/user-guide/tof/wfm.ipynb b/docs/user-guide/tof/wfm.ipynb index ae11d428..6e2b19be 100644 --- a/docs/user-guide/tof/wfm.ipynb +++ b/docs/user-guide/tof/wfm.ipynb @@ -364,7 +364,6 @@ "lut_wf[DiskChoppers[AnyRun]] = disk_choppers\n", "lut_wf[SourcePosition] = source_position\n", "lut_wf[LtotalRange] = Ltotal, Ltotal\n", - "lut_wf[LookupTableRelativeErrorThreshold] = 0.1\n", "lut_wf.visualize(TofLookupTable)" ] }, diff --git a/src/ess/reduce/time_of_flight/__init__.py b/src/ess/reduce/time_of_flight/__init__.py index 499bbff5..ba5233b6 100644 --- a/src/ess/reduce/time_of_flight/__init__.py +++ b/src/ess/reduce/time_of_flight/__init__.py @@ -10,7 +10,6 @@ from .eto_to_tof import providers from .lut import ( DistanceResolution, - LookupTableRelativeErrorThreshold, LtotalRange, NumberOfSimulatedNeutrons, PulsePeriod, @@ -24,6 +23,8 @@ ) from .types import ( DetectorLtotal, + ErrorLimitedTofLookupTable, + LookupTableRelativeErrorThreshold, MonitorLtotal, PulseStrideOffset, TimeOfFlightLookupTable, @@ -42,6 +43,7 @@ "DetectorLtotal", "DiskChoppers", "DistanceResolution", + "ErrorLimitedTofLookupTable", "GenericTofWorkflow", "LookupTableRelativeErrorThreshold", "LtotalRange", diff --git a/src/ess/reduce/time_of_flight/eto_to_tof.py b/src/ess/reduce/time_of_flight/eto_to_tof.py index 6fd8668c..56dbf0f7 100644 --- a/src/ess/reduce/time_of_flight/eto_to_tof.py +++ b/src/ess/reduce/time_of_flight/eto_to_tof.py @@ -8,6 +8,7 @@ """ from collections.abc import Callable +from dataclasses import asdict import numpy as np import scipp as sc @@ -34,6 +35,8 @@ from .resample import rebin_strictly_increasing from .types import ( DetectorLtotal, + ErrorLimitedTofLookupTable, + LookupTableRelativeErrorThreshold, MonitorLtotal, PulseStrideOffset, ToaDetector, @@ -99,7 +102,7 @@ def __call__( def _time_of_flight_data_histogram( - da: sc.DataArray, lookup: TofLookupTable, ltotal: sc.Variable + da: sc.DataArray, lookup: ErrorLimitedTofLookupTable, ltotal: sc.Variable ) -> sc.DataArray: # In NeXus, 'time_of_flight' is the canonical name in NXmonitor, but in some files, # it may be called 'tof' or 'frame_time'. @@ -204,7 +207,7 @@ def _guess_pulse_stride_offset( def _prepare_tof_interpolation_inputs( da: sc.DataArray, - lookup: TofLookupTable, + lookup: ErrorLimitedTofLookupTable, ltotal: sc.Variable, pulse_stride_offset: int | None, ) -> dict: @@ -298,7 +301,7 @@ def _prepare_tof_interpolation_inputs( def _time_of_flight_data_events( da: sc.DataArray, - lookup: TofLookupTable, + lookup: ErrorLimitedTofLookupTable, ltotal: sc.Variable, pulse_stride_offset: int | None, ) -> sc.DataArray: @@ -396,9 +399,36 @@ def monitor_ltotal_from_straight_line_approximation( ) +def mask_large_uncertainty_in_lut( + table: TofLookupTable, error_threshold: LookupTableRelativeErrorThreshold +) -> ErrorLimitedTofLookupTable: + """ + Mask regions in the time-of-flight lookup table with large uncertainty using NaNs. + + Parameters + ---------- + table: + Lookup table with time-of-flight as a function of distance and time-of-arrival. + error_threshold: + Threshold for the relative standard deviation (coefficient of variation) of the + projected time-of-flight above which values are masked. + """ + # TODO: The error threshold could be made dependent on the time-of-flight or + # distance, instead of being a single value for the whole table. + da = table.array + relative_error = sc.stddevs(da.data) / sc.values(da.data) + mask = relative_error > sc.scalar(error_threshold) + return ErrorLimitedTofLookupTable( + **{ + **asdict(table), + "array": sc.where(mask, sc.scalar(np.nan, unit=da.unit), da), + } + ) + + def _compute_tof_data( da: sc.DataArray, - lookup: TofLookupTable, + lookup: ErrorLimitedTofLookupTable, ltotal: sc.Variable, pulse_stride_offset: int, ) -> sc.DataArray: @@ -417,7 +447,7 @@ def _compute_tof_data( def detector_time_of_flight_data( detector_data: RawDetector[RunType], - lookup: TofLookupTable, + lookup: ErrorLimitedTofLookupTable, ltotal: DetectorLtotal[RunType], pulse_stride_offset: PulseStrideOffset, ) -> TofDetector[RunType]: @@ -452,7 +482,7 @@ def detector_time_of_flight_data( def monitor_time_of_flight_data( monitor_data: RawMonitor[RunType, MonitorType], - lookup: TofLookupTable, + lookup: ErrorLimitedTofLookupTable, ltotal: MonitorLtotal[RunType, MonitorType], pulse_stride_offset: PulseStrideOffset, ) -> TofMonitor[RunType, MonitorType]: @@ -487,7 +517,7 @@ def monitor_time_of_flight_data( def detector_time_of_arrival_data( detector_data: RawDetector[RunType], - lookup: TofLookupTable, + lookup: ErrorLimitedTofLookupTable, ltotal: DetectorLtotal[RunType], pulse_stride_offset: PulseStrideOffset, ) -> ToaDetector[RunType]: @@ -585,4 +615,5 @@ def providers() -> tuple[Callable]: detector_time_of_arrival_data, detector_wavelength_data, monitor_wavelength_data, + mask_large_uncertainty_in_lut, ) diff --git a/src/ess/reduce/time_of_flight/lut.py b/src/ess/reduce/time_of_flight/lut.py index 8bddbf59..92078feb 100644 --- a/src/ess/reduce/time_of_flight/lut.py +++ b/src/ess/reduce/time_of_flight/lut.py @@ -7,7 +7,6 @@ from dataclasses import dataclass from typing import NewType -import numpy as np import sciline as sl import scipp as sc @@ -112,13 +111,6 @@ class SimulationResults: smaller if the pulse period is not an integer multiple of the time resolution. """ - -LookupTableRelativeErrorThreshold = NewType("LookupTableRelativeErrorThreshold", float) -""" -Threshold for the relative standard deviation (coefficient of variation) of the -projected time-of-flight above which values are masked. -""" - PulsePeriod = NewType("PulsePeriod", sc.Variable) """ Period of the source pulses, i.e., time between consecutive pulse starts. @@ -144,26 +136,6 @@ class SimulationResults: """ -def _mask_large_uncertainty(table: sc.DataArray, error_threshold: float): - """ - Mask regions with large uncertainty with NaNs. - The values are modified in place in the input table. - - Parameters - ---------- - table: - Lookup table with time-of-flight as a function of distance and time-of-arrival. - error_threshold: - Threshold for the relative standard deviation (coefficient of variation) of the - projected time-of-flight above which values are masked. - """ - # Finally, mask regions with large uncertainty with NaNs. - relative_error = sc.stddevs(table.data) / sc.values(table.data) - mask = relative_error > sc.scalar(error_threshold) - # Use numpy for indexing as table is 2D - table.values[mask.values] = np.nan - - def _compute_mean_tof( simulation: BeamlineComponentReading, distance: sc.Variable, @@ -235,7 +207,6 @@ def make_tof_lookup_table( time_resolution: TimeResolution, pulse_period: PulsePeriod, pulse_stride: PulseStride, - error_threshold: LookupTableRelativeErrorThreshold, ) -> TofLookupTable: """ Compute a lookup table for time-of-flight as a function of distance and @@ -258,9 +229,6 @@ def make_tof_lookup_table( pulse_stride: Stride of used pulses. Usually 1, but may be a small integer when pulse-skipping. - error_threshold: - Threshold for the relative standard deviation (coefficient of variation) of the - projected time-of-flight above which values are masked. Notes ----- @@ -387,9 +355,6 @@ def make_tof_lookup_table( }, ) - # In-place masking for better performance - _mask_large_uncertainty(table, error_threshold) - return TofLookupTable( array=table, pulse_period=pulse_period, @@ -397,7 +362,6 @@ def make_tof_lookup_table( distance_resolution=table.coords["distance"][1] - table.coords["distance"][0], time_resolution=table.coords["event_time_offset"][1] - table.coords["event_time_offset"][0], - error_threshold=error_threshold, choppers=sc.DataGroup( {k: sc.DataGroup(ch.as_dict()) for k, ch in simulation.choppers.items()} ) @@ -490,7 +454,6 @@ def TofLookupTableWorkflow(): PulseStride: 1, DistanceResolution: sc.scalar(0.1, unit="m"), TimeResolution: sc.scalar(250.0, unit='us'), - LookupTableRelativeErrorThreshold: 0.1, NumberOfSimulatedNeutrons: 1_000_000, SimulationSeed: None, SimulationFacility: 'ess', diff --git a/src/ess/reduce/time_of_flight/types.py b/src/ess/reduce/time_of_flight/types.py index b5ccd2bc..c3221ab7 100644 --- a/src/ess/reduce/time_of_flight/types.py +++ b/src/ess/reduce/time_of_flight/types.py @@ -34,9 +34,6 @@ class TofLookupTable: """Resolution of the distance coordinate in the lookup table.""" time_resolution: sc.Variable """Resolution of the time_of_arrival coordinate in the lookup table.""" - error_threshold: float - """The table is masked with NaNs in regions where the standard deviation of the - time-of-flight is above this threshold.""" choppers: sc.DataGroup | None = None """Chopper parameters used when generating the lookup table, if any. This is made optional so we can still support old lookup tables without chopper info.""" @@ -54,12 +51,24 @@ def plot(self, *args, **kwargs) -> Any: """Lookup table giving time-of-flight as a function of distance and time of arrival (alias).""" + +class ErrorLimitedTofLookupTable(TofLookupTable): + """Lookup table that is masked with NaNs in regions where the standard deviation of + the time-of-flight is above a certain threshold.""" + + PulseStrideOffset = NewType("PulseStrideOffset", int | None) """ When pulse-skipping, the offset of the first pulse in the stride. This is typically zero but can be a small integer < pulse_stride. If None, a guess is made. """ +LookupTableRelativeErrorThreshold = NewType("LookupTableRelativeErrorThreshold", float) +""" +Threshold for the relative standard deviation (coefficient of variation) of the +projected time-of-flight above which values are masked. +""" + class DetectorLtotal(sl.Scope[RunType, sc.Variable], sc.Variable): """Total path length of neutrons from source to detector (L1 + L2).""" diff --git a/src/ess/reduce/time_of_flight/workflow.py b/src/ess/reduce/time_of_flight/workflow.py index f3843864..88777301 100644 --- a/src/ess/reduce/time_of_flight/workflow.py +++ b/src/ess/reduce/time_of_flight/workflow.py @@ -7,33 +7,34 @@ from ..nexus import GenericNeXusWorkflow from . import eto_to_tof -from .types import PulseStrideOffset, TofLookupTable, TofLookupTableFilename +from .types import ( + LookupTableRelativeErrorThreshold, + PulseStrideOffset, + TofLookupTable, + TofLookupTableFilename, +) -def load_tof_lookup_table( - filename: TofLookupTableFilename, -) -> TofLookupTable: +def load_tof_lookup_table(filename: TofLookupTableFilename) -> TofLookupTable: """Load a time-of-flight lookup table from an HDF5 file.""" table = sc.io.load_hdf5(filename) # Support old format where the metadata were stored as coordinates of the DataArray. # Note that no chopper info was saved in the old format. if isinstance(table, sc.DataArray): + to_be_dropped = { + "pulse_period", + "pulse_stride", + "distance_resolution", + "time_resolution", + "error_threshold", + } & set(table.coords) table = { - "array": table.drop_coords( - [ - "pulse_period", - "pulse_stride", - "distance_resolution", - "time_resolution", - "error_threshold", - ] - ), + "array": table.drop_coords(to_be_dropped), "pulse_period": table.coords["pulse_period"], "pulse_stride": table.coords["pulse_stride"].value, "distance_resolution": table.coords["distance_resolution"], "time_resolution": table.coords["time_resolution"], - "error_threshold": table.coords["error_threshold"].value, } return TofLookupTable(**table) @@ -87,5 +88,6 @@ def GenericTofWorkflow( # Default parameters wf[PulseStrideOffset] = None + wf[LookupTableRelativeErrorThreshold] = 1.0 return wf diff --git a/tests/time_of_flight/lut_test.py b/tests/time_of_flight/lut_test.py index 694cd141..ee118dcb 100644 --- a/tests/time_of_flight/lut_test.py +++ b/tests/time_of_flight/lut_test.py @@ -159,7 +159,6 @@ def test_lut_workflow_computes_table_with_choppers(): ) wf[time_of_flight.DistanceResolution] = sc.scalar(0.1, unit='m') wf[time_of_flight.TimeResolution] = sc.scalar(250.0, unit='us') - wf[time_of_flight.LookupTableRelativeErrorThreshold] = 2e3 table = wf.compute(time_of_flight.TofLookupTable) @@ -194,7 +193,6 @@ def test_lut_workflow_computes_table_with_choppers_full_beamline_range(): ) wf[time_of_flight.DistanceResolution] = sc.scalar(0.1, unit='m') wf[time_of_flight.TimeResolution] = sc.scalar(250.0, unit='us') - wf[time_of_flight.LookupTableRelativeErrorThreshold] = 2e3 table = wf.compute(time_of_flight.TofLookupTable) diff --git a/tests/time_of_flight/unwrap_test.py b/tests/time_of_flight/unwrap_test.py index fa3cb1fe..46001e05 100644 --- a/tests/time_of_flight/unwrap_test.py +++ b/tests/time_of_flight/unwrap_test.py @@ -70,10 +70,10 @@ def _make_workflow_event_mode( pl[RawDetector[SampleRun]] = mon pl[time_of_flight.DetectorLtotal[SampleRun]] = distance pl[time_of_flight.PulseStrideOffset] = pulse_stride_offset + pl[time_of_flight.LookupTableRelativeErrorThreshold] = error_threshold lut_wf = lut_workflow.copy() lut_wf[time_of_flight.LtotalRange] = distance, distance - lut_wf[time_of_flight.LookupTableRelativeErrorThreshold] = error_threshold pl[time_of_flight.TofLookupTable] = lut_wf.compute(time_of_flight.TofLookupTable) diff --git a/tests/time_of_flight/wfm_test.py b/tests/time_of_flight/wfm_test.py index 2b05441f..8468cb75 100644 --- a/tests/time_of_flight/wfm_test.py +++ b/tests/time_of_flight/wfm_test.py @@ -132,10 +132,10 @@ def setup_workflow( pl = GenericTofWorkflow(run_types=[SampleRun], monitor_types=[]) pl[RawDetector[SampleRun]] = raw_data pl[time_of_flight.DetectorLtotal[SampleRun]] = ltotal + pl[time_of_flight.LookupTableRelativeErrorThreshold] = error_threshold lut_wf = lut_workflow.copy() lut_wf[time_of_flight.LtotalRange] = ltotal.min(), ltotal.max() - lut_wf[time_of_flight.LookupTableRelativeErrorThreshold] = error_threshold pl[time_of_flight.TofLookupTable] = lut_wf.compute(time_of_flight.TofLookupTable) return pl diff --git a/tests/time_of_flight/workflow_test.py b/tests/time_of_flight/workflow_test.py index b3e165a4..372af771 100644 --- a/tests/time_of_flight/workflow_test.py +++ b/tests/time_of_flight/workflow_test.py @@ -86,7 +86,6 @@ def test_TofLookupTableWorkflow_can_compute_tof_lut(): assert lut.time_resolution is not None assert lut.pulse_stride is not None assert lut.pulse_period is not None - assert lut.error_threshold is not None assert lut.choppers is not None @@ -144,7 +143,6 @@ def test_GenericTofWorkflow_with_tof_lut_from_file( assert lut.pulse_stride == loaded_lut.pulse_stride assert_identical(lut.distance_resolution, loaded_lut.distance_resolution) assert_identical(lut.time_resolution, loaded_lut.time_resolution) - assert lut.error_threshold == loaded_lut.error_threshold assert_identical(lut.choppers, loaded_lut.choppers) if coord == "tof": @@ -176,7 +174,6 @@ def test_GenericTofWorkflow_with_tof_lut_from_file_old_format( "pulse_stride": sc.scalar(lut.pulse_stride, unit=None), "distance_resolution": lut.distance_resolution, "time_resolution": lut.time_resolution, - "error_threshold": sc.scalar(lut.error_threshold, unit=None), }, ) old_lut.save_hdf5(filename=tmp_path / "lut.h5") @@ -188,7 +185,6 @@ def test_GenericTofWorkflow_with_tof_lut_from_file_old_format( assert lut.pulse_stride == loaded_lut.pulse_stride assert_identical(lut.distance_resolution, loaded_lut.distance_resolution) assert_identical(lut.time_resolution, loaded_lut.time_resolution) - assert lut.error_threshold == loaded_lut.error_threshold assert loaded_lut.choppers is None # No chopper info in old format detector = workflow.compute(time_of_flight.TofDetector[SampleRun]) @@ -239,7 +235,6 @@ def test_GenericTofWorkflow_with_tof_lut_from_file_using_alias( assert lut.pulse_stride == loaded_lut.pulse_stride assert_identical(lut.distance_resolution, loaded_lut.distance_resolution) assert_identical(lut.time_resolution, loaded_lut.time_resolution) - assert lut.error_threshold == loaded_lut.error_threshold assert_identical(lut.choppers, loaded_lut.choppers) detector = workflow.compute(time_of_flight.TofDetector[SampleRun])