From bb9b3009e9e07e1aeb3dfd496620da6510652ee6 Mon Sep 17 00:00:00 2001 From: kobebryant432 Date: Fri, 13 Feb 2026 18:50:44 +0100 Subject: [PATCH 1/9] add datatree manipulation functionality --- src/valenspy/__init__.py | 2 +- src/valenspy/_utilities/__init__.py | 3 +- src/valenspy/_utilities/_datatree.py | 46 ++++++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 2 deletions(-) create mode 100644 src/valenspy/_utilities/_datatree.py diff --git a/src/valenspy/__init__.py b/src/valenspy/__init__.py index 8d7260cc..32e1be5c 100644 --- a/src/valenspy/__init__.py +++ b/src/valenspy/__init__.py @@ -11,7 +11,7 @@ from valenspy.diagnostic import Diagnostic, Model2Self, Model2Ref, Ensemble2Ref, Ensemble2Self from valenspy.diagnostic.visualizations import * #Utility -from valenspy._utilities import is_cf_compliant, cf_status +from valenspy._utilities import is_cf_compliant, cf_status, datatree_to_dataset, datatree_to_dataframe # ============================================================================= # Version diff --git a/src/valenspy/_utilities/__init__.py b/src/valenspy/_utilities/__init__.py index 0e1cba49..d4c64dcb 100644 --- a/src/valenspy/_utilities/__init__.py +++ b/src/valenspy/_utilities/__init__.py @@ -5,6 +5,7 @@ load_yml, generate_parameters_doc ) -from._formatting import create_named_regex, parse_string_to_time_period +from ._formatting import create_named_regex, parse_string_to_time_period from .cf_checks import is_cf_compliant, cf_status from .unit_converter import CORDEX_VARIABLES, _convert_all_units_to_CF +from ._datatree import datatree_to_dataset, datatree_to_dataframe diff --git a/src/valenspy/_utilities/_datatree.py b/src/valenspy/_utilities/_datatree.py new file mode 100644 index 00000000..ad963aec --- /dev/null +++ b/src/valenspy/_utilities/_datatree.py @@ -0,0 +1,46 @@ +import pandas as pd +import xarray as xr + +def datatree_to_dataset(dt: xr.DataTree, **kwargs): + """ + Convert a DataTree to a xarray Dataset. + + Parameters + ---------- + dt : xr.DataTree + The DataTree to convert to a xarray Dataset. + **kwargs : dict + Keyword arguments to pass to the xarray concat function. + """ + datasets = [] + for key, ds in dt.to_dict().items(): + if ds: + ds_copy = ds.copy() + ds_copy = ds_copy.expand_dims({"id": [str(key)]}) + datasets.append(ds_copy) + + return xr.concat(datasets, dim="id", **kwargs) + +def datatree_to_dataframe(dt: xr.DataTree, add_attributes=False): + """ + Convert a DataTree to a pandas DataFrame. + """ + data_frames = [] + for key, ds in dt.to_dict().items(): + if ds: + if not ds.dims: # Non-dimensional datasets + df = pd.DataFrame({var: float(ds[var].values) for var in ds.data_vars}, index=[key]) + else: # Dimensional datasets + df = ds.to_dataframe() + df["id"] = str(key) + + if add_attributes: + for attr in add_attributes: + #if attr is a substring of any attribute in ds.attrs: + for ds_attr in ds.attrs: + if attr in ds_attr: + df[attr] = ds.attrs[ds_attr] + break + data_frames.append(df) + + return pd.concat(data_frames, axis=0).reset_index() \ No newline at end of file From 3b07b46d73290f1d054b318c9bbc02254fe51b1b Mon Sep 17 00:00:00 2001 From: kobebryant432 Date: Fri, 13 Feb 2026 18:51:12 +0100 Subject: [PATCH 2/9] New ensemble to self structure --- src/valenspy/diagnostic/_ensemble2self.py | 14 ++ src/valenspy/diagnostic/diagnostic.py | 281 +++++++++++++--------- 2 files changed, 181 insertions(+), 114 deletions(-) diff --git a/src/valenspy/diagnostic/_ensemble2self.py b/src/valenspy/diagnostic/_ensemble2self.py index e69de29b..e00aca12 100644 --- a/src/valenspy/diagnostic/_ensemble2self.py +++ b/src/valenspy/diagnostic/_ensemble2self.py @@ -0,0 +1,14 @@ +from valenspy.diagnostic.diagnostic import Ensemble2Self +from valenspy.diagnostic.functions import * +from valenspy.diagnostic.visualizations import * + +__all__ = [ + "Ensemble_Quantile_Spatial_Mean" +] + +Ensemble_Quantile_Spatial_Mean = Ensemble2Self( + ensemble_quantile_of_spatial_mean, + plot_map_per_dimension, + "Ensemble quantiles of spatial mean", + "The quantiles accross the ensembles spatial mean." +) diff --git a/src/valenspy/diagnostic/diagnostic.py b/src/valenspy/diagnostic/diagnostic.py index 926915dd..dd133ea7 100644 --- a/src/valenspy/diagnostic/diagnostic.py +++ b/src/valenspy/diagnostic/diagnostic.py @@ -61,8 +61,8 @@ def apply(self, data): """ pass - def plot(self, result, title=None, **kwargs): - """Plot the diagnostic. Single ax plots. + def plot(self, result, **kwargs): + """Plot the diagnostic. Parameters ---------- @@ -78,11 +78,90 @@ def plot(self, result, title=None, **kwargs): ax : matplotlib.axis.Axis The axis (singular) of the plot. """ - ax = self.plotting_function(result, **kwargs) - if not title: - title = self.name - ax.set_title(title) + return self.plotting_function(result, **kwargs) + + def plot_dt_single(self, dt, var, ax, label="name", colors=None, **kwargs): + """ + Plot the diagnostic by iterating over the leaves of a DataTree. + + Parameters + ---------- + dt : DataTree + The DataTree to plot. + var : str + The variable to plot. + ax : matplotlib.axis.Axis + The axis to plot on. + label : str + The attribute of the DataTree nodes to use as a title for the plots. + colors : dict or list + The colors to use for the different leaves of the DataTree. + Either a dictionary with the colors as values and the DataTree paths as keys or a list of colors. + **kwargs + Keyword arguments to pass to the plotting function. + + Returns + ------- + ax : matplotlib.axis.Axis + The axis of the plot. + """ + if colors: + if isinstance(colors, list): + colors = {dt_leave.path: color for dt_leave, color in zip(dt.leaves, colors)} + + for dt_leave in dt.leaves: + if label: + kwargs["label"] = getattr(dt_leave, label) + if colors: + kwargs["color"] = colors[dt_leave.path] + self.plot(dt_leave[var], ax=ax, **kwargs) + return ax + + def plot_dt_facetted(self, dt, var, axes, label="name", shared_cbar=None, **kwargs): + """ + Plot the diagnostic by iterating over the leaves of a DataTree. + + Parameters + ---------- + dt : DataTree + The DataTree to plot. + var : str + The variable to plot. + axes : np.ndarray + The axes to plot on. + label : str + The attribute of the DataTree nodes to use as a title for the plots. + shared_cbar : str + How to handle the vmin and vmax of the plot. Options are None, "min_max", "abs". + If None, the vmin and vmax are not automatically set. Passing the vmin and vmax as kwargs will still result in shared colorbars. + If "min_max", the vmin and vmax are set respectively to the minimum and maximum over all the leaves of the DataTree. + If "abs", the vmin and vmax are set to the maximum of the absolute value of the minimum and maximum over all the leaves of the DataTree. + **kwargs + Keyword arguments to pass to the plotting function. + + Returns + ------- + axes : np.ndarray + The axes of the plot. + """ + #Flatten the axes if needed + #Add option if axes is not provided to create new axes + + if shared_cbar: + max = np.max([ds[var].values for ds in dt.max().leaves]) + min = np.min([ds[var].values for ds in dt.min().leaves]) + if shared_cbar == "min_max": + kwargs = _augment_kwargs({"vmin": min, "vmax": max}, **kwargs) + elif shared_cbar == "abs": + abs_max = np.max([np.abs(min), np.abs(max)]) + kwargs = _augment_kwargs({"vmin": -abs_max, "vmax": abs_max}, **kwargs) + + for ax, dt_leave in zip(axes, dt.leaves): + if label: + kwargs["title"] = getattr(dt_leave, label) + self.plot(dt_leave[var], ax=ax, **kwargs) + return axes @property def description(self): @@ -112,8 +191,11 @@ def __init__( If "single", plot_dt will plot all the leaves of the DataTree on the same axis. If "facetted", plot_dt will plot all the leaves of the DataTree on different axes. """ - super().__init__(diagnostic_function, plotting_function, name, description) + if plot_type not in ["single", "facetted"]: + raise ValueError("Invalid plot_type provided. Options are 'single' or 'facetted'.") self.plot_type = plot_type + super().__init__(diagnostic_function, plotting_function, name, description) + def __call__(self, data, *args, **kwargs): if isinstance(data, DataTree): @@ -164,94 +246,109 @@ def apply(self, ds: xr.Dataset, *args, **kwargs): """ return self.diagnostic_function(ds, *args, **kwargs) + def plot(self, result, title=None, **kwargs): + """Plot the diagnostic. Single ax plots. + + Parameters + ---------- + result : xr.Dataset or xr.DataArray or DataTree + The output of the diagnostic function. + title : str + The title of the plot. + **kwargs + Keyword arguments to pass to the plotting function. + + Returns + ------- + ax : matplotlib.axis.Axis + The axis (singular) of the plot. + """ + ax = super().plot(result, **kwargs) + if not title: + title = self.name + ax.set_title(title) + return ax + def plot_dt(self, dt, *args, **kwargs): if self.plot_type == "single": return self.plot_dt_single(dt, *args, **kwargs) elif self.plot_type == "facetted": return self.plot_dt_facetted(dt, *args, **kwargs) - def plot_dt_single(self, dt, var, ax, label="name", colors=None, **kwargs): +class DataTreeDiagnostic(Diagnostic): + """A class representing a diagnostic that operates on the level of DataTrees.""" + + def __init__( + self, diagnostic_function, plotting_function, name=None, description=None, plot_type=None + ): + """Initialize the DataTreeDiagnostic. + Parameters + ---------- + plot_type : str, optional + The type of plotting function to use. Default is None, which means the plotting function will be used as is. + Options are "single" or "facetted". + If "single", plot_dt will plot all the leaves of the DataTree on the same axis. + If "facetted", plot_dt will plot all the leaves of the DataTree on different axes. + """ - Plot the diagnostic by iterating over the leaves of a DataTree. + if plot_type not in [None, "single", "facetted"]: + raise ValueError("Invalid plot_type provided. Options are None, 'single', or 'facetted'.") + self.plot_type = plot_type + super().__init__(diagnostic_function, plotting_function, name, description) + def __call__(self, data, *args, **kwargs): + if not isinstance(data, DataTree): + raise ValueError("Data must be a DataTree.") + return self.apply(data, *args, **kwargs) + + def apply(self, dt: DataTree, *args, **kwargs): + """Apply the diagnostic to a DataTree. + Parameters ---------- dt : DataTree - The DataTree to plot. - var : str - The variable to plot. - ax : matplotlib.axis.Axis - The axis to plot on. - label : str - The attribute of the DataTree nodes to use as a title for the plots. - colors : dict or list - The colors to use for the different leaves of the DataTree. - Either a dictionary with the colors as values and the DataTree paths as keys or a list of colors. + The data to apply the diagnostic to. + *args + Positional arguments to pass to the diagnostic function. **kwargs - Keyword arguments to pass to the plotting function. + Keyword arguments to pass to the diagnostic function. Returns ------- - ax : matplotlib.axis.Axis - The axis of the plot. + DataTree or dict + The data after applying the diagnostic as a DataTree or a dictionary of results with the tree nodes as keys. """ - if colors: - if isinstance(colors, list): - colors = {dt_leave.path: color for dt_leave, color in zip(dt.leaves, colors)} + return self.diagnostic_function(dt, *args, **kwargs) - for dt_leave in dt.leaves: - if label: - kwargs["label"] = getattr(dt_leave, label) - if colors: - kwargs["color"] = colors[dt_leave.path] - self.plot(dt_leave[var], ax=ax, **kwargs) + def plot_dt(self, dt, *args, **kwargs): + """Plot the diagnostic by iterating over the leaves of a DataTree. - return ax - - def plot_dt_facetted(self, dt, var, axes, label="name", shared_cbar=None, **kwargs): - """ - Plot the diagnostic by iterating over the leaves of a DataTree. - Parameters ---------- dt : DataTree The DataTree to plot. - var : str - The variable to plot. - axes : np.ndarray - The axes to plot on. - label : str - The attribute of the DataTree nodes to use as a title for the plots. - shared_cbar : str - How to handle the vmin and vmax of the plot. Options are None, "min_max", "abs". - If None, the vmin and vmax are not automatically set. Passing the vmin and vmax as kwargs will still result in shared colorbars. - If "min_max", the vmin and vmax are set respectively to the minimum and maximum over all the leaves of the DataTree. - If "abs", the vmin and vmax are set to the maximum of the absolute value of the minimum and maximum over all the leaves of the DataTree. + *args + Positional arguments to pass to the plotting function. **kwargs Keyword arguments to pass to the plotting function. Returns ------- - axes : np.ndarray - The axes of the plot. + Figure + The figure representing the diagnostic. """ - #Flatten the axes if needed - #Add option if axes is not provided to create new axes - - if shared_cbar: - max = np.max([ds[var].values for ds in dt.max().leaves]) - min = np.min([ds[var].values for ds in dt.min().leaves]) - if shared_cbar == "min_max": - kwargs = _augment_kwargs({"vmin": min, "vmax": max}, **kwargs) - elif shared_cbar == "abs": - abs_max = np.max([np.abs(min), np.abs(max)]) - kwargs = _augment_kwargs({"vmin": -abs_max, "vmax": abs_max}, **kwargs) - - for ax, dt_leave in zip(axes, dt.leaves): - if label: - kwargs["title"] = getattr(dt_leave, label) - self.plot(dt_leave[var], ax=ax, **kwargs) - return axes + #Check if the dt is a DataTree and if not raise an error + if not isinstance(dt, DataTree): + raise ValueError("Data must be a DataTree. Use self.plot to plot non-DataTree data results.") + if not self.plot_type: + warnings.warn("No plot_type specified, using the default plotting function. It is recommended to use self.plot instead of self.plot_dt when no plot_type is specified.") + return self.plotting_function(dt, *args, **kwargs) + elif self.plot_type == "single": + return self.plot_dt_single(dt, *args, **kwargs) + elif self.plot_type == "facetted": + return self.plot_dt_facetted(dt, *args, **kwargs) + else: + raise ValueError("Invalid plot_type specified. Options are 'single', 'facetted', or None.") class Model2Self(DataSetDiagnostic): """A class representing a diagnostic that compares a model to itself.""" @@ -292,58 +389,14 @@ def apply(self, ds: xr.Dataset, ref: xr.Dataset, **kwargs): return super().apply(ds, ref, **kwargs) -class Ensemble2Self(Diagnostic): +class Ensemble2Self(DataTreeDiagnostic): """A class representing a diagnostic that compares an ensemble to itself.""" def __init__( - self, diagnostic_function, plotting_function, name=None, description=None, iterative_plotting=False + self, diagnostic_function, plotting_function, name=None, description=None, plot_type=None ): """Initialize the Ensemble2Self diagnostic.""" - self.iterative_plotting = iterative_plotting - super().__init__(diagnostic_function, plotting_function, name, description) - - - def apply(self, dt: DataTree, mask=None, **kwargs): - """Apply the diagnostic to the data. - - Parameters - ---------- - dt : DataTree - The data to apply the diagnostic to. - - Returns - ------- - DataTree or dict - The data after applying the diagnostic as a DataTree or a dictionary of results with the tree nodes as keys. - """ - if mask == "prudence": - dt = dt.map_over_datasets(add_prudence_regions) - - return self.diagnostic_function(dt, **kwargs) - - def plot(self, result, variables=None, title=None, facetted=None, **kwargs): - """Plot the diagnostic. - - If facetted multiple plots on different axes are created. If not facetted, the plots are created on the same axis. - - Parameters - ---------- - result : DataTree - The result of applying the ensemble diagnostic to a DataTree. - - Returns - ------- - Figure - The figure representing the diagnostic. - """ - if not self.iterative_plotting: - if facetted is not None: - warnings.warn("facetted is ignored when using a non-iterative plotting function.") - return self._plot_non_iterative(result, title=title, **kwargs) - else: - if variables is None: - raise ValueError("variables must be provided when using an iterative plotting function. The variables can be a list of variables to plot or a single variable to plot.") - return self._plot_iterative(result, title=title, variables=variables, facetted=facetted, **kwargs) + super().__init__(diagnostic_function, plotting_function, name, description, plot_type) class Ensemble2Ref(Diagnostic): """A class representing a diagnostic that compares an ensemble to a reference.""" From 58851e572274ff61efda6a74c1f7a1af7b7fc7df Mon Sep 17 00:00:00 2001 From: kobebryant432 Date: Fri, 13 Feb 2026 18:51:38 +0100 Subject: [PATCH 3/9] Ensemble quantile of spatial mean --- src/valenspy/diagnostic/functions.py | 21 ++++++++ src/valenspy/diagnostic/visualizations.py | 60 +++++++++++++++++++++++ 2 files changed, 81 insertions(+) diff --git a/src/valenspy/diagnostic/functions.py b/src/valenspy/diagnostic/functions.py index 37855f90..eec7ddd9 100644 --- a/src/valenspy/diagnostic/functions.py +++ b/src/valenspy/diagnostic/functions.py @@ -7,6 +7,7 @@ from valenspy.processing import select_point from valenspy.diagnostic.wrappers import acceptable_variables, required_variables +from valenspy._utilities import datatree_to_dataframe, datatree_to_dataset # make sure attributes are passed through xr.set_options(keep_attrs=True) @@ -261,6 +262,26 @@ def calc_metrics_ds(ds_mod: xr.Dataset, ds_obs: xr.Dataset, metrics=None, pss_bi """ return {variable: calc_metrics_da(ds_mod[variable], ds_obs[variable], metrics, pss_binwidth=pss_binwidth) for variable in ds_mod.data_vars} +###################################### +# Ensemble2Self diagnostic functions # +###################################### + +def ensemble_quantile_of_spatial_mean(dt: DataTree, quantile: float | list[float]): + """ + Calculate the ensemble quantile of the spatial mean of the data. If the time dimension is present, the data is averaged over the time dimension before calculating the percentiles. + + Parameters + ---------- + dt : DataTree + The data to calculate the ensemble quantiles of the spatial mean of. + quantile : float or list of float + The quantiles to calculate. Value(s) between 0 and 1. + """ + + ds_m = datatree_to_dataset(dt.mean(dim="time"),compat="override",coords="minimal") #Compat is set to override to avoid issues height conflicts between the different datatrees. To be checked why this is needed. + return ds_m.quantile(quantile, dim="id") + + ##################################### # Ensemble2Ref diagnostic functions # ##################################### diff --git a/src/valenspy/diagnostic/visualizations.py b/src/valenspy/diagnostic/visualizations.py index cd702907..9300cadc 100644 --- a/src/valenspy/diagnostic/visualizations.py +++ b/src/valenspy/diagnostic/visualizations.py @@ -541,6 +541,40 @@ def plot_metric_ranking(df_metric, ax=None, plot_colorbar=True, hex_color1 = Non return ax +###################################### +# Ensemble2Self diagnostic visuals # +###################################### + +def plot_map_per_dimension(ds: xr.Dataset, var: str, dim: str, axes=None, **kwargs): + """ + Plots a map for each unique value along a specified dimension in an xarray Dataset. + + Parameters + ---------- + ds : xr.Dataset + The xarray Dataset containing the data to be plotted. It should have the variable of interest and the specified dimension. + var : str + The name of the variable in the Dataset to be plotted. + dim : str + The name of the dimension along which to create separate plots for each unique value. + axes : array-like of matplotlib.axes.Axes, optional + An array of axes to plot on. If None, new axes will be created for each unique value along the specified dimension. + **kwargs : dict + Additional keyword arguments to pass to the plot_map function for each plot. + + Returns + ------- + list of matplotlib.axes.Axes + A list of axes objects corresponding to each unique value along the specified dimension, with the respective maps plotted. + """ + unique_values = ds[dim].values + + axes = _get_axes(n_axes=len(unique_values), axes=axes, **kwargs) + for i, value in enumerate(unique_values): + plot_map(ds[var].sel({dim: value}), ax=axes[i], **kwargs) + + return axes + ################################## # Helper functions # ################################## @@ -554,6 +588,32 @@ def _get_gca(**kwargs): else: return plt.gca() +def _get_axes(n_axes=1, **kwargs): + """ + Get axes for a multi-axes plot. + + If 'axes' is provided in kwargs, return it. + Otherwise, create a new figure with `n_axes` subplots. + + Parameters + ---------- + n_axes : int, default=1 + Number of axes to create if none are provided. + **kwargs + May contain 'axes'. + + Returns + ------- + np.ndarray + 1D array of matplotlib.axes.Axes + """ + if "axes" in kwargs: + axes = kwargs["axes"] + return np.atleast_1d(axes).ravel() + else: + _, axes = plt.subplots(n_axes) + return np.atleast_1d(axes).ravel() + # Define a function to add borders, coastlines to the axes def _add_features(ax, region=None): """ From e13caf6caec0d8f3bc79e0cde661f4507ab7d398 Mon Sep 17 00:00:00 2001 From: kobebryant432 Date: Tue, 24 Feb 2026 10:31:24 +0100 Subject: [PATCH 4/9] Add shared color bar option --- src/valenspy/diagnostic/visualizations.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/valenspy/diagnostic/visualizations.py b/src/valenspy/diagnostic/visualizations.py index 9300cadc..c36829b1 100644 --- a/src/valenspy/diagnostic/visualizations.py +++ b/src/valenspy/diagnostic/visualizations.py @@ -545,7 +545,7 @@ def plot_metric_ranking(df_metric, ax=None, plot_colorbar=True, hex_color1 = Non # Ensemble2Self diagnostic visuals # ###################################### -def plot_map_per_dimension(ds: xr.Dataset, var: str, dim: str, axes=None, **kwargs): +def plot_map_per_dimension(ds: xr.Dataset, var: str, dim: str, axes=None, shared_cbar=None, **kwargs): """ Plots a map for each unique value along a specified dimension in an xarray Dataset. @@ -559,6 +559,11 @@ def plot_map_per_dimension(ds: xr.Dataset, var: str, dim: str, axes=None, **kwar The name of the dimension along which to create separate plots for each unique value. axes : array-like of matplotlib.axes.Axes, optional An array of axes to plot on. If None, new axes will be created for each unique value along the specified dimension. + shared_cbar : str + How to handle the vmin and vmax of the plot. Options are None, "min_max", "abs". + If None, the vmin and vmax are not automatically set. Passing the vmin and vmax as kwargs will still result in shared colorbars. + If "min_max", the vmin and vmax are set respectively to the minimum and maximum over all the leaves of the DataTree. + If "abs", the vmin and vmax are set to the maximum of the absolute value of the minimum and maximum over all the leaves of the DataTree. **kwargs : dict Additional keyword arguments to pass to the plot_map function for each plot. @@ -568,6 +573,15 @@ def plot_map_per_dimension(ds: xr.Dataset, var: str, dim: str, axes=None, **kwar A list of axes objects corresponding to each unique value along the specified dimension, with the respective maps plotted. """ unique_values = ds[dim].values + + if shared_cbar: + max = ds[var].max().values + min = ds[var].min().values + if shared_cbar == "min_max": + kwargs = _augment_kwargs({"vmin": min, "vmax": max}, **kwargs) + elif shared_cbar == "abs": + abs_max = np.max([np.abs(min), np.abs(max)]) + kwargs = _augment_kwargs({"vmin": -abs_max, "vmax": abs_max}, **kwargs) axes = _get_axes(n_axes=len(unique_values), axes=axes, **kwargs) for i, value in enumerate(unique_values): From 082aa9ea84f3778335a57404019ca590a90f4982 Mon Sep 17 00:00:00 2001 From: kobebryant432 Date: Tue, 24 Feb 2026 16:08:43 +0100 Subject: [PATCH 5/9] Deal with edge case where time is no longer a dimension (e.g. the spatial mean has already been taken. --- src/valenspy/diagnostic/functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/valenspy/diagnostic/functions.py b/src/valenspy/diagnostic/functions.py index eec7ddd9..f0af723c 100644 --- a/src/valenspy/diagnostic/functions.py +++ b/src/valenspy/diagnostic/functions.py @@ -277,8 +277,8 @@ def ensemble_quantile_of_spatial_mean(dt: DataTree, quantile: float | list[float quantile : float or list of float The quantiles to calculate. Value(s) between 0 and 1. """ - - ds_m = datatree_to_dataset(dt.mean(dim="time"),compat="override",coords="minimal") #Compat is set to override to avoid issues height conflicts between the different datatrees. To be checked why this is needed. + dt_m = dt.map_over_datasets(_average_over_dims, "time") + ds_m = datatree_to_dataset(dt_m, compat="override",coords="minimal") #Compat is set to override to avoid issues height conflicts between the different datatrees. To be checked why this is needed. return ds_m.quantile(quantile, dim="id") From 2821600ba0d268bebfffba503a73b496abced8e2 Mon Sep 17 00:00:00 2001 From: kobebryant432 Date: Tue, 24 Feb 2026 18:11:58 +0100 Subject: [PATCH 6/9] Add datatree functionality --- src/valenspy/__init__.py | 2 +- src/valenspy/_utilities/__init__.py | 2 +- src/valenspy/_utilities/_datatree.py | 52 ++++++++++++++++++++++++++++ 3 files changed, 54 insertions(+), 2 deletions(-) diff --git a/src/valenspy/__init__.py b/src/valenspy/__init__.py index 32e1be5c..5d62c6aa 100644 --- a/src/valenspy/__init__.py +++ b/src/valenspy/__init__.py @@ -11,7 +11,7 @@ from valenspy.diagnostic import Diagnostic, Model2Self, Model2Ref, Ensemble2Ref, Ensemble2Self from valenspy.diagnostic.visualizations import * #Utility -from valenspy._utilities import is_cf_compliant, cf_status, datatree_to_dataset, datatree_to_dataframe +from valenspy._utilities import is_cf_compliant, cf_status, datatree_to_dataset, datatree_to_dataframe, restructure_by_level, split_by_level # ============================================================================= # Version diff --git a/src/valenspy/_utilities/__init__.py b/src/valenspy/_utilities/__init__.py index d4c64dcb..d7fe53b0 100644 --- a/src/valenspy/_utilities/__init__.py +++ b/src/valenspy/_utilities/__init__.py @@ -8,4 +8,4 @@ from ._formatting import create_named_regex, parse_string_to_time_period from .cf_checks import is_cf_compliant, cf_status from .unit_converter import CORDEX_VARIABLES, _convert_all_units_to_CF -from ._datatree import datatree_to_dataset, datatree_to_dataframe +from ._datatree import datatree_to_dataset, datatree_to_dataframe, restructure_by_level, split_by_level diff --git a/src/valenspy/_utilities/_datatree.py b/src/valenspy/_utilities/_datatree.py index ad963aec..c8b3a6a3 100644 --- a/src/valenspy/_utilities/_datatree.py +++ b/src/valenspy/_utilities/_datatree.py @@ -1,5 +1,57 @@ import pandas as pd import xarray as xr +from copy import deepcopy + +def split_by_level(dt: xr.DataTree, level: int): + """ + Split a DataTree into multiple DataTrees based on the unique values at a given level in the node paths. + + Parameters + ---------- + dt : xr.DataTree + The DataTree to split. + level : int + The level in the node paths to split the DataTree by. Level 0 is the root level. + + Returns + ------- + dict + A dictionary where the keys are the unique values at the specified level in the node paths, and the values are the corresponding DataTrees containing only the nodes with that value at the specified level. + """ + #Make a deep copy of the datatree to avoid modifying the original one. This is needed as we are going to orphan the datasets in the new datatrees, which would also orphan them in the original datatree if we don't make a copy. + dt = deepcopy(dt) + if level >= 2: + dt = restructure_by_level(dt, level) + result = {value: dt[value] for value in set(dt.children)} + for value in result: + result[value].orphan() + return result + +def restructure_by_level(dt: xr.DataTree, level: int): + """ + Restructure a DataTree such that level n in the node paths becomes the new root level. + + Parameters + ---------- + dt : xr.DataTree + The DataTree to restructure. + level : int + The level in the node paths to restructure the DataTree by. Level 0 is the root level. + + Returns + ------- + xr.DataTree + A restructured DataTree where level n in the node paths becomes the new root level. + """ + if level < 2: + raise ValueError("Level must be greater than or equal to 2 as reording the first level leaves the tree unchanged") + level = level - 1 #As the root level is not included in the path split, we need to subtract 1 from the level to get the correct index. + reorganized_nodes = { + "/".join([path.split("/")[level]] + path.split("/")[:level] + path.split("/")[level+1:]): node.dataset + for path, node in dt.subtree_with_keys + if len(path.split("/")) > level #Strict + } + return xr.DataTree.from_dict(reorganized_nodes) def datatree_to_dataset(dt: xr.DataTree, **kwargs): """ From dd532019408be41a9da0319d3799034ed62c4d95 Mon Sep 17 00:00:00 2001 From: kobebryant432 Date: Tue, 24 Feb 2026 18:26:49 +0100 Subject: [PATCH 7/9] Add ClimateChangeSignal Diagnostic --- src/valenspy/diagnostic/_ensemble2ref.py | 13 +++++++++- src/valenspy/diagnostic/diagnostic.py | 30 +++--------------------- src/valenspy/diagnostic/functions.py | 27 +++++++++++++++++++++ 3 files changed, 42 insertions(+), 28 deletions(-) diff --git a/src/valenspy/diagnostic/_ensemble2ref.py b/src/valenspy/diagnostic/_ensemble2ref.py index 5cf51100..b96051ce 100644 --- a/src/valenspy/diagnostic/_ensemble2ref.py +++ b/src/valenspy/diagnostic/_ensemble2ref.py @@ -2,7 +2,18 @@ from valenspy.diagnostic.functions import * from valenspy.diagnostic.visualizations import * -__all__ = ["MetricsRankings"] +__all__ = [ + "ClimateChangeSignal", + "MetricsRankings" + ] + +ClimateChangeSignal = Ensemble2Ref( + climate_change_signal, + plot_map, + "Climate Change Signal", + "The climate change signal as the average difference between the GWL and the reference period.", + plot_type="facetted" +) MetricsRankings = Ensemble2Ref( calc_metrics_dt, diff --git a/src/valenspy/diagnostic/diagnostic.py b/src/valenspy/diagnostic/diagnostic.py index dd133ea7..1dba5905 100644 --- a/src/valenspy/diagnostic/diagnostic.py +++ b/src/valenspy/diagnostic/diagnostic.py @@ -398,14 +398,14 @@ def __init__( """Initialize the Ensemble2Self diagnostic.""" super().__init__(diagnostic_function, plotting_function, name, description, plot_type) -class Ensemble2Ref(Diagnostic): +class Ensemble2Ref(DataTreeDiagnostic): """A class representing a diagnostic that compares an ensemble to a reference.""" def __init__( - self, diagnostic_function, plotting_function, name=None, description=None + self, diagnostic_function, plotting_function, name=None, description=None, plot_type=None ): """Initialize the Ensemble2Ref diagnostic.""" - super().__init__(diagnostic_function, plotting_function, name, description) + super().__init__(diagnostic_function, plotting_function, name, description, plot_type) def apply(self, dt: DataTree, ref, **kwargs): """Apply the diagnostic to the data. @@ -425,30 +425,6 @@ def apply(self, dt: DataTree, ref, **kwargs): # TODO: Add some checks to make sure the reference is a DataTree or a Dataset and contain common variables with the data. return self.diagnostic_function(dt, ref, **kwargs) - def plot(self, result, facetted=True, **kwargs): - """Plot the diagnostic. - - If axes are provided, the diagnostic is plotted facetted. If ax is provided, the diagnostic is plotted non-facetted. - If neither axes nor ax are provided, the diagnostic is plotted on the current axis and no facetting is applied. - - Parameters - ---------- - result : DataTree - The result of applying the ensemble diagnostic to a DataTree. - - Returns - ------- - Figure - The figure representing the diagnostic. - """ - if "ax" in kwargs and "axes" in kwargs: - raise ValueError("Either ax or axes can be provided, not both.") - elif "ax" not in kwargs and "axes" not in kwargs: - ax = plt.gca() - return self.plotting_function(result, ax=ax, **kwargs) - else: - return self.plotting_function(result, **kwargs) - def _common_vars(ds1, ds2): """Return the common variables in two datasets.""" return set(ds1.data_vars).intersection(set(ds2.data_vars)) diff --git a/src/valenspy/diagnostic/functions.py b/src/valenspy/diagnostic/functions.py index f0af723c..77dde7df 100644 --- a/src/valenspy/diagnostic/functions.py +++ b/src/valenspy/diagnostic/functions.py @@ -286,6 +286,33 @@ def ensemble_quantile_of_spatial_mean(dt: DataTree, quantile: float | list[float # Ensemble2Ref diagnostic functions # ##################################### +def climate_change_signal(fut: DataTree, ref: DataTree, abs_diff=True): + """ + Calculate the climate change signal as the difference between the spatial mean of the fut and ref datatree. + The difference is only taken for members which are both in the fut and ref datatree with exactly the same path. If abs_diff is True, the absolute difference is calculated, otherwise the relative difference is calculated. + + Parameters + ---------- + fut : DataTree + The future data to calculate the climate change signal of. + ref : DataTree + The reference data to compare the future data to. + abs_diff : bool, optional + If True, calculate the absolute difference, if False calculate the relative difference, by default True + + Returns + ------- + xr.Datatree + The climate change signal as the difference between the spatial mean of the fut and ref datatree. + """ + fut = fut.filter_like(ref).map_over_datasets(_average_over_dims, "time") #For the "members" that are also in the ref datatree, calculate the spatial mean of the fut datatree. + ref = ref.filter_like(fut).map_over_datasets(_average_over_dims, "time") #For the "members" that are also in the fut datatree, calculate the spatial mean of the ref datatree. + if abs_diff: + return fut - ref + else: + return (fut - ref) / ref + + def calc_metrics_dt(dt_mod: DataTree, da_obs: xr.Dataset, metrics=None, pss_binwidth=None): """ Calculate statistical performance metrics for model data against observed data. From 122a9c12793e166a75ba1fc3e592baf6f3af562a Mon Sep 17 00:00:00 2001 From: kobebryant432 Date: Wed, 25 Feb 2026 10:48:23 +0100 Subject: [PATCH 8/9] Added ClimateChangeSignal for whole domain --- src/valenspy/diagnostic/_ensemble2ref.py | 17 +++++-- src/valenspy/diagnostic/functions.py | 61 ++++++++++++++++++++++-- 2 files changed, 69 insertions(+), 9 deletions(-) diff --git a/src/valenspy/diagnostic/_ensemble2ref.py b/src/valenspy/diagnostic/_ensemble2ref.py index b96051ce..3254db58 100644 --- a/src/valenspy/diagnostic/_ensemble2ref.py +++ b/src/valenspy/diagnostic/_ensemble2ref.py @@ -3,18 +3,25 @@ from valenspy.diagnostic.visualizations import * __all__ = [ - "ClimateChangeSignal", + "ClimateChangeSignalOfSpatialMean", "MetricsRankings" ] -ClimateChangeSignal = Ensemble2Ref( - climate_change_signal, +ClimateChangeSignalOfSpatialMean = Ensemble2Ref( + climate_change_signal_of_spatial_mean, plot_map, - "Climate Change Signal", - "The climate change signal as the average difference between the GWL and the reference period.", + "Climate Change Signal of the spatial means", + "The spatial climate change signal as the difference between the temporal average of two periods", plot_type="facetted" ) +ClimateChangeSignal = Ensemble2Ref( + mean_climate_change_signal, + None, + "Climate Change Signal", + "The climate change signal as the difference between the spatial and temporal average of two periods." +) + MetricsRankings = Ensemble2Ref( calc_metrics_dt, plot_metric_ranking, diff --git a/src/valenspy/diagnostic/functions.py b/src/valenspy/diagnostic/functions.py index 77dde7df..ae36c508 100644 --- a/src/valenspy/diagnostic/functions.py +++ b/src/valenspy/diagnostic/functions.py @@ -286,7 +286,7 @@ def ensemble_quantile_of_spatial_mean(dt: DataTree, quantile: float | list[float # Ensemble2Ref diagnostic functions # ##################################### -def climate_change_signal(fut: DataTree, ref: DataTree, abs_diff=True): +def climate_change_signal_of_spatial_mean(fut: DataTree, ref: DataTree, abs_diff=True): """ Calculate the climate change signal as the difference between the spatial mean of the fut and ref datatree. The difference is only taken for members which are both in the fut and ref datatree with exactly the same path. If abs_diff is True, the absolute difference is calculated, otherwise the relative difference is calculated. @@ -305,13 +305,66 @@ def climate_change_signal(fut: DataTree, ref: DataTree, abs_diff=True): xr.Datatree The climate change signal as the difference between the spatial mean of the fut and ref datatree. """ - fut = fut.filter_like(ref).map_over_datasets(_average_over_dims, "time") #For the "members" that are also in the ref datatree, calculate the spatial mean of the fut datatree. - ref = ref.filter_like(fut).map_over_datasets(_average_over_dims, "time") #For the "members" that are also in the fut datatree, calculate the spatial mean of the ref datatree. + return _climate_change_signal(fut, ref, abs_diff=abs_diff, mean_over_dims="time") + +def mean_climate_change_signal(fut: DataTree, ref: DataTree, abs_diff=True, add_attributes=False): + """ + Calculate the mean climate change signal as the difference between the mean of the fut and ref datatree. + The difference is only taken for members which are both in the fut and ref datatree with exactly the same path. If abs_diff is True, the absolute difference is calculated, otherwise the relative difference is calculated. + + Parameters + ---------- + fut : DataTree + The future data to calculate the mean climate change signal of. + ref : DataTree + The reference data to compare the future data to. + abs_diff : bool, optional + If True, calculate the absolute difference, if False calculate the relative difference, by default True + add_attributes : bool, optional + If True, add attributes to the resulting dataframe, by default False + + Returns + ------- + pd.DataFrame + A dataframe with the mean climate change signal for each member in the datatree along with its unique path as an identifier. If add_attributes is True, the dataframe also contains the attributes of the datasets in the datatree. + """ + dt_diff = _climate_change_signal(fut, ref, abs_diff=abs_diff, mean_over_dims=None) + return datatree_to_dataframe(dt_diff, add_attributes=add_attributes) + +def _climate_change_signal(fut: DataTree, ref: DataTree, abs_diff=True, mean_over_dims=None): + """ + Calculate the climate change signal as the difference between the fut and ref. + The mean is taken over the specified dimension(s) before calculating the difference. + The difference is only taken for members which are both in the fut and ref datatree with exactly the same path. If abs_diff is True, the absolute difference is calculated, otherwise the relative difference is calculated. + + Parameters + ---------- + fut : DataTree + The future data to calculate the climate change signal of. + ref : DataTree + The reference data to compare the future data to. + abs_diff : bool, optional + If True, calculate the absolute difference, if False calculate the relative difference, by default True + mean_over_dims : str or list of str, optional + The dimension(s) to calculate the mean over before calculating the difference. If None, no mean is calculated, by default None + + Returns + ------- + xr.Datatree + The climate change signal as the difference between the fut and ref datatree. + """ + fut = fut.filter_like(ref) #For the "members" that are also in the ref datatree + ref = ref.filter_like(fut) #For the "members" that are also in the fut datatree + if mean_over_dims: + fut = fut.map_over_datasets(_average_over_dims, mean_over_dims) + ref = ref.map_over_datasets(_average_over_dims, mean_over_dims) + else: + fut = fut.mean() #Mean over all dimensions + ref = ref.mean() #Mean over all dimensions if abs_diff: return fut - ref else: return (fut - ref) / ref - def calc_metrics_dt(dt_mod: DataTree, da_obs: xr.Dataset, metrics=None, pss_binwidth=None): """ From 80c59ea12896cee17d52406da1032232326acc3f Mon Sep 17 00:00:00 2001 From: kobebryant432 Date: Thu, 26 Feb 2026 13:38:10 +0100 Subject: [PATCH 9/9] Small bug fixes --- src/valenspy/diagnostic/_ensemble2ref.py | 5 +++-- src/valenspy/diagnostic/diagnostic.py | 9 +++++---- src/valenspy/diagnostic/visualizations.py | 10 +++++----- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/valenspy/diagnostic/_ensemble2ref.py b/src/valenspy/diagnostic/_ensemble2ref.py index 3254db58..4a8fa704 100644 --- a/src/valenspy/diagnostic/_ensemble2ref.py +++ b/src/valenspy/diagnostic/_ensemble2ref.py @@ -3,8 +3,9 @@ from valenspy.diagnostic.visualizations import * __all__ = [ + "ClimateChangeSignal", "ClimateChangeSignalOfSpatialMean", - "MetricsRankings" + "MetricsRankings", ] ClimateChangeSignalOfSpatialMean = Ensemble2Ref( @@ -17,7 +18,7 @@ ClimateChangeSignal = Ensemble2Ref( mean_climate_change_signal, - None, + lambda ds : ds, "Climate Change Signal", "The climate change signal as the difference between the spatial and temporal average of two periods." ) diff --git a/src/valenspy/diagnostic/diagnostic.py b/src/valenspy/diagnostic/diagnostic.py index 1dba5905..25605de0 100644 --- a/src/valenspy/diagnostic/diagnostic.py +++ b/src/valenspy/diagnostic/diagnostic.py @@ -68,8 +68,6 @@ def plot(self, result, **kwargs): ---------- result : xr.Dataset or xr.DataArray or DataTree The output of the diagnostic function. - title : str - The title of the plot. **kwargs Keyword arguments to pass to the plotting function. @@ -147,6 +145,7 @@ def plot_dt_facetted(self, dt, var, axes, label="name", shared_cbar=None, **kwar """ #Flatten the axes if needed #Add option if axes is not provided to create new axes + #Check how to deal with shared_cbar (shared vmin and vmas - should this be named differently?) and should the cbar really be shared? if shared_cbar: max = np.max([ds[var].values for ds in dt.max().leaves]) @@ -158,9 +157,11 @@ def plot_dt_facetted(self, dt, var, axes, label="name", shared_cbar=None, **kwar kwargs = _augment_kwargs({"vmin": -abs_max, "vmax": abs_max}, **kwargs) for ax, dt_leave in zip(axes, dt.leaves): - if label: - kwargs["title"] = getattr(dt_leave, label) self.plot(dt_leave[var], ax=ax, **kwargs) + if label: + title = getattr(dt_leave, label) + ax.set_title(title) + return axes @property diff --git a/src/valenspy/diagnostic/visualizations.py b/src/valenspy/diagnostic/visualizations.py index c36829b1..f9dc856a 100644 --- a/src/valenspy/diagnostic/visualizations.py +++ b/src/valenspy/diagnostic/visualizations.py @@ -143,11 +143,11 @@ def plot_map(da: xr.DataArray, max_chars=25, **kwargs): ax : matplotlib.axes.Axes The matplotlib Axes with the plot. """ - - label = f"{da.attrs.get('long_name', 'Data')} ({da.units})" - - label_wrapped = "\n".join(textwrap.wrap(label, width=max_chars)) - kwargs = _augment_kwargs({"cbar_kwargs": {"label":label_wrapped}}, **kwargs) + #Check if add_colorbar is set to False, if so, do not set the colorbar label + if kwargs.get("add_colorbar", True): + label = f"{da.attrs.get('long_name', 'Data')} ({da.units})" + label_wrapped = "\n".join(textwrap.wrap(label, width=max_chars)) + kwargs = _augment_kwargs({"cbar_kwargs": {"label":label_wrapped}}, **kwargs) da.plot(**kwargs)