diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 3d6bcb47885..1bcd2016324 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -26,6 +26,8 @@ Deprecations Bug Fixes ~~~~~~~~~ +- Coerce masked dask arrays to filled (:issue:`9374` :pull:`11157`). + By `Julia Signell `_ Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 85ef75f352c..e10ab5a3558 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -127,6 +127,10 @@ def fail_on_dask_array_input(values, msg=None, func_name=None): "masked_invalid", eager_module=np.ma, dask_module="dask.array.ma" ) +getmaskarray = _dask_or_eager_func( + "getmaskarray", eager_module=np.ma, dask_module="dask.array.ma" +) + def sliding_window_view(array, window_shape, axis=None, **kwargs): # TODO: some libraries (e.g. jax) don't have this, implement an alternative? diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 7478d48bb23..ab8b4be775b 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -50,6 +50,7 @@ from xarray.namedarray.core import NamedArray, _raise_if_any_duplicate_dimensions from xarray.namedarray.parallelcompat import get_chunked_array_type from xarray.namedarray.pycompat import ( + array_type, async_to_duck_array, integer_types, is_0d_dask_array, @@ -292,13 +293,12 @@ def convert_non_numpy_type(data): else: data = pandas_data - if isinstance(data, np.ma.MaskedArray): - mask = np.ma.getmaskarray(data) - if mask.any(): - _dtype, fill_value = dtypes.maybe_promote(data.dtype) - data = duck_array_ops.where_method(data, ~mask, fill_value) - else: - data = np.asarray(data) + if isinstance(data, np.ma.MaskedArray) or ( + isinstance(data, array_type("dask")) + and isinstance(getattr(data, "_meta", None), np.ma.MaskedArray) + ): + mask = duck_array_ops.getmaskarray(data) + data = duck_array_ops.where_method(data, ~mask) if isinstance(data, np.matrix): data = np.asarray(data) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index f7d2b516a05..ff5470094ec 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -85,6 +85,7 @@ mock, network, parametrize_zarr_format, + raise_if_dask_computes, requires_cftime, requires_dask, requires_fsspec, @@ -2231,6 +2232,28 @@ def test_create_default_indexes(self, tmp_path, create_default_indexes) -> None: else: assert len(loaded_ds.xindexes) == 0 + @requires_dask + def test_encoding_masked_arrays(self, tmp_path) -> None: + store_path = tmp_path / "tmp.nc" + + with raise_if_dask_computes(): + ds = xr.DataArray( + dask.array.from_array( + np.ma.masked_array( + np.array([[np.nan, np.nan], [np.nan, 2]]), + np.array([[True, True], [True, False]]), + ) + ).astype("float32"), + dims=("x", "y"), + ).to_dataset(name="mydata") + + expected = ds.mean("x") + expected.to_netcdf( + store_path, encoding=dict(mydata=dict(_FillValue=np.float32(1e20))) + ) + with open_dataset(store_path, engine=self.engine) as actual: + assert_identical(expected.compute(), actual.compute()) + @requires_netCDF4 class TestNetCDF4Data(NetCDF4Base): diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 44a044275e2..1ef6adc9caf 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -2757,7 +2757,7 @@ def test_masked_array(self): expected = np.arange(5) actual: Any = as_compatible_data(original) assert_array_equal(expected, actual) - assert np.dtype(int) == actual.dtype + assert np.dtype(float) == actual.dtype original1: Any = np.ma.MaskedArray(np.arange(5), mask=4 * [False] + [True]) expected1: Any = np.arange(5.0)