Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
609cf74
ETKF with updated variable names differentiating vectors and matrices
kysolvik Oct 18, 2024
62cbc9d
Variable name updates to dacycler, replacing xb with x0
kysolvik Oct 18, 2024
fe33c14
Updated variable names for var3d
kysolvik Oct 18, 2024
8a077df
Input to step_cycle is always X0 instad of x0 since it is a multidime…
kysolvik Oct 18, 2024
f29cd00
Renaming X0_ds to more general cur_state with base dacycler
kysolvik Oct 18, 2024
e333d08
Rename var3d X0_ds to x0_ds to avoid confusion with ensemble
kysolvik Oct 18, 2024
0769a50
Updated variable names in var4d and var4dBP to match other dacyclers …
kysolvik Oct 18, 2024
e559ff1
Update variable name for general obs op
kysolvik Oct 18, 2024
68d8f30
Updated base dacycler docstrings and type hints
kysolvik Oct 24, 2024
79fd4b1
Remove analysis_time_in_window as a attribute
kysolvik Oct 24, 2024
be661f1
Missing space in dacycler base class
kysolvik Oct 24, 2024
0fba72c
ETKF with type hints and updated docstrings
kysolvik Oct 24, 2024
e7374d9
Dacycler utils with type hints and updated docstrings
kysolvik Oct 24, 2024
99fe8ec
Start var4d type hints
kysolvik Mar 12, 2025
a53115c
3D Var type hints
kysolvik Mar 20, 2025
2921905
Type hints for 4dvar backprop
kysolvik Mar 20, 2025
393a542
4DVar type hints
kysolvik Mar 20, 2025
8c8d23e
Type hints for base Data class
kysolvik Mar 21, 2025
0ec25bc
type hints for data/_utils.py
kysolvik Mar 21, 2025
1750a8b
type hints for xarray accessors
kysolvik Mar 21, 2025
5f97e02
Fix bad output type hint for base generate method
kysolvik Mar 21, 2025
a51dff2
Barotropic model type hints, but barotropic might need to removed sin…
kysolvik Mar 21, 2025
227056e
type hints for enso indices
kysolvik Mar 21, 2025
e74d10d
Type hints for gcp.py
kysolvik Mar 21, 2025
404c9f5
type hints for Lorenz63 and 96
kysolvik Mar 21, 2025
8ea2808
type hints for pyqg_jax
kysolvik Mar 21, 2025
c63c67d
Type hints for qgs
kysolvik Mar 21, 2025
6542daf
Type hints for sqgturb
kysolvik Mar 21, 2025
a92d19a
Type hints for metrics module (which is WIP)
kysolvik Mar 21, 2025
40dee55
Type hints for base model class (todo for rc and neuralGCM, but waiti…
kysolvik Mar 21, 2025
d27b811
Observer type hints
kysolvik Mar 21, 2025
537a590
analysis_time_in_window is arg passed to cycle instead of att on class
kysolvik Mar 21, 2025
b5b4ccc
Extra space in _dacycler
kysolvik Apr 4, 2025
cc89d5d
Update x0 -> xb in etkf and 3dvar
kysolvik Apr 4, 2025
9b1fec4
Updated var names for 4dvar. xb = bg, xa = analysis, x = temporary du…
kysolvik Apr 4, 2025
6ffeeb5
Fix QGS at 0.2.8 temporarily, 1.0.0 released last week and broke tests
kysolvik Apr 4, 2025
412a808
Pin qgs at 0.2.8 in pyproject (mistakenly added unused pip_constraint…
kysolvik Apr 4, 2025
ea9d8f6
Pin sparse to 0.15.4
kysolvik Apr 4, 2025
1c03929
Docstring fixes for var3d (remove extra items) and etkf (typo)
kysolvik Apr 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 85 additions & 57 deletions dabench/dacycler/_dacycler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,43 +5,49 @@
import jax
import xarray as xr
import xarray_jax as xj
from typing import Callable

import dabench.dacycler._utils as dac_utils
from dabench.model import Model


# For typing
ArrayLike = np.ndarray | jax.Array
XarrayDatasetLike = xr.Dataset | xj.XjDataset

class DACycler():
"""Base class for DACycler object

Attributes:
system_dim (int): System dimension
delta_t (float): The timestep of the model (assumed uniform)
model_obj (dabench.Model): Forecast model object.
in_4d (bool): True for 4D data assimilation techniques (e.g. 4DVar).
system_dim: System dimension
delta_t: The timestep of the model (assumed uniform)
model_obj: Forecast model object.
in_4d: True for 4D data assimilation techniques (e.g. 4DVar).
Default is False.
ensemble (bool): True for ensemble-based data assimilation techniques
ensemble: True for ensemble-based data assimilation techniques
(ETKF). Default is False
B (ndarray): Initial / static background error covariance. Shape:
B: Initial / static background error covariance. Shape:
(system_dim, system_dim). If not provided, will be calculated
automatically.
R (ndarray): Observation error covariance matrix. Shape
R: Observation error covariance matrix. Shape
(obs_dim, obs_dim). If not provided, will be calculated
automatically.
H (ndarray): Observation operator with shape: (obs_dim, system_dim).
H: Observation operator with shape: (obs_dim, system_dim).
If not provided will be calculated automatically.
h (function): Optional observation operator as function. More flexible
h: Optional observation operator as function. More flexible
(allows for more complex observation operator). Default is None.
"""

def __init__(self,
system_dim=None,
delta_t=None,
model_obj=None,
in_4d=False,
ensemble=False,
B=None,
R=None,
H=None,
h=None,
analysis_time_in_window=None
system_dim: int,
delta_t: float,
model_obj: Model,
in_4d: bool = False,
ensemble: bool = False,
B: ArrayLike | None = None,
R: ArrayLike | None = None,
H: ArrayLike | None = None,
h: Callable | None = None,
):

self.h = h
Expand All @@ -53,43 +59,64 @@ def __init__(self,
self.system_dim = system_dim
self.delta_t = delta_t
self.model_obj = model_obj
self.analysis_time_in_window = analysis_time_in_window


def _calc_default_H(self, obs_values, obs_loc_indices):
def _calc_default_H(self,
obs_values: ArrayLike,
obs_loc_indices: ArrayLike
) -> jax.Array:
H = jnp.zeros((obs_values.flatten().shape[0], self.system_dim))
H = H.at[jnp.arange(H.shape[0]),
obs_loc_indices.flatten(),
].set(1)
return H

def _calc_default_R(self, obs_values, obs_error_sd):
def _calc_default_R(self,
obs_values: ArrayLike,
obs_error_sd: float
) -> jax.Array:
return jnp.identity(obs_values.flatten().shape[0])*(obs_error_sd**2)

def _calc_default_B(self):
def _calc_default_B(self) -> jax.Array:
"""If B is not provided, identity matrix with shape (system_dim, system_dim."""
return jnp.identity(self.system_dim)

def _step_forecast(self, xa, n_steps=1):
def _step_forecast(self,
xa: XarrayDatasetLike,
n_steps: int = 1
) -> XarrayDatasetLike:
"""Perform forecast using model object"""
return self.model_obj.forecast(xa, n_steps=n_steps)

def _step_cycle(self, xb, obs_vals, obs_locs, obs_time_mask, obs_loc_mask,
H=None, h=None, R=None, B=None, **kwargs):
def _step_cycle(self,
cur_state: XarrayDatasetLike,
obs_vals: ArrayLike,
obs_locs: ArrayLike,
obs_time_mask: ArrayLike,
obs_loc_mask: ArrayLike,
H: ArrayLike | None = None,
h: Callable | None =None,
R: ArrayLike | None = None,
B:ArrayLike | None = None,
**kwargs
) -> XarrayDatasetLike:
if H is not None or h is None:
vals = self._cycle_obsop(
xb, obs_vals, obs_locs, obs_time_mask,
cur_state, obs_vals, obs_locs, obs_time_mask,
obs_loc_mask, H, R, B, **kwargs)
return vals
else:
raise ValueError(
'Only linear obs operators (H) are supported right now.')
vals = self._cycle_general_obsop(
xb, obs_vals, obs_locs, obs_time_mask,
cur_state, obs_vals, obs_locs, obs_time_mask,
obs_loc_mask, h, R, B, **kwargs)
return vals

def _cycle_and_forecast(self, cur_state, filtered_idx):
def _cycle_and_forecast(self,
cur_state: xj.XjDataset,
filtered_idx: ArrayLike
) -> tuple[xj.XjDataset, XarrayDatasetLike]:
# 1. Get data
# 1-b. Calculate obs_time_mask and restore filtered_idx to original values
cur_state = cur_state.to_xarray()
Expand Down Expand Up @@ -119,7 +146,10 @@ def _cycle_and_forecast(self, cur_state, filtered_idx):

return xj.from_xarray(next_state), forecast_states

def _cycle_and_forecast_4d(self, cur_state, filtered_idx):
def _cycle_and_forecast_4d(self,
cur_state: xj.XjDataset,
filtered_idx: ArrayLike
) -> tuple[xj.XjDataset, XarrayDatasetLike]:
# 1. Get data
# 1-b. Calculate obs_time_mask and restore filtered_idx to original values
cur_state = cur_state.to_xarray()
Expand Down Expand Up @@ -160,35 +190,32 @@ def _cycle_and_forecast_4d(self, cur_state, filtered_idx):
return xj.from_xarray(next_state), forecast_states

def cycle(self,
input_state,
start_time,
obs_vector,
n_cycles,
obs_error_sd=None,
analysis_window=0.2,
analysis_time_in_window=None,
return_forecast=False
):
input_state: XarrayDatasetLike,
start_time: float | np.datetime64,
obs_vector: XarrayDatasetLike,
n_cycles: int,
obs_error_sd: float | ArrayLike | None = None,
analysis_window: float = 0.2,
analysis_time_in_window: float | None = None,
return_forecast: bool = False
) -> XarrayDatasetLike:
"""Perform DA cycle repeatedly, including analysis and forecast

Args:
input_state (vector.StateVector): Input state.
start_time (float or datetime-like): Starting time.
obs_vector (vector.ObsVector): Observations vector.
n_cycles (int): Number of analysis cycles to run, each of length
input_state: Input state as a Xarray Dataset
start_time: Starting time.
obs_vector: Observations vector.
n_cycles: Number of analysis cycles to run, each of length
analysis_window.
analysis_window (float): Time window from which to gather
analysis_window: Time window from which to gather
observations for DA Cycle.
analysis_time_in_window (float): Where within analysis_window
analysis_time_in_window: Where within analysis_window
to perform analysis. For example, 0.0 is the start of the
window. Default is None, which selects the middle of the
window.
return_forecast (bool): If True, returns forecast at each model
return_forecast: If True, returns forecast at each model
timestep. If False, returns only analyses, one per analysis
cycle. Default is False.

Returns:
vector.StateVector of analyses and times.
cycle.
"""

# These could be different if observer doesn't observe all variables
Expand All @@ -202,10 +229,11 @@ def cycle(self,
self.analysis_window = analysis_window

# If don't specify analysis_time_in_window, is assumed to be middle
if self.analysis_time_in_window is None and analysis_time_in_window is None:
analysis_time_in_window = self.analysis_window/2
else:
analysis_time_in_window = self.analysis_time_in_window
if analysis_time_in_window is None:
if self.in_4d:
analysis_time_in_window = 0
else:
analysis_time_in_window = self.analysis_window/2

# Steps per window + 1 to include start
self.steps_per_window = round(analysis_window/self.delta_t) + 1
Expand Down Expand Up @@ -256,13 +284,13 @@ def cycle(self,
xj.from_xarray(input_state),
all_filtered_padded)

all_vals_xr = xr.Dataset(
all_vals_ds = xr.Dataset(
{var: (('cycle',) + tuple(all_values[var].dims),
all_values[var].data)
for var in all_values.data_vars}
).rename_dims({'time': 'cycle_timestep'})

if return_forecast:
return all_vals_xr.drop_isel(cycle_timestep=-1)
return all_vals_ds.drop_isel(cycle_timestep=-1)
else:
return all_vals_xr.isel(cycle_timestep=0)
return all_vals_ds.isel(cycle_timestep=0)
Loading