diff --git a/dabench/dacycler/_dacycler.py b/dabench/dacycler/_dacycler.py index 6afc605..34ac8f8 100644 --- a/dabench/dacycler/_dacycler.py +++ b/dabench/dacycler/_dacycler.py @@ -5,43 +5,49 @@ import jax import xarray as xr import xarray_jax as xj +from typing import Callable import dabench.dacycler._utils as dac_utils +from dabench.model import Model + + +# For typing +ArrayLike = np.ndarray | jax.Array +XarrayDatasetLike = xr.Dataset | xj.XjDataset class DACycler(): """Base class for DACycler object Attributes: - system_dim (int): System dimension - delta_t (float): The timestep of the model (assumed uniform) - model_obj (dabench.Model): Forecast model object. - in_4d (bool): True for 4D data assimilation techniques (e.g. 4DVar). + system_dim: System dimension + delta_t: The timestep of the model (assumed uniform) + model_obj: Forecast model object. + in_4d: True for 4D data assimilation techniques (e.g. 4DVar). Default is False. - ensemble (bool): True for ensemble-based data assimilation techniques + ensemble: True for ensemble-based data assimilation techniques (ETKF). Default is False - B (ndarray): Initial / static background error covariance. Shape: + B: Initial / static background error covariance. Shape: (system_dim, system_dim). If not provided, will be calculated automatically. - R (ndarray): Observation error covariance matrix. Shape + R: Observation error covariance matrix. Shape (obs_dim, obs_dim). If not provided, will be calculated automatically. - H (ndarray): Observation operator with shape: (obs_dim, system_dim). + H: Observation operator with shape: (obs_dim, system_dim). If not provided will be calculated automatically. - h (function): Optional observation operator as function. More flexible + h: Optional observation operator as function. More flexible (allows for more complex observation operator). Default is None. """ def __init__(self, - system_dim=None, - delta_t=None, - model_obj=None, - in_4d=False, - ensemble=False, - B=None, - R=None, - H=None, - h=None, - analysis_time_in_window=None + system_dim: int, + delta_t: float, + model_obj: Model, + in_4d: bool = False, + ensemble: bool = False, + B: ArrayLike | None = None, + R: ArrayLike | None = None, + H: ArrayLike | None = None, + h: Callable | None = None, ): self.h = h @@ -53,43 +59,64 @@ def __init__(self, self.system_dim = system_dim self.delta_t = delta_t self.model_obj = model_obj - self.analysis_time_in_window = analysis_time_in_window - def _calc_default_H(self, obs_values, obs_loc_indices): + def _calc_default_H(self, + obs_values: ArrayLike, + obs_loc_indices: ArrayLike + ) -> jax.Array: H = jnp.zeros((obs_values.flatten().shape[0], self.system_dim)) H = H.at[jnp.arange(H.shape[0]), obs_loc_indices.flatten(), ].set(1) return H - def _calc_default_R(self, obs_values, obs_error_sd): + def _calc_default_R(self, + obs_values: ArrayLike, + obs_error_sd: float + ) -> jax.Array: return jnp.identity(obs_values.flatten().shape[0])*(obs_error_sd**2) - def _calc_default_B(self): + def _calc_default_B(self) -> jax.Array: """If B is not provided, identity matrix with shape (system_dim, system_dim.""" return jnp.identity(self.system_dim) - def _step_forecast(self, xa, n_steps=1): + def _step_forecast(self, + xa: XarrayDatasetLike, + n_steps: int = 1 + ) -> XarrayDatasetLike: """Perform forecast using model object""" return self.model_obj.forecast(xa, n_steps=n_steps) - def _step_cycle(self, xb, obs_vals, obs_locs, obs_time_mask, obs_loc_mask, - H=None, h=None, R=None, B=None, **kwargs): + def _step_cycle(self, + cur_state: XarrayDatasetLike, + obs_vals: ArrayLike, + obs_locs: ArrayLike, + obs_time_mask: ArrayLike, + obs_loc_mask: ArrayLike, + H: ArrayLike | None = None, + h: Callable | None =None, + R: ArrayLike | None = None, + B:ArrayLike | None = None, + **kwargs + ) -> XarrayDatasetLike: if H is not None or h is None: vals = self._cycle_obsop( - xb, obs_vals, obs_locs, obs_time_mask, + cur_state, obs_vals, obs_locs, obs_time_mask, obs_loc_mask, H, R, B, **kwargs) return vals else: raise ValueError( 'Only linear obs operators (H) are supported right now.') vals = self._cycle_general_obsop( - xb, obs_vals, obs_locs, obs_time_mask, + cur_state, obs_vals, obs_locs, obs_time_mask, obs_loc_mask, h, R, B, **kwargs) return vals - def _cycle_and_forecast(self, cur_state, filtered_idx): + def _cycle_and_forecast(self, + cur_state: xj.XjDataset, + filtered_idx: ArrayLike + ) -> tuple[xj.XjDataset, XarrayDatasetLike]: # 1. Get data # 1-b. Calculate obs_time_mask and restore filtered_idx to original values cur_state = cur_state.to_xarray() @@ -119,7 +146,10 @@ def _cycle_and_forecast(self, cur_state, filtered_idx): return xj.from_xarray(next_state), forecast_states - def _cycle_and_forecast_4d(self, cur_state, filtered_idx): + def _cycle_and_forecast_4d(self, + cur_state: xj.XjDataset, + filtered_idx: ArrayLike + ) -> tuple[xj.XjDataset, XarrayDatasetLike]: # 1. Get data # 1-b. Calculate obs_time_mask and restore filtered_idx to original values cur_state = cur_state.to_xarray() @@ -160,35 +190,32 @@ def _cycle_and_forecast_4d(self, cur_state, filtered_idx): return xj.from_xarray(next_state), forecast_states def cycle(self, - input_state, - start_time, - obs_vector, - n_cycles, - obs_error_sd=None, - analysis_window=0.2, - analysis_time_in_window=None, - return_forecast=False - ): + input_state: XarrayDatasetLike, + start_time: float | np.datetime64, + obs_vector: XarrayDatasetLike, + n_cycles: int, + obs_error_sd: float | ArrayLike | None = None, + analysis_window: float = 0.2, + analysis_time_in_window: float | None = None, + return_forecast: bool = False + ) -> XarrayDatasetLike: """Perform DA cycle repeatedly, including analysis and forecast Args: - input_state (vector.StateVector): Input state. - start_time (float or datetime-like): Starting time. - obs_vector (vector.ObsVector): Observations vector. - n_cycles (int): Number of analysis cycles to run, each of length + input_state: Input state as a Xarray Dataset + start_time: Starting time. + obs_vector: Observations vector. + n_cycles: Number of analysis cycles to run, each of length analysis_window. - analysis_window (float): Time window from which to gather + analysis_window: Time window from which to gather observations for DA Cycle. - analysis_time_in_window (float): Where within analysis_window + analysis_time_in_window: Where within analysis_window to perform analysis. For example, 0.0 is the start of the window. Default is None, which selects the middle of the window. - return_forecast (bool): If True, returns forecast at each model + return_forecast: If True, returns forecast at each model timestep. If False, returns only analyses, one per analysis - cycle. Default is False. - - Returns: - vector.StateVector of analyses and times. + cycle. """ # These could be different if observer doesn't observe all variables @@ -202,10 +229,11 @@ def cycle(self, self.analysis_window = analysis_window # If don't specify analysis_time_in_window, is assumed to be middle - if self.analysis_time_in_window is None and analysis_time_in_window is None: - analysis_time_in_window = self.analysis_window/2 - else: - analysis_time_in_window = self.analysis_time_in_window + if analysis_time_in_window is None: + if self.in_4d: + analysis_time_in_window = 0 + else: + analysis_time_in_window = self.analysis_window/2 # Steps per window + 1 to include start self.steps_per_window = round(analysis_window/self.delta_t) + 1 @@ -256,13 +284,13 @@ def cycle(self, xj.from_xarray(input_state), all_filtered_padded) - all_vals_xr = xr.Dataset( + all_vals_ds = xr.Dataset( {var: (('cycle',) + tuple(all_values[var].dims), all_values[var].data) for var in all_values.data_vars} ).rename_dims({'time': 'cycle_timestep'}) if return_forecast: - return all_vals_xr.drop_isel(cycle_timestep=-1) + return all_vals_ds.drop_isel(cycle_timestep=-1) else: - return all_vals_xr.isel(cycle_timestep=0) + return all_vals_ds.isel(cycle_timestep=0) diff --git a/dabench/dacycler/_etkf.py b/dabench/dacycler/_etkf.py index b5514f5..f90a67c 100644 --- a/dabench/dacycler/_etkf.py +++ b/dabench/dacycler/_etkf.py @@ -1,58 +1,57 @@ """Class for Ensemble Transform Kalman Filter (ETKF) DA Class""" -from functools import partial import numpy as np import jax import jax.numpy as jnp from jax.scipy import linalg import xarray as xr import xarray_jax as xj +from typing import Callable from dabench import dacycler +from dabench.model import Model +# For typing +ArrayLike = np.ndarray | jax.Array +XarrayDatasetLike = xr.Dataset | xj.XjDataset + class ETKF(dacycler.DACycler): """Class for building ETKF DA Cycler Attributes: - system_dim (int): System dimension. - ensemble_dim (int): Number of ensemble instances for ETKF. Default is - 4. Higher ensemble_dim increases accuracy but has performance cost. - delta_t (float): The timestep of the model (assumed uniform) - model_obj (dabench.Model): Forecast model object. - in_4d (bool): True for 4D data assimilation techniques (e.g. 4DVar). - Always False for ETKF. - ensemble (bool): True for ensemble-based data assimilation techniques - (ETKF). Always True for ETKF. - B (ndarray): Initial / static background error covariance. Shape: + system_dim: System dimension. + delta_t: The timestep of the model (assumed uniform) + model_obj: Forecast model object. + B: Initial / static background error covariance. Shape: (system_dim, system_dim). If not provided, will be calculated automatically. - R (ndarray): Observation error covariance matrix. Shape + R: Observation error covariance matrix. Shape (obs_dim, obs_dim). If not provided, will be calculated automatically. - H (ndarray): Observation operator with shape: (obs_dim, system_dim). + H: Observation operator with shape: (obs_dim, system_dim). If not provided will be calculated automatically. - h (function): Optional observation operator as function. More flexible - (allows for more complex observation operator). + h: Optional observation operator as function. More flexible + (allows for more complex observation operator). Default is None. + ensemble_dim: Number of ensemble instances for ETKF. Default is + 4. Higher ensemble_dim increases accuracy but has performance cost. + multiplicative_inflation: Scaling factor by which to multiply ensemble + deviation. Default is 1.0 (no inflation). """ def __init__(self, - system_dim=None, - ensemble_dim=4, - delta_t=None, - model_obj=None, - multiplicative_inflation=1.0, - B=None, - R=None, - H=None, - h=None, - random_seed=99, - **kwargs + system_dim: int, + delta_t: float, + model_obj: Model, + B: ArrayLike | None = None, + R: ArrayLike | None = None, + H: ArrayLike | None = None, + h: Callable | None = None, + ensemble_dim: int = 4, + multiplicative_inflation: float = 1.0 ): self.ensemble_dim = ensemble_dim - self.random_seed = random_seed - self._rng = np.random.default_rng(self.random_seed) self.multiplicative_inflation = multiplicative_inflation super().__init__(system_dim=system_dim, @@ -62,13 +61,16 @@ def __init__(self, ensemble=True, B=B, R=R, H=H, h=h) - def _step_forecast(self, xa, n_steps): + def _step_forecast(self, + Xa: XarrayDatasetLike, + n_steps: int = 1 + ) -> XarrayDatasetLike: """Ensemble method needs a slightly different _step_forecast method""" ensemble_forecasts = [] ensemble_inputs = [] for i in range(self.ensemble_dim): cur_inputs, cur_forecast = self.model_obj.forecast( - xa.isel(ensemble=i), + Xa.isel(ensemble=i), n_steps=n_steps ) ensemble_inputs.append(cur_inputs) @@ -77,59 +79,70 @@ def _step_forecast(self, xa, n_steps): return (xr.concat(ensemble_inputs, dim='ensemble'), xr.concat(ensemble_forecasts, dim='ensemble')) - def _apply_obsop(self, xb, H, h): + def _apply_obsop(self, + Xb: ArrayLike, + H: ArrayLike | None, + h: Callable | None + ) -> ArrayLike: if H is not None: - yb = H @ xb + Yb = H @ Xb else: - yb = h(xb) - - return yb - - def _compute_analysis(self, xb, y, H, h, R, rho=1.0, yb=None): + Yb = h(Xb) + + return Yb + + def _compute_analysis(self, + Xb: ArrayLike, + Y: ArrayLike, + H: ArrayLike | None, + h: Callable | None, + R: ArrayLike, + rho: float = 1.0 + ) -> ArrayLike: """ETKF analysis algorithm Args: - xb (ndarray): Forecast/background ensemble with shape + Xb: Forecast/background ensemble with shape (system_dim, ensemble_dim). - y (ndarray): Observation array with shape (observation_dim,) - H (ndarray): Observation operator with shape (observation_dim, + Y: Observation array with shape (obs_time_dim, observation_dim) + H: Linear observation operator with shape (observation_dim, system_dim). - R (ndarray): Observation error covariance matrix with shape + h: Callable observation operator (optional). + R: Observation error covariance matrix with shape (observation_dim, observation_dim) - rho (float): Multiplicative inflation factor. Default=1.0, + rho: Multiplicative inflation factor. Default=1.0, (i.e. no inflation) Returns: - xa (ndarray): Analysis ensemble [size: (system_dim, ensemble_dim)] + Xa: Analysis ensemble [size: (system_dim, ensemble_dim)] """ # Number of state variables, ensemble members and observations - system_dim, ensemble_dim = xb.shape - observation_dim = y.shape[0] + system_dim, ensemble_dim = Xb.shape # Auxiliary matrices that will ease the computations U = jnp.ones((ensemble_dim, ensemble_dim))/ensemble_dim I = jnp.identity(ensemble_dim) # The ensemble is inflated (rho=1.0 is no inflation) - xb_pert = xb @ (I-U) - xb = xb_pert + xb @ U + Xb_pert = Xb @ (I-U) + Xb = Xb_pert + Xb @ U # Map every ensemble member into observation space - yb = self._apply_obsop(xb, H, h) + Yb = self._apply_obsop(Xb, H, h) # Get ensemble means and perturbations - xb_bar = jnp.mean(xb, axis=1) - xb_pert = xb @ (I-U) + Xb_bar = jnp.mean(Xb, axis=1) + Xb_pert = Xb @ (I-U) - yb_bar = jnp.mean(yb, axis=1) - yb_pert = yb @ (I-U) + yb_bar = jnp.mean(Yb, axis=1) + Yb_pert = Yb @ (I-U) # Compute the analysis if len(R) > 0: Rinv = jnp.linalg.pinv(R, rtol=1e-15) Pa_ens = jnp.linalg.pinv((ensemble_dim-1)/rho*I - + yb_pert.T @ Rinv @ yb_pert, + + Yb_pert.T @ Rinv @ Yb_pert, rtol=1e-15) Wa = linalg.sqrtm((ensemble_dim-1) * Pa_ens) Wa = Wa.real @@ -138,20 +151,28 @@ def _compute_analysis(self, xb, y, H, h, R, rho=1.0, yb=None): Pa_ens = jnp.zeros((ensemble_dim, ensemble_dim), dtype=R.dtype) Wa = jnp.zeros((ensemble_dim, ensemble_dim), dtype=R.dtype) - wa = Pa_ens @ yb_pert.T @ Rinv @ (y.flatten()-yb_bar) + wa = Pa_ens @ Yb_pert.T @ Rinv @ (Y.flatten()-yb_bar) - xa_pert = xb_pert @ Wa + Xa_pert = Xb_pert @ Wa - xa_bar = xb_bar + jnp.ravel(xb_pert @ wa) + Xa_bar = Xb_bar + jnp.ravel(Xb_pert @ wa) v = jnp.ones((1, ensemble_dim)) - xa = xa_pert + xa_bar[:, None] @ v - - return xa - - def _cycle_obsop(self, x0_xarray, obs_values, obs_loc_indices, - obs_time_mask, obs_loc_mask, - H=None, h=None, R=None, B=None): + Xa = Xa_pert + Xa_bar[:, None] @ v + + return Xa + + def _cycle_obsop(self, + Xb_ds: XarrayDatasetLike, + obs_values: ArrayLike, + obs_loc_indices: ArrayLike, + obs_time_mask: ArrayLike, + obs_loc_mask: ArrayLike, + H: ArrayLike | None = None, + h: Callable | None = None, + R: ArrayLike | None = None, + B: ArrayLike | None = None + ) -> XarrayDatasetLike: if H is None and h is None: if self.H is None: if self.h is None: @@ -171,8 +192,8 @@ def _cycle_obsop(self, x0_xarray, obs_values, obs_loc_indices, else: B = self.B - xb = x0_xarray.to_stacked_array('system',['ensemble']).data.T - n_sys, n_ens = xb.shape + Xb = Xb_ds.to_stacked_array('system',['ensemble']).data.T + n_sys, n_ens = Xb.shape assert n_ens == self.ensemble_dim, ( 'cycle:: model_forecast must have dimension {}x{}').format( self.ensemble_dim, self.system_dim) @@ -182,11 +203,11 @@ def _cycle_obsop(self, x0_xarray, obs_values, obs_loc_indices, H = jnp.where(obs_loc_mask.flatten(), H.T, 0).T # Analysis cycles over all obs in data_obs - xa = self._compute_analysis(xb=xb, - y=obs_values, + Xa = self._compute_analysis(Xb=Xb, + Y=obs_values, H=H, h=h, R=R, rho=self.multiplicative_inflation) - return x0_xarray.assign(x=(['ensemble','i'], xa.T)) + return Xb_ds.assign(x=(['ensemble','i'], Xa.T)) diff --git a/dabench/dacycler/_utils.py b/dabench/dacycler/_utils.py index 147eec3..c26b911 100644 --- a/dabench/dacycler/_utils.py +++ b/dabench/dacycler/_utils.py @@ -1,24 +1,31 @@ """Utils for data assimilation cyclers""" import jax.numpy as jnp +import jax import numpy as np +import xarray as xr +import xarray_jax as xj +# For typing +ArrayLike = list | np.ndarray | jax.Array +XarrayDatasetLike = xr.Dataset | xj.XjDataset + def _get_all_times( - start_time, - analysis_window, - analysis_cycles, - ): + start_time: float, + analysis_window: float, + analysis_cycles: int + ) -> jax.Array: """Calculate times of the centers of all analysis windows. Args: - start_time (float): Start time of DA experiment in model time units. - analysis_window (float): Length of analysis window, in model time + start_time: Start time of DA experiment in model time units. + analysis_window: Length of analysis window, in model time units. - analysis_cycles (int): Number of analysis cycles to perform. + analysis_cycles: Number of analysis cycles to perform. Returns: - array of all analysis window center-times. + Array of all analysis window center-times. """ @@ -32,26 +39,26 @@ def _get_all_times( def _get_obs_indices( - analysis_times, - obs_times, - analysis_window, - start_inclusive=True, - end_inclusive=False - ): + analysis_times: ArrayLike, + obs_times: ArrayLike, + analysis_window: float, + start_inclusive: bool = True, + end_inclusive: bool = False + ) -> list: """Get indices of obs times for each analysis cycle to pass to jax.lax.scan Args: - analysis_times (list): List of times for all analysis window, centered + analysis_times: List of times for all analysis window, centered in middle of time window. Output of _get_all_times(). - obs_times (list): List of times for all observations. - analysis_window (float): Length of analysis window. - start_inclusive (bool): Include obs times equal to beginning of + obs_times: List of times for all observations. + analysis_window: Length of analysis window. + start_inclusive: Include obs times equal to beginning of analysis window. Default is True - end_inclusive (bool): Include obs times equal to end of + end_inclusive: Include obs times equal to end of analysis window. Default is False. Returns: - list with each element containing array of obs indices for the + List with each element containing array of obs indices for the corresponding analysis cycle. """ # Get the obs vectors for each analysis window @@ -77,60 +84,69 @@ def _get_obs_indices( return all_filtered_idx +def _time_resize( + row: ArrayLike, + size: int, + add_one: bool + ) -> np.ndarray: + new = np.array(row) + add_one + new.resize(size) + return new + + def _pad_time_indices( - obs_indices, - add_one=True - ): + obs_indices: ArrayLike, + add_one: bool = True + ) -> ArrayLike: """Pad observation indices for each analysis window. Args: - obs_indices (list): List of arrays where each array contains + obs_indices: List of arrays where each array contains obs indices for an analysis cycle. Result of _get_obs_indices. - add_one (bool): If True, will add one to all index values to encode + add_one: If True, will add one to all index values to encode indices to be masked out for DA (i.e. zeros represent indices to be masked out). Default is True. Returns: - padded_indices (array): Array of padded obs_indices, with shape: + padded_indices: Array of padded obs_indices, with shape: (num_analysis_cycles, max_obs_per_cycle). """ - - def resize(row, size, add_one): - new = np.array(row) + add_one - new.resize(size) - return new - # find longest row length row_length = max(obs_indices, key=len).__len__() - padded_indices = np.array([resize(row, row_length, add_one) for row in obs_indices]) + padded_indices = np.array([_time_resize(row, row_length, add_one) + for row in obs_indices]) return padded_indices -def _pad_obs_locs(obs_vec): +def _obs_resize( + row: ArrayLike, + size: float + ) -> np.ndarray: + new_vals_locs = np.array(np.stack(row), order='F') + new_vals_locs.resize((new_vals_locs.shape[0], size)) + mask = np.ones_like(new_vals_locs[0]).astype(int) + if size > len(row[0]): + mask[-(size-len(row[0])):] = 0 + return np.vstack([new_vals_locs, mask]).T + + +def _pad_obs_locs( + obs_vec: XarrayDatasetLike + ) -> tuple[ArrayLike, ArrayLike, ArrayLike]: """Pad observation location indices to equal spacing Args: - obs_vec (dabench.vector.ObsVector): Observation vector - object containing times, locations, and values of obs. + obs_vec: Xarray containing times, locations, and values of obs. Returns: - (vals, locs, masks): Tuple containing padded arrays of obs + Tuple containing padded arrays of obs values and locations, and binary array masks where 1 is a valid observation value/location and 0 is not. """ - - def resize(row, size): - new_vals_locs = np.array(np.stack(row), order='F') - new_vals_locs.resize((new_vals_locs.shape[0], size)) - mask = np.ones_like(new_vals_locs[0]).astype(int) - if size > len(row[0]): - mask[-(size-len(row[0])):] = 0 - return np.vstack([new_vals_locs, mask]).T - # Find longest row length row_length = max(obs_vec.values, key=len).__len__() - padded_arrays_masks = np.array([resize(row, row_length) for row in + padded_arrays_masks = np.array([_obs_resize(row, row_length) for row in np.stack([obs_vec.values, obs_vec.location_indices], axis=1)], dtype=float) diff --git a/dabench/dacycler/_var3d.py b/dabench/dacycler/_var3d.py index 271b969..533b85c 100644 --- a/dabench/dacycler/_var3d.py +++ b/dabench/dacycler/_var3d.py @@ -2,43 +2,46 @@ import numpy as np import jax.numpy as jnp +import jax import jax.scipy as jscipy +import xarray as xr +import xarray_jax as xj +from typing import Callable from dabench import dacycler +from dabench.model import Model +# For typing +ArrayLike = np.ndarray | jax.Array +XarrayDatasetLike = xr.Dataset | xj.XjDataset class Var3D(dacycler.DACycler): """Class for building 3DVar DA Cycler Attributes: - system_dim (int): System dimension. - delta_t (float): The timestep of the model (assumed uniform) - model_obj (dabench.Model): Forecast model object. - in_4d (bool): True for 4D data assimilation techniques (e.g. 4DVar). - Always False for Var3D. - ensemble (bool): True for ensemble-based data assimilation techniques - (ETKF). Always False for Var3D - B (ndarray): Initial / static background error covariance. Shape: + system_dim: System dimension. + delta_t: The timestep of the model (assumed uniform) + model_obj: Forecast model object. + B: Initial / static background error covariance. Shape: (system_dim, system_dim). If not provided, will be calculated automatically. - R (ndarray): Observation error covariance matrix. Shape + R: Observation error covariance matrix. Shape (obs_dim, obs_dim). If not provided, will be calculated automatically. - H (ndarray): Observation operator with shape: (obs_dim, system_dim). + H: Observation operator with shape: (obs_dim, system_dim). If not provided will be calculated automatically. - h (function): Optional observation operator as function. More flexible + h: Optional observation operator as function. More flexible (allows for more complex observation operator). Default is None. """ def __init__(self, - system_dim=None, - delta_t=None, - in_4d=False, - model_obj=None, - B=None, - R=None, - H=None, - h=None, + system_dim: int, + delta_t: float, + model_obj: Model, + B: ArrayLike | None = None, + R: ArrayLike | None = None, + H: ArrayLike | None = None, + h: Callable | None = None, ): super().__init__(system_dim=system_dim, @@ -48,9 +51,16 @@ def __init__(self, ensemble=False, B=B, R=R, H=H, h=h) - def _cycle_obsop(self, x0_xarray, obs_values, obs_loc_indices, - obs_time_mask, obs_loc_mask, - H=None, h=None, R=None, B=None): + def _cycle_obsop(self, + xb_ds: XarrayDatasetLike, + obs_values: ArrayLike, + obs_loc_indices: ArrayLike, + obs_time_mask: ArrayLike, + obs_loc_mask: ArrayLike, + H: ArrayLike, + h: Callable | None = None, + R: ArrayLike | None = None, + B: ArrayLike | None = None) -> XarrayDatasetLike: """When obsop (H) is linear""" if H is None and h is None: if self.H is None: @@ -71,8 +81,8 @@ def _cycle_obsop(self, x0_xarray, obs_values, obs_loc_indices, else: B = self.B - xb = x0_xarray.to_stacked_array('system',[]).data.flatten() - yo = obs_values.flatten() + xb = xb_ds.to_stacked_array('system',[]).data.flatten() + y = obs_values.flatten() # Apply masks to H H = jnp.where(obs_time_mask.flatten(), H.T, 0).T @@ -87,10 +97,10 @@ def _cycle_obsop(self, x0_xarray, obs_values, obs_loc_indices, BHt = jnp.dot(B, H.T) BHtRinv = jnp.dot(BHt, Rinv) A = I + jnp.dot(BHtRinv, H) - b1 = xb + jnp.dot(BHtRinv, yo) + b1 = xb + jnp.dot(BHtRinv, y) # Use minimization algorithm to minimize cost function: xa, ierr = jscipy.sparse.linalg.cg(A, b1, x0=xb, tol=1e-05, maxiter=1000) - return x0_xarray.assign(x=(x0_xarray.dims, xa.T)) + return xb_ds.assign(x=(xb_ds.dims, xa.T)) diff --git a/dabench/dacycler/_var4d.py b/dabench/dacycler/_var4d.py index 19a8d30..52439b2 100644 --- a/dabench/dacycler/_var4d.py +++ b/dabench/dacycler/_var4d.py @@ -11,42 +11,48 @@ from jax.scipy.sparse.linalg import bicgstab from copy import deepcopy from functools import partial +from typing import Callable import xarray as xr import xarray_jax as xj from dabench import dacycler +from dabench.model import Model import dabench.dacycler._utils as dac_utils +# For typing +ArrayLike = np.ndarray | jax.Array +XarrayDatasetLike = xr.Dataset | xj.XjDataset + class Var4D(dacycler.DACycler): """Class for building 4D DA Cycler Attributes: - system_dim (int): System dimension. - delta_t (float): The timestep of the model (assumed uniform) - model_obj (dabench.Model): Forecast model object. - in_4d (bool): True for 4D data assimilation techniques (e.g. 4DVar). + system_dim: System dimension. + delta_t: The timestep of the model (assumed uniform) + model_obj: Forecast model object. + in_4d: True for 4D data assimilation techniques (e.g. 4DVar). Always True for Var4D. - ensemble (bool): True for ensemble-based data assimilation techniques + ensemble: True for ensemble-based data assimilation techniques (ETKF). Always False for Var4D. - B (ndarray): Initial / static background error covariance. Shape: + B: Initial / static background error covariance. Shape: (system_dim, system_dim). If not provided, will be calculated automatically. - R (ndarray): Observation error covariance matrix. Shape + R: Observation error covariance matrix. Shape (obs_dim, obs_dim). If not provided, will be calculated automatically. - H (ndarray): Observation operator with shape: (obs_dim, system_dim). + H: Observation operator with shape: (obs_dim, system_dim). If not provided will be calculated automatically. - h (function): Optional observation operator as function. More flexible + h: Optional observation operator as function. More flexible (allows for more complex observation operator). Default is None. - solver (str): Name of solver to use. Default is 'bicgstab'. - n_outer_loops (int): Number of times to run through outer loop over + solver: Name of solver to use. Default is 'bicgstab'. + n_outer_loops: Number of times to run through outer loop over 4DVar. Increasing this may result in higher accuracy but slower performance. Default is 1. - steps_per_window (int): Number of timesteps per analysis window. + steps_per_window: Number of timesteps per analysis window. If None (default), will calculate automatically based on delta_t and .cycle() analysis_window length. - obs_window_indices (list): Timestep indices where observations fall + obs_window_indices: Timestep indices where observations fall within each analysis window. For example, if analysis window is 0 - 0.05 with delta_t = 0.01 and observations fall at 0, 0.01, 0.02, 0.03, 0.04, and 0.05, obs_window_indices = @@ -55,18 +61,17 @@ class Var4D(dacycler.DACycler): """ def __init__(self, - system_dim=None, - delta_t=None, - model_obj=None, - B=None, - R=None, - H=None, - h=None, - solver='bicgstab', - n_outer_loops=1, - steps_per_window=1, - obs_window_indices=None, - analysis_time_in_window=0, + system_dim: int, + delta_t: float, + model_obj: Model, + B: ArrayLike | None = None, + R: ArrayLike | None = None, + H: ArrayLike | None = None, + h: Callable | None = None, + solver: str = 'bicgstab', + n_outer_loops: int = 1, + steps_per_window: int = 1, + obs_window_indices: ArrayLike | None = None, **kwargs ): @@ -84,10 +89,11 @@ def __init__(self, model_obj=model_obj, in_4d=True, ensemble=False, - B=B, R=R, H=H, h=h, - analysis_time_in_window=analysis_time_in_window) + B=B, R=R, H=H, h=h) - def _calc_default_H(self, obs_loc_indices): + def _calc_default_H(self, + obs_loc_indices: ArrayLike + ) -> jax.Array: Hs = jnp.zeros((obs_loc_indices.shape[0], obs_loc_indices.shape[1], self.system_dim), dtype=int) @@ -96,10 +102,19 @@ def _calc_default_H(self, obs_loc_indices): ].set(1) return Hs - def _calc_default_R(self, obs_values, obs_error_sd): + def _calc_default_R(self, + obs_values: ArrayLike, + obs_error_sd: float + ) -> jax.Array: return jnp.identity(obs_values[0].shape[0])*(obs_error_sd**2) - def _calc_J_term(self, H, M, Rinv, y, x): + def _calc_J_term(self, + H: ArrayLike, + M: ArrayLike, + Rinv: ArrayLike, + y: ArrayLike, + x: ArrayLike + ) -> jax.Array: # The Jb Term (A) HM = H @ M MtHtRinv = HM.T @ Rinv @@ -109,11 +124,21 @@ def _calc_J_term(self, H, M, Rinv, y, x): return MtHtRinv @ HM, MtHtRinv @ D[:, None] @partial(jax.jit, static_argnums=[0, 1]) - def _innerloop_4d(self, system_dim, x, xb0, obs_vals, Hs, B, Rinv, M, - obs_window_indices, obs_time_mask): + def _innerloop_4d(self, + system_dim: int, + X_ds: XarrayDatasetLike, + xb0_ds: XarrayDatasetLike, + obs_vals: ArrayLike, + Hs: ArrayLike, + B: ArrayLike, + Rinv: ArrayLike, + M: ArrayLike, + obs_window_indices: ArrayLike | list, + obs_time_mask: ArrayLike + ) -> XarrayDatasetLike: """4DVar innerloop""" - x0_last = x.isel(time=0) - x = x.to_stacked_array('system',['time']) + x0_ds = X_ds.isel(time=0) + X_ar = X_ds.to_stacked_array('system',['time']) # Set up Variables SumMtHtRinvHM = jnp.zeros_like(B) # A input @@ -123,50 +148,65 @@ def _innerloop_4d(self, system_dim, x, xb0, obs_vals, Hs, B, Rinv, M, for i, j in enumerate(obs_window_indices): Jb, Jo = jax.lax.cond( obs_time_mask.at[i].get(mode='fill', fill_value=0), - lambda: self._calc_J_term(Hs.at[i].get(mode='clip'), M.data[j], - Rinv, obs_vals[i], x.data[j]), + lambda: self._calc_J_term( + Hs.at[i].get(mode='clip'), + M.data[j], + Rinv, obs_vals[i], X_ar.data[j]), lambda: (jnp.zeros_like(SumMtHtRinvHM), jnp.zeros_like(SumMtHtRinvD)) ) SumMtHtRinvHM += Jb SumMtHtRinvD += Jo # Compute initial departure - db0 = (xb0 - x0_last).to_stacked_array('system',[]).data + db0 = (xb0_ds - x0_ds).to_stacked_array('system',[]).data # Solve Ax=b for the initial perturbation dx0 = self._solve(db0, SumMtHtRinvHM, SumMtHtRinvD, B) # New x0 guess is the last guess plus the analyzed delta - x0_new = x0_last + dx0.ravel() - - return x0_new - - def _make_outerloop_4d(self, xb0, Hs, B, Rinv, - obs_values, obs_window_indices, obs_time_mask, - n_steps): - - def _outerloop_4d(x0, _): + xa0_ds = x0_ds + dx0.ravel() + + return xa0_ds + + def _make_outerloop_4d(self, + xb0_ds: XarrayDatasetLike, + Hs: ArrayLike, + B: ArrayLike, + Rinv: ArrayLike, + obs_values: ArrayLike, + obs_window_indices: ArrayLike | list, + obs_time_mask: ArrayLike, + n_steps: int + ) -> Callable: + + def _outerloop_4d(x0_ds: XarrayDatasetLike, + _: None + ) -> tuple[XarrayDatasetLike, XarrayDatasetLike]: # Get TLM and current forecast trajectory # Based on current best guess for x0 - x0 = x0.to_xarray() - x, M = self.model_obj.compute_tlm( + x0_ds = x0_ds.to_xarray() + X_ds, M = self.model_obj.compute_tlm( n_steps=n_steps, - state_vec=x0 + state_vec=x0_ds ) # 4D-Var inner loop - x0_new = self._innerloop_4d(self.system_dim, - x, xb0, obs_values, - Hs, B, Rinv, M, - obs_window_indices, - obs_time_mask) + xa0_ds = self._innerloop_4d( + self.system_dim, X_ds, xb0_ds, obs_values, + Hs, B, Rinv, M, obs_window_indices, obs_time_mask + ) - return xj.from_xarray(x0_new.assign_coords(x0.coords)), x0 + return xj.from_xarray(xa0_ds.assign_coords(x0_ds.coords)), x0_ds return _outerloop_4d @partial(jax.jit, static_argnums=0) - def _solve(self, db0, SumMtHtRinvHM, SumMtHtRinvD, B): + def _solve(self, + db0: ArrayLike, + SumMtHtRinvHM: ArrayLike, + SumMtHtRinvD: ArrayLike, + B: ArrayLike + ) -> jax.Array: """Solve the 4D-Var linear optimization Notes: @@ -195,9 +235,18 @@ def _solve(self, db0, SumMtHtRinvHM, SumMtHtRinvD, B): return dx0 - def _cycle_obsop(self, xb0, obs_values, obs_loc_indices, - obs_time_mask, obs_loc_mask, - H=None, h=None, R=None, B=None, obs_window_indices=None): + def _cycle_obsop(self, + xb0_ds: XarrayDatasetLike, + obs_values: ArrayLike, + obs_loc_indices: ArrayLike, + obs_time_mask: ArrayLike, + obs_loc_mask: ArrayLike, + H: ArrayLike | None = None, + h: Callable | None = None, + R: ArrayLike | None = None, + B: ArrayLike | None = None, + obs_window_indices = ArrayLike | list | None + ) -> XarrayDatasetLike: if H is None and h is None: if self.H is None: if self.h is None: @@ -235,14 +284,12 @@ def _cycle_obsop(self, xb0, obs_values, obs_loc_indices, # Static Variables Rinv = jscipy.linalg.inv(R) - # Best guess for x0 starts as background - x0_new = deepcopy(xb0) - outerloop_4d_func = self._make_outerloop_4d( - xb0, Hs, B, Rinv, obs_values, obs_window_indices, + xb0_ds, Hs, B, Rinv, obs_values, obs_window_indices, obs_time_mask, self.steps_per_window) - x0_new, all_x0s = jax.lax.scan(outerloop_4d_func, init=xj.from_xarray(x0_new), + xa0_ds, all_x0s = jax.lax.scan(outerloop_4d_func, + init=xj.from_xarray(xb0_ds), xs=None, length=self.n_outer_loops) - return x0_new.to_xarray() + return xa0_ds.to_xarray() diff --git a/dabench/dacycler/_var4d_backprop.py b/dabench/dacycler/_var4d_backprop.py index 7cd9a38..25e4274 100644 --- a/dabench/dacycler/_var4d_backprop.py +++ b/dabench/dacycler/_var4d_backprop.py @@ -13,68 +13,73 @@ from functools import partial import xarray as xr import xarray_jax as xj +from typing import Callable, Any from dabench import dacycler import dabench.dacycler._utils as dac_utils +from dabench.model import Model +# For typing +ArrayLike = np.ndarray | jax.Array +XarrayDatasetLike = xr.Dataset | xj.XjDataset +ScheduleState = Any class Var4DBackprop(dacycler.DACycler): """Class for building Backpropagation 4D DA Cycler Attributes: - system_dim (int): System dimension. - delta_t (float): The timestep of the model (assumed uniform) - model_obj (dabench.Model): Forecast model object. - in_4d (bool): True for 4D data assimilation techniques (e.g. 4DVar). + system_dim: System dimension. + delta_t: The timestep of the model (assumed uniform) + model_obj: Forecast model object. + in_4d: True for 4D data assimilation techniques (e.g. 4DVar). Always True for Var4DBackprop. - ensemble (bool): True for ensemble-based data assimilation techniques + ensemble: True for ensemble-based data assimilation techniques (ETKF). Always False for Var4DBackprop. - B (ndarray): Initial / static background error covariance. Shape: + B: Initial / static background error covariance. Shape: (system_dim, system_dim). If not provided, will be calculated automatically. - R (ndarray): Observation error covariance matrix. Shape + R: Observation error covariance matrix. Shape (obs_dim, obs_dim). If not provided, will be calculated automatically. - H (ndarray): Observation operator with shape: (obs_dim, system_dim). + H: Observation operator with shape: (obs_dim, system_dim). If not provided will be calculated automatically. - h (function): Optional observation operator as function. More flexible + h: Optional observation operator as function. More flexible (allows for more complex observation operator). Default is None. - num_iters (int): Number of iterations for backpropagation per analysis + num_iters: Number of iterations for backpropagation per analysis cycle. Default is 3. - steps_per_window (int): Number of timesteps per analysis window. + steps_per_window: Number of timesteps per analysis window. If None (default), will calculate automatically based on delta_t and .cycle() analysis_window length. - learning_rate (float): LR for backpropogation. Default is 0.5, but + learning_rate: LR for backpropogation. Default is 0.5, but DA results can be quite sensitive to this parameter. - lr_decay (float): Exponential learning rate decay. If set to 1, + lr_decay: Exponential learning rate decay. If set to 1, no decay. Default is 0.5. - obs_window_indices (list): Timestep indices where observations fall + obs_window_indices: Timestep indices where observations fall within each analysis window. For example, if analysis window is 0 - 0.05 with delta_t = 0.01 and observations fall at 0, 0.01, 0.02, 0.03, 0.04, and 0.05, obs_window_indices = [0, 1, 2, 3, 4, 5]. If None (default), will calculate automatically. - loss_growth_limit (float): If loss grows by more than this factor + loss_growth_limit: If loss grows by more than this factor during one analysis cycle, JAX will cut off computation and return an error. This prevents it from hanging indefinitely when loss grows exponentionally. Default is 10. """ def __init__(self, - system_dim=None, - delta_t=None, - model_obj=None, - B=None, - R=None, - H=None, - h=None, - learning_rate=0.5, - lr_decay=0.5, - num_iters=3, - steps_per_window=None, - obs_window_indices=None, - loss_growth_limit=10, - analysis_time_in_window=0, + system_dim: int, + delta_t: float, + model_obj: Model, + B: ArrayLike | None = None, + R: ArrayLike | None = None, + H: ArrayLike | None = None, + h: Callable | None = None, + learning_rate: float = 0.5, + lr_decay: float = 0.5, + num_iters: int = 3, + steps_per_window: int | None = None, + obs_window_indices: ArrayLike | list | None = None, + loss_growth_limit: float = 10, **kwargs ): @@ -94,10 +99,11 @@ def __init__(self, model_obj=model_obj, in_4d=True, ensemble=False, - B=B, R=R, H=H, h=h, - analysis_time_in_window=analysis_time_in_window) + B=B, R=R, H=H, h=h) - def _calc_default_H(self, obs_loc_indices): + def _calc_default_H(self, + obs_loc_indices: ArrayLike + ) -> jax.Array: Hs = jnp.zeros((obs_loc_indices.shape[0], obs_loc_indices.shape[1], self.system_dim), dtype=int) @@ -107,7 +113,10 @@ def _calc_default_H(self, obs_loc_indices): return Hs - def _calc_default_R(self, obs_values, obs_error_sd): + def _calc_default_R(self, + obs_values: ArrayLike, + obs_error_sd: float + ) -> jax.Array: return jnp.identity(obs_values[0].shape[0])*(obs_error_sd**2) def _raise_nan_error(self): @@ -116,30 +125,45 @@ def _raise_nan_error(self): def _raise_loss_growth_error(self): raise ValueError('Loss value has exceeded self.loss_growth_limit, exiting optimization') - def _callback_raise_error(self, error_method, loss_val): + def _callback_raise_error(self, + error_method: Callable, + loss_val: float + ) -> float: jax.debug.callback(error_method) return loss_val # @partial(jax.jit, static_argnums=[0]) - def _calc_obs_term(self, pred_x, obs_vals, Ht, Rinv): - pred_obs = pred_x @ Ht - resid = pred_obs.ravel() - obs_vals.ravel() + def _calc_obs_term(self, + X: ArrayLike, + obs_vals: ArrayLike, + Ht: ArrayLike, + Rinv: ArrayLike + ) -> jax.Array: + Y = X @ Ht + resid = Y.ravel() - obs_vals.ravel() return jnp.sum(resid.T @ Rinv @ resid) - def _make_loss(self, xb0, obs_vals, Hs, Binv, Rinv, - obs_window_indices, - obs_time_mask, n_steps): + def _make_loss(self, + xb0: XarrayDatasetLike, + obs_vals: ArrayLike, + Hs: ArrayLike, + Binv: ArrayLike, + Rinv: ArrayLike, + obs_window_indices: ArrayLike | list, + obs_time_mask: ArrayLike, + n_steps: int + ) -> Callable: """Define loss function based on 4dvar cost""" # @jax.jit - def loss_4dvarcost(x0): + def loss_4dvarcost(x0: XarrayDatasetLike) -> jax.Array: # Get initial departure db0 = (x0.to_array().data.ravel() - xb0.to_array().data.ravel()) # Make new prediction # NOTE: [1] selects the full forecast instead of last timestep only - pred_x = self._step_forecast( + X = self._step_forecast( x0, n_steps)[1].to_stacked_array('system',['time']).data # Calculate observation term of J_0 @@ -147,7 +171,7 @@ def loss_4dvarcost(x0): for i, j in enumerate(obs_window_indices): obs_term += jax.lax.cond( obs_time_mask.at[i].get(mode='fill', fill_value=0), - lambda: self._calc_obs_term(pred_x[j], obs_vals[i], + lambda: self._calc_obs_term(X[j], obs_vals[i], Hs.at[i].get(mode='clip').T, Rinv), lambda: 0.0 @@ -166,16 +190,22 @@ def loss_4dvarcost(x0): return loss_4dvarcost - def _make_backprop_epoch(self, loss_func, optimizer, hessian_inv): + def _make_backprop_epoch(self, + loss_func: Callable, + optimizer: optax.GradientTransformation, + hessian_inv: ArrayLike): loss_value_grad = value_and_grad(loss_func, argnums=0) # @jax.jit - def _backprop_epoch(epoch_state_tuple, i): - x0, init_loss, opt_state = epoch_state_tuple - x0 = x0.to_xarray() - loss_val, dx0 = loss_value_grad(x0) - x0_array = x0.to_stacked_array('system', []) + def _backprop_epoch( + epoch_state_tuple: tuple[XarrayDatasetLike, ArrayLike, ScheduleState], + i: int + ) -> tuple[tuple[XarrayDatasetLike, ArrayLike, ScheduleState], ArrayLike]: + x0_ds, init_loss, opt_state = epoch_state_tuple + x0_ds = x0_ds.to_xarray() + loss_val, dx0 = loss_value_grad(x0_ds) + x0_ar = x0_ds.to_stacked_array('system', []) dx0_hess = hessian_inv @ dx0.to_stacked_array('system',[]).data init_loss = jax.lax.cond( i == 0, @@ -188,18 +218,27 @@ def _backprop_epoch(epoch_state_tuple, i): lambda: loss_val) updates, opt_state = optimizer.update(dx0_hess, opt_state) - x0_array.data = optax.apply_updates( - x0_array.data, updates) - x0_new = x0_array.to_unstacked_dataset('system').assign_attrs( - x0.attrs + x0_ar.data = optax.apply_updates( + x0_ar.data, updates) + xa0_ds = x0_ar.to_unstacked_dataset('system').assign_attrs( + x0_ds.attrs ) - return (xj.from_xarray(x0_new), init_loss, opt_state), loss_val + return (xj.from_xarray(xa0_ds), init_loss, opt_state), loss_val return _backprop_epoch - def _cycle_obsop(self, x0_xarray, obs_values, obs_loc_indices, - obs_time_mask, obs_loc_mask, - H=None, h=None, R=None, B=None, obs_window_indices=None): + def _cycle_obsop(self, + xb0_ds: XarrayDatasetLike, + obs_values: ArrayLike, + obs_loc_indices: ArrayLike, + obs_time_mask: ArrayLike, + obs_loc_mask: ArrayLike, + H: ArrayLike | None = None, + h: Callable | None = None, + R: ArrayLike | None = None, + B: ArrayLike | None = None, + obs_window_indices = ArrayLike | list | None + ) -> XarrayDatasetLike: if H is None and h is None: if self.H is None: if self.h is None: @@ -242,7 +281,7 @@ def _cycle_obsop(self, x0_xarray, obs_values, obs_loc_indices, Binv + Hs.at[0].get().T @ Rinv @ Hs.at[0].get()) loss_func = self._make_loss( - x0_xarray, + xb0_ds, obs_values, Hs, Binv, @@ -256,16 +295,15 @@ def _cycle_obsop(self, x0_xarray, obs_values, obs_loc_indices, 1, self.lr_decay) optimizer = optax.sgd(lr) - opt_state = optimizer.init(x0_xarray.to_stacked_array('system',[]).data) + opt_state = optimizer.init(xb0_ds.to_stacked_array('system',[]).data) # Make initial forecast and calculate loss backprop_epoch_func = self._make_backprop_epoch(loss_func, optimizer, hessian_inv) - # epoch_state_tuple, loss_vals = backprop_epoch_func((xj.from_xarray(x0_xarray), 0., opt_state),0) epoch_state_tuple, loss_vals = jax.lax.scan( - backprop_epoch_func, init=(xj.from_xarray(x0_xarray), 0., opt_state), + backprop_epoch_func, init=(xj.from_xarray(xb0_ds), 0., opt_state), xs=jnp.arange(self.num_iters)) - x0_new = epoch_state_tuple[0].to_xarray() + xa0_ds = epoch_state_tuple[0].to_xarray() - return x0_new + return xa0_ds diff --git a/dabench/data/_data.py b/dabench/data/_data.py index 9b9d4a1..e1c16a3 100644 --- a/dabench/data/_data.py +++ b/dabench/data/_data.py @@ -3,6 +3,7 @@ import copy import numpy as np import jax.numpy as jnp +import jax import xarray as xr import warnings from importlib import resources @@ -11,29 +12,32 @@ from dabench import _suppl_data +# For typing +ArrayLike = np.ndarray | jax.Array + class Data(): """Generic class for data generator objects. Attributes: - system_dim (int): system dimension - time_dim (int): total time steps - original_dim (tuple): dimensions in original space, e.g. could be 3x3 + system_dim: system dimension + time_dim: total time steps + original_dim: dimensions in original space, e.g. could be 3x3 for a 2d system with system_dim = 9. Defaults to (system_dim), i.e. 1d. - random_seed (int): random seed, defaults to 37 - delta_t (float): the timestep of the data (assumed uniform) - store_as_jax (bool): Store values as jax array instead of numpy array. + random_seed: random seed, defaults to 37 + delta_t: the timestep of the data (assumed uniform) + store_as_jax: Store values as jax array instead of numpy array. Default is False (store as numpy). """ def __init__(self, - system_dim=3, - time_dim=1, - original_dim=None, - random_seed=37, - delta_t=0.01, - store_as_jax=False, - x0=None, + system_dim: int = 3, + time_dim: int = 1, + original_dim: tuple[int, ...] | None = None, + random_seed: int = 37, + delta_t: float = 0.01, + store_as_jax: bool = False, + x0: ArrayLike | None = None, **kwargs): """Initializes the base data object""" @@ -82,8 +86,14 @@ def x0_gridded(self): else: return self._x0.reshape(self.original_dim) - def generate(self, n_steps=None, t_final=None, x0=None, M0=None, - return_tlm=False, stride=None, **kwargs): + def generate(self, + n_steps: int | None = None, + t_final: float | None = None, + x0: ArrayLike | None = None, + M0: ArrayLike | None = None, + return_tlm: bool = False, + stride: int | None = None, + **kwargs) -> xr.Dataset | tuple[xr.Dataset | xr.DataArray]: """Generates a dataset and returns xarray state vector. Notes: @@ -92,24 +102,24 @@ def generate(self, n_steps=None, t_final=None, x0=None, M0=None, time_dim attributes. Args: - n_steps (int): Number of timesteps. One of n_steps OR + n_steps: Number of timesteps. One of n_steps OR t_final must be specified. - t_final (float): Final time of trajectory. One of n_steps OR + t_final: Final time of trajectory. One of n_steps OR t_final must be specified. - M0 (ndarray): the initial condition of the TLM matrix computation + x0: initial conditions state vector of shape (system_dim) + M0: the initial condition of the TLM matrix computation shape (system_dim, system_dim). - return_tlm (bool): specifies whether to compute and return the + return_tlm: specifies whether to compute and return the integrated Jacobian as a Tangent Linear Model for each timestep. - x0 (ndarray): initial conditions state vector of shape (system_dim) - stride (int): specify how many steps to skip in the output data + stride: specify how many steps to skip in the output data versus the model timestep (delta_t). **kwargs: arguments to the integrate function (permits changes in convergence tolerance, etc.). Returns: - Nothing if return_tlm=False. If return_tlm=True, a list - of TLMs corresponding to the system trajectory. + Xarray Dataset of output vector and (if return_tlm=True) + Xarray DataArray of TLMs corresponding to the system trajectory. """ # Check that n_steps or t_final is supplied @@ -204,12 +214,15 @@ def generate(self, n_steps=None, t_final=None, x0=None, M0=None, else: return out_vec - def rhs_aux(self, x, t): + def rhs_aux(self, + x: ArrayLike, + t: ArrayLike + ) -> jax.Array: """The auxiliary model used to compute the TLM. Args: - x (ndarray): State vector with size (system_dim) - t (ndarray): Array of times with size (time_dim) + x: State vector with size (system_dim) + t: Array of times with size (time_dim) Returns: dxaux (ndarray): State vector [size: (system_dim,)] @@ -228,8 +241,13 @@ def rhs_aux(self, x, t): return dxaux - def calc_lyapunov_exponents_series(self, total_time=None, rescale_time=1, - convergence=0.01, x0=None): + def calc_lyapunov_exponents_series( + self, + total_time: float | None = None, + rescale_time: float = 1, + convergence: float = 0.01, + x0: ArrayLike | None = None + ) -> ArrayLike: """Computes the spectrum of Lyapunov Exponents. Notes: @@ -246,19 +264,19 @@ def calc_lyapunov_exponents_series(self, total_time=None, rescale_time=1, Lyapunov Exponent, use self.calc_lyapunov_exponents. Args: - total_time (float) : Time to integrate over to compute LEs. + total_time: Time to integrate over to compute LEs. Usually there's a tradeoff between accuracy and computation time (more total_time leads to higher accuracy but more computation time). Default depends on model type and are based roughly on how long it takes for satisfactory convergence: For Lorenz63: n_steps=15000 (total_time=150 for delta_t=0.01) For Lorenz96: n_steps=50000 (total_time=500 for delta_t=0.01) - rescale_time (float) : Time for when the algorithm rescales the + rescale_time: Time for when the algorithm rescales the propagator to reduce the exponential growth in errors. Default is 1 (i.e. 100 timesteps when delta_t = 0.01). - convergence (float) : Prints warning if LE convergence is below + convergence: Prints warning if LE convergence is below this number. Default is 0.01. - x0 (array) : initial condition to start computing LE. Needs + x0: initial condition to start computing LE. Needs to be on the attractor (i.e., remove transients). Default is None, which will fallback to use the x0 set during model object initialization. @@ -317,27 +335,32 @@ def calc_lyapunov_exponents_series(self, total_time=None, rescale_time=1, return LE - def calc_lyapunov_exponents_final(self, total_time=None, rescale_time=1, - convergence=0.05, x0=None): + def calc_lyapunov_exponents_final( + self, + total_time: float | None = None, + rescale_time: float = 1, + convergence: float = 0.01, + x0: ArrayLike | None = None + ) -> ArrayLike: """Computes the final Lyapunov Exponents Notes: See self.calc_lyapunov_exponents_series for full info Args: - total_time (float) : Time to integrate over to compute LEs. + total_time: Time to integrate over to compute LEs. Usually there's a tradeoff between accuracy and computation time (more total_time leads to higher accuracy but more computation time). Default depends on model type and are based roughly on how long it takes for satisfactory convergence: For Lorenz63: n_steps=15000 (total_time=150 for delta_t=0.01) For Lorenz96: n_steps=50000 (total_time=500 for delta_t=0.01) - rescale_time (float) : Time for when the algorithm rescales the + rescale_time: Time for when the algorithm rescales the propagator to reduce the exponential growth in errors. Default is 1 (i.e. 100 timesteps when delta_t = 0.01). - convergence (float) : Prints warning if LE convergence is below + convergence: Prints warning if LE convergence is below this number. Default is 0.01. - x0 (array) : initial condition to start computing LE. Needs + x0: initial condition to start computing LE. Needs to be on the attractor (i.e., remove transients). Default is None, which will fallback to use the x0 set during model object initialization. @@ -346,31 +369,37 @@ def calc_lyapunov_exponents_final(self, total_time=None, rescale_time=1, Lyapunov exponents array of size (system_dim) """ - return self.calc_lyapunov_exponents_series(total_time=total_time, - rescale_time=rescale_time, - x0=x0, - convergence=convergence)[-1] - - def load_netcdf(self, filepath=None, include_vars=None, exclude_vars=None, - years_select=None, dates_select=None): + return self.calc_lyapunov_exponents_series( + total_time=total_time, + rescale_time=rescale_time, + x0=x0, + convergence=convergence)[-1] + + def load_netcdf(self, + filepath: str | None = None, + include_vars: list | ArrayLike | None = None, + exclude_vars: list | ArrayLike | None = None, + years_select: list | ArrayLike | None = None, + dates_select: list | ArrayLike | None = None + ) -> xr.Dataset: """Loads values from netCDF file, saves them in values attribute Args: - filepath (str): Path to netCDF file to load. If not given, + filepath: Path to netCDF file to load. If not given, defaults to loading ERA5 ECMWF SLP data over Japan from 2018 to 2021. - include_vars (list-like): Data variables to load from NetCDF. If + include_vars: Data variables to load from NetCDF. If None (default), loads all variables. Can be used to exclude bad variables. - exclude_vars (list-like): Data variabes to exclude from NetCDF + exclude_vars: Data variabes to exclude from NetCDF loading. If None (default), loads all vars (or only those specified in include_vars). It's recommended to only specify include_vars OR exclude_vars (unless you want to do extra typing). - years_select (list-like): Years to load (ints). If None, loads all + years_select: Years to load (ints). If None, loads all timesteps. - dates_select (list-like): Dates to load. Elements must be + dates_select: Dates to load. Elements must be datetime date or datetime objects, depending on type of time indices in NetCDF. If both years_select and dates_select are specified, time_stamps overwrites "years" argument. If @@ -414,12 +443,14 @@ def load_netcdf(self, filepath=None, include_vars=None, exclude_vars=None, ds = ds[ds.data_vars[ds.data_vars == exclude_vars]] return ds - def save_netcdf(self, ds, filename): + def save_netcdf(self, + ds: xr.Dataset, + filename: str): """Saves values in values attribute to netCDF file Args: - ds (Xarray Dataset): Xarray dataset - filepath (str): Path to netCDF file to save + ds: Xarray dataset + filepath: Path to netCDF file to save """ ds.to_netcdf(filename, mode='w') diff --git a/dabench/data/_utils.py b/dabench/data/_utils.py index 6034237..8c0d06c 100644 --- a/dabench/data/_utils.py +++ b/dabench/data/_utils.py @@ -2,27 +2,37 @@ import logging import numpy as np +import jax import jax.numpy as jnp from scipy.integrate import odeint as spodeint from jax.experimental.ode import odeint -logging.basicConfig(filename='logfile.log', level=logging.DEBUG) +# For typing +ArrayLike = np.ndarray | jax.Array + +logging.basicConfig(filename='logfile.log', level=logging.DEBUG) -def integrate(function, x0, t_final, delta_t, method='odeint', stride=None, - jax_comps=True, - **kwargs): - """ Integrate forward in time. +def integrate(function: ArrayLike, + x0: ArrayLike, + t_final: float, + delta_t: float, + method: str = 'odeint', + stride: float | None = None, + jax_comps: bool = True, + **kwargs + ) -> tuple[ArrayLike, ArrayLike]: + """Integrate forward in time. Args: - function (ndarray): the model equations to integrate - x0 (ndarray): initial conditions state vector with shape (system_dim) - t_final (float): the final absolute time + function: the model equations to integrate + x0: initial conditions state vector with shape (system_dim) + t_final: the final absolute time delta_t (float): timestep size - method (str): Integration method, one of 'odeint', 'euler', + method: Integration method, one of 'odeint', 'euler', 'adambash2', 'ode_adambash2', 'rk2'. Right now, only odeint is implemented - stride (float): stride for output data + stride: stride for output data **kwargs: keyword arguments for the integrator Returns: diff --git a/dabench/data/_xarray_accessor.py b/dabench/data/_xarray_accessor.py index 8e2ec3a..d0fe028 100644 --- a/dabench/data/_xarray_accessor.py +++ b/dabench/data/_xarray_accessor.py @@ -3,7 +3,9 @@ import warnings -def _check_split_lengths(xr_obj, split_lengths): +def _check_split_lengths( + xr_obj: xr.Dataset | xr.DataArray, + split_lengths: list | np.ndarray): total_length = np.sum(split_lengths) xr_timedim = xr_obj.sizes['time'] if xr_timedim < total_length: @@ -21,17 +23,20 @@ def _check_split_lengths(xr_obj, split_lengths): @xr.register_dataset_accessor('dab') class DABenchDatasetAccessor: """Helper methods for manipulating xarray Datasets""" - def __init__(self, xarray_obj): + def __init__(self, + xarray_obj: xr.Dataset): self._obj = xarray_obj - def flatten(self): + def flatten(self) -> xr.DataArray: if 'time' in self._obj.coords: remaining_dim = ['time'] else: remaining_dim = [] return self._obj.to_stacked_array('system', remaining_dim) - def split_train_val_test(self, split_lengths): + def split_train_val_test(self, + split_lengths: list | np.ndarray + ) -> tuple[xr.Dataset, ...]: if (np.array(split_lengths) > 1.0).sum() == 0: # Assuming split_lengths is provided as fraction split_lengths = np.round( @@ -49,13 +54,16 @@ def split_train_val_test(self, split_lengths): @xr.register_dataarray_accessor('dab') class DABenchDataArrayAccessor: """Helper methods for manipulating xarray DataArrays""" - def __init__(self, xarray_obj): + def __init__(self, + xarray_obj: xr.DataArray): self._obj = xarray_obj - def unflatten(self): + def unflatten(self) -> xr.Dataset: return self._obj.to_unstacked_dataset('system') - def split_train_val_test(self, split_lengths): + def split_train_val_test(self, + split_lengths: list | np.ndarray + ) -> tuple[xr.DataArray, ...]: if (np.array(split_lengths) > 1.0).sum() == 0: # Assuming split_lengths is provided as fraction split_lengths = np.round( @@ -68,5 +76,3 @@ def split_train_val_test(self, split_lengths): end_i = start_i + sl out_ds.append(self._obj.isel(time=slice(start_i, end_i))) return tuple(out_ds) - - diff --git a/dabench/data/barotropic.py b/dabench/data/barotropic.py index 7dbc85d..d742978 100644 --- a/dabench/data/barotropic.py +++ b/dabench/data/barotropic.py @@ -6,7 +6,9 @@ import logging import numpy as np from copy import deepcopy +import jax import jax.numpy as jnp +import xarray as xr from dabench.data import _data @@ -23,6 +25,9 @@ 'For more information: https://pyqg.readthedocs.io/en/latest/installation.html' ) +# For typing +ArrayLike = np.ndarray | jax.Array + class Barotropic(_data.Data): """ Class to set up barotropic model @@ -39,46 +44,47 @@ class Barotropic(_data.Data): pp 21-43 doi:10.1017/S0022112084001750. Attributes: - system_dim (int): system dimension - beta (float): Gradient of coriolis parameter. Units: meters^-1 * + system_dim: system dimension + beta: Gradient of coriolis parameter. Units: meters^-1 * seconds^-1. Default is 0. - rek (float): Linear drag in lower layer. Units: seconds^-1. + rek: Linear drag in lower layer. Units: seconds^-1. Default is 0. - rd (float): Deformation radius. Units: meters. Default is 0. - H (float): Layer thickness. Units: meters. Default is 1. - nx (int): Number of grid points in the x direction. Default is 256. - ny (int): Number of grid points in the y direction. Default: nx. - L (float): Domain length in x direction. Units: meters. Default is + rd: Deformation radius. Units: meters. Default is 0. + H: Layer thickness. Units: meters. Default is 1. + nx: Number of grid points in the x direction. Default is 256. + ny: Number of grid points in the y direction. Default: nx. + L: Domain length in x direction. Units: meters. Default is 2*pi. - W (float): Domain width in y direction. Units: meters. Default: L. + x0: the initial conditions. Can also be + provided when initializing model object. If provided by + both, the generate() arg takes precedence. + W: Domain width in y direction. Units: meters. Default: L. filterfac (float): amplitdue of the spectral spherical filter. Default is 23.6. - delta_t (float): Numerical timestep. Units: seconds. - taveint (float): Time interval for accumulation of diagnostic averages. + delta_t: Numerical timestep. Units: seconds. + taveint: Time interval for accumulation of diagnostic averages. For performance purposes, averaging does not have to occur every timestep. Units: seconds. Default is 1 (i.e. every 1000 timesteps when delta_t = 0.001) - ntd (int): Number of threads to use. Should not exceed the number of + ntd: Number of threads to use. Should not exceed the number of cores on your machine. - store_as_jax (bool): Store values as jax array instead of numpy array. + store_as_jax: Store values as jax array instead of numpy array. Default is False (store as numpy). """ def __init__(self, - beta=0., - rek=0., - rd=0., - H=1., - L=2*np.pi, - x0=None, - nx=256, - ny=None, - delta_t=0.001, - taveint=1, - ntd=1, - time_dim=None, - values=None, - times=None, - store_as_jax=False, + beta: float = 0., + rek: float = 0., + rd: float = 0., + H: float = 1., + nx: int = 256, + ny: int | None = None, + L: float = 2*np.pi, + x0: ArrayLike | None = None, + delta_t: float = 0.001, + taveint: float = 1, + ntd: int = 1, + time_dim: int | None = None, + store_as_jax: bool = False, **kwargs): """ Initializes Barotropic object, subclass of Data @@ -101,11 +107,16 @@ def __init__(self, system_dim = self.m.q.size super().__init__(system_dim=system_dim, time_dim=time_dim, - values=values, times=times, delta_t=delta_t, + delta_t=delta_t, store_as_jax=store_as_jax, x0=x0, **kwargs) - def generate(self, n_steps=None, t_final=40, x0=None): + # TODO: Change to produce xarray dataset instead of updating values att. + def generate(self, + n_steps: int | None = None, + t_final: float = 40, + x0: ArrayLike | None = None + ) -> xr.Dataset: """Generates values and times, saves them to the data object Notes: @@ -114,11 +125,11 @@ def generate(self, n_steps=None, t_final=40, x0=None): time_dim attributes. Args: - n_steps (int): Number of timesteps. Default is None, which sets + n_steps: Number of timesteps. Default is None, which sets n_steps to t_final/delta_t - t_final (float): Final time of trajectory. Default is 40, which + t_final: Final time of trajectory. Default is 40, which results in n_steps = 40000 - x0 (ndarray, optional): the initial conditions. Can also be + x0: the initial conditions. Can also be provided when initializing model object. If provided by both, the generate() arg takes precedence. """ @@ -185,6 +196,7 @@ def generate(self, n_steps=None, t_final=40, x0=None): self.time_dim = qs.shape[0] self.values = qs.reshape((self.time_dim, -1)) + # TODO: Remove? I believe this is deprecated def forecast(self, n_steps=None, t_final=None, x0=None): """Alias for self.generate(), except returns values as output""" self.generate(n_steps, t_final, x0) diff --git a/dabench/data/enso_indices.py b/dabench/data/enso_indices.py index e86d957..edfb02a 100644 --- a/dabench/data/enso_indices.py +++ b/dabench/data/enso_indices.py @@ -4,6 +4,7 @@ import ssl import logging import warnings +import jax import jax.numpy as jnp import numpy as np import textwrap @@ -19,11 +20,11 @@ class ENSOIndices(_data.Data): Source: https://www.cpc.ncep.noaa.gov/data/indices/ Attributes: - system_dim (int): system dimension - time_dim (int): total time steps - store_as_jax (bool): Store values as jax array instead of numpy array. + system_dim: system dimension + time_dim: total time steps + store_as_jax: Store values as jax array instead of numpy array. Default is False (store as numpy). - file_dict (dict): Lists of files to get. Dict keys are type of data: + file_dict: Lists of files to get. Dict keys are type of data: 'wnd': Wind 'slp': Sea level pressure 'soi': Southern Oscillation Index @@ -37,7 +38,7 @@ class ENSOIndices(_data.Data): Dict values are individual files from the website, see full list at https://www.cpc.ncep.noaa.gov/data/indices/ Default is {'wnd': ['zwnd200'], 'slp': ['darwin']} - var_types (dict): List of variables within file to get. Dict keys are + var_types: List of variables within file to get. Dict keys are type of data (see list in file_dict description). Dict values are type of variable: 'ori' = Original @@ -53,8 +54,13 @@ class ENSOIndices(_data.Data): 'olr'='ori', 'cpolr'='ano' """ - def __init__(self, file_dict=None, var_types=None, system_dim=None, - time_dim=None, store_as_jax=False, **kwargs): + def __init__(self, + file_dict: dict | None = None, + var_types: dict | None = None, + system_dim: int | None = None, + time_dim: int | None = None, + store_as_jax: bool = False, + **kwargs): """Initialize ENSOIndices object, subclass of Base""" @@ -64,13 +70,13 @@ def __init__(self, file_dict=None, var_types=None, system_dim=None, values=None, delta_t=None, **kwargs, store_as_jax=store_as_jax) - def generate(self): + def generate(self) -> xr.Dataset: """Alias for _load_gcp_era5""" warnings.warn('ENSOIndices.generate() is an alias for the load() method. ' 'Proceeding with downloading ENSO Indices data...') return self.load() - def load(self): + def load(self) -> xr.Dataset: # Full list of file names at bottom of this page: # https://www.cpc.ncep.noaa.gov/data/indices/Readme.index.shtml @@ -159,21 +165,28 @@ def load(self): return ds - def _download_cpc_vals(self, file_name, var, var_types, var_types_full, - all_vals, all_years): + def _download_cpc_vals( + self, + file_name: str, + var: str, + var_types: dict, + var_types_full: dict, + all_vals: dict, + all_years: dict + ) -> tuple[dict, dict]: """Downloads data for one file_name and variable pair Args: - file_name (str): CPC file name. - var (str): Variable name, e.g. 'wnd', 'slp', etc. - var_types (dict): Types of variables to get for each variable name, + file_name: CPC file name. + var: Variable name, e.g. 'wnd', 'slp', etc. + var_types: Types of variables to get for each variable name, e.g. 'ori' (original), 'ano' (anomaly), etc. - var_types_full (dict): Types of variables available for each + var_types_full: Types of variables available for each variable name, used to help with parsing. - all_vals (dict): Dictionary of variable names and corresponding + all_vals: Dictionary of variable names and corresponding values downloaded so far. This method adds new variables and returns. - all_years (dict): Dictionary of variable names and corresponding + all_years: Dictionary of variable names and corresponding years downloaded so far. This method adds new variables and returns. @@ -256,14 +269,17 @@ def _download_cpc_vals(self, file_name, var, var_types, var_types_full, return all_vals, all_years - def _combine_vals_years(self, all_vals, all_years): + def _combine_vals_years( + self, + all_vals: dict, + all_years: dict + ) -> tuple[jax.Array, jax.Array, list]: """Merges all_vals and all_years dicts into ndarrays Args: - all_vals (dict): Dictionary of downloaded variable names and + all_vals: Dictionary of downloaded variable names and corresponding values. - all_years (dict): Dictionary of variable names and corresponding - all_years (dict): Dictionary of downloaded variable names and + all_years: Dictionary of downloaded variable names and corresponding years. Returns: @@ -290,12 +306,15 @@ def _combine_vals_years(self, all_vals, all_years): common_years)] return common_vals, common_years - def _get_vals(self, tmp, n_header): + def _get_vals(self, + tmp: list, + n_header: int + ) -> tuple[jax.Array, jax.Array]: """Parses text lines from files of most data types Args: - tmp (list): List of lines from text file - n_header (int): Number of lines in header at top of file and + tmp: List of lines from text file + n_header: Number of lines in header at top of file and between blocks. Returns: @@ -327,11 +346,13 @@ def _get_vals(self, tmp, n_header): return vals, years - def _get_eqsoi(self, tmp): + def _get_eqsoi(self, + tmp: list + ) -> tuple[jax.Array, jax.Array]: """Parses text lines from eqsoi file. Args: - tmp (list): List of lines from text file + tmp: List of lines from text file Returns: Tuple of data values (ndarray) and years (ndarray). @@ -348,12 +369,15 @@ def _get_eqsoi(self, tmp): return vals, years - def _get_sst(self, tmp, var_types_indices): + def _get_sst(self, + tmp: list, + var_types_indices: list + ) -> tuple[jax.Array, jax.Array]: """Parses text lines from sst file. Args: - tmp (list): List of lines from text file - var_types_indices (list): List of variable type indices. Variable + tmp: List of lines from text file + var_types_indices: List of variable type indices. Variable types are: ['nino12', 'nino12_ano', 'nino3', 'nino3_ano', 'nino4', 'nino4_ano', 'nino34', 'nino34_ano'] [0] is 'nino12' only, [1] is 'nino12_ano', [0, 2] is 'nino12' diff --git a/dabench/data/gcp.py b/dabench/data/gcp.py index 3e8727c..36508c3 100644 --- a/dabench/data/gcp.py +++ b/dabench/data/gcp.py @@ -25,39 +25,35 @@ class GCP(_data.Data): Data is hourly Attributes: - system_dim (int): System dimension - time_dim (int): Total time steps - variables (list of strings): Names of ERA5 variables to + variables: Names of ERA5 variables to load. For description of variables, see: https://github.com/google-research/arco-era5?tab=readme-ov-file#full_37-1h-0p25deg-chunk-1zarr-v3 Default is ['2m_temperature'] (Air temperature at 2 metres) - date_start (string): Start of time range to download, in 'yyyy-mm-dd' + date_start: Start of time range to download, in 'yyyy-mm-dd' format. Can also just specify year ('yyyy') or year and month ('yyyy-mm'). Default is '2020-06-01'. - date_end (string): End of time range to download, in 'yyyy-mm-dd' + date_end: End of time range to download, in 'yyyy-mm-dd' format. Can also just specify year ('yyyy') or year and month ('yyyy-mm'). Default is '2020-9-30'. - min_lat (float): Minimum latitude for bounding box. If None, loads + min_lat: Minimum latitude for bounding box. If None, loads global data (which can be VERY large). Bounding box default covers Cuba. - max_lat (float): Max latitude for bounding box (see min_lat for info). - min_lon (float): Min latitude for bounding box (see min_lat for info). - max_lon (float): Max latitude for bounding box (see min_lat for info). - store_as_jax (bool): Store values as jax array instead of numpy array. + max_lat: Max latitude for bounding box (see min_lat for info). + min_lon: Min latitude for bounding box (see min_lat for info). + max_lon: Max latitude for bounding box (see min_lat for info). + store_as_jax: Store values as jax array instead of numpy array. Default is False (store as numpy). """ def __init__( self, - variables=['2m_temperature'], - date_start='2020-01-01', - date_end='2020-12-31', - min_lat=19.8554808619, - max_lat=23.1886107447, - min_lon=-84.9749110583, - max_lon=-74.1780248685, - system_dim=None, - time_dim=None, - store_as_jax=False, + variables: list[str] = ['2m_temperature'], + date_start: str = '2020-01-01', + date_end: str = '2020-12-31', + min_lat: float = 19.8554808619, + max_lat: float = 23.1886107447, + min_lon: float = -84.9749110583, + max_lon: float = -74.1780248685, + store_as_jax: bool = False, **kwargs ): @@ -69,13 +65,12 @@ def __init__( self.min_lon = min_lon self.max_lon = max_lon - super().__init__(system_dim=system_dim, time_dim=time_dim, - values=None, delta_t=None, store_as_jax=store_as_jax, - x0=None, + super().__init__(values=None, delta_t=None, + store_as_jax=store_as_jax, x0=None, **kwargs) - def _load_gcp_era5(self): + def _load_gcp_era5(self) -> xr.Dataset: """Load ERA5 data from Google Cloud Platform""" url = 'http://storage.googleapis.com/gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3' @@ -117,12 +112,12 @@ def _load_gcp_era5(self): return ds - def generate(self): + def generate(self) -> xr.Dataset: """Alias for _load_gcp_era5""" warnings.warn('GCP.generate() is an alias for the load() method. ' 'Proceeding with downloading ERA5 data from GCP...') return self._load_gcp_era5() - def load(self): + def load(self) -> xr.Dataset: """Alias for _load_gcp_era5""" return self._load_gcp_era5() diff --git a/dabench/data/lorenz63.py b/dabench/data/lorenz63.py index 1021a62..48c1423 100644 --- a/dabench/data/lorenz63.py +++ b/dabench/data/lorenz63.py @@ -1,46 +1,50 @@ """Lorenz 1963 3-variable model data generation""" import logging +import numpy as np +import jax import jax.numpy as jnp from dabench.data import _data logging.basicConfig(filename='logfile.log', level=logging.DEBUG) +# For typing +ArrayLike = np.ndarray | jax.Array class Lorenz63(_data.Data): """ Class to set up Lorenz 63 model data Attributes: - sigma (float): Lorenz 63 param. Default is 10., the original value + sigma: Lorenz 63 param. Default is 10., the original value used in Lorenz, 1963. https://doi.org/10.1175/1520-0469(1963)020<0130:DNF>2.0.CO;2 - rho (float): Lorenz 63 param. Default is 28., the value used in + rho: Lorenz 63 param. Default is 28., the value used in Lorenz, 1963 (see DOI above) - beta (float): Lorenz 63 param. Default is 8./3., the value used in + beta: Lorenz 63 param. Default is 8./3., the value used in Lorenz, 1963 (see DOI above) - x0 (ndarray, float): Initial state, array of floats of size + delta_t: length of one time step + x0: Initial state, array of floats of size (system_dim). Default is jnp.array([-10.0, -15.0, 21.3]), which is the system state after a 6000 step spinup with delta_t=0.01 and initial conditions [0., 1., 0.], a spinup which replicates the simulation described in Lorenz, 1963. - system_dim (int): system dimension. Must be 3 for Lorenz63. - time_dim (int): total time steps - delta_t (float): length of one time step - store_as_jax (bool): Store values as jax array instead of numpy array. + system_dim: system dimension. Must be 3 for Lorenz63. + time_dim: total time steps + store_as_jax: Store values as jax array instead of numpy array. Default is False (store as numpy). """ def __init__(self, - sigma=10., - rho=28., - beta=8./3., - delta_t=0.01, - x0=jnp.array([-10.0, -15.0, 21.3]), - system_dim=3, - time_dim=None, - values=None, - store_as_jax=False, + sigma: float = 10., + rho: float = 28., + beta: float = 8./3., + delta_t: float = 0.01, + x0: ArrayLike | None = jnp.array([-10.0, -15.0, 21.3]), + system_dim: int = 3, + time_dim: int | None = None, + values: ArrayLike | None = None, + store_as_jax: bool = False, **kwargs): """Initialize Lorenz63 object, subclass of Base""" @@ -65,7 +69,10 @@ def __init__(self, # Initial conditions self.x0 = x0 - def rhs(self, x, t=None): + def rhs(self, + x: ArrayLike, + t: ArrayLike | None = None + ) -> jax.Array: """vector field of Lorenz 63 Args: @@ -84,8 +91,10 @@ def rhs(self, x, t=None): return dx - def Jacobian(self, x): - """ Jacobian of the L63 system + def Jacobian(self, + x: ArrayLike + ) -> jax.Array: + """Jacobian of the L63 system Args: x: state vector with shape (system_dim) diff --git a/dabench/data/lorenz96.py b/dabench/data/lorenz96.py index 9f8ec84..991b21e 100644 --- a/dabench/data/lorenz96.py +++ b/dabench/data/lorenz96.py @@ -1,11 +1,16 @@ """Lorenz 96 model data generator""" import logging +import jax import jax.numpy as jnp +import numpy as np from dabench.data import _data logging.basicConfig(filename='logfile.log', level=logging.DEBUG) +# For typing +ArrayLike = np.ndarray | jax.Array + class Lorenz96(_data.Data): """Class to set up Lorenz 96 model data. @@ -15,9 +20,9 @@ class Lorenz96(_data.Data): eapsweb.mit.edu/sites/default/files/Predicability_a_Problem_2006.pdf Attributes: - forcing_term (float): Forcing constant for Lorenz96, prevents energy + forcing_term: Forcing constant for Lorenz96, prevents energy from decaying to 0. Default is 8.0. - x0 (ndarray, float): Initial state vector, array of floats of size + x0: Initial state vector, array of floats of size (system_dim). For system_dim of 5, 6, or 36, defaults are the final state of a 14400 timestep spinup starting with an initial state (x0) of all 0s except the first element, which is set to 0.01. @@ -26,23 +31,23 @@ class Lorenz96(_data.Data): delta_t = 0.01 (more frequently used today) For all other system_dim settings, default is all 0s except the first element, which is set to 0.01. - system_dim (int): System dimension, must be between 4 and 40. + system_dim: System dimension, must be between 4 and 40. Default is 36. - time_dim (int): Total time steps - delta_t (float): Length of one time step. Default is 0.05 from + time_dim: Total time steps + delta_t: Length of one time step. Default is 0.05 from Lorenz, 1996, but on modern computers 0.01 is often used. - store_as_jax (bool): Store values as jax array instead of numpy array. + store_as_jax: Store values as jax array instead of numpy array. Default is False (store as numpy). """ def __init__(self, - forcing_term=8., - delta_t=0.05, - x0=None, - system_dim=36, - time_dim=None, - values=None, - store_as_jax=False, + forcing_term: float = 8., + delta_t: float = 0.05, + x0: ArrayLike | None = None, + system_dim: int = 36, + time_dim: int | None = None, + values: ArrayLike | None = None, + store_as_jax: bool = False, **kwargs): """Initialize Lorenz96 object, subclass of Base""" @@ -113,7 +118,10 @@ def __init__(self, jnp.zeros(system_dim-1)]) self.x0 = x0 - def rhs(self, x, t=None): + def rhs(self, + x: ArrayLike, + t: ArrayLike | None = None + ) -> jax.Array: """Computes vector field of Lorenz 96 Args: @@ -144,7 +152,9 @@ def rhs(self, x, t=None): return dx - def Jacobian(self, x): + def Jacobian(self, + x: ArrayLike + ) -> jax.Array: """Computes the Jacobian of the Lorenz96 system Args: diff --git a/dabench/data/pyqg_jax.py b/dabench/data/pyqg_jax.py index b118fb0..812250a 100644 --- a/dabench/data/pyqg_jax.py +++ b/dabench/data/pyqg_jax.py @@ -6,14 +6,19 @@ import logging import numpy as np from copy import deepcopy -import jax import functools +import jax +import jax.numpy as jnp +import xarray as xr import jax.numpy as jnp from dabench.data import _data logging.basicConfig(filename='logfile.log', level=logging.DEBUG) +# For typing +ArrayLike = np.ndarray | jax.Array + try: import pyqg_jax except ImportError: @@ -37,45 +42,38 @@ class PyQGJax(_data.Data): https://pyqg.readthedocs.io/en/latest/api.html#pyqg.QGModel Attributes: - beta (float): Gradient of coriolis parameter. Units: meters^-1 * + beta: Gradient of coriolis parameter. Units: meters^-1 * seconds^-1 - rek (float): Linear drag in lower layer. Units: seconds^-1 - rd (float): Deformation radius. Units: meters. - delta (float): Layer thickness ratio (H1/H2) - U1 (float): Upper layer flow. Units: m/s - U2 (float): Lower layer flow. Units: m/s - H1 (float): Layer thickness (sets both H1 and H2). - nx (int): Number of grid points in the x direction. - ny (int): Number of grid points in the y direction (default: nx). - L (float): Domain length in x direction. Units: meters. - W (float): Domain width in y direction. Units: meters (default: L). - filterfac (float): amplitdue of the spectral spherical filter - (originally 18.4, later changed to 23.6). - delta_t (float): Numerical timestep. Units: seconds. - tmax (float): Total time of integration (overwritten by t_final). - Units: seconds. - ntd (int): Number of threads to use. Should not exceed the number of - cores on your machine. - store_as_jax (bool): Store values as jax array instead of numpy array. + rd: Deformation radius. Units: meters. + delta: Layer thickness ratio (H1/H2) + H1: Layer thickness (sets both H1 and H2 if H2 not specified). + H2: Layer 2 thickness. + U1: Upper layer flow. Units: m/s + U2: Lower layer flow. Units: m/s + x0: the initial conditions. Can also be + provided when initializing model object. If provided by + both, the generate() arg takes precedence. + nx: Number of grid points in the x direction. + ny: Number of grid points in the y direction (default: nx). + delta_t: Numerical timestep. Units: seconds. + store_as_jax: Store values as jax array instead of numpy array. Default is False (store as numpy). """ def __init__(self, - beta=1.5e-11, - rd=15000.0, - delta=0.25, - H1=500, - H2=None, - U1=0.025, - U2=0.0, - x0=None, - nx=64, - ny=None, - delta_t=7200, - random_seed=37, - time_dim=None, - values=None, - times=None, - store_as_jax=False, + beta: float = 1.5e-11, + rd: float = 15000.0, + delta: float = 0.25, + H1: float = 500, + H2: float | None = None, + U1: float = 0.025, + U2: float = 0.0, + x0: ArrayLike | None = None, + nx: int = 64, + ny: int | None = None, + delta_t: float = 7200, + random_seed: int = 37, + time_dim: int | None = None, + store_as_jax: bool = False, **kwargs): """ Initialize PyQGJax QGModel object, subclass of Base @@ -112,8 +110,8 @@ def __init__(self, jax.random.PRNGKey(0) ) super().__init__(system_dim=system_dim, original_dim=original_dim, - time_dim=time_dim, values=values, times=times, - delta_t=delta_t, store_as_jax=store_as_jax, x0=x0, + time_dim=time_dim, delta_t=delta_t, + store_as_jax=store_as_jax, x0=x0, **kwargs) @functools.partial(jax.jit, static_argnames=["self", "num_steps"]) @@ -133,7 +131,9 @@ def loop_fn(carry, _x): return traj_steps - def _spec_var(self, ph): + def _spec_var(self, + ph: np.ndarray + ) -> float: """Compute variance of p from Fourier coefficients ph Note: Taken from original pyqg package: @@ -147,7 +147,12 @@ def _spec_var(self, ph): return var_dens.sum() - def generate(self, n_steps=None, t_final=None, x0=None): + # TODO: Change to produce xarray dataset instead of updating values att. + def generate(self, + n_steps: int | None = None, + t_final: float = 40, + x0: ArrayLike | None = None + ) -> xr.Dataset: """Generates values and times, saves them to the data object Notes: @@ -156,11 +161,11 @@ def generate(self, n_steps=None, t_final=None, x0=None): time_dim attributes. Args: - n_steps (int): Number of timesteps. One of n_steps OR + n_steps: Number of timesteps. One of n_steps OR t_final must be specified. - t_final (float): Final time of trajectory. One of n_steps OR + t_final: Final time of trajectory. One of n_steps OR t_final must be specified. - x0 (ndarray, optional): the initial conditions. Can also be + x0: the initial conditions. Can also be provided when initializing model object. If provided by both, the generate() arg takes precedence. """ @@ -223,6 +228,7 @@ def generate(self, n_steps=None, t_final=None, x0=None): self.time_dim = qs.shape[0] self.values = qs.reshape((self.time_dim, -1)) + # TODO: Remove? Believe this is deprecated def forecast(self, n_steps=None, t_final=None, x0=None): """Alias for self.generate(), except returns values as output""" self.generate(n_steps, t_final, x0) diff --git a/dabench/data/qgs.py b/dabench/data/qgs.py index 54f1bbc..f256612 100644 --- a/dabench/data/qgs.py +++ b/dabench/data/qgs.py @@ -7,6 +7,7 @@ import numpy as np from copy import deepcopy import xarray as xr +import jax import jax.numpy as jnp from dabench.data._utils import integrate @@ -27,6 +28,9 @@ 'For more information: https://qgs.readthedocs.io/en/latest/files/general_information.html' ) +# For typing +ArrayLike = np.ndarray | jax.Array + class QGS(_data.Data): """ Class to set up QGS quasi-geostrophic model @@ -35,25 +39,23 @@ class QGS(_data.Data): See https://qgs.readthedocs.io/ Attributes: - model_params (QgParams): qgs parameter object. See: + model_params: qgs parameter object. See: https://qgs.readthedocs.io/en/latest/files/technical/configuration.html#qgs.params.params.QgParams If None, will use defaults specified by: De Cruz, et al. (2016). Geosci. Model Dev., 9, 2793-2808. - delta_t (float): Numerical timestep. Units: seconds. - store_as_jax (bool): Store values as jax array instead of numpy array. + x0: Initial state vector, array of floats. Default is: + delta_t: Numerical timestep. Units: seconds. + store_as_jax: Store values as jax array instead of numpy array. Default is False (store as numpy). - x0 (ndarray): Initial state vector, array of floats. Default is: """ def __init__(self, - model_params=None, - x0=None, - delta_t=0.1, - system_dim=None, - time_dim=None, - values=None, - times=None, - store_as_jax=False, - random_seed=37, + model_params: QgParams | None = None, + x0: ArrayLike | None = None, + delta_t: ArrayLike | None = 0.1, + system_dim: int | None = None, + time_dim: int | None = None, + store_as_jax: bool = False, + random_seed: int = 37, **kwargs): """ Initialize qgs object, subclass of Base @@ -85,13 +87,12 @@ def __init__(self, x0 = self._rng.random(system_dim)*0.001 super().__init__(system_dim=system_dim, time_dim=time_dim, - values=values, times=times, delta_t=delta_t, - store_as_jax=store_as_jax, x0=x0, + delta_t=delta_t, store_as_jax=store_as_jax, x0=x0, **kwargs) self.f, self.Df = create_tendencies(self.model_params) - def _create_default_qgparams(self): + def _create_default_qgparams(self) -> QgParams: model_params = QgParams() # Mode truncation at the wavenumber 2 in both x and y spatial @@ -117,11 +118,14 @@ def _create_default_qgparams(self): return model_params - def rhs(self, x, t=0): + def rhs(self, + x: ArrayLike, + t: float | None = 0 + ) -> np.ndarray: """Vector field (tendencies) of qgs system Arg: - x (ndarray): State vector, shape: (system_dim) + x: State vector, shape: (system_dim) t: times vector. Required as argument slot for some numerical integrators but unused. Returns: @@ -133,11 +137,14 @@ def rhs(self, x, t=0): return dx - def Jacobian(self, x, t=0): + def Jacobian(self, + x: ArrayLike, + t: float | None = 0 + ) -> np.ndarray: """Jacobian of the qgs system Arg: - x (ndarray): State vector, shape: (system_dim) + x: State vector, shape: (system_dim) t: times vector. Required as argument slot for some numerical integrators but unused. @@ -150,8 +157,14 @@ def Jacobian(self, x, t=0): return J - def generate(self, n_steps=None, t_final=None, x0=None, M0=None, - return_tlm=False, stride=None, **kwargs): + def generate(self, + n_steps: int | None = None, + t_final: float | None = None, + x0: ArrayLike | None = None, + M0: ArrayLike | None = None, + return_tlm: bool = False, + stride: int | None = None, + **kwargs) -> xr.Dataset | tuple[xr.Dataset | xr.DataArray]: """Generates a dataset and assigns values and times to the data object. Notes: @@ -176,8 +189,8 @@ def generate(self, n_steps=None, t_final=None, x0=None, M0=None, convergence tolerance, etc.). Returns: - Nothing if return_tlm=False. If return_tlm=True, a list - of TLMs corresponding to the system trajectory. + Xarray Dataset of output vector and (if return_tlm=True) + Xarray DataArray of TLMs corresponding to the system trajectory. """ # Check that n_steps or t_final is supplied @@ -271,34 +284,19 @@ def generate(self, n_steps=None, t_final=None, x0=None, M0=None, return out_vec, M else: return out_vec - # self.values = y[:, :self.system_dim] - # self.times = t - # self.time_dim = len(t) - - # # Return the data series and associated TLMs if requested - # if return_tlm: - # # Reshape M matrix - # M = np.reshape(y[:, self.system_dim:], - # (self.time_dim, - # self.system_dim, - # self.system_dim) - # ) - - # if self.store_as_jax: - # return M - # else: - # return np.array(M) - - def rhs_aux(self, x, t): + + def rhs_aux(self, + x: ArrayLike, + t: ArrayLike + ) -> jax.Array: """The auxiliary model used to compute the TLM. Args: - x (ndarray): State vector with size (system_dim) - t (ndarray): Array of times with size (time_dim) + x: State vector with size (system_dim) + t: Array of times with size (time_dim) Returns: dxaux (ndarray): State vector [size: (system_dim,)] - """ # Compute M dxdt = self.rhs(x[:self.system_dim], t) @@ -313,8 +311,13 @@ def rhs_aux(self, x, t): return dxaux - def calc_lyapunov_exponents_series(self, total_time=None, rescale_time=1, - convergence=0.01, x0=None): + def calc_lyapunov_exponents_series( + self, + total_time: float | None = None, + rescale_time: float = 1, + convergence: float = 0.01, + x0: ArrayLike | None = None + ) -> ArrayLike: """Computes the spectrum of Lyapunov Exponents. Notes: @@ -331,19 +334,19 @@ def calc_lyapunov_exponents_series(self, total_time=None, rescale_time=1, Lyapunov Exponent, use self.calc_lyapunov_exponents. Args: - total_time (float) : Time to integrate over to compute LEs. + total_time: Time to integrate over to compute LEs. Usually there's a tradeoff between accuracy and computation time (more total_time leads to higher accuracy but more computation time). Default depends on model type and are based roughly on how long it takes for satisfactory convergence: For Lorenz63: n_steps=15000 (total_time=150 for delta_t=0.01) For Lorenz96: n_steps=50000 (total_time=500 for delta_t=0.01) - rescale_time (float) : Time for when the algorithm rescales the + rescale_time: Time for when the algorithm rescales the propagator to reduce the exponential growth in errors. Default is 1 (i.e. 100 timesteps when delta_t = 0.01). - convergence (float) : Prints warning if LE convergence is below + convergence: Prints warning if LE convergence is below this number. Default is 0.01. - x0 (array) : initial condition to start computing LE. Needs + x0: initial condition to start computing LE. Needs to be on the attractor (i.e., remove transients). Default is None, which will fallback to use the x0 set during model object initialization. @@ -352,7 +355,6 @@ def calc_lyapunov_exponents_series(self, total_time=None, rescale_time=1, Lyapunov exponents for all timesteps, array of size (total_time/rescale_time - 1, system_dim) """ - # Set total_time if total_time is None: subclass_name = self.__class__.__name__ diff --git a/dabench/data/sqgturb.py b/dabench/data/sqgturb.py index efae09b..0cc2738 100644 --- a/dabench/data/sqgturb.py +++ b/dabench/data/sqgturb.py @@ -45,74 +45,66 @@ # Set to enable 64bit floats in Jax jax.config.update('jax_enable_x64', True) +# For typing +ArrayLike = np.ndarray | jax.Array + class SQGTurb(_data.Data): """Class to set up SQGTurb model and manage data. Attributes: - pv (ndarray): Potential vorticity array. If None (default), + pv: Potential vorticity array. If None (default), loads data from 57600 step spinup with initial conditions taken from Jeff Whitaker's original implementation: https://github.com/jswhit/sqgturb. 57600 steps matches the "nature run" spin up in that repository. - system_dim (int): The dimension of the system state - time_dim (int): The dimension of the timeseries (not used) - delta_t (float): model time step (seconds) - x0 (ndarray, float): Initial state, array of floats of size + system_dim: The dimension of the system state + time_dim: The dimension of the timeseries (not used) + delta_t: model time step (seconds) + x0: Initial state, array of floats of size (system_dim). - f (float): coriolis - nqr (float): Brunt-Vaisalla (buoyancy) freq squared - L (float): size of square domain - H (float): height of upper boundary - U (float): basic state velocity at z = H - r (float): Ekman damping (at z=0) - tdiab (float): thermal relaxation damping - diff_order (int): hyperdiffusion order - diff_efold (float): hyperdiff time scale - symmetric (bool): symmetric jet, or jet with U=0 at sf - dealias (bool): if True, dealiasing applied using 2/3 rule - precision (char): 'single' or 'double'. Default is 'single' - tstart (float): initialize time counter - store_as_jax (bool): Store values as jax array instead of numpy array. + f: coriolis + nqr: Brunt-Vaisalla (buoyancy) freq squared + L: size of square domain + H: height of upper boundary + U: basic state velocity at z = H + r: Ekman damping (at z=0) + tdiab: thermal relaxation damping + diff_order: hyperdiffusion order + diff_efold: hyperdiff time scale + symmetric: symmetric jet, or jet with U=0 at sf + dealias: if True, dealiasing applied using 2/3 rule + precision: 'single' or 'double'. Default is 'single' + tstart: initialize time counter + delta_t: the timestep of the data (assumed uniform) + store_as_jax: Store values as jax array instead of numpy array. Default is False (store as numpy). - is_spectral (bool): Attribute to track which generators store values - in spectral space. Is automatically set to True for SQGTurb. """ def __init__(self, - pv=None, - f=1.0e-4, - nsq=1.0e-4, - L=20.0e6, - H=10.0e3, - U=30.0, - r=0.0, - tdiab=10.0 * 86400, - diff_order=8, - diff_efold=86400./3, - symmetric=True, - dealias=True, - precision='single', - tstart=0, - system_dim=None, - input_dim=None, - output_dim=None, - time_dim=None, - values=None, - times=None, - delta_t=900, - store_as_jax=False, + pv: ArrayLike | None = None, + f: float = 1.0e-4, + nsq: float = 1.0e-4, + L: float = 20.0e6, + H: float = 10.0e3, + U: float = 30.0, + r: float = 0.0, + tdiab: float = 10.0 * 86400, + diff_order: int = 8, + diff_efold: float = 86400./3, + symmetric: bool = True, + dealias: bool = True, + precision: str = 'single', + tstart: float = 0, + delta_t: float = 900, + store_as_jax: bool = False, **kwargs, ): # Attribute to track which generators store spectral values by default self.is_spectral = True - super().__init__(system_dim=system_dim, input_dim=input_dim, - output_dim=output_dim, time_dim=time_dim, - values=values, times=times, delta_t=delta_t, - store_as_jax=store_as_jax, **kwargs) - + super().__init__(delta_t=delta_t, store_as_jax=store_as_jax, **kwargs) self.coord_names = ['level','x','y'] self.var_names=['pv'] @@ -300,7 +292,7 @@ def _invert_inverse(self, psispec): return pvspec @partial(jax.jit, static_argnums=(0,)) - def _specpad(self, specarr): + def _specpad(self, specarr: ArrayLike) -> jax.Array: """Pads spectral arrays with zeros to interpolate to 3/2 larger grid using inverse fft.""" # Take care of normalization factor for inverse transform. @@ -319,7 +311,7 @@ def _specpad(self, specarr): return specarr_pad @partial(jax.jit, static_argnums=(0,)) - def _spectrunc(self, specarr): + def _spectrunc(self, specarr: ArrayLike) -> jax.Array: """Truncates spectral array to 2/3 size""" specarr_trunc = jnp.zeros((2, self.N, self.N // 2 + 1), dtype=specarr.dtype) @@ -332,21 +324,26 @@ def _spectrunc(self, specarr): return specarr_trunc @partial(jax.jit, static_argnums=(0,)) - def _xyderiv(self, specarr): + def _xyderiv(self, specarr: ArrayLike) -> tuple[jax.Array, jax.Array]: """Calculates x and y derivatives""" xderiv = self.ifft2(self.ik * specarr) yderiv = self.ifft2(self.il * specarr) return xderiv, yderiv @partial(jax.jit, static_argnums=(0,)) - def _xyderiv_dealias(self, specarr): + def _xyderiv_dealias(self, + specarr: ArrayLike + ) -> tuple[jax.Array, jax.Array]: """Calculates x and y derivatives""" specarr_pad = self._specpad(specarr) xderiv = self.ifft2(self.ik_pad * specarr_pad) yderiv = self.ifft2(self.il_pad * specarr_pad) return xderiv, yderiv - def _rk4(self, x, all_x): + def _rk4(self, + x: ArrayLike, + all_x: ArrayLike | None + ) -> tuple[ArrayLike, ArrayLike]: """Updates pv using 4th order runge-kutta time step with implicit "integrating factor" treatment of hyperdiffusion. @@ -363,7 +360,13 @@ def _rk4(self, x, all_x): return self.hyperdiff*y, self.hyperdiff*y @partial(jax.jit, static_argnums=(0,)) - def _symmetric_pvbar(self, mu, U, l, H, y): + def _symmetric_pvbar(self, + mu: jax.Array, + U: jax.Array, + l: jax.Array, + H: jax.Array, + y: jax.Array + ) -> jax.Array: """Computes symmetric pvbar""" pvbar = ( -(mu * 0.5 * U / (l * H)) @@ -374,7 +377,13 @@ def _symmetric_pvbar(self, mu, U, l, H, y): return pvbar @partial(jax.jit, static_argnums=(0,)) - def _asymmetric_pvbar(self, mu, U, l, H, y): + def _asymmetric_pvbar(self, + mu: jax.Array, + U: jax.Array, + l: jax.Array, + H: jax.Array, + y: jax.Array + ) -> jax.Array: """Computes asymmetric pvbar""" pvbar = (-(mu * U / (l * H)) * jnp.cos(l * y) / jnp.sinh(mu)) @@ -383,52 +392,75 @@ def _asymmetric_pvbar(self, mu, U, l, H, y): # Public support methods @partial(jax.jit, static_argnums=(0,)) - def fft2(self, pv): + def fft2(self, + pv: jax.Array + ) -> jax.Array: """Alias method for FFT of PV""" return rfft2(pv) @partial(jax.jit, static_argnums=(0,)) - def ifft2(self, pvspec): + def ifft2(self, + pvspec: jax.Array + ) -> jax.Array: """Alias method for inverse FFT of PV Spectral""" return irfft2(pvspec) @partial(jax.jit, static_argnums=(0,)) - def map2dto1d(self, pv): + def map2dto1d(self, + pv: jax.Array + ) -> jax.Array: """Maps 2D PV to 1D system state""" return pv.ravel() @partial(jax.jit, static_argnums=(0,)) - def map1dto2d(self, x): + def map1dto2d(self, + x: jax.Array + ) -> jax.Array: """Maps 1D state vector to 2D PV""" return jnp.reshape(x, (self.Nv, self.Nx, self.Ny)) @partial(jax.jit, static_argnums=(0,)) - def fft2_2dto1d(self, pv): + def fft2_2dto1d(self, + pv: jax.Array + ) -> jax.Array: """Runs FFT then maps from 2D to 1D""" pvspec = self.fft2(pv) return self.map2dto1d(pvspec) @partial(jax.jit, static_argnums=(0,)) - def ifft2_2dto1d(self, pvspec): + def ifft2_2dto1d(self, + pvspec: jax.Array + ) -> jax.Array: """Runs inverse FFT then maps from 2D to 1D""" pv = self.ifft2(pvspec) return self.map2dto1d(pv) @partial(jax.jit, static_argnums=(0,)) - def map1dto2d_fft2(self, x): + def map1dto2d_fft2(self, + x: jax.Array + ) -> jax.Array: """Maps for 1D to 2D then runs FFT""" pv = self.map1dto2d(x) return self.fft2(pv) @partial(jax.jit, static_argnums=(0,)) - def map1dto2d_ifft2(self, x): + def map1dto2d_ifft2(self, + x: jax.Array + ) -> jax.Array: """Maps for 1D to 2D then runs inverse FFT""" pvspec = self.map1dto2d(x) return self.ifft2(pvspec) # Integration methods - def integrate(self, f, x0, t_final, delta_t=None, include_x0=True, - t=None, **kwargs): + def integrate(self, + f: None, + x0: ArrayLike, + t_final: float, + delta_t: float | None = None, + include_x0: bool = True, + t: float | None = None, + **kwargs + ) -> tuple[jax.Array, jax.Array]: """Advances pv forward number of timesteps given by self.n_steps. Note: @@ -483,7 +515,9 @@ def integrate(self, f, x0, t_final, delta_t=None, include_x0=True, return values, times - def rhs(self, pvspec): + def rhs(self, + pvspec: ArrayLike + ) -> jax.Array: """Computes pv deriv on z=0, inverts pv to get streamfunction.""" psispec = self._invert(pvspec) @@ -516,7 +550,7 @@ def rhs(self, pvspec): self.v = psix return dpvspecdt - def _to_original_dim(self): + def _to_original_dim(self) -> np.ndarray: """Going back to 2D is a bit trickier for sqgturb""" gridded_vals = np.zeros((self.time_dim, self.Nv, self.Nx, self.Nx)) diff --git a/dabench/metrics/_deterministic.py b/dabench/metrics/_deterministic.py index 375761b..4667bd1 100644 --- a/dabench/metrics/_deterministic.py +++ b/dabench/metrics/_deterministic.py @@ -1,8 +1,14 @@ """Deterministic metrics""" import jax.numpy as jnp +import numpy as np +import jax + from dabench.metrics import _utils +# For typing +ArrayLike = np.ndarray | jax.Array + __all__ = [ 'pearson_r', @@ -12,7 +18,10 @@ ] -def pearson_r(predictions, targets): +def pearson_r( + predictions: ArrayLike, + targets: ArrayLike + ) -> jax.Array: """JAX array implementation of Pearson R Args: @@ -31,7 +40,10 @@ def pearson_r(predictions, targets): return top/bottom -def mse(predictions, targets): +def mse( + predictions: ArrayLike, + targets: ArrayLike + ) -> jax.Array: """JAX array implementation of Mean Squared Error Args: @@ -44,7 +56,10 @@ def mse(predictions, targets): """ return jnp.mean(jnp.square(predictions - targets)) -def rmse(predictions, targets): +def rmse( + predictions: ArrayLike, + targets: ArrayLike + ) -> jax.Array: """JAX array implementation of Root Mean Squared Error Args: @@ -57,7 +72,10 @@ def rmse(predictions, targets): """ return jnp.sqrt(mse(predictions, targets)) -def mae(predictions, targets): +def mae( + predictions: ArrayLike, + targets: ArrayLike + ) -> jax.Array: """JAX array implementation of Mean Absolute Error Args: diff --git a/dabench/metrics/_utils.py b/dabench/metrics/_utils.py index 50c25c4..f51e087 100644 --- a/dabench/metrics/_utils.py +++ b/dabench/metrics/_utils.py @@ -1,9 +1,17 @@ """Helper functions for metrics""" import jax.numpy as jnp +import numpy as np +import jax +# For typing +ArrayLike = np.ndarray | jax.Array -def _cov(a, b): + +def _cov( + a: ArrayLike, + b: ArrayLike + ) -> jax.Array: """Covariance""" a_mean = jnp.mean(a) b_mean = jnp.mean(b) diff --git a/dabench/model/_model.py b/dabench/model/_model.py index afa0e21..c319ce2 100644 --- a/dabench/model/_model.py +++ b/dabench/model/_model.py @@ -4,6 +4,8 @@ inherits from dabench.model.Model, with an forecast() method. """ +from typing import Any +import xarray as xr class Model(): """Base class for Model object @@ -15,10 +17,11 @@ class Model(): model_obj (obj): underlying model object, e.g. pytorch neural network. """ def __init__(self, - system_dim=None, - time_dim=None, - delta_t=None, - model_obj=None): + system_dim: int | None = None, + time_dim: int | None = None, + delta_t: int | None = None, + model_obj: int | None = None + ): self.system_dim = system_dim self.time_dim = time_dim @@ -31,7 +34,10 @@ def __init__(self, raise ValueError('Model object does not have a defined forecast() ' 'method.') - def _default_forecast(self, state_vec, timesteps=1, other_inputs=None): + def _default_forecast(self, + state_vec: xr.Dataset, + timesteps: int = 1, + other_inputs: Any = None): """Default method for forecasting""" new_state_vec = state_vec for i in range(timesteps): diff --git a/dabench/observer/_observer.py b/dabench/observer/_observer.py index 29aeb28..cd52096 100644 --- a/dabench/observer/_observer.py +++ b/dabench/observer/_observer.py @@ -6,43 +6,49 @@ import warnings import numpy as np +import jax import jax.numpy as jnp import xarray as xr +from dabench.data import Data + +# For typing +ArrayLike = np.ndarray | jax.Array + class Observer(): """Base class for Observer objects Attributes: - data_obj (dabench.data.Data): Data generator/loader object from which + data_obj: Data generator/loader object from which to gather observations. - random_location_density (float or tuple): Fraction of locations in + random_location_density: Fraction of locations in system_dim to randomly select for observing, must be value between 0 and 1. Default is 1. - random_time_density (float): Fraction of times to randomly select + random_time_density: Fraction of times to randomly select for observing must be value between 0 and 1. Default is 1. - random_location_count (int): Number of locations in data_obj's + random_location_count: Number of locations in data_obj's system_dim to randomly select for observing. Default is None. User should specify one of: random_location_count, random_location_density, or location_indices. If random_location_count is specified, it takes precedent over random_location_density. - random_time_count (int): Number of times to randomly select for + random_time_count: Number of times to randomly select for observing. Default is None. User should specify one of: random_time_count, random_time_density, or time_indices. If random_time_count is specified, it takes precedent over random_time_density. - location_indices (ndarray): Manually specified indices for observing. + locations: Manually specified indices for observing. If 1D array provided, assumed to be for flattened system_dim. If >1D, must have same dimensionality as data generator's original_dim (e.g. (x, y, z)). If stationary_observers=False, expects leading time dimension. If not specified, will be randomly generated according to random_location_density OR random_location_count. Default is None. - time_indices (ndarray): Indices of times to gather observations from. + times: Indices of times to gather observations from. If not specified, randomly generate according to random_time_density OR random_time_count. Default is None. - stationary_observers (bool): If True, samples from same indices at + stationary_observers: If True, samples from same indices at each time step. If False, randomly generates/expects new observation indices at each timestep. Default is True. If False: @@ -53,37 +59,37 @@ class Observer(): of locations at each times step.. If using location_indices, expects indices to either be 2D (time_dim, system_dim) or >2D (time_dim, original_dim). - error_bias (float or array): Mean of normal distribution of + error_bias: Mean of normal distribution of observation errors. If provided as an array, it is taken to be variable-specific and the length must be equal to data_obj.system_dim. Default is 0. - error_sd (float or array): Standard deviation of observation errors. + error_sd: Standard deviation of observation errors. observation errors. If provided as an array, it is taken to be variable-specific and the length be equal to data_obj.system_dim. Default is 0. - error_positive_only (bool): Clip errors to be positive only. Default is + error_positive_only: Clip errors to be positive only. Default is False. - random_seed (int): Random seed for sampling times and locations. + random_seed: Random seed for sampling times and locations. Default is 99. - store_as_jax (bool): Store values as jax array instead of numpy array. + store_as_jax: Store values as jax array instead of numpy array. Default is False (store as numpy). """ def __init__(self, - state_vec, - random_time_density=1., - random_location_density=1., - random_time_count=None, - random_location_count=None, - times=None, - locations=None, - stationary_observers=True, - error_bias=0., - error_sd=0., - error_positive_only=False, - random_seed=99, - store_as_jax=False, + state_vec: xr.Dataset, + random_time_density: float = 1., + random_location_density: float | tuple[float, ...] = 1., + random_time_count: int | None = None, + random_location_count: int | tuple[int, ...] | None = None, + times: ArrayLike | None = None, + locations: ArrayLike | None = None, + stationary_observers: bool = True, + error_bias: ArrayLike | float = 0., + error_sd: ArrayLike | float = 0., + error_positive_only: bool = False, + random_seed: int = 99, + store_as_jax: bool = False, ): self.state_vec = state_vec @@ -173,7 +179,10 @@ def __init__(self, self.error_positive_only = error_positive_only - def _generate_times(self, rng): + def _generate_times( + self, + rng: np.random.Generator + ): if self.random_time_count is not None: self.times = np.sort(rng.choice( self.state_vec['time'], @@ -187,7 +196,10 @@ def _generate_times(self, rng): ).astype('bool') )[0]] - def _generate_stationary_locs(self, rng): + def _generate_stationary_locs( + self, + rng: np.random.Generator + ): if self.random_location_count is not None: location_count = self.random_location_count else: @@ -211,7 +223,10 @@ def _generate_stationary_locs(self, rng): } self.location_dim = location_count - def _generate_nonstationary_locs(self, rng): + def _generate_nonstationary_locs( + self, + rng: np.random.Generator + ): """Generate different locations for each observation time""" if self.random_location_count is not None: self._location_counts = np.repeat( @@ -245,7 +260,7 @@ def _generate_nonstationary_locs(self, rng): self.location_dim = np.max(self._location_counts) - def observe(self): + def observe(self) -> xr.Dataset: """Generate observations. Returns: diff --git a/pyproject.toml b/pyproject.toml index 6714bc2..15d694f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,8 @@ pyqg = [ ] qgs = [ - "qgs" + "sparse == 0.15.4", + "qgs == 0.2.8" ] full = [ diff --git a/tests/dacycler_base_test.py b/tests/dacycler_base_test.py index f55a4bb..ba33419 100644 --- a/tests/dacycler_base_test.py +++ b/tests/dacycler_base_test.py @@ -9,7 +9,8 @@ def test_dacycler_init(): params = {'system_dim': 6, 'delta_t': 0.5, - 'ensemble': True} + 'ensemble': True, + 'model_obj':dab.model.RCModel(6, 10)} test_dac = dab.dacycler.DACycler(**params)