diff --git a/pyproject.toml b/pyproject.toml index 43296e4..8dfdb10 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,6 +120,7 @@ ignore = [ "PLR09", # Too many <...> "PLR2004", # Magic value used in comparison "ISC001", # Conflicts with formatter + "RET504", # Assignment before return ] isort.required-imports = ["from __future__ import annotations"] # Uncomment if using a _compat.typing backport diff --git a/src/pscpy/psc.py b/src/pscpy/psc.py index ecf5b31..d2654cc 100644 --- a/src/pscpy/psc.py +++ b/src/pscpy/psc.py @@ -27,6 +27,13 @@ def __init__( self.length = ds.attrs.get("length", length) self.corner = ds.attrs.get("corner", corner) + if self.length is None: + message = "Dataset is missing length. A value must be manually provided." + raise ValueError(message) + if self.corner is None: + message = "Dataset is missing corner. A value must be manually provided." + raise ValueError(message) + self.x = self._get_coord(0) self.y = self._get_coord(1) self.z = self._get_coord(2) @@ -121,13 +128,12 @@ def decode_psc( ds = ds.drop_vars([var_name]) ds = ds.assign(data_vars) - if length is not None: - run_info = RunInfo(ds, length=length, corner=corner) - coords = { - "x": ("x", run_info.x), - "y": ("y", run_info.y), - "z": ("z", run_info.z), - } - ds = ds.assign_coords(coords) + run_info = RunInfo(ds, length=length, corner=corner) + coords = { + "x": ("x", run_info.x), + "y": ("y", run_info.y), + "z": ("z", run_info.z), + } + ds = ds.assign_coords(coords) return ds diff --git a/tests/test_xarray_adios2.py b/tests/test_xarray_adios2.py index f852bf3..c1ac651 100644 --- a/tests/test_xarray_adios2.py +++ b/tests/test_xarray_adios2.py @@ -117,6 +117,26 @@ def test_nbytes(): assert ds_decoded.nbytes == ds_decoded.nbytes +def test_missing_length(): + ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") + with pytest.raises(ValueError, match=r".*length.*"): + pscpy.decode_psc( + ds_raw, + species_names=["e", "i"], + corner=[0, -6.4, -25.6], + ) + + +def test_missing_corner(): + ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") + with pytest.raises(ValueError, match=r".*corner.*"): + pscpy.decode_psc( + ds_raw, + species_names=["e", "i"], + length=[1, 12.8, 51.2], + ) + + def test_computed(): ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") ds_decoded = _decode_dataset(ds_raw)