Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
fda4de0
Make data submodules private to match other submodules
kysolvik Apr 11, 2025
329293b
Remove private members from autoapi_options'
kysolvik Apr 11, 2025
e74ca6e
Add support for Google docstrings and include typehints
kysolvik Apr 11, 2025
f712eb1
Example update to fix bad docs formatting, need to distinguish betwee…
kysolvik Apr 17, 2025
19c20fb
Tweaks to conf.py to help with formatting, etc.
kysolvik Apr 18, 2025
fba0fab
Add docstrings to __init__.py
kysolvik Apr 18, 2025
e23cbb8
Distinguish params from class attributes (only in_4d and ensemble). C…
kysolvik Apr 18, 2025
9a41f88
Updated data classes docstrings to clarify parameters vs. attributes
kysolvik Apr 18, 2025
3c505be
Remove class attributes from super.__init__ for dacyclers
kysolvik Apr 18, 2025
0e9cde0
Make in_4d and uses_ensemble non-public class attributes, user doesn'…
kysolvik Apr 19, 2025
9a3bcdc
Update dacycler base test with in_4d and uses_ensemble as class attri…
kysolvik Apr 19, 2025
7ac9b1b
Remove self.time_dim attribute from data generators, doesn't make sen…
kysolvik Apr 19, 2025
6b61598
Update dacycler base test with new _uses_ensemble attribute
kysolvik Apr 19, 2025
d00cae9
Fix typo in a couple qgs docstrings (replace Arg: with Args:)
kysolvik Apr 19, 2025
2f10eab
Leftover docstring type annotations in qgs
kysolvik Apr 19, 2025
a462b00
Update docstrings for observer to work with sphinx. Clearly distingui…
kysolvik Apr 19, 2025
a4132d6
Update docstrings and remove time_dim attribute from model object
kysolvik Apr 19, 2025
921ea90
Update rc.py docstrings, but also needs type hints and fixes to confu…
kysolvik Apr 19, 2025
0e42f9d
Attributes: -> Args: for params in main rc class docstring
kysolvik Apr 19, 2025
1429fcc
Fix class docstrings to remove redundant 'Class' statement, based on …
kysolvik Apr 19, 2025
a0e1db5
Remove unnecessary redundant statement ('Class') from Class docstring…
kysolvik Apr 19, 2025
afaff7a
Observer fix bad return docstring and mismatched arg name
kysolvik Apr 24, 2025
638845c
RC model docstring updates for sphinx, but still need to update metho…
kysolvik Apr 24, 2025
560d575
Data class docstring fixes
kysolvik Apr 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dabench/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
"""DataAssimBench"""
from . import data, model, observer, obsop, dacycler, _suppl_data
2 changes: 2 additions & 0 deletions dabench/dacycler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Data Assimilation cyclers"""

from ._dacycler import DACycler
from ._var3d import Var3D
from ._etkf import ETKF
Expand Down
20 changes: 7 additions & 13 deletions dabench/dacycler/_dacycler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,12 @@
XarrayDatasetLike = xr.Dataset | xj.XjDataset

class DACycler():
"""Base class for DACycler object
"""Base for all DACyclers

Attributes:
Args:
system_dim: System dimension
delta_t: The timestep of the model (assumed uniform)
model_obj: Forecast model object.
in_4d: True for 4D data assimilation techniques (e.g. 4DVar).
Default is False.
ensemble: True for ensemble-based data assimilation techniques
(ETKF). Default is False
B: Initial / static background error covariance. Shape:
(system_dim, system_dim). If not provided, will be calculated
automatically.
Expand All @@ -37,13 +33,13 @@ class DACycler():
h: Optional observation operator as function. More flexible
(allows for more complex observation operator). Default is None.
"""
_in_4d: bool = False
_uses_ensemble: bool = False

def __init__(self,
system_dim: int,
delta_t: float,
model_obj: Model,
in_4d: bool = False,
ensemble: bool = False,
B: ArrayLike | None = None,
R: ArrayLike | None = None,
H: ArrayLike | None = None,
Expand All @@ -54,8 +50,6 @@ def __init__(self,
self.H = H
self.R = R
self.B = B
self.in_4d = in_4d
self.ensemble = ensemble
self.system_dim = system_dim
self.delta_t = delta_t
self.model_obj = model_obj
Expand Down Expand Up @@ -230,7 +224,7 @@ def cycle(self,

# If don't specify analysis_time_in_window, is assumed to be middle
if analysis_time_in_window is None:
if self.in_4d:
if self._in_4d:
analysis_time_in_window = 0
else:
analysis_time_in_window = self.analysis_window/2
Expand All @@ -257,7 +251,7 @@ def cycle(self,
obs_times=jnp.array(obs_vector.time.values),
analysis_times=all_times+_time_offset,
start_inclusive=True,
end_inclusive=self.in_4d,
end_inclusive=self._in_4d,
analysis_window=analysis_window
)
input_state = input_state.assign(_cur_time=start_time)
Expand All @@ -273,7 +267,7 @@ def cycle(self,
obs_vector[self._observed_vars].to_array().data)
self._obs_vector=self._obs_vector.fillna(0)

if self.in_4d:
if self._in_4d:
cur_state, all_values = jax.lax.scan(
self._cycle_and_forecast_4d,
xj.from_xarray(input_state),
Expand Down
8 changes: 4 additions & 4 deletions dabench/dacycler/_etkf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
XarrayDatasetLike = xr.Dataset | xj.XjDataset

class ETKF(dacycler.DACycler):
"""Class for building ETKF DA Cycler
"""Ensemble transform Kalman filter DA Cycler

Attributes:
Args:
system_dim: System dimension.
delta_t: The timestep of the model (assumed uniform)
model_obj: Forecast model object.
Expand All @@ -38,6 +38,8 @@ class ETKF(dacycler.DACycler):
multiplicative_inflation: Scaling factor by which to multiply ensemble
deviation. Default is 1.0 (no inflation).
"""
_in_4d: bool = False
_uses_ensemble: bool = True

def __init__(self,
system_dim: int,
Expand All @@ -57,8 +59,6 @@ def __init__(self,
super().__init__(system_dim=system_dim,
delta_t=delta_t,
model_obj=model_obj,
in_4d=False,
ensemble=True,
B=B, R=R, H=H, h=h)

def _step_forecast(self,
Expand Down
8 changes: 4 additions & 4 deletions dabench/dacycler/_var3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
XarrayDatasetLike = xr.Dataset | xj.XjDataset

class Var3D(dacycler.DACycler):
"""Class for building 3DVar DA Cycler
"""3D-Var DA Cycler

Attributes:
Args:
system_dim: System dimension.
delta_t: The timestep of the model (assumed uniform)
model_obj: Forecast model object.
Expand All @@ -33,6 +33,8 @@ class Var3D(dacycler.DACycler):
h: Optional observation operator as function. More flexible
(allows for more complex observation operator). Default is None.
"""
_in_4d: bool = False
_uses_ensemble: bool = False

def __init__(self,
system_dim: int,
Expand All @@ -47,8 +49,6 @@ def __init__(self,
super().__init__(system_dim=system_dim,
delta_t=delta_t,
model_obj=model_obj,
in_4d=False,
ensemble=False,
B=B, R=R, H=H, h=h)

def _cycle_obsop(self,
Expand Down
12 changes: 4 additions & 8 deletions dabench/dacycler/_var4d.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,12 @@
XarrayDatasetLike = xr.Dataset | xj.XjDataset

class Var4D(dacycler.DACycler):
"""Class for building 4D DA Cycler
"""4D-Var DA Cycler

Attributes:
Args:
system_dim: System dimension.
delta_t: The timestep of the model (assumed uniform)
model_obj: Forecast model object.
in_4d: True for 4D data assimilation techniques (e.g. 4DVar).
Always True for Var4D.
ensemble: True for ensemble-based data assimilation techniques
(ETKF). Always False for Var4D.
B: Initial / static background error covariance. Shape:
(system_dim, system_dim). If not provided, will be calculated
automatically.
Expand All @@ -59,6 +55,8 @@ class Var4D(dacycler.DACycler):
[0, 1, 2, 3, 4, 5]. If None (default), will calculate
automatically.
"""
_in_4d: bool = True
_uses_ensemble: bool = False

def __init__(self,
system_dim: int,
Expand Down Expand Up @@ -87,8 +85,6 @@ def __init__(self,
super().__init__(system_dim=system_dim,
delta_t=delta_t,
model_obj=model_obj,
in_4d=True,
ensemble=False,
B=B, R=R, H=H, h=h)

def _calc_default_H(self,
Expand Down
12 changes: 4 additions & 8 deletions dabench/dacycler/_var4d_backprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,12 @@
ScheduleState = Any

class Var4DBackprop(dacycler.DACycler):
"""Class for building Backpropagation 4D DA Cycler
"""Backpropagation 4D-Var DA Cycler

Attributes:
Args:
system_dim: System dimension.
delta_t: The timestep of the model (assumed uniform)
model_obj: Forecast model object.
in_4d: True for 4D data assimilation techniques (e.g. 4DVar).
Always True for Var4DBackprop.
ensemble: True for ensemble-based data assimilation techniques
(ETKF). Always False for Var4DBackprop.
B: Initial / static background error covariance. Shape:
(system_dim, system_dim). If not provided, will be calculated
automatically.
Expand Down Expand Up @@ -65,6 +61,8 @@ class Var4DBackprop(dacycler.DACycler):
return an error. This prevents it from hanging indefinitely
when loss grows exponentionally. Default is 10.
"""
_in_4d: bool = True
_uses_ensemble: bool = False

def __init__(self,
system_dim: int,
Expand Down Expand Up @@ -97,8 +95,6 @@ def __init__(self,
super().__init__(system_dim=system_dim,
delta_t=delta_t,
model_obj=model_obj,
in_4d=True,
ensemble=False,
B=B, R=R, H=H, h=h)

def _calc_default_H(self,
Expand Down
19 changes: 10 additions & 9 deletions dabench/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""Data generators and downloaders"""
from ._data import Data

from .lorenz63 import Lorenz63
from .lorenz96 import Lorenz96
from .sqgturb import SQGTurb
from .gcp import GCP
from .pyqg import PyQG
from .pyqg_jax import PyQGJax
from .barotropic import Barotropic
from .enso_indices import ENSOIndices
from .qgs import QGS
from ._lorenz63 import Lorenz63
from ._lorenz96 import Lorenz96
from ._sqgturb import SQGTurb
from ._gcp import GCP
from ._pyqg import PyQG
from ._pyqg_jax import PyQGJax
from ._barotropic import Barotropic
from ._enso_indices import ENSOIndices
from ._qgs import QGS
from ._xarray_accessor import DABenchDatasetAccessor, DABenchDataArrayAccessor

__all__ = [
Expand Down
11 changes: 6 additions & 5 deletions dabench/data/barotropic.py → dabench/data/_barotropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,21 @@


class Barotropic(_data.Data):
""" Class to set up barotropic model
"""Barotropic model data generator based on pyqg

The data class is a wrapper of a "optional" pyqg package.
This data class is a wrapper of a "optional" pyqg package.
See https://pyqg.readthedocs.io

Notes:
DEPRECATED
Uses default attribute values from pyqg.BTModel:
https://pyqg.readthedocs.io/en/latest/api.html#pyqg.BTModel
Those values originally come from Mcwilliams 1984:
J. C. Mcwilliams (1984). The emergence of isolated coherent
vortices in turbulent flow. Journal of Fluid Mechanics, 146,
pp 21-43 doi:10.1017/S0022112084001750.

Attributes:
Args:
system_dim: system dimension
beta: Gradient of coriolis parameter. Units: meters^-1 *
seconds^-1. Default is 0.
Expand Down Expand Up @@ -207,8 +208,8 @@ def __advance__(self,):
"""Advances the QG model according to set attributes

Returns:
qs (array_like): absolute potential vorticity (relative potential
vorticity + background vorticity).
Array of absolute potential vorticity (relative potential
vorticity + background vorticity).
"""
qs = []
for _ in self.m.run_with_snapshots(tsnapstart=0, tsnapint=self.m.dt):
Expand Down
24 changes: 10 additions & 14 deletions dabench/data/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@
ArrayLike = np.ndarray | jax.Array

class Data():
"""Generic class for data generator objects.
"""Base for all data generator objects.

Attributes:
Args:
system_dim: system dimension
time_dim: total time steps
original_dim: dimensions in original space, e.g. could be 3x3
for a 2d system with system_dim = 9. Defaults to (system_dim),
i.e. 1d.
Expand All @@ -32,7 +31,6 @@ class Data():

def __init__(self,
system_dim: int = 3,
time_dim: int = 1,
original_dim: tuple[int, ...] | None = None,
random_seed: int = 37,
delta_t: float = 0.01,
Expand All @@ -42,7 +40,6 @@ def __init__(self,
"""Initializes the base data object"""

self.system_dim = system_dim
self.time_dim = time_dim
self.random_seed = random_seed
self.delta_t = delta_t
self.store_as_jax = store_as_jax
Expand Down Expand Up @@ -98,8 +95,7 @@ def generate(self,

Notes:
Either provide n_steps or t_final in order to indicate the length
of the forecast. These are used to set the values, times, and
time_dim attributes.
of the forecast.

Args:
n_steps: Number of timesteps. One of n_steps OR
Expand All @@ -118,8 +114,8 @@ def generate(self,
convergence tolerance, etc.).

Returns:
Xarray Dataset of output vector and (if return_tlm=True)
Xarray DataArray of TLMs corresponding to the system trajectory.
Xarray Dataset of output vector, and if return_tlm=True
Xarray DataArray of TLMs corresponding to the system trajectory.
"""

# Check that n_steps or t_final is supplied
Expand Down Expand Up @@ -172,8 +168,8 @@ def generate(self,
**kwargs)

# Convert to JAX if necessary
self.time_dim = t.shape[0]
out_dim = (self.time_dim,) + self.original_dim
time_dim = t.shape[0]
out_dim = (time_dim,) + self.original_dim
if self.store_as_jax:
y_out = jnp.array(y[:,:self.system_dim].reshape(out_dim))
else:
Expand All @@ -197,13 +193,13 @@ def generate(self,
# Reshape M matrix
if self.store_as_jax:
M = jnp.reshape(y[:, self.system_dim:],
(self.time_dim,
(time_dim,
self.system_dim,
self.system_dim)
)
else:
M = np.reshape(y[:, self.system_dim:],
(self.time_dim,
(time_dim,
self.system_dim,
self.system_dim)
)
Expand Down Expand Up @@ -283,7 +279,7 @@ def calc_lyapunov_exponents_series(

Returns:
Lyapunov exponents for all timesteps, array of size
(total_time/rescale_time - 1, system_dim)
(total_time/rescale_time - 1, system_dim)
"""

# Set total_time
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@


class ENSOIndices(_data.Data):
"""Class to get ENSO indices from CPC website
"""Gets ENSO indices from CPC website

Notes:
Source: https://www.cpc.ncep.noaa.gov/data/indices/

Attributes:
Args:
system_dim: system dimension
time_dim: total time steps
store_as_jax: Store values as jax array instead of numpy array.
Default is False (store as numpy).
file_dict: Lists of files to get. Dict keys are type of data:
Expand Down Expand Up @@ -58,15 +57,14 @@ def __init__(self,
file_dict: dict | None = None,
var_types: dict | None = None,
system_dim: int | None = None,
time_dim: int | None = None,
store_as_jax: bool = False,
**kwargs):

"""Initialize ENSOIndices object, subclass of Base"""

self.file_dict = file_dict
self.var_types = var_types
super().__init__(system_dim=system_dim, time_dim=time_dim,
super().__init__(system_dim=system_dim,
values=None, delta_t=None, **kwargs,
store_as_jax=store_as_jax)

Expand Down
Loading