From bdc477b8f4ddcc8c3da60a26c342c63db154d921 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Mon, 5 May 2025 10:18:50 -0600 Subject: [PATCH 01/18] sqgturb tracer error fix: delta_t shouldn't be traced, messes up eventual xarray dataset building --- dabench/data/_data.py | 27 +++++++++++++++----------- dabench/data/_sqgturb.py | 42 ++++++++++------------------------------ 2 files changed, 26 insertions(+), 43 deletions(-) diff --git a/dabench/data/_data.py b/dabench/data/_data.py index fb92159..d92cea1 100644 --- a/dabench/data/_data.py +++ b/dabench/data/_data.py @@ -5,9 +5,12 @@ import jax.numpy as jnp import jax import xarray as xr +import xarray_jax as xj import warnings from importlib import resources +from functools import partial + from dabench.data._utils import integrate from dabench import _suppl_data @@ -159,7 +162,8 @@ def generate(self, # Integrate and store values and times # If data object has its own integration method, use that if hasattr(self, 'integrate') and callable(getattr(self, 'integrate')): - y, t = self.integrate(f, x0, t_final, self.delta_t, stride=stride, + y, t = self.integrate(f, x0, n_steps=n_steps, t_final=t_final, + delta_t=self.delta_t, stride=stride, **kwargs) # Otherwise, use integrate from dabench.support.utils else: @@ -167,25 +171,26 @@ def generate(self, jax_comps=self.store_as_jax, **kwargs) - # Convert to JAX if necessary - time_dim = t.shape[0] - out_dim = (time_dim,) + self.original_dim - if self.store_as_jax: - y_out = jnp.array(y[:,:self.system_dim].reshape(out_dim)) - else: - y_out = np.array(y[:,:self.system_dim].reshape(out_dim)) # Build Xarray object for output coord_dict = dict(zip( ['time'] + self.coord_names, [t] + [np.arange(dim) for dim in self.original_dim] )) + time_dim = t.shape[0] + out_dim = (time_dim,) + self.original_dim + + # Convert to JAX if necessary + if self.store_as_jax or isinstance(y, jax.core.Tracer): + y_out = jnp.array(y[:, :self.system_dim].reshape(out_dim)) + else: + y_out = np.array(y[:, :self.system_dim].reshape(out_dim)) out_vec = xr.Dataset( - {self.var_names[0]: (coord_dict.keys(),y_out)}, + {self.var_names[0]: (coord_dict.keys(), y_out)}, coords=coord_dict, - attrs={'store_as_jax':self.store_as_jax, + attrs={'store_as_jax': self.store_as_jax, 'system_dim': self.system_dim, 'delta_t': self.delta_t - } + } ) # Return the data series and associated TLMs if requested diff --git a/dabench/data/_sqgturb.py b/dabench/data/_sqgturb.py index f0ca83e..5e2601d 100644 --- a/dabench/data/_sqgturb.py +++ b/dabench/data/_sqgturb.py @@ -158,7 +158,7 @@ def __init__(self, self.H = jnp.array(H, dtype) self.U = jnp.array(U, dtype) self.L = jnp.array(L, dtype) - self.delta_t = jnp.array(delta_t, dtype) + self.delta_t = delta_t self.dealias = dealias if r < 1.0e-10: self.ekman = False @@ -416,7 +416,8 @@ def map1dto2d(self, x: jax.Array ) -> jax.Array: """Maps 1D state vector to 2D PV""" - return jnp.reshape(x, (self.Nv, self.Nx, self.Ny)) + # return jnp.reshape(x, (self.Nv, self.Nx, self.Ny)) + return np.reshape(x, (self.Nv, self.Nx, self.Ny)) @partial(jax.jit, static_argnums=(0,)) def fft2_2dto1d(self, @@ -452,12 +453,10 @@ def map1dto2d_ifft2(self, # Integration methods def integrate(self, - f: None, + function: None, x0: ArrayLike, - t_final: float, + t_final: float | None = None, delta_t: float | None = None, - include_x0: bool = True, - t: float | None = None, **kwargs ) -> tuple[jax.Array, jax.Array]: """Advances pv forward number of timesteps given by self.n_steps. @@ -466,43 +465,22 @@ def integrate(self, If pv not specified, use pvspec instance variable. Args: - f (function): right hand side (rhs) of the ODE. Not used, but + function (function): right hand side (rhs) of the ODE. Not used, but needed to function with generate() from _data.Data(). x0 (ndarray): potential vorticity (pvspec) initial condition in spectral space """ + times = np.arange(0.0, t_final - delta_t/2, delta_t) # Convert input state vector to a 2D spectral array pvspec = self.map1dto2d(x0) - # Get number of time steps - n_steps = int(t_final/self.delta_t) - - # Checks - # Make sure that there is no remainder - if not n_steps * delta_t == t_final: - raise ValueError('Cannot have remainder in nsteps = {}, ' - 'delta_t = {}, t_final = {}, and n_steps * ' - 'delta_t = {}'.format(n_steps, delta_t, t_final, - n_steps*delta_t)) - # If delta_t not specified as arg for method, use delta_t from object if delta_t is None: delta_t = self.delta_t - # If t not specified as arg for method, use t from object - if t is None: - t = self.t - - # If including initial state, add 1 to n_steps - if include_x0: - n_steps = n_steps + 1 - - times = t + jnp.arange(n_steps)*delta_t - - # Integrate in spectral spacestep_n - pvspec, values = jax.lax.scan(self._rk4, pvspec, xs=None, - length=n_steps) + # Run integration + pvspec, values = jax.lax.scan(self._rk4, pvspec, xs=times) # Apply reverse fft to values = self.ifft2(values) @@ -546,4 +524,4 @@ def rhs(self, # save wind field self.u = -psiy self.v = psix - return dpvspecdt \ No newline at end of file + return dpvspecdt From d5a43d02f700f41c0e43bf581841dfb1977e5888 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Mon, 12 May 2025 15:21:18 -0600 Subject: [PATCH 02/18] Remove static B covariance arg from etkf --- dabench/dacycler/_etkf.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/dabench/dacycler/_etkf.py b/dabench/dacycler/_etkf.py index 7ad8aa9..559c100 100644 --- a/dabench/dacycler/_etkf.py +++ b/dabench/dacycler/_etkf.py @@ -23,9 +23,6 @@ class ETKF(dacycler.DACycler): system_dim: System dimension. delta_t: The timestep of the model (assumed uniform) model_obj: Forecast model object. - B: Initial / static background error covariance. Shape: - (system_dim, system_dim). If not provided, will be calculated - automatically. R: Observation error covariance matrix. Shape (obs_dim, obs_dim). If not provided, will be calculated automatically. @@ -45,7 +42,6 @@ def __init__(self, system_dim: int, delta_t: float, model_obj: Model, - B: ArrayLike | None = None, R: ArrayLike | None = None, H: ArrayLike | None = None, h: Callable | None = None, @@ -59,7 +55,7 @@ def __init__(self, super().__init__(system_dim=system_dim, delta_t=delta_t, model_obj=model_obj, - B=B, R=R, H=H, h=h) + R=R, H=H, h=h) def _step_forecast(self, Xa: XarrayDatasetLike, @@ -170,8 +166,7 @@ def _cycle_obsop(self, obs_loc_mask: ArrayLike, H: ArrayLike | None = None, h: Callable | None = None, - R: ArrayLike | None = None, - B: ArrayLike | None = None + R: ArrayLike | None = None ) -> XarrayDatasetLike: if H is None and h is None: if self.H is None: @@ -186,11 +181,6 @@ def _cycle_obsop(self, R = self._calc_default_R(obs_values, self.obs_error_sd) else: R = self.R - if B is None: - if self.B is None: - B = self._calc_default_B() - else: - B = self.B Xb = Xb_ds.to_stacked_array('system',['ensemble']).data.T n_sys, n_ens = Xb.shape From b7bfc2ac7db2fb5cdb35b0c655bb98832d8ec78b Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Mon, 12 May 2025 15:28:11 -0600 Subject: [PATCH 03/18] Move var3d analysis computation to separate function --- dabench/dacycler/_var3d.py | 38 +++++++++++++++++++++++++++----------- 1 file changed, 27 insertions(+), 11 deletions(-) diff --git a/dabench/dacycler/_var3d.py b/dabench/dacycler/_var3d.py index c5ff024..8b204df 100644 --- a/dabench/dacycler/_var3d.py +++ b/dabench/dacycler/_var3d.py @@ -51,6 +51,27 @@ def __init__(self, model_obj=model_obj, B=B, R=R, H=H, h=h) + + def _compute_analysis(self, + xb, + y, + B, + H, + Rinv + ): + # 'preconditioning with B' + xdim = xb.size + I = jnp.identity(xdim) + BHt = jnp.dot(B, H.T) + BHtRinv = jnp.dot(BHt, Rinv) + A = I + jnp.dot(BHtRinv, H) + b1 = xb + jnp.dot(BHtRinv, y) + + # Use minimization algorithm to minimize cost function: + xa, ierr = jscipy.sparse.linalg.cg(A, b1, x0=xb, tol=1e-05, + maxiter=1000) + return xa + def _cycle_obsop(self, xb_ds: XarrayDatasetLike, obs_values: ArrayLike, @@ -89,18 +110,13 @@ def _cycle_obsop(self, H = jnp.where(obs_loc_mask.flatten(), H.T, 0).T # Set parameters - xdim = xb.size # Size or get one of the shape params? Rinv = jnp.linalg.inv(R) - # 'preconditioning with B' - I = jnp.identity(xdim) - BHt = jnp.dot(B, H.T) - BHtRinv = jnp.dot(BHt, Rinv) - A = I + jnp.dot(BHtRinv, H) - b1 = xb + jnp.dot(BHtRinv, y) - - # Use minimization algorithm to minimize cost function: - xa, ierr = jscipy.sparse.linalg.cg(A, b1, x0=xb, tol=1e-05, - maxiter=1000) + xa, ierr = self._compute_analysis( + xb, + B, + H, + Rinv, + ) return xb_ds.assign(x=(xb_ds.dims, xa.T)) From 96526aaab0cc7dc599987be8faf9099d6742ba94 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Mon, 12 May 2025 15:28:37 -0600 Subject: [PATCH 04/18] New hybrid gain DA method --- dabench/dacycler/__init__.py | 2 + dabench/dacycler/_hybrid_gain.py | 156 +++++++++++++++++++++++++++++++ 2 files changed, 158 insertions(+) create mode 100644 dabench/dacycler/_hybrid_gain.py diff --git a/dabench/dacycler/__init__.py b/dabench/dacycler/__init__.py index 2478491..bbd8e28 100644 --- a/dabench/dacycler/__init__.py +++ b/dabench/dacycler/__init__.py @@ -5,6 +5,7 @@ from ._etkf import ETKF from ._var4d_backprop import Var4DBackprop from ._var4d import Var4D +from ._hybrid_gain import HybridGain __all__ = [ 'DACycler', @@ -12,4 +13,5 @@ 'ETKF', 'Var4DBackprop', 'Var4D', + 'HybridGain' ] diff --git a/dabench/dacycler/_hybrid_gain.py b/dabench/dacycler/_hybrid_gain.py new file mode 100644 index 0000000..831cca9 --- /dev/null +++ b/dabench/dacycler/_hybrid_gain.py @@ -0,0 +1,156 @@ +"""Class for Hybrid Gain (ETKF + 3DVar) Data Assimilation""" + +import numpy as np +import jax +import jax.numpy as jnp +from jax.scipy import linalg +import xarray as xr +import xarray_jax as xj +from typing import Callable + +from dabench import dacycler +from dabench.model import Model + + +# For typing +ArrayLike = np.ndarray | jax.Array +XarrayDatasetLike = xr.Dataset | xj.XjDataset + +class HybridGain(dacycler.DACycler): + """HybridGain DA, combining ETKF with 3DVar + + Args: + system_dim: System dimension. + delta_t: The timestep of the model (assumed uniform) + model_obj: Forecast model object. + B: Initial / static background error covariance. Shape: + (system_dim, system_dim). If not provided, will be calculated + automatically. + R: Observation error covariance matrix. Shape + (obs_dim, obs_dim). If not provided, will be calculated + automatically. + H: Observation operator with shape: (obs_dim, system_dim). + If not provided will be calculated automatically. + h: Optional observation operator as function. More flexible + (allows for more complex observation operator). Default is None. + alpha: Weight for 3DVar DA analysis. If 0.0, runs pure ETKF. If 1.0, + runs pure 3DVar. Default is 0.2. + ensemble_dim: Number of ensemble instances for ETKF. Default is + 4. Higher ensemble_dim increases accuracy but has performance cost. + multiplicative_inflation: Scaling factor by which to multiply ensemble + deviation. Default is 1.0 (no inflation). + """ + _in_4d: bool = False + _uses_ensemble: bool = True + + def __init__(self, + system_dim: int, + delta_t: float, + model_obj: Model, + B: ArrayLike | None = None, + R: ArrayLike | None = None, + H: ArrayLike | None = None, + h: Callable | None = None, + alpha: float = 0.2, + ensemble_dim: int = 4, + multiplicative_inflation: float = 1.0 + ): + + self.ensemble_dim = ensemble_dim + self.multiplicative_inflation = multiplicative_inflation + self.alpha = alpha + + # Create ETKF DA Cycler + self._etkf_da = dacycler.ETKF( + system_dim=system_dim, + delta_t=delta_t, + model_obj=model_obj, + R=R, + H=H, + h=h, + ensemble_dim=ensemble_dim, + multiplicative_inflation=multiplicative_inflation + ) + # Create 3D-Var DA Cycler + self._var3d_da = dacycler.Var3D( + system_dim=system_dim, + delta_t=delta_t, + model_obj=model_obj, + R=R, + H=H, + h=h, + B=B + ) + + super().__init__(system_dim=system_dim, + delta_t=delta_t, + model_obj=model_obj, + B=B, R=R, H=H, h=h) + + def _step_forecast(self, + Xa: XarrayDatasetLike, + n_steps: int = 1 + ) -> XarrayDatasetLike: + """Ensemble method needs a slightly different _step_forecast method""" + return self._etkf_da._step_forecast(Xa, n_steps) + + def _cycle_obsop(self, + Xb_ds: XarrayDatasetLike, + obs_values: ArrayLike, + obs_loc_indices: ArrayLike, + obs_time_mask: ArrayLike, + obs_loc_mask: ArrayLike, + H: ArrayLike | None = None, + h: Callable | None = None, + R: ArrayLike | None = None, + B: ArrayLike | None = None + ) -> XarrayDatasetLike: + if H is None and h is None: + if self.H is None: + if self.h is None: + H = self._calc_default_H(obs_values, obs_loc_indices) + else: + h = self.h + else: + H = self.H + if R is None: + if self.R is None: + R = self._calc_default_R(obs_values, self.obs_error_sd) + else: + R = self.R + if B is None: + if self.B is None: + B = self._calc_default_B() + else: + B = self.B + + Xb = Xb_ds.to_stacked_array('system',['ensemble']).data.T + n_sys, n_ens = Xb.shape + assert n_ens == self.ensemble_dim, ( + 'cycle:: model_forecast must have dimension {}x{}').format( + self.ensemble_dim, self.system_dim) + + # Apply obs masks to H + H = jnp.where(obs_time_mask.flatten(), H.T, 0).T + H = jnp.where(obs_loc_mask.flatten(), H.T, 0).T + + # Compute ETKF analysis + Xa_etkf = self._etkf_da._compute_analysis(Xb=Xb, + Y=obs_values, + H=H, + h=h, + R=R, + rho=self.multiplicative_inflation) + + # Compute Var3D Analysis + xa_var3d = self._var3d_da._compute_analysis(xb=jnp.mean(Xb, axis=1).flatten(), + y=obs_values.flatten(), + H=H, + B=B, + Rinv=jnp.linalg.inv(R)) + + xa_etkf_mean = jnp.mean(Xa_etkf, axis=1) + xa_final = self.alpha*xa_var3d + (1-self.alpha)*xa_etkf_mean + Xa_final = Xa_etkf.T - (xa_etkf_mean - xa_final) + + return Xb_ds.assign(x=(['ensemble','i'], Xa_final)) From e60ec440acb36964378981b8ecf2b89e7e90f305 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Wed, 14 May 2025 12:10:08 -0600 Subject: [PATCH 05/18] Enable nearest obs selection and fix multi-dimensional selection (without replacement) --- dabench/observer/_observer.py | 46 ++++++++++++++++++++++++----------- 1 file changed, 32 insertions(+), 14 deletions(-) diff --git a/dabench/observer/_observer.py b/dabench/observer/_observer.py index decb844..eda6be5 100644 --- a/dabench/observer/_observer.py +++ b/dabench/observer/_observer.py @@ -73,6 +73,9 @@ class Observer(): Default is 99. store_as_jax: Store values as jax array instead of numpy array. Default is False (store as numpy). + sel_method: Xarray selection indexing method (e.g. 'nearest', 'pad'). + See https://docs.xarray.dev/en/latest/generated/xarray.Dataset.sel.html. + Default is 'nearest', which selects nearest neighbor. Attributes: locations (ArrayLike): Location indices for making @@ -100,6 +103,7 @@ def __init__(self, error_positive_only: bool = False, random_seed: int = 99, store_as_jax: bool = False, + sel_method: str = 'nearest' ): self.state_vec = state_vec @@ -136,6 +140,7 @@ def __init__(self, self.random_location_density = random_location_density self.random_location_count = random_location_count self.stationary_observers = stationary_observers + self.sel_method = sel_method self.random_seed = random_seed if (store_as_jax and self.random_location_density != 1. and @@ -217,19 +222,26 @@ def _generate_stationary_locs( rng.binomial(1, p=self.random_location_density, size=self.state_vec.system_dim)) - if len(self._nontime_coord_names) > 1: - sample_w_replace=True - else: - sample_w_replace=False + # Sample from flattened dimension + sizes = tuple( + self.state_vec.sizes[cn] for cn in self._nontime_coord_names + ) + flat_locs = rng.choice( + np.prod(sizes), + size=location_count, + replace=False, + shuffle=False + ) + full_locs = np.unravel_index( + flat_locs, + sizes + ) + loc_dict = dict(zip(self._nontime_coord_names, full_locs)) self.locations = { - coord_name: xr.DataArray( - rng.choice( - self.state_vec[coord_name], - size=location_count, - replace=sample_w_replace, - shuffle=False), + coord: xr.DataArray( + locs, dims=['observations']) - for coord_name in self._nontime_coord_names + for coord, locs in loc_dict.items() } self.location_dim = location_count @@ -297,13 +309,19 @@ def observe(self) -> xr.Dataset: # Sample - obs_vec = self.state_vec.sel(time=self.times).sel(self.locations) + obs_vec = self.state_vec.sel( + time=self.times, method=self.sel_method + ).sel( + self.locations, method=self.sel_method + ) # If NON-stationary observer else: # Generate location_indices if not specified if self.locations is None: self._generate_nonstationary_locs(rng) + else: + self.location_dim = next(iter(self.locations.items()))[1]['observations'].size # If there's an unequal number of obs, will pad pad_widths = self.location_dim - np.array(self._location_counts) @@ -312,10 +330,10 @@ def observe(self) -> xr.Dataset: obs_vec = xr.concat([ # Select by time self.state_vec.sel( - time=t + time=t, method=self.sel_method # Select locations ).sel( - self.locations[i] + self.locations[i], method=self.sel_method # Pad observations to max number ).pad( observations=(0, pad_widths[i]) From fbdd86787acadaab2fd64b23684ed8e11b0ed75f Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Wed, 14 May 2025 13:11:37 -0600 Subject: [PATCH 06/18] Fixing multi-dim sampling for nonstationary sampler --- dabench/observer/_observer.py | 65 +++++++++++++++++------------------ 1 file changed, 32 insertions(+), 33 deletions(-) diff --git a/dabench/observer/_observer.py b/dabench/observer/_observer.py index eda6be5..564280c 100644 --- a/dabench/observer/_observer.py +++ b/dabench/observer/_observer.py @@ -211,21 +211,13 @@ def _generate_times( ).astype('bool') )[0]] - def _generate_stationary_locs( + def _sample_multi_dim( self, + sizes: tuple[int], + location_count: int, rng: np.random.Generator ): - if self.random_location_count is not None: - location_count = self.random_location_count - else: - location_count = np.sum( - rng.binomial(1, - p=self.random_location_density, - size=self.state_vec.system_dim)) - # Sample from flattened dimension - sizes = tuple( - self.state_vec.sizes[cn] for cn in self._nontime_coord_names - ) + """Select locations randomly without replacement""" flat_locs = rng.choice( np.prod(sizes), size=location_count, @@ -237,12 +229,29 @@ def _generate_stationary_locs( sizes ) loc_dict = dict(zip(self._nontime_coord_names, full_locs)) - self.locations = { + loc_xr_dict = { coord: xr.DataArray( - locs, - dims=['observations']) - for coord, locs in loc_dict.items() - } + locs, + dims=['observations']) + for coord, locs in loc_dict.items()} + return loc_xr_dict + + def _generate_stationary_locs( + self, + rng: np.random.Generator + ): + if self.random_location_count is not None: + location_count = self.random_location_count + else: + location_count = np.sum( + rng.binomial(1, + p=self.random_location_density, + size=self.state_vec.system_dim)) + # Get sizes of state vector as tuple + sizes = tuple( + self.state_vec.sizes[cn] for cn in self._nontime_coord_names + ) + self.locations = self._sample_multi_dim(sizes, location_count, rng) self.location_dim = location_count def _generate_nonstationary_locs( @@ -263,23 +272,13 @@ def _generate_nonstationary_locs( ) for i in range(self.times.shape[0])] - if len(self._nontime_coord_names) > 1: - sample_w_replace=True - else: - sample_w_replace=False - - self.locations = [{ - coord_name: xr.DataArray( - rng.choice( - self.state_vec[coord_name], - size=lc, - replace=sample_w_replace, - shuffle=False), - dims=['observations']) - for coord_name in self._nontime_coord_names - } - for lc in self._location_counts] + # Get sizes of state vector as tuple + sizes = tuple( + self.state_vec.sizes[cn] for cn in self._nontime_coord_names + ) + self.locations = [self._sample_multi_dim(sizes, lc, rng) + for lc in self._location_counts] self.location_dim = np.max(self._location_counts) def observe(self) -> xr.Dataset: From a9c5d0a5b4095b6a951d715923657e9bc8c938f0 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Wed, 14 May 2025 14:27:19 -0600 Subject: [PATCH 07/18] Observer separate case for non-stationary but regular (same num per timestep) locations --- dabench/observer/_observer.py | 52 +++++++++++++++++++++++------------ 1 file changed, 34 insertions(+), 18 deletions(-) diff --git a/dabench/observer/_observer.py b/dabench/observer/_observer.py index 564280c..bf8dda5 100644 --- a/dabench/observer/_observer.py +++ b/dabench/observer/_observer.py @@ -306,7 +306,6 @@ def observe(self) -> xr.Dataset: else: self.location_dim = next(iter(self.locations.items()))[1]['observations'].size - # Sample obs_vec = self.state_vec.sel( time=self.times, method=self.sel_method @@ -321,24 +320,41 @@ def observe(self) -> xr.Dataset: self._generate_nonstationary_locs(rng) else: self.location_dim = next(iter(self.locations.items()))[1]['observations'].size - - # If there's an unequal number of obs, will pad - pad_widths = self.location_dim - np.array(self._location_counts) - - # Sample - obs_vec = xr.concat([ - # Select by time - self.state_vec.sel( - time=t, method=self.sel_method - # Select locations + self._location_counts = np.repeat(self.location_dim, self.times.shape[0]) + + # Special case: user-specified and same number of obs per timestep + # In this case, self.locations is a dict. + if isinstance(self.locations, dict): + # Sample + obs_vec = self.state_vec.sel( + time=self.times, method=self.sel_method ).sel( - self.locations[i], method=self.sel_method - # Pad observations to max number - ).pad( - observations=(0, pad_widths[i]) - ) - for i, t in enumerate(self.times)], - dim='time') + self.locations, method=self.sel_method + ) + + # Randomly generated observation locations + # self.locations is a list of dicts. + else: + # If there's an unequal number of obs, will pad + # NOTE: This may fail if user specifies nonstationary obs + # with varying number of obs per time step, since + # self._location_counts would never be set properly. + pad_widths = self.location_dim - np.array(self._location_counts) + + # Sample + obs_vec = xr.concat([ + # Select by time + self.state_vec.sel( + time=t, method=self.sel_method + # Select locations + ).sel( + self.locations[i], method=self.sel_method + # Pad observations to max number + ).pad( + observations=(0, pad_widths[i]) + ) + for i, t in enumerate(self.times)], + dim='time') # Transpose system_index to ensure consistency with flattened data obs_vec['system_index'] = obs_vec['system_index'].transpose('variable','time','observations').fillna( From d074eb4a664ff6edf81e11e9084e53c73219421a Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Wed, 21 May 2025 11:41:08 -0600 Subject: [PATCH 08/18] observer fix: use dims not coords att for finding nontime dimensions, and make system_dim consistently include all variables (not sure if this is the best way) --- dabench/observer/_observer.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/dabench/observer/_observer.py b/dabench/observer/_observer.py index bf8dda5..d5fcde6 100644 --- a/dabench/observer/_observer.py +++ b/dabench/observer/_observer.py @@ -107,7 +107,7 @@ def __init__(self, ): self.state_vec = state_vec - self._coord_names = list(self.state_vec.coords.keys()) + self._coord_names = list(self.state_vec.dims) self._nontime_coord_names = [coord for coord in self._coord_names if coord != 'time'] self.state_vec = self.state_vec.assign_coords( @@ -246,8 +246,11 @@ def _generate_stationary_locs( location_count = np.sum( rng.binomial(1, p=self.random_location_density, - size=self.state_vec.system_dim)) - # Get sizes of state vector as tuple + size=int(self.state_vec.system_dim + / self.state_vec.sizes['variable']) + ) + ) + # Get sizes of state vector as tuple sizes = tuple( self.state_vec.sizes[cn] for cn in self._nontime_coord_names ) @@ -268,11 +271,13 @@ def _generate_nonstationary_locs( self._location_counts = [np.sum( rng.binomial(1, p=self.random_location_density, - size=self.state_vec.system_dim) + size=int(self.state_vec.system_dim + / self.state_vec.sizes['variable']) ) - for i in range(self.times.shape[0])] + ) + for i in range(self.times.shape[0])] - # Get sizes of state vector as tuple + # Get sizes of state vector as tuple sizes = tuple( self.state_vec.sizes[cn] for cn in self._nontime_coord_names ) From a8fc31310266d60ec4652c8daf384e611dbf0e26 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Thu, 5 Jun 2025 12:04:01 -0400 Subject: [PATCH 09/18] Add _rebuild_dataset to properly reconscutrct multi-variable datasets, and allow obs_error_sd to be array --- dabench/dacycler/_dacycler.py | 32 +++++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/dabench/dacycler/_dacycler.py b/dabench/dacycler/_dacycler.py index feb3f2e..5955c57 100644 --- a/dabench/dacycler/_dacycler.py +++ b/dabench/dacycler/_dacycler.py @@ -75,6 +75,16 @@ def _calc_default_B(self) -> jax.Array: """If B is not provided, identity matrix with shape (system_dim, system_dim.""" return jnp.identity(self.system_dim) + def _rebuild_dataset(self, + xb: XarrayDatasetLike, + xa: ArrayLike, + ) -> XarrayDatasetLike: + xb_as_array = xb.to_array() + xa = xa.reshape(tuple(xb_as_array.sizes[s] for s in xb_as_array.sizes)) + xb_as_array.values = xa + xa_ds = xb_as_array.to_dataset(dim='variable') + return xa_ds + def _step_forecast(self, xa: XarrayDatasetLike, n_steps: int = 1 @@ -201,6 +211,10 @@ def cycle(self, obs_vector: Observations vector. n_cycles: Number of analysis cycles to run, each of length analysis_window. + obs_error_sd: Estimate observation error standard deviation, + used for calculating observation covariance matrix (R). + If float, all observations will have same estimated error. + If ArrayLike, must be of size system_dim. analysis_window: Time window from which to gather observations for DA Cycle. analysis_time_in_window: Where within analysis_window @@ -217,8 +231,21 @@ def cycle(self, self._observed_vars = obs_vector['variable'].values self._data_vars = list(input_state.data_vars) - if obs_error_sd is None: - obs_error_sd = obs_vector.error_sd + # NOTE: Consider removing this. It may cause problems if the obs_vector + # error_sd is provided as a array of size obs_dim. + # if obs_error_sd is None: + # obs_error_sd = obs_vector.error_sd + # Check if obs_error_sd is array + if jnp.isscalar(obs_error_sd): + self._scalar_obs_error = True + elif len(obs_error_sd) == self.system_dim: + obs_error_sd = jnp.array(obs_error_sd) + self._scalar_obs_error = False + else: + raise ValueError(( + 'obs_error_sd must be either scalar or array with length' + 'system_dim. Currently is: {}'.format(obs_error_sd) + )) self.analysis_window = analysis_window @@ -241,7 +268,6 @@ def cycle(self, start_time, analysis_window, n_cycles) - if self.steps_per_window is None: self.steps_per_window = round(analysis_window/self.delta_t) + 1 From 75674fca183143be7fb25449cfc67119ed0158a3 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Thu, 5 Jun 2025 12:05:55 -0400 Subject: [PATCH 10/18] 3D-Var and ETKF: fix H shape, obs_error_sd array support, and use rebuild dataset --- dabench/dacycler/_etkf.py | 9 +++++++-- dabench/dacycler/_var3d.py | 17 ++++++++++++----- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/dabench/dacycler/_etkf.py b/dabench/dacycler/_etkf.py index 559c100..94ac370 100644 --- a/dabench/dacycler/_etkf.py +++ b/dabench/dacycler/_etkf.py @@ -178,7 +178,12 @@ def _cycle_obsop(self, H = self.H if R is None: if self.R is None: - R = self._calc_default_R(obs_values, self.obs_error_sd) + if self._scalar_obs_error: + R = self._calc_default_R(obs_values, self.obs_error_sd) + else: + R = self._calc_default_R( + obs_values, + self.obs_error_sd[obs_loc_indices.flatten()]) else: R = self.R @@ -200,4 +205,4 @@ def _cycle_obsop(self, R=R, rho=self.multiplicative_inflation) - return Xb_ds.assign(x=(['ensemble','i'], Xa.T)) + return self._rebuild_dataset(Xb_ds, Xa.T) \ No newline at end of file diff --git a/dabench/dacycler/_var3d.py b/dabench/dacycler/_var3d.py index 8b204df..2c0e9ad 100644 --- a/dabench/dacycler/_var3d.py +++ b/dabench/dacycler/_var3d.py @@ -68,7 +68,7 @@ def _compute_analysis(self, b1 = xb + jnp.dot(BHtRinv, y) # Use minimization algorithm to minimize cost function: - xa, ierr = jscipy.sparse.linalg.cg(A, b1, x0=xb, tol=1e-05, + xa, ierr = jscipy.sparse.linalg.cg(A, b1, x0=xb.astype(float), tol=1e-05, maxiter=1000) return xa @@ -93,7 +93,12 @@ def _cycle_obsop(self, H = self.H if R is None: if self.R is None: - R = self._calc_default_R(obs_values, self.obs_error_sd) + if self._scalar_obs_error: + R = self._calc_default_R(obs_values, self.obs_error_sd) + else: + R = self._calc_default_R( + obs_values, + self.obs_error_sd[obs_loc_indices.flatten()]) else: R = self.R if B is None: @@ -106,17 +111,19 @@ def _cycle_obsop(self, y = obs_values.flatten() # Apply masks to H - H = jnp.where(obs_time_mask.flatten(), H.T, 0).T + H = jnp.where(jnp.tile(obs_time_mask.flatten(), obs_loc_mask.shape[0]), H.T, 0).T H = jnp.where(obs_loc_mask.flatten(), H.T, 0).T # Set parameters Rinv = jnp.linalg.inv(R) - xa, ierr = self._compute_analysis( + xa = self._compute_analysis( xb, + y, B, H, Rinv, ) - return xb_ds.assign(x=(xb_ds.dims, xa.T)) + # Reshape + return self._rebuild_dataset(xb_ds, xa) From 6c4213647fa4af77caaa8535e85830af8e2ae1d4 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Thu, 5 Jun 2025 12:09:37 -0400 Subject: [PATCH 11/18] 4D-Var and Backprop: Support array obs_error_sds (warning message for non-stationary observers, may not function), and fix issues with multi-dimensional datasets --- dabench/dacycler/_var4d.py | 18 +++++++++++++++--- dabench/dacycler/_var4d_backprop.py | 26 +++++++++++++++++--------- 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/dabench/dacycler/_var4d.py b/dabench/dacycler/_var4d.py index 1b2974e..87e8faa 100644 --- a/dabench/dacycler/_var4d.py +++ b/dabench/dacycler/_var4d.py @@ -154,13 +154,15 @@ def _innerloop_4d(self, SumMtHtRinvHM += Jb SumMtHtRinvD += Jo # Compute initial departure - db0 = (xb0_ds - x0_ds).to_stacked_array('system',[]).data + db0 = (xb0_ds - x0_ds).to_stacked_array('system', []).data # Solve Ax=b for the initial perturbation dx0 = self._solve(db0, SumMtHtRinvHM, SumMtHtRinvD, B) # New x0 guess is the last guess plus the analyzed delta - xa0_ds = x0_ds + dx0.ravel() + # NOTE: This works, but there may be better way to get + # multi-dim systems into same shape + xa0_ds = x0_ds + dx0.reshape(x0_ds.to_array().shape[1:]) return xa0_ds @@ -267,7 +269,17 @@ def _cycle_obsop(self, if R is None: if self.R is None: - R = self._calc_default_R(obs_values, self.obs_error_sd) + if self._scalar_obs_error: + R = self._calc_default_R(obs_values, self.obs_error_sd) + else: + warnings.warn(( + 'Using array-like obs_error_sd with 4D DA methods is' + 'not fully supported. If observations are not stationary,' + 'will likely produce incorrect results' + )) + R = self._calc_default_R( + obs_values, + self.obs_error_sd[obs_loc_indices[0].flatten()]) else: R = self.R diff --git a/dabench/dacycler/_var4d_backprop.py b/dabench/dacycler/_var4d_backprop.py index 944f980..f8309b9 100644 --- a/dabench/dacycler/_var4d_backprop.py +++ b/dabench/dacycler/_var4d_backprop.py @@ -111,7 +111,7 @@ def _calc_default_H(self, def _calc_default_R(self, obs_values: ArrayLike, - obs_error_sd: float + obs_error_sd: float | ArrayLike ) -> jax.Array: return jnp.identity(obs_values[0].shape[0])*(obs_error_sd**2) @@ -160,7 +160,7 @@ def loss_4dvarcost(x0: XarrayDatasetLike) -> jax.Array: # Make new prediction # NOTE: [1] selects the full forecast instead of last timestep only X = self._step_forecast( - x0, n_steps)[1].to_stacked_array('system',['time']).data + x0, n_steps)[1].dab.flatten().data # Calculate observation term of J_0 obs_term = 0 @@ -201,8 +201,8 @@ def _backprop_epoch( x0_ds, init_loss, opt_state = epoch_state_tuple x0_ds = x0_ds.to_xarray() loss_val, dx0 = loss_value_grad(x0_ds) - x0_ar = x0_ds.to_stacked_array('system', []) - dx0_hess = hessian_inv @ dx0.to_stacked_array('system',[]).data + x0_ar = x0_ds.dab.flatten() + dx0_hess = hessian_inv @ dx0.dab.flatten().data init_loss = jax.lax.cond( i == 0, lambda: loss_val, @@ -216,9 +216,7 @@ def _backprop_epoch( updates, opt_state = optimizer.update(dx0_hess, opt_state) x0_ar.data = optax.apply_updates( x0_ar.data, updates) - xa0_ds = x0_ar.to_unstacked_dataset('system').assign_attrs( - x0_ds.attrs - ) + xa0_ds = x0_ar.dab.unflatten().assign_attrs(x0_ds.attrs) return (xj.from_xarray(xa0_ds), init_loss, opt_state), loss_val return _backprop_epoch @@ -259,7 +257,17 @@ def _cycle_obsop(self, if R is None: if self.R is None: - R = self._calc_default_R(obs_values, self.obs_error_sd) + if self._scalar_obs_error: + R = self._calc_default_R(obs_values, self.obs_error_sd) + else: + warnings.warn(( + 'Using array-like obs_error_sd with 4D DA methods is' + 'not fully supported. If observations are not stationary,' + 'will likely produce incorrect results' + )) + R = self._calc_default_R( + obs_values, + self.obs_error_sd[obs_loc_indices[0].flatten()]) else: R = self.R @@ -291,7 +299,7 @@ def _cycle_obsop(self, 1, self.lr_decay) optimizer = optax.sgd(lr) - opt_state = optimizer.init(xb0_ds.to_stacked_array('system',[]).data) + opt_state = optimizer.init(xb0_ds.dab.flatten().data)#to_stacked_array('system',[]).data) # Make initial forecast and calculate loss backprop_epoch_func = self._make_backprop_epoch(loss_func, optimizer, From c941404e230ffccbdd12aad999f347ab263798cc Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Thu, 5 Jun 2025 12:11:05 -0400 Subject: [PATCH 12/18] Fixes for dimension issues in xarray accessors --- dabench/data/_xarray_accessor.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/dabench/data/_xarray_accessor.py b/dabench/data/_xarray_accessor.py index d0fe028..da46031 100644 --- a/dabench/data/_xarray_accessor.py +++ b/dabench/data/_xarray_accessor.py @@ -28,11 +28,13 @@ def __init__(self, self._obj = xarray_obj def flatten(self) -> xr.DataArray: - if 'time' in self._obj.coords: + if 'time' in self._obj.coords and 'time' in self._obj.sizes: remaining_dim = ['time'] else: remaining_dim = [] - return self._obj.to_stacked_array('system', remaining_dim) + return self._obj.to_stacked_array( + 'system', remaining_dim + ) def split_train_val_test(self, split_lengths: list | np.ndarray @@ -59,7 +61,10 @@ def __init__(self, self._obj = xarray_obj def unflatten(self) -> xr.Dataset: - return self._obj.to_unstacked_dataset('system') + ds = self._obj.to_unstacked_dataset('system') + if 'system' in ds.dims: + ds = ds.unstack('system') + return ds def split_train_val_test(self, split_lengths: list | np.ndarray From c6a34100fdec521efa056c9bd9cd52741eb2652d Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Thu, 5 Jun 2025 12:11:52 -0400 Subject: [PATCH 13/18] Pyqg jax with xarray output --- dabench/data/_pyqg_jax.py | 35 +++++++++++++++++++++++++++++------ 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/dabench/data/_pyqg_jax.py b/dabench/data/_pyqg_jax.py index e26282d..328c681 100644 --- a/dabench/data/_pyqg_jax.py +++ b/dabench/data/_pyqg_jax.py @@ -113,6 +113,9 @@ def __init__(self, store_as_jax=store_as_jax, x0=x0, **kwargs) + self.coord_names = ['level','x','y'] + self.var_names=['q'] + @functools.partial(jax.jit, static_argnames=["self", "num_steps"]) def _roll_out_state(self, state, num_steps): """Helper method taken from pyqg-jax docs: @@ -122,7 +125,7 @@ def _roll_out_state(self, state, num_steps): def loop_fn(carry, _x): current_state = carry next_state = self.m.step_model(current_state) - return next_state, next_state + return next_state, current_state _final_carry, traj_steps = jax.lax.scan( loop_fn, state, None, length=num_steps @@ -213,18 +216,38 @@ def generate(self, ) ) - self.x0 = x0.flatten() - # Store step times - self.times = jnp.arange(0, t_final, self.delta_t) + times = np.arange(0, t_final, self.delta_t) # Run simulation traj = self._roll_out_state(init_state, num_steps=n_steps) qs = traj.state.q + # Build Xarray object for output + coord_dict = dict(zip( + ['time'] + self.coord_names, + [times] + [np.arange(dim) for dim in self.original_dim] + )) + time_dim = times.shape[0] + out_dim = (time_dim,) + self.original_dim + + # Convert to JAX if necessary + y = qs + if self.store_as_jax or isinstance(y, jax.core.Tracer): + y_out = jnp.array(y[:, :self.system_dim].reshape(out_dim)) + else: + y_out = np.array(y[:, :self.system_dim].reshape(out_dim)) + out_vec = xr.Dataset( + {self.var_names[0]: (coord_dict.keys(), y_out)}, + coords=coord_dict, + attrs={'store_as_jax': self.store_as_jax, + 'system_dim': self.system_dim, + 'delta_t': self.delta_t + } + ) + # Save values - self.time_dim = qs.shape[0] - self.values = qs.reshape((self.time_dim, -1)) + return out_vec # TODO: Remove? Believe this is deprecated def forecast(self, n_steps=None, t_final=None, x0=None): From e7ca10459d009eeb57ce4c96e5d07889590e6227 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Thu, 5 Jun 2025 12:12:22 -0400 Subject: [PATCH 14/18] sqturb with proper xarray outputs --- dabench/data/_sqgturb.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/dabench/data/_sqgturb.py b/dabench/data/_sqgturb.py index 5e2601d..1ea9583 100644 --- a/dabench/data/_sqgturb.py +++ b/dabench/data/_sqgturb.py @@ -167,9 +167,6 @@ def __init__(self, self.r = jnp.array(r, dtype) # Ekman damping (at z=0) self.tdiab = jnp.array(tdiab, dtype) # thermal relaxation damping. - # Initialize time counter - self.t = tstart - # Setup basic state pv (for thermal relaxation) self.symmetric = symmetric y = jnp.arange(0, self.L, self.L / self.N, dtype=dtype) @@ -199,6 +196,7 @@ def __init__(self, pvbar = pvbar * jnp.ones((2, N, N), dtype) self.pvbar = pvbar # state to relax to with timescale tdiab + # NOTE: Is this an error? It is never updated. self.pvspec_eq = rfft2(pvbar) # initial pv field (spectral) self.pvspec = rfft2(pv) @@ -461,9 +459,6 @@ def integrate(self, ) -> tuple[jax.Array, jax.Array]: """Advances pv forward number of timesteps given by self.n_steps. - Note: - If pv not specified, use pvspec instance variable. - Args: function (function): right hand side (rhs) of the ODE. Not used, but needed to function with generate() from _data.Data(). @@ -480,15 +475,14 @@ def integrate(self, delta_t = self.delta_t # Run integration - pvspec, values = jax.lax.scan(self._rk4, pvspec, xs=times) + pvspec_updated, values = jax.lax.scan(self._rk4, pvspec, xs=times[:-1]) + + # Prepend input to values so x0 is include + values = jnp.insert(values, 0, pvspec, axis=0) # Apply reverse fft to values = self.ifft2(values) - # Update internal states - self.pvspec = pvspec - self.t = times[-1] - return values, times def rhs(self, From 86e32e7e7353b5f5c051f2ef6029cc554e328d31 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Thu, 5 Jun 2025 14:25:24 -0400 Subject: [PATCH 15/18] Remove defaults from obs_error_sd and analysis_window --- dabench/dacycler/_dacycler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dabench/dacycler/_dacycler.py b/dabench/dacycler/_dacycler.py index 5955c57..b612cf3 100644 --- a/dabench/dacycler/_dacycler.py +++ b/dabench/dacycler/_dacycler.py @@ -198,8 +198,8 @@ def cycle(self, start_time: float | np.datetime64, obs_vector: XarrayDatasetLike, n_cycles: int, - obs_error_sd: float | ArrayLike | None = None, - analysis_window: float = 0.2, + obs_error_sd: float | ArrayLike, + analysis_window: float, analysis_time_in_window: float | None = None, return_forecast: bool = False ) -> XarrayDatasetLike: From 943fce4ddf0457b4d8743daba97c86da83abcdb1 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Thu, 5 Jun 2025 15:35:42 -0400 Subject: [PATCH 16/18] Var3d test fix to include obs_error_sd, which is now required arg for cycle() --- tests/dacycler_var3d_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/dacycler_var3d_test.py b/tests/dacycler_var3d_test.py index fb0094e..33ddc37 100644 --- a/tests/dacycler_var3d_test.py +++ b/tests/dacycler_var3d_test.py @@ -67,6 +67,7 @@ def test_var3d_l96(l96_nature_run, obs_vec_l96, var3d_cycler): start_time = start_time, obs_vector = obs_vec_l96, n_cycles=10, + obs_error_sd=0.7, analysis_window=0.25, return_forecast=False) From 6eb5aeb859852f3ec63966d787346e66305000f1 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Thu, 5 Jun 2025 16:01:51 -0400 Subject: [PATCH 17/18] Observation sampling was updated to fix multi-dim problems, broke the tests --- tests/observer_base_test.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/observer_base_test.py b/tests/observer_base_test.py index 4ee1c7c..6a4f487 100644 --- a/tests/observer_base_test.py +++ b/tests/observer_base_test.py @@ -226,12 +226,12 @@ def test_obs_gcp(): assert obs_vec_flat.shape == (3, 50) assert obs_vec['time'].shape[0] == 3 - assert obs_vec['system_index'].values[0, 0, 0] == 326 + assert obs_vec['system_index'].values[0, 0, 0] == 516 assert obs_vec_flat.values[0, 0] == pytest.approx( ds.sel(time=obs_vec_flat['time'][0] - ).drop_vars('time').dab.flatten().values[326] + ).drop_vars('time').dab.flatten().values[516] + 5) - assert obs_vec_flat[2, 42].values == pytest.approx(304.60122681) + assert obs_vec_flat[2, 42].values == pytest.approx(304.5305481) assert np.array_equal(obs_vec['errors'], np.repeat(5, 3*50).reshape(1, 3, 50)) assert obs_vec['time'][1] == np.datetime64('2010-01-01T18:00:00.000000000') @@ -244,18 +244,18 @@ def test_obs_sqgturb(): obs = observer.Observer( ds, - random_time_density=0.3, + random_time_density=0.4, random_location_density=0.01, error_sd=25.) obs_vec = obs.observe() - assert obs_vec['pv'].shape == (2, 204) - assert obs_vec['time'].shape[0] == 2 - assert obs_vec['system_index'][0, 0, 0] == 10130 + assert obs_vec['pv'].shape == (3, 204) + assert obs_vec['time'].shape[0] == 3 + assert obs_vec['system_index'][0, 0, 0] == 16302 assert obs_vec['pv'].values[1, 123] == pytest.approx( (ds.dab.flatten().sel(time=obs_vec['time'][1])[ obs_vec['system_index'][0, 1, 123]] + obs_vec['errors'][0, 1, 123])) - assert obs_vec['pv'].values[0, 44] == pytest.approx(2128.20749725) - assert obs_vec['errors'].values[0, 1, 187] == pytest.approx(25.714826330507496) - assert np.allclose(obs_vec['time'], np.array([2700., 9000.])) + assert obs_vec['pv'].values[0, 44] == pytest.approx(4166.347415908463) + assert obs_vec['errors'].values[0, 1, 187] == pytest.approx(-14.59475742413906) + assert np.allclose(obs_vec['time'], np.array([2700., 3600., 8100.])) From b2930febb1cbcc3133d54b5a7689ff71fc023911 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Thu, 5 Jun 2025 16:03:27 -0400 Subject: [PATCH 18/18] Sqgturb generate produces exactly n_steps now, not n_steps+1, in order to match other generators --- tests/data_sqgturb_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/data_sqgturb_test.py b/tests/data_sqgturb_test.py index cae5c4d..2a3f2d3 100644 --- a/tests/data_sqgturb_test.py +++ b/tests/data_sqgturb_test.py @@ -40,5 +40,5 @@ def test_variable_sizes(sqgturb): traj = sqgturb.generate(n_steps=n_steps) assert traj.system_dim == 18432 - assert traj.sizes['time'] == n_steps+1 - assert traj.dab.flatten().shape == (n_steps+1, 18432) + assert traj.sizes['time'] == n_steps + assert traj.dab.flatten().shape == (n_steps, 18432)