From fda4de0911adcea3325678919db9634dc10c9550 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Fri, 11 Apr 2025 17:08:23 -0600 Subject: [PATCH 01/24] Make data submodules private to match other submodules --- dabench/data/__init__.py | 18 +++++++++--------- dabench/data/{barotropic.py => _barotropic.py} | 0 .../data/{enso_indices.py => _enso_indices.py} | 0 dabench/data/{gcp.py => _gcp.py} | 0 dabench/data/{lorenz63.py => _lorenz63.py} | 0 dabench/data/{lorenz96.py => _lorenz96.py} | 0 dabench/data/{pyqg.py => _pyqg.py} | 0 dabench/data/{pyqg_jax.py => _pyqg_jax.py} | 0 dabench/data/{qgs.py => _qgs.py} | 0 dabench/data/{sqgturb.py => _sqgturb.py} | 0 10 files changed, 9 insertions(+), 9 deletions(-) rename dabench/data/{barotropic.py => _barotropic.py} (100%) rename dabench/data/{enso_indices.py => _enso_indices.py} (100%) rename dabench/data/{gcp.py => _gcp.py} (100%) rename dabench/data/{lorenz63.py => _lorenz63.py} (100%) rename dabench/data/{lorenz96.py => _lorenz96.py} (100%) rename dabench/data/{pyqg.py => _pyqg.py} (100%) rename dabench/data/{pyqg_jax.py => _pyqg_jax.py} (100%) rename dabench/data/{qgs.py => _qgs.py} (100%) rename dabench/data/{sqgturb.py => _sqgturb.py} (100%) diff --git a/dabench/data/__init__.py b/dabench/data/__init__.py index 11e367e..2516033 100644 --- a/dabench/data/__init__.py +++ b/dabench/data/__init__.py @@ -1,14 +1,14 @@ from ._data import Data -from .lorenz63 import Lorenz63 -from .lorenz96 import Lorenz96 -from .sqgturb import SQGTurb -from .gcp import GCP -from .pyqg import PyQG -from .pyqg_jax import PyQGJax -from .barotropic import Barotropic -from .enso_indices import ENSOIndices -from .qgs import QGS +from ._lorenz63 import Lorenz63 +from ._lorenz96 import Lorenz96 +from ._sqgturb import SQGTurb +from ._gcp import GCP +from ._pyqg import PyQG +from ._pyqg_jax import PyQGJax +from ._barotropic import Barotropic +from ._enso_indices import ENSOIndices +from ._qgs import QGS from ._xarray_accessor import DABenchDatasetAccessor, DABenchDataArrayAccessor __all__ = [ diff --git a/dabench/data/barotropic.py b/dabench/data/_barotropic.py similarity index 100% rename from dabench/data/barotropic.py rename to dabench/data/_barotropic.py diff --git a/dabench/data/enso_indices.py b/dabench/data/_enso_indices.py similarity index 100% rename from dabench/data/enso_indices.py rename to dabench/data/_enso_indices.py diff --git a/dabench/data/gcp.py b/dabench/data/_gcp.py similarity index 100% rename from dabench/data/gcp.py rename to dabench/data/_gcp.py diff --git a/dabench/data/lorenz63.py b/dabench/data/_lorenz63.py similarity index 100% rename from dabench/data/lorenz63.py rename to dabench/data/_lorenz63.py diff --git a/dabench/data/lorenz96.py b/dabench/data/_lorenz96.py similarity index 100% rename from dabench/data/lorenz96.py rename to dabench/data/_lorenz96.py diff --git a/dabench/data/pyqg.py b/dabench/data/_pyqg.py similarity index 100% rename from dabench/data/pyqg.py rename to dabench/data/_pyqg.py diff --git a/dabench/data/pyqg_jax.py b/dabench/data/_pyqg_jax.py similarity index 100% rename from dabench/data/pyqg_jax.py rename to dabench/data/_pyqg_jax.py diff --git a/dabench/data/qgs.py b/dabench/data/_qgs.py similarity index 100% rename from dabench/data/qgs.py rename to dabench/data/_qgs.py diff --git a/dabench/data/sqgturb.py b/dabench/data/_sqgturb.py similarity index 100% rename from dabench/data/sqgturb.py rename to dabench/data/_sqgturb.py From 329293b17c5b33bbb0d13ee1da1bda977a5cea18 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Fri, 11 Apr 2025 17:13:14 -0600 Subject: [PATCH 02/24] Remove private members from autoapi_options' --- docs/conf.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/conf.py b/docs/conf.py index 86a22e8..b885dfc 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -16,6 +16,9 @@ 'autoapi.extension' ] autoapi_dirs = ['../dabench'] +autoapi_options = ['members', 'undoc-members', 'show-inheritance', + 'show-module-summary', 'special-members', + 'imported-members'] intersphinx_mapping = { 'python': ('https://docs.python.org/3/', None), From e74ca6e585154cf941b40662ce12254eb712d3ad Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Fri, 11 Apr 2025 17:42:08 -0600 Subject: [PATCH 03/24] Add support for Google docstrings and include typehints --- docs/conf.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index b885dfc..9b19c56 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -13,12 +13,15 @@ 'sphinx.ext.duration', 'sphinx.ext.doctest', 'sphinx.ext.intersphinx', - 'autoapi.extension' + 'autoapi.extension', + 'sphinx.ext.autodoc.typehints', + 'sphinx.ext.napoleon' ] autoapi_dirs = ['../dabench'] autoapi_options = ['members', 'undoc-members', 'show-inheritance', 'show-module-summary', 'special-members', 'imported-members'] +autodoc_typehints = 'description' intersphinx_mapping = { 'python': ('https://docs.python.org/3/', None), From f712eb172aea85674f1f56c4169232c9dac051ec Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Thu, 17 Apr 2025 17:40:02 -0600 Subject: [PATCH 04/24] Example update to fix bad docs formatting, need to distinguish between Attributes and Args across classes --- dabench/dacycler/_dacycler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dabench/dacycler/_dacycler.py b/dabench/dacycler/_dacycler.py index 34ac8f8..774b368 100644 --- a/dabench/dacycler/_dacycler.py +++ b/dabench/dacycler/_dacycler.py @@ -18,7 +18,7 @@ class DACycler(): """Base class for DACycler object - Attributes: + Args: system_dim: System dimension delta_t: The timestep of the model (assumed uniform) model_obj: Forecast model object. From 19c20fbbd3b767c855a1ae8d5e4e83df6fac7e9b Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Fri, 18 Apr 2025 14:36:26 -0600 Subject: [PATCH 05/24] Tweaks to conf.py to help with formatting, etc. --- docs/conf.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 9b19c56..ecaaa1f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -18,10 +18,18 @@ 'sphinx.ext.napoleon' ] autoapi_dirs = ['../dabench'] -autoapi_options = ['members', 'undoc-members', 'show-inheritance', - 'show-module-summary', 'special-members', - 'imported-members'] +# Important: Because we're not including "undoc-members", +# you need to include a docstring on *everything* you want documented. +# Including in __init__.py for submodules. +autoapi_options = ['members', 'show-module-summary', + 'special-members', 'imported-members'] + autodoc_typehints = 'description' +autoapi_member_order = 'groupwise' +autoapi_add_toctree_entry = True +autoapi_own_page_level = 'module' +napoleon_numpy_docstring = False +napoleon_google_docstring = True intersphinx_mapping = { 'python': ('https://docs.python.org/3/', None), From fba0fab09060e1a121c2e5e7bcc42aa4ec19df16 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Fri, 18 Apr 2025 14:36:56 -0600 Subject: [PATCH 06/24] Add docstrings to __init__.py --- dabench/__init__.py | 1 + dabench/dacycler/__init__.py | 2 ++ dabench/data/__init__.py | 1 + dabench/model/__init__.py | 1 + dabench/observer/__init__.py | 1 + 5 files changed, 6 insertions(+) diff --git a/dabench/__init__.py b/dabench/__init__.py index 4bba201..e19f0c9 100644 --- a/dabench/__init__.py +++ b/dabench/__init__.py @@ -1 +1,2 @@ +"""DataAssimBench""" from . import data, model, observer, obsop, dacycler, _suppl_data diff --git a/dabench/dacycler/__init__.py b/dabench/dacycler/__init__.py index eca5762..2478491 100644 --- a/dabench/dacycler/__init__.py +++ b/dabench/dacycler/__init__.py @@ -1,3 +1,5 @@ +"""Data Assimilation cyclers""" + from ._dacycler import DACycler from ._var3d import Var3D from ._etkf import ETKF diff --git a/dabench/data/__init__.py b/dabench/data/__init__.py index 2516033..02bb34a 100644 --- a/dabench/data/__init__.py +++ b/dabench/data/__init__.py @@ -1,3 +1,4 @@ +"""Data generators and downloaders""" from ._data import Data from ._lorenz63 import Lorenz63 diff --git a/dabench/model/__init__.py b/dabench/model/__init__.py index f05d128..15591a2 100644 --- a/dabench/model/__init__.py +++ b/dabench/model/__init__.py @@ -1,3 +1,4 @@ +"""Model classes""" from ._model import Model from ._rc import RCModel diff --git a/dabench/observer/__init__.py b/dabench/observer/__init__.py index 63df24f..16b4d98 100644 --- a/dabench/observer/__init__.py +++ b/dabench/observer/__init__.py @@ -1,3 +1,4 @@ +"""Observer module""" from ._observer import Observer __all__ = [ From e23cbb84759d6c75a4d6e7f9c8908eebe821de3b Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Fri, 18 Apr 2025 14:49:25 -0600 Subject: [PATCH 07/24] Distinguish params from class attributes (only in_4d and ensemble). Could improve documentation for class attributes or just make them non-public since user shouldn't need them --- dabench/dacycler/_dacycler.py | 15 +++++++-------- dabench/dacycler/_etkf.py | 6 +++--- dabench/dacycler/_var3d.py | 4 +++- dabench/dacycler/_var4d.py | 8 +++----- dabench/dacycler/_var4d_backprop.py | 8 +++----- 5 files changed, 19 insertions(+), 22 deletions(-) diff --git a/dabench/dacycler/_dacycler.py b/dabench/dacycler/_dacycler.py index 774b368..ee3a853 100644 --- a/dabench/dacycler/_dacycler.py +++ b/dabench/dacycler/_dacycler.py @@ -22,10 +22,6 @@ class DACycler(): system_dim: System dimension delta_t: The timestep of the model (assumed uniform) model_obj: Forecast model object. - in_4d: True for 4D data assimilation techniques (e.g. 4DVar). - Default is False. - ensemble: True for ensemble-based data assimilation techniques - (ETKF). Default is False B: Initial / static background error covariance. Shape: (system_dim, system_dim). If not provided, will be calculated automatically. @@ -36,14 +32,19 @@ class DACycler(): If not provided will be calculated automatically. h: Optional observation operator as function. More flexible (allows for more complex observation operator). Default is None. + + Attributes: + in_4d: True for 4D data assimilation techniques (e.g. 4DVar). + ensemble: True for ensemble-based data assimilation techniques + (ETKF). """ + in_4d = False + ensemble = False def __init__(self, system_dim: int, delta_t: float, model_obj: Model, - in_4d: bool = False, - ensemble: bool = False, B: ArrayLike | None = None, R: ArrayLike | None = None, H: ArrayLike | None = None, @@ -54,8 +55,6 @@ def __init__(self, self.H = H self.R = R self.B = B - self.in_4d = in_4d - self.ensemble = ensemble self.system_dim = system_dim self.delta_t = delta_t self.model_obj = model_obj diff --git a/dabench/dacycler/_etkf.py b/dabench/dacycler/_etkf.py index f90a67c..b80c2e9 100644 --- a/dabench/dacycler/_etkf.py +++ b/dabench/dacycler/_etkf.py @@ -19,7 +19,7 @@ class ETKF(dacycler.DACycler): """Class for building ETKF DA Cycler - Attributes: + Args: system_dim: System dimension. delta_t: The timestep of the model (assumed uniform) model_obj: Forecast model object. @@ -38,6 +38,8 @@ class ETKF(dacycler.DACycler): multiplicative_inflation: Scaling factor by which to multiply ensemble deviation. Default is 1.0 (no inflation). """ + in_4d = False + ensemble = True def __init__(self, system_dim: int, @@ -57,8 +59,6 @@ def __init__(self, super().__init__(system_dim=system_dim, delta_t=delta_t, model_obj=model_obj, - in_4d=False, - ensemble=True, B=B, R=R, H=H, h=h) def _step_forecast(self, diff --git a/dabench/dacycler/_var3d.py b/dabench/dacycler/_var3d.py index 533b85c..6abb90b 100644 --- a/dabench/dacycler/_var3d.py +++ b/dabench/dacycler/_var3d.py @@ -18,7 +18,7 @@ class Var3D(dacycler.DACycler): """Class for building 3DVar DA Cycler - Attributes: + Args: system_dim: System dimension. delta_t: The timestep of the model (assumed uniform) model_obj: Forecast model object. @@ -33,6 +33,8 @@ class Var3D(dacycler.DACycler): h: Optional observation operator as function. More flexible (allows for more complex observation operator). Default is None. """ + in_4d = False + ensemble = False def __init__(self, system_dim: int, diff --git a/dabench/dacycler/_var4d.py b/dabench/dacycler/_var4d.py index 52439b2..9d60ac8 100644 --- a/dabench/dacycler/_var4d.py +++ b/dabench/dacycler/_var4d.py @@ -27,14 +27,10 @@ class Var4D(dacycler.DACycler): """Class for building 4D DA Cycler - Attributes: + Args: system_dim: System dimension. delta_t: The timestep of the model (assumed uniform) model_obj: Forecast model object. - in_4d: True for 4D data assimilation techniques (e.g. 4DVar). - Always True for Var4D. - ensemble: True for ensemble-based data assimilation techniques - (ETKF). Always False for Var4D. B: Initial / static background error covariance. Shape: (system_dim, system_dim). If not provided, will be calculated automatically. @@ -59,6 +55,8 @@ class Var4D(dacycler.DACycler): [0, 1, 2, 3, 4, 5]. If None (default), will calculate automatically. """ + in_4d = True + ensemble = False def __init__(self, system_dim: int, diff --git a/dabench/dacycler/_var4d_backprop.py b/dabench/dacycler/_var4d_backprop.py index 25e4274..a36a2fe 100644 --- a/dabench/dacycler/_var4d_backprop.py +++ b/dabench/dacycler/_var4d_backprop.py @@ -27,14 +27,10 @@ class Var4DBackprop(dacycler.DACycler): """Class for building Backpropagation 4D DA Cycler - Attributes: + Args: system_dim: System dimension. delta_t: The timestep of the model (assumed uniform) model_obj: Forecast model object. - in_4d: True for 4D data assimilation techniques (e.g. 4DVar). - Always True for Var4DBackprop. - ensemble: True for ensemble-based data assimilation techniques - (ETKF). Always False for Var4DBackprop. B: Initial / static background error covariance. Shape: (system_dim, system_dim). If not provided, will be calculated automatically. @@ -65,6 +61,8 @@ class Var4DBackprop(dacycler.DACycler): return an error. This prevents it from hanging indefinitely when loss grows exponentionally. Default is 10. """ + in_4d = True + ensemble = False def __init__(self, system_dim: int, From 9a41f88c29a5620c386db4f0f9f48bc10d6a780d Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Fri, 18 Apr 2025 14:55:39 -0600 Subject: [PATCH 08/24] Updated data classes docstrings to clarify parameters vs. attributes --- dabench/data/_barotropic.py | 2 +- dabench/data/_data.py | 2 +- dabench/data/_enso_indices.py | 2 +- dabench/data/_gcp.py | 2 +- dabench/data/_lorenz63.py | 2 +- dabench/data/_lorenz96.py | 2 +- dabench/data/_pyqg.py | 4 ++-- dabench/data/_pyqg_jax.py | 2 +- dabench/data/_qgs.py | 2 +- dabench/data/_sqgturb.py | 2 +- 10 files changed, 11 insertions(+), 11 deletions(-) diff --git a/dabench/data/_barotropic.py b/dabench/data/_barotropic.py index d742978..fd4fa0b 100644 --- a/dabench/data/_barotropic.py +++ b/dabench/data/_barotropic.py @@ -43,7 +43,7 @@ class Barotropic(_data.Data): vortices in turbulent flow. Journal of Fluid Mechanics, 146, pp 21-43 doi:10.1017/S0022112084001750. - Attributes: + Args: system_dim: system dimension beta: Gradient of coriolis parameter. Units: meters^-1 * seconds^-1. Default is 0. diff --git a/dabench/data/_data.py b/dabench/data/_data.py index e1c16a3..e9bf46f 100644 --- a/dabench/data/_data.py +++ b/dabench/data/_data.py @@ -18,7 +18,7 @@ class Data(): """Generic class for data generator objects. - Attributes: + Args: system_dim: system dimension time_dim: total time steps original_dim: dimensions in original space, e.g. could be 3x3 diff --git a/dabench/data/_enso_indices.py b/dabench/data/_enso_indices.py index edfb02a..171affb 100644 --- a/dabench/data/_enso_indices.py +++ b/dabench/data/_enso_indices.py @@ -19,7 +19,7 @@ class ENSOIndices(_data.Data): Notes: Source: https://www.cpc.ncep.noaa.gov/data/indices/ - Attributes: + Args: system_dim: system dimension time_dim: total time steps store_as_jax: Store values as jax array instead of numpy array. diff --git a/dabench/data/_gcp.py b/dabench/data/_gcp.py index 36508c3..d34d5d0 100644 --- a/dabench/data/_gcp.py +++ b/dabench/data/_gcp.py @@ -24,7 +24,7 @@ class GCP(_data.Data): Source: https://cloud.google.com/storage/docs/public-datasets/era5 Data is hourly - Attributes: + Args: variables: Names of ERA5 variables to load. For description of variables, see: https://github.com/google-research/arco-era5?tab=readme-ov-file#full_37-1h-0p25deg-chunk-1zarr-v3 diff --git a/dabench/data/_lorenz63.py b/dabench/data/_lorenz63.py index 48c1423..efef961 100644 --- a/dabench/data/_lorenz63.py +++ b/dabench/data/_lorenz63.py @@ -15,7 +15,7 @@ class Lorenz63(_data.Data): """ Class to set up Lorenz 63 model data - Attributes: + Args: sigma: Lorenz 63 param. Default is 10., the original value used in Lorenz, 1963. https://doi.org/10.1175/1520-0469(1963)020<0130:DNF>2.0.CO;2 diff --git a/dabench/data/_lorenz96.py b/dabench/data/_lorenz96.py index 991b21e..ef917c7 100644 --- a/dabench/data/_lorenz96.py +++ b/dabench/data/_lorenz96.py @@ -19,7 +19,7 @@ class Lorenz96(_data.Data): Default values come from Lorenz, 1996: eapsweb.mit.edu/sites/default/files/Predicability_a_Problem_2006.pdf - Attributes: + Args: forcing_term: Forcing constant for Lorenz96, prevents energy from decaying to 0. Default is 8.0. x0: Initial state vector, array of floats of size diff --git a/dabench/data/_pyqg.py b/dabench/data/_pyqg.py index ea28864..739eca1 100644 --- a/dabench/data/_pyqg.py +++ b/dabench/data/_pyqg.py @@ -34,7 +34,7 @@ class PyQG(_data.Data): Uses default attribute values from pyqg.QGModel: https://pyqg.readthedocs.io/en/latest/api.html#pyqg.QGModel - Attributes: + Args: beta (float): Gradient of coriolis parameter. Units: meters^-1 * seconds^-1 rek (float): Linear drag in lower layer. Units: seconds^-1 @@ -47,7 +47,7 @@ class PyQG(_data.Data): ny (int): Number of grid points in the y direction (default: nx). L (float): Domain length in x direction. Units: meters. W (float): Domain width in y direction. Units: meters (default: L). - filterfac (float): amplitdue of the spectral spherical filter + filterfac (float): amplitude of the spectral spherical filter (originally 18.4, later changed to 23.6). delta_t (float): Numerical timestep. Units: seconds. twrite (int): Interval for cfl writeout. Units: number of timesteps. diff --git a/dabench/data/_pyqg_jax.py b/dabench/data/_pyqg_jax.py index 812250a..3f8e2c1 100644 --- a/dabench/data/_pyqg_jax.py +++ b/dabench/data/_pyqg_jax.py @@ -41,7 +41,7 @@ class PyQGJax(_data.Data): Uses default attribute values from pyqg_jax.QGModel: https://pyqg.readthedocs.io/en/latest/api.html#pyqg.QGModel - Attributes: + Args: beta: Gradient of coriolis parameter. Units: meters^-1 * seconds^-1 rd: Deformation radius. Units: meters. diff --git a/dabench/data/_qgs.py b/dabench/data/_qgs.py index f256612..d8881bc 100644 --- a/dabench/data/_qgs.py +++ b/dabench/data/_qgs.py @@ -38,7 +38,7 @@ class QGS(_data.Data): The QGS class is simply a wrapper of an *optional* qgs package. See https://qgs.readthedocs.io/ - Attributes: + Args: model_params: qgs parameter object. See: https://qgs.readthedocs.io/en/latest/files/technical/configuration.html#qgs.params.params.QgParams If None, will use defaults specified by: diff --git a/dabench/data/_sqgturb.py b/dabench/data/_sqgturb.py index 0cc2738..fd4e888 100644 --- a/dabench/data/_sqgturb.py +++ b/dabench/data/_sqgturb.py @@ -52,7 +52,7 @@ class SQGTurb(_data.Data): """Class to set up SQGTurb model and manage data. - Attributes: + Args: pv: Potential vorticity array. If None (default), loads data from 57600 step spinup with initial conditions taken from Jeff Whitaker's original implementation: From 3c505be6aa76cf91afca0a9da007147e7dddfbf1 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Fri, 18 Apr 2025 15:22:47 -0600 Subject: [PATCH 09/24] Remove class attributes from super.__init__ for dacyclers --- dabench/dacycler/_var3d.py | 2 -- dabench/dacycler/_var4d.py | 2 -- dabench/dacycler/_var4d_backprop.py | 2 -- 3 files changed, 6 deletions(-) diff --git a/dabench/dacycler/_var3d.py b/dabench/dacycler/_var3d.py index 6abb90b..d12ce05 100644 --- a/dabench/dacycler/_var3d.py +++ b/dabench/dacycler/_var3d.py @@ -49,8 +49,6 @@ def __init__(self, super().__init__(system_dim=system_dim, delta_t=delta_t, model_obj=model_obj, - in_4d=False, - ensemble=False, B=B, R=R, H=H, h=h) def _cycle_obsop(self, diff --git a/dabench/dacycler/_var4d.py b/dabench/dacycler/_var4d.py index 9d60ac8..a437e88 100644 --- a/dabench/dacycler/_var4d.py +++ b/dabench/dacycler/_var4d.py @@ -85,8 +85,6 @@ def __init__(self, super().__init__(system_dim=system_dim, delta_t=delta_t, model_obj=model_obj, - in_4d=True, - ensemble=False, B=B, R=R, H=H, h=h) def _calc_default_H(self, diff --git a/dabench/dacycler/_var4d_backprop.py b/dabench/dacycler/_var4d_backprop.py index a36a2fe..fc7f713 100644 --- a/dabench/dacycler/_var4d_backprop.py +++ b/dabench/dacycler/_var4d_backprop.py @@ -95,8 +95,6 @@ def __init__(self, super().__init__(system_dim=system_dim, delta_t=delta_t, model_obj=model_obj, - in_4d=True, - ensemble=False, B=B, R=R, H=H, h=h) def _calc_default_H(self, From 0e9cde0ad881f8ea59e8338e210f1fff3dab191a Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Sat, 19 Apr 2025 15:28:21 -0600 Subject: [PATCH 10/24] Make in_4d and uses_ensemble non-public class attributes, user doesn't need them --- dabench/dacycler/_dacycler.py | 15 +++++---------- dabench/dacycler/_etkf.py | 4 ++-- dabench/dacycler/_var3d.py | 4 ++-- dabench/dacycler/_var4d.py | 4 ++-- dabench/dacycler/_var4d_backprop.py | 4 ++-- 5 files changed, 13 insertions(+), 18 deletions(-) diff --git a/dabench/dacycler/_dacycler.py b/dabench/dacycler/_dacycler.py index ee3a853..35f48fd 100644 --- a/dabench/dacycler/_dacycler.py +++ b/dabench/dacycler/_dacycler.py @@ -32,14 +32,9 @@ class DACycler(): If not provided will be calculated automatically. h: Optional observation operator as function. More flexible (allows for more complex observation operator). Default is None. - - Attributes: - in_4d: True for 4D data assimilation techniques (e.g. 4DVar). - ensemble: True for ensemble-based data assimilation techniques - (ETKF). """ - in_4d = False - ensemble = False + _in_4d: bool = False + _uses_ensemble: bool = False def __init__(self, system_dim: int, @@ -229,7 +224,7 @@ def cycle(self, # If don't specify analysis_time_in_window, is assumed to be middle if analysis_time_in_window is None: - if self.in_4d: + if self._in_4d: analysis_time_in_window = 0 else: analysis_time_in_window = self.analysis_window/2 @@ -256,7 +251,7 @@ def cycle(self, obs_times=jnp.array(obs_vector.time.values), analysis_times=all_times+_time_offset, start_inclusive=True, - end_inclusive=self.in_4d, + end_inclusive=self._in_4d, analysis_window=analysis_window ) input_state = input_state.assign(_cur_time=start_time) @@ -272,7 +267,7 @@ def cycle(self, obs_vector[self._observed_vars].to_array().data) self._obs_vector=self._obs_vector.fillna(0) - if self.in_4d: + if self._in_4d: cur_state, all_values = jax.lax.scan( self._cycle_and_forecast_4d, xj.from_xarray(input_state), diff --git a/dabench/dacycler/_etkf.py b/dabench/dacycler/_etkf.py index b80c2e9..9b30cf3 100644 --- a/dabench/dacycler/_etkf.py +++ b/dabench/dacycler/_etkf.py @@ -38,8 +38,8 @@ class ETKF(dacycler.DACycler): multiplicative_inflation: Scaling factor by which to multiply ensemble deviation. Default is 1.0 (no inflation). """ - in_4d = False - ensemble = True + _in_4d: bool = False + _uses_ensemble: bool = True def __init__(self, system_dim: int, diff --git a/dabench/dacycler/_var3d.py b/dabench/dacycler/_var3d.py index d12ce05..1e571ae 100644 --- a/dabench/dacycler/_var3d.py +++ b/dabench/dacycler/_var3d.py @@ -33,8 +33,8 @@ class Var3D(dacycler.DACycler): h: Optional observation operator as function. More flexible (allows for more complex observation operator). Default is None. """ - in_4d = False - ensemble = False + _in_4d: bool = False + _uses_ensemble: bool = False def __init__(self, system_dim: int, diff --git a/dabench/dacycler/_var4d.py b/dabench/dacycler/_var4d.py index a437e88..5f9a50d 100644 --- a/dabench/dacycler/_var4d.py +++ b/dabench/dacycler/_var4d.py @@ -55,8 +55,8 @@ class Var4D(dacycler.DACycler): [0, 1, 2, 3, 4, 5]. If None (default), will calculate automatically. """ - in_4d = True - ensemble = False + _in_4d: bool = True + _uses_ensemble: bool = False def __init__(self, system_dim: int, diff --git a/dabench/dacycler/_var4d_backprop.py b/dabench/dacycler/_var4d_backprop.py index fc7f713..6bb4dd6 100644 --- a/dabench/dacycler/_var4d_backprop.py +++ b/dabench/dacycler/_var4d_backprop.py @@ -61,8 +61,8 @@ class Var4DBackprop(dacycler.DACycler): return an error. This prevents it from hanging indefinitely when loss grows exponentionally. Default is 10. """ - in_4d = True - ensemble = False + _in_4d: bool = True + _uses_ensemble: bool = False def __init__(self, system_dim: int, From 9a3bcdccaad01f6c30995fcba4f02b9f1b36d94e Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Sat, 19 Apr 2025 15:29:29 -0600 Subject: [PATCH 11/24] Update dacycler base test with in_4d and uses_ensemble as class attributes --- tests/dacycler_base_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/dacycler_base_test.py b/tests/dacycler_base_test.py index ba33419..26b5a1a 100644 --- a/tests/dacycler_base_test.py +++ b/tests/dacycler_base_test.py @@ -9,7 +9,6 @@ def test_dacycler_init(): params = {'system_dim': 6, 'delta_t': 0.5, - 'ensemble': True, 'model_obj':dab.model.RCModel(6, 10)} test_dac = dab.dacycler.DACycler(**params) @@ -17,4 +16,4 @@ def test_dacycler_init(): assert test_dac.system_dim == 6 assert test_dac.delta_t == 0.5 assert test_dac.ensemble - assert not test_dac.in_4d + assert not test_dac._in_4d From 7ac9b1b9fe95ee169981b30c6ba024185ac09520 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Sat, 19 Apr 2025 15:56:27 -0600 Subject: [PATCH 12/24] Remove self.time_dim attribute from data generators, doesn't make sense since we now produce xarray objects which have that information (used to save values in the data generator object instead) --- dabench/data/_data.py | 14 +++++--------- dabench/data/_enso_indices.py | 4 +--- dabench/data/_lorenz63.py | 4 +--- dabench/data/_lorenz96.py | 4 +--- dabench/data/_pyqg_jax.py | 6 ++---- dabench/data/_qgs.py | 14 ++++++-------- dabench/data/_sqgturb.py | 13 +------------ 7 files changed, 17 insertions(+), 42 deletions(-) diff --git a/dabench/data/_data.py b/dabench/data/_data.py index e9bf46f..f594660 100644 --- a/dabench/data/_data.py +++ b/dabench/data/_data.py @@ -20,7 +20,6 @@ class Data(): Args: system_dim: system dimension - time_dim: total time steps original_dim: dimensions in original space, e.g. could be 3x3 for a 2d system with system_dim = 9. Defaults to (system_dim), i.e. 1d. @@ -32,7 +31,6 @@ class Data(): def __init__(self, system_dim: int = 3, - time_dim: int = 1, original_dim: tuple[int, ...] | None = None, random_seed: int = 37, delta_t: float = 0.01, @@ -42,7 +40,6 @@ def __init__(self, """Initializes the base data object""" self.system_dim = system_dim - self.time_dim = time_dim self.random_seed = random_seed self.delta_t = delta_t self.store_as_jax = store_as_jax @@ -98,8 +95,7 @@ def generate(self, Notes: Either provide n_steps or t_final in order to indicate the length - of the forecast. These are used to set the values, times, and - time_dim attributes. + of the forecast. Args: n_steps: Number of timesteps. One of n_steps OR @@ -172,8 +168,8 @@ def generate(self, **kwargs) # Convert to JAX if necessary - self.time_dim = t.shape[0] - out_dim = (self.time_dim,) + self.original_dim + time_dim = t.shape[0] + out_dim = (time_dim,) + self.original_dim if self.store_as_jax: y_out = jnp.array(y[:,:self.system_dim].reshape(out_dim)) else: @@ -197,13 +193,13 @@ def generate(self, # Reshape M matrix if self.store_as_jax: M = jnp.reshape(y[:, self.system_dim:], - (self.time_dim, + (time_dim, self.system_dim, self.system_dim) ) else: M = np.reshape(y[:, self.system_dim:], - (self.time_dim, + (time_dim, self.system_dim, self.system_dim) ) diff --git a/dabench/data/_enso_indices.py b/dabench/data/_enso_indices.py index 171affb..d2aafae 100644 --- a/dabench/data/_enso_indices.py +++ b/dabench/data/_enso_indices.py @@ -21,7 +21,6 @@ class ENSOIndices(_data.Data): Args: system_dim: system dimension - time_dim: total time steps store_as_jax: Store values as jax array instead of numpy array. Default is False (store as numpy). file_dict: Lists of files to get. Dict keys are type of data: @@ -58,7 +57,6 @@ def __init__(self, file_dict: dict | None = None, var_types: dict | None = None, system_dim: int | None = None, - time_dim: int | None = None, store_as_jax: bool = False, **kwargs): @@ -66,7 +64,7 @@ def __init__(self, self.file_dict = file_dict self.var_types = var_types - super().__init__(system_dim=system_dim, time_dim=time_dim, + super().__init__(system_dim=system_dim, values=None, delta_t=None, **kwargs, store_as_jax=store_as_jax) diff --git a/dabench/data/_lorenz63.py b/dabench/data/_lorenz63.py index efef961..8b87b17 100644 --- a/dabench/data/_lorenz63.py +++ b/dabench/data/_lorenz63.py @@ -30,7 +30,6 @@ class Lorenz63(_data.Data): and initial conditions [0., 1., 0.], a spinup which replicates the simulation described in Lorenz, 1963. system_dim: system dimension. Must be 3 for Lorenz63. - time_dim: total time steps store_as_jax: Store values as jax array instead of numpy array. Default is False (store as numpy). """ @@ -42,7 +41,6 @@ def __init__(self, delta_t: float = 0.01, x0: ArrayLike | None = jnp.array([-10.0, -15.0, 21.3]), system_dim: int = 3, - time_dim: int | None = None, values: ArrayLike | None = None, store_as_jax: bool = False, **kwargs): @@ -57,7 +55,7 @@ def __init__(self, print('Assigning system_dim to 3.') system_dim = 3 - super().__init__(system_dim=system_dim, time_dim=time_dim, + super().__init__(system_dim=system_dim, values=values, delta_t=delta_t, store_as_jax=store_as_jax, **kwargs) diff --git a/dabench/data/_lorenz96.py b/dabench/data/_lorenz96.py index ef917c7..194bd71 100644 --- a/dabench/data/_lorenz96.py +++ b/dabench/data/_lorenz96.py @@ -33,7 +33,6 @@ class Lorenz96(_data.Data): which is set to 0.01. system_dim: System dimension, must be between 4 and 40. Default is 36. - time_dim: Total time steps delta_t: Length of one time step. Default is 0.05 from Lorenz, 1996, but on modern computers 0.01 is often used. store_as_jax: Store values as jax array instead of numpy array. @@ -45,13 +44,12 @@ def __init__(self, delta_t: float = 0.05, x0: ArrayLike | None = None, system_dim: int = 36, - time_dim: int | None = None, values: ArrayLike | None = None, store_as_jax: bool = False, **kwargs): """Initialize Lorenz96 object, subclass of Base""" - super().__init__(system_dim=system_dim, time_dim=time_dim, + super().__init__(system_dim=system_dim, values=values, delta_t=delta_t, store_as_jax=store_as_jax, **kwargs) diff --git a/dabench/data/_pyqg_jax.py b/dabench/data/_pyqg_jax.py index 3f8e2c1..14d1a70 100644 --- a/dabench/data/_pyqg_jax.py +++ b/dabench/data/_pyqg_jax.py @@ -72,7 +72,6 @@ def __init__(self, ny: int | None = None, delta_t: float = 7200, random_seed: int = 37, - time_dim: int | None = None, store_as_jax: bool = False, **kwargs): """ Initialize PyQGJax QGModel object, subclass of Base @@ -110,7 +109,7 @@ def __init__(self, jax.random.PRNGKey(0) ) super().__init__(system_dim=system_dim, original_dim=original_dim, - time_dim=time_dim, delta_t=delta_t, + delta_t=delta_t, store_as_jax=store_as_jax, x0=x0, **kwargs) @@ -157,8 +156,7 @@ def generate(self, Notes: Either provide n_steps or t_final in order to indicate the length - of the forecast. These are used to set the values, times, and - time_dim attributes. + of the forecast. Args: n_steps: Number of timesteps. One of n_steps OR diff --git a/dabench/data/_qgs.py b/dabench/data/_qgs.py index d8881bc..a77ca5d 100644 --- a/dabench/data/_qgs.py +++ b/dabench/data/_qgs.py @@ -53,7 +53,6 @@ def __init__(self, x0: ArrayLike | None = None, delta_t: ArrayLike | None = 0.1, system_dim: int | None = None, - time_dim: int | None = None, store_as_jax: bool = False, random_seed: int = 37, **kwargs): @@ -86,7 +85,7 @@ def __init__(self, if x0 is None: x0 = self._rng.random(system_dim)*0.001 - super().__init__(system_dim=system_dim, time_dim=time_dim, + super().__init__(system_dim=system_dim, delta_t=delta_t, store_as_jax=store_as_jax, x0=x0, **kwargs) @@ -169,8 +168,7 @@ def generate(self, Notes: Either provide n_steps or t_final in order to indicate the length - of the forecast. These are used to set the values, times, and - time_dim attributes. + of the forecast. Args: n_steps (int): Number of timesteps. One of n_steps OR @@ -243,8 +241,8 @@ def generate(self, **kwargs) # Convert to JAX if necessary - self.time_dim = t.shape[0] - out_dim = (self.time_dim,) + self.original_dim + time_dim = t.shape[0] + out_dim = (time_dim,) + self.original_dim if self.store_as_jax: y_out = jnp.array(y[:,:self.system_dim].reshape(out_dim)) else: @@ -268,13 +266,13 @@ def generate(self, # Reshape M matrix if self.store_as_jax: M = jnp.reshape(y[:, self.system_dim:], - (self.time_dim, + (time_dim, self.system_dim, self.system_dim) ) else: M = np.reshape(y[:, self.system_dim:], - (self.time_dim, + (time_dim, self.system_dim, self.system_dim) ) diff --git a/dabench/data/_sqgturb.py b/dabench/data/_sqgturb.py index fd4e888..919eb49 100644 --- a/dabench/data/_sqgturb.py +++ b/dabench/data/_sqgturb.py @@ -59,7 +59,6 @@ class SQGTurb(_data.Data): https://github.com/jswhit/sqgturb. 57600 steps matches the "nature run" spin up in that repository. system_dim: The dimension of the system state - time_dim: The dimension of the timeseries (not used) delta_t: model time step (seconds) x0: Initial state, array of floats of size (system_dim). @@ -499,7 +498,6 @@ def integrate(self, if include_x0: n_steps = n_steps + 1 - self.time_dim = n_steps times = t + jnp.arange(n_steps)*delta_t # Integrate in spectral spacestep_n @@ -548,13 +546,4 @@ def rhs(self, # save wind field self.u = -psiy self.v = psix - return dpvspecdt - - def _to_original_dim(self) -> np.ndarray: - """Going back to 2D is a bit trickier for sqgturb""" - gridded_vals = np.zeros((self.time_dim, self.Nv, self.Nx, self.Nx)) - - for t in np.arange(self.time_dim): - gridded_vals[t] = self.map1dto2d_ifft2(self.values[t]) - - return gridded_vals + return dpvspecdt \ No newline at end of file From 6b615984c10e6dceb321d929d0264ab8c173651d Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Sat, 19 Apr 2025 15:57:37 -0600 Subject: [PATCH 13/24] Update dacycler base test with new _uses_ensemble attribute --- tests/dacycler_base_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/dacycler_base_test.py b/tests/dacycler_base_test.py index 26b5a1a..b51036a 100644 --- a/tests/dacycler_base_test.py +++ b/tests/dacycler_base_test.py @@ -15,5 +15,5 @@ def test_dacycler_init(): assert test_dac.system_dim == 6 assert test_dac.delta_t == 0.5 - assert test_dac.ensemble + assert not test_dac._uses_ensemble assert not test_dac._in_4d From d00cae9920a00485a693728c2ddcd6a3f64087fd Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Sat, 19 Apr 2025 16:02:26 -0600 Subject: [PATCH 14/24] Fix typo in a couple qgs docstrings (replace Arg: with Args:) --- dabench/data/_qgs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dabench/data/_qgs.py b/dabench/data/_qgs.py index a77ca5d..f7d25f1 100644 --- a/dabench/data/_qgs.py +++ b/dabench/data/_qgs.py @@ -123,7 +123,7 @@ def rhs(self, ) -> np.ndarray: """Vector field (tendencies) of qgs system - Arg: + Args: x: State vector, shape: (system_dim) t: times vector. Required as argument slot for some numerical integrators but unused. @@ -142,7 +142,7 @@ def Jacobian(self, ) -> np.ndarray: """Jacobian of the qgs system - Arg: + Args: x: State vector, shape: (system_dim) t: times vector. Required as argument slot for some numerical integrators but unused. From 2f10eabfce6b290c19adbd74df940ed09dd9b2d4 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Sat, 19 Apr 2025 16:08:29 -0600 Subject: [PATCH 15/24] Leftover docstring type annotations in qgs --- dabench/data/_qgs.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dabench/data/_qgs.py b/dabench/data/_qgs.py index f7d25f1..ef69097 100644 --- a/dabench/data/_qgs.py +++ b/dabench/data/_qgs.py @@ -128,7 +128,7 @@ def rhs(self, t: times vector. Required as argument slot for some numerical integrators but unused. Returns: - dx: vector field of qgs + Vector field of qgs """ @@ -138,7 +138,7 @@ def rhs(self, def Jacobian(self, x: ArrayLike, - t: float | None = 0 + t: float | None = 0 ) -> np.ndarray: """Jacobian of the qgs system @@ -148,7 +148,7 @@ def Jacobian(self, integrators but unused. Returns: - J (ndarray): Jacobian matrix, shape: (system_dim, system_dim) + J: Jacobian matrix, shape: (system_dim, system_dim) """ @@ -294,7 +294,7 @@ def rhs_aux(self, t: Array of times with size (time_dim) Returns: - dxaux (ndarray): State vector [size: (system_dim,)] + State vector [size: (system_dim,)] """ # Compute M dxdt = self.rhs(x[:self.system_dim], t) From a462b0097568984544283d2e1c9340d2d7e5be61 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Sat, 19 Apr 2025 16:27:07 -0600 Subject: [PATCH 16/24] Update docstrings for observer to work with sphinx. Clearly distinguish between constructor args and attributes assigned based on those args (and 'locations' and 'times' which can be both) --- dabench/observer/_observer.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/dabench/observer/_observer.py b/dabench/observer/_observer.py index cd52096..b95f303 100644 --- a/dabench/observer/_observer.py +++ b/dabench/observer/_observer.py @@ -19,7 +19,7 @@ class Observer(): """Base class for Observer objects - Attributes: + Args: data_obj: Data generator/loader object from which to gather observations. random_location_density: Fraction of locations in @@ -74,6 +74,16 @@ class Observer(): store_as_jax: Store values as jax array instead of numpy array. Default is False (store as numpy). + Attributes: + locations (ArrayLike): Location indices for making + observations. In system_dim (1D) or original dim + (>1D) of self.state_vec. + location_dim (int): Number of locations sampled from (max + in a single time step, if non-stationary observers). + times (ArrayLike): Time indices to gather observations + from. + time_dim (int): Number of times sampled from. + """ def __init__(self, @@ -353,4 +363,4 @@ def observe(self) -> xr.Dataset: obs_vec[data_var] = obs_vec[data_var] + obs_vec['errors'].sel(variable=data_var) - return obs_vec \ No newline at end of file + return obs_vec From a4132d6c7102f6da9b38c374601b9dcdd6606c4c Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Sat, 19 Apr 2025 16:31:20 -0600 Subject: [PATCH 17/24] Update docstrings and remove time_dim attribute from model object --- dabench/model/_model.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/dabench/model/_model.py b/dabench/model/_model.py index c319ce2..d55fc7f 100644 --- a/dabench/model/_model.py +++ b/dabench/model/_model.py @@ -10,21 +10,18 @@ class Model(): """Base class for Model object - Attributes: - system_dim (int): system dimension - time_dim (int): total time steps - delta_t (float): the timestep of the model (assumed uniform) - model_obj (obj): underlying model object, e.g. pytorch neural network. + Args: + system_dim: system dimension + delta_t: the timestep of the model (assumed uniform) + model_obj: underlying model object, e.g. pytorch neural network. """ def __init__(self, system_dim: int | None = None, - time_dim: int | None = None, delta_t: int | None = None, - model_obj: int | None = None + model_obj: Any | None = None ): self.system_dim = system_dim - self.time_dim = time_dim self.delta_t = delta_t self.model_obj = model_obj From 921ea90a69b347ac826f4dfe02d6b40a2ee1ab59 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Sat, 19 Apr 2025 16:32:41 -0600 Subject: [PATCH 18/24] Update rc.py docstrings, but also needs type hints and fixes to confusing method names. Wait for its own PR --- dabench/model/_rc.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/dabench/model/_rc.py b/dabench/model/_rc.py index e842dd7..6b2c20f 100644 --- a/dabench/model/_rc.py +++ b/dabench/model/_rc.py @@ -170,6 +170,7 @@ def weights_init(self): def generate(self, state_vec, A=None, Win=None, r0=None): """generate reservoir time series from input signal u + Args: u (array_like): (time_dimension, system_dimension), input signal to reservoir @@ -287,6 +288,7 @@ def predict(self, state_vec, delta_t, initial_index=0, n_steps=100, def readout(self, rt, Wout=None, utm1=None): """use Wout to map reservoir state to output + Args: rt (array_like): 1D or 2D with dims: (Nr,) or (Ntime, Nr) reservoir state, either passed as single time snapshot, @@ -294,9 +296,11 @@ def readout(self, rt, Wout=None, utm1=None): utm1 (array_like): 1D or 2D with dims: (Nu,) or (Ntime, Nu) u(t-1) for r(t), only used if readout_method = 'biased', then Wout*[1, u(t-1), r(t)]=u(t) + Returns: vt (array_like): 1D or 2D with dims: (Nout,) or (Ntime, Nout) depending on shape of input array + Todo: generalize similar to DiffRC """ @@ -346,6 +350,7 @@ def _predict_backend(self, n_samples, s_last, u_last, delta_t, Default is None. Wout (array_like, optional): Rutput weight matrix. If None, uses self.Wout. Default is None. + Returns: y (Data): data object with predicted signal from reservoir """ @@ -451,8 +456,8 @@ def _compute_Wout(self, rt, y, update_Wout=True, u=None): return self.Wout def _linsolve(self, X, Y, beta=None, **kwargs): - '''Linear solver wrapper - Solve for A in Y = AX + '''Linear solver wrapper for A in Y = AX + Args: X (matrix) : independent variable Y (matrix) : dependent variable @@ -464,9 +469,11 @@ def _linsolve(self, X, Y, beta=None, **kwargs): def _linsolve_pinv(self, X, Y, beta=None): """Solve for A in Y = AX, assuming X and Y are known. + Args: X : independent variable, square matrix Y : dependent variable, square matrix + Returns: A : Solution matrix, rectangular matrix """ From 0e42f9ddd0d9c7cbb62331e1819610594c37d5cc Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Sat, 19 Apr 2025 16:33:37 -0600 Subject: [PATCH 19/24] Attributes: -> Args: for params in main rc class docstring --- dabench/model/_rc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dabench/model/_rc.py b/dabench/model/_rc.py index 6b2c20f..8c7569c 100644 --- a/dabench/model/_rc.py +++ b/dabench/model/_rc.py @@ -17,7 +17,7 @@ class RCModel(model.Model): """Class for a simple Reservoir Computing data-driven model - Attributes: + Args: system_dim (int): Dimension of reservoir output. input_dim (int): Dimension of reservoir input signal. reservoir_dim (int): Dimension of reservoir state. Default: 512. From 1429fcc3d972e8d153d02d9f22db9e81ed5b4b87 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Sat, 19 Apr 2025 16:37:42 -0600 Subject: [PATCH 20/24] Fix class docstrings to remove redundant 'Class' statement, based on Google style guide --- dabench/dacycler/_dacycler.py | 2 +- dabench/dacycler/_etkf.py | 2 +- dabench/dacycler/_var3d.py | 2 +- dabench/dacycler/_var4d.py | 2 +- dabench/dacycler/_var4d_backprop.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/dabench/dacycler/_dacycler.py b/dabench/dacycler/_dacycler.py index 35f48fd..feb3f2e 100644 --- a/dabench/dacycler/_dacycler.py +++ b/dabench/dacycler/_dacycler.py @@ -16,7 +16,7 @@ XarrayDatasetLike = xr.Dataset | xj.XjDataset class DACycler(): - """Base class for DACycler object + """Base for all DACyclers Args: system_dim: System dimension diff --git a/dabench/dacycler/_etkf.py b/dabench/dacycler/_etkf.py index 9b30cf3..7ad8aa9 100644 --- a/dabench/dacycler/_etkf.py +++ b/dabench/dacycler/_etkf.py @@ -17,7 +17,7 @@ XarrayDatasetLike = xr.Dataset | xj.XjDataset class ETKF(dacycler.DACycler): - """Class for building ETKF DA Cycler + """Ensemble transform Kalman filter DA Cycler Args: system_dim: System dimension. diff --git a/dabench/dacycler/_var3d.py b/dabench/dacycler/_var3d.py index 1e571ae..c5ff024 100644 --- a/dabench/dacycler/_var3d.py +++ b/dabench/dacycler/_var3d.py @@ -16,7 +16,7 @@ XarrayDatasetLike = xr.Dataset | xj.XjDataset class Var3D(dacycler.DACycler): - """Class for building 3DVar DA Cycler + """3D-Var DA Cycler Args: system_dim: System dimension. diff --git a/dabench/dacycler/_var4d.py b/dabench/dacycler/_var4d.py index 5f9a50d..1b2974e 100644 --- a/dabench/dacycler/_var4d.py +++ b/dabench/dacycler/_var4d.py @@ -25,7 +25,7 @@ XarrayDatasetLike = xr.Dataset | xj.XjDataset class Var4D(dacycler.DACycler): - """Class for building 4D DA Cycler + """4D-Var DA Cycler Args: system_dim: System dimension. diff --git a/dabench/dacycler/_var4d_backprop.py b/dabench/dacycler/_var4d_backprop.py index 6bb4dd6..944f980 100644 --- a/dabench/dacycler/_var4d_backprop.py +++ b/dabench/dacycler/_var4d_backprop.py @@ -25,7 +25,7 @@ ScheduleState = Any class Var4DBackprop(dacycler.DACycler): - """Class for building Backpropagation 4D DA Cycler + """Backpropagation 4D-Var DA Cycler Args: system_dim: System dimension. From a0e1db54472dd5c3fb17d6450383dbb403fc68ec Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Sat, 19 Apr 2025 16:43:59 -0600 Subject: [PATCH 21/24] Remove unnecessary redundant statement ('Class') from Class docstrings to match Google style guide --- dabench/data/_barotropic.py | 4 ++-- dabench/data/_data.py | 2 +- dabench/data/_enso_indices.py | 2 +- dabench/data/_gcp.py | 2 +- dabench/data/_lorenz63.py | 2 +- dabench/data/_lorenz96.py | 2 +- dabench/data/_pyqg.py | 2 +- dabench/data/_pyqg_jax.py | 2 +- dabench/data/_qgs.py | 2 +- dabench/data/_sqgturb.py | 2 +- dabench/model/_model.py | 2 +- dabench/model/_rc.py | 2 +- dabench/observer/_observer.py | 2 +- 13 files changed, 14 insertions(+), 14 deletions(-) diff --git a/dabench/data/_barotropic.py b/dabench/data/_barotropic.py index fd4fa0b..d32b760 100644 --- a/dabench/data/_barotropic.py +++ b/dabench/data/_barotropic.py @@ -30,9 +30,9 @@ class Barotropic(_data.Data): - """ Class to set up barotropic model + """Barotropic model data generator based on pyqg - The data class is a wrapper of a "optional" pyqg package. + This data class is a wrapper of a "optional" pyqg package. See https://pyqg.readthedocs.io Notes: diff --git a/dabench/data/_data.py b/dabench/data/_data.py index f594660..2777ead 100644 --- a/dabench/data/_data.py +++ b/dabench/data/_data.py @@ -16,7 +16,7 @@ ArrayLike = np.ndarray | jax.Array class Data(): - """Generic class for data generator objects. + """Base for all data generator objects. Args: system_dim: system dimension diff --git a/dabench/data/_enso_indices.py b/dabench/data/_enso_indices.py index d2aafae..cda4453 100644 --- a/dabench/data/_enso_indices.py +++ b/dabench/data/_enso_indices.py @@ -14,7 +14,7 @@ class ENSOIndices(_data.Data): - """Class to get ENSO indices from CPC website + """Gets ENSO indices from CPC website Notes: Source: https://www.cpc.ncep.noaa.gov/data/indices/ diff --git a/dabench/data/_gcp.py b/dabench/data/_gcp.py index d34d5d0..8b964cb 100644 --- a/dabench/data/_gcp.py +++ b/dabench/data/_gcp.py @@ -18,7 +18,7 @@ class GCP(_data.Data): - """Class for loading ERA5 data from Google Cloud Platform + """Loads ERA5 data from Google Cloud Platform Notes: Source: https://cloud.google.com/storage/docs/public-datasets/era5 diff --git a/dabench/data/_lorenz63.py b/dabench/data/_lorenz63.py index 8b87b17..ee15a26 100644 --- a/dabench/data/_lorenz63.py +++ b/dabench/data/_lorenz63.py @@ -13,7 +13,7 @@ ArrayLike = np.ndarray | jax.Array class Lorenz63(_data.Data): - """ Class to set up Lorenz 63 model data + """Lorenz 63 model data generator. Args: sigma: Lorenz 63 param. Default is 10., the original value diff --git a/dabench/data/_lorenz96.py b/dabench/data/_lorenz96.py index 194bd71..4b89b16 100644 --- a/dabench/data/_lorenz96.py +++ b/dabench/data/_lorenz96.py @@ -13,7 +13,7 @@ class Lorenz96(_data.Data): - """Class to set up Lorenz 96 model data. + """Lorenz 96 model data generator. Notes: Default values come from Lorenz, 1996: diff --git a/dabench/data/_pyqg.py b/dabench/data/_pyqg.py index 739eca1..d8f90c6 100644 --- a/dabench/data/_pyqg.py +++ b/dabench/data/_pyqg.py @@ -25,7 +25,7 @@ class PyQG(_data.Data): - """ Class to set up quasi-geotropic model + """PyQG quasi-geotropic model data generator. The PyQG class is simply a wrapper of a "optional" pyqg package. See https://pyqg.readthedocs.io diff --git a/dabench/data/_pyqg_jax.py b/dabench/data/_pyqg_jax.py index 14d1a70..e26282d 100644 --- a/dabench/data/_pyqg_jax.py +++ b/dabench/data/_pyqg_jax.py @@ -32,7 +32,7 @@ class PyQGJax(_data.Data): - """Class to set up quasi-geotropic model + """PyQGJax quasi-geotropic model data generator. The PyQGJax class is simply a wrapper of the "optional" pyqg-jax package. See https://pyqg-jax.readthedocs.io diff --git a/dabench/data/_qgs.py b/dabench/data/_qgs.py index ef69097..08cad68 100644 --- a/dabench/data/_qgs.py +++ b/dabench/data/_qgs.py @@ -33,7 +33,7 @@ class QGS(_data.Data): - """ Class to set up QGS quasi-geostrophic model + """QGS quasi-geostrophic model data generator. The QGS class is simply a wrapper of an *optional* qgs package. See https://qgs.readthedocs.io/ diff --git a/dabench/data/_sqgturb.py b/dabench/data/_sqgturb.py index 919eb49..f0ca83e 100644 --- a/dabench/data/_sqgturb.py +++ b/dabench/data/_sqgturb.py @@ -50,7 +50,7 @@ class SQGTurb(_data.Data): - """Class to set up SQGTurb model and manage data. + """SQGTurb model data generator. Args: pv: Potential vorticity array. If None (default), diff --git a/dabench/model/_model.py b/dabench/model/_model.py index d55fc7f..75802ba 100644 --- a/dabench/model/_model.py +++ b/dabench/model/_model.py @@ -8,7 +8,7 @@ import xarray as xr class Model(): - """Base class for Model object + """Base for Model objects Args: system_dim: system dimension diff --git a/dabench/model/_rc.py b/dabench/model/_rc.py index 8c7569c..f93d667 100644 --- a/dabench/model/_rc.py +++ b/dabench/model/_rc.py @@ -15,7 +15,7 @@ class RCModel(model.Model): - """Class for a simple Reservoir Computing data-driven model + """A simple Reservoir Computing data-driven model Args: system_dim (int): Dimension of reservoir output. diff --git a/dabench/observer/_observer.py b/dabench/observer/_observer.py index b95f303..c5c59e5 100644 --- a/dabench/observer/_observer.py +++ b/dabench/observer/_observer.py @@ -17,7 +17,7 @@ class Observer(): - """Base class for Observer objects + """Flexibly samples observations from generated data Args: data_obj: Data generator/loader object from which From afaff7ac15d7631442439e0fa8c290c76f674708 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Wed, 23 Apr 2025 18:16:15 -0600 Subject: [PATCH 22/24] Observer fix bad return docstring and mismatched arg name --- dabench/observer/_observer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dabench/observer/_observer.py b/dabench/observer/_observer.py index c5c59e5..decb844 100644 --- a/dabench/observer/_observer.py +++ b/dabench/observer/_observer.py @@ -20,7 +20,7 @@ class Observer(): """Flexibly samples observations from generated data Args: - data_obj: Data generator/loader object from which + state_vec: Data generator/loader object from which to gather observations. random_location_density: Fraction of locations in system_dim to randomly select for observing, must be value @@ -275,7 +275,7 @@ def observe(self) -> xr.Dataset: Returns: ObsVector containing observation values, times, locations, and - errors + errors """ # Define random num generator From 638845c322c21dc310c6f4416b3c9f45c384810a Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Wed, 23 Apr 2025 18:16:50 -0600 Subject: [PATCH 23/24] RC model docstring updates for sphinx, but still need to update methods themselves --- dabench/model/_rc.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/dabench/model/_rc.py b/dabench/model/_rc.py index f93d667..87a3dbf 100644 --- a/dabench/model/_rc.py +++ b/dabench/model/_rc.py @@ -183,7 +183,7 @@ def generate(self, state_vec, A=None, Win=None, r0=None): If False, returns states. Default: False. Returns: - r (array_like): (time_dim, reservoir_dim), reservoir state + Reservoirs state, size (time_dim, reservoir_dim) """ u = state_vec.to_stacked_array('system',['time']).data r = np.zeros((u.shape[0], self.reservoir_dim)) @@ -215,7 +215,7 @@ def update(self, r, u, A=None, Win=None): reservoir input weight matrix. If None, uses self.Win. Default is None Returns: - q (array_like): (reservoir_dim,) Reservoir state at next time step + Reservoir state at next time step, of size (reservoir_dim,) """ if A is None: @@ -249,8 +249,7 @@ def predict(self, state_vec, delta_t, initial_index=0, n_steps=100, r0 (array_like, optional): initial reservoir state Returns: - dataobj_pred (vector.StateVector): StateVector object covering - prediction period + Data object covering prediction period """ # Recompute the initial reservoir spinup to get reservoir states @@ -298,8 +297,8 @@ def readout(self, rt, Wout=None, utm1=None): then Wout*[1, u(t-1), r(t)]=u(t) Returns: - vt (array_like): 1D or 2D with dims: (Nout,) or (Ntime, Nout) - depending on shape of input array + 1D or 2D array with dims(Nout,) or (Ntime, Nout) + depending on shape of input array Todo: generalize similar to DiffRC @@ -352,7 +351,7 @@ def _predict_backend(self, n_samples, s_last, u_last, delta_t, uses self.Wout. Default is None. Returns: - y (Data): data object with predicted signal from reservoir + Data object with predicted signal from reservoir """ s = jnp.zeros((n_samples, self.reservoir_dim)) @@ -402,8 +401,8 @@ def _compute_Wout(self, rt, y, update_Wout=True, u=None): initialize it by rewriting the ybar and sbar matrices Returns: - Wout (array_like): 2D with dims (output_dim, reservoir_dim), - this is also stored within the object + Wout array, 2D with dims (output_dim, reservoir_dim), + this is also stored within the object Sets Attributes: ybar (array_like): y.T @ st, st is rt with readout_method accounted @@ -475,7 +474,7 @@ def _linsolve_pinv(self, X, Y, beta=None): Y : dependent variable, square matrix Returns: - A : Solution matrix, rectangular matrix + Solution matrix, rectangular matrix """ if beta is not None: Xinv = linalg.pinv(X+beta*np.eye(X.shape[0])) From 560d575e427ca1024443ef11a6fe191695e17909 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Wed, 23 Apr 2025 18:17:33 -0600 Subject: [PATCH 24/24] Data class docstring fixes --- dabench/data/_barotropic.py | 5 +++-- dabench/data/_data.py | 6 +++--- dabench/data/_pyqg.py | 5 +++-- dabench/data/_qgs.py | 16 +++++++--------- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/dabench/data/_barotropic.py b/dabench/data/_barotropic.py index d32b760..bbc090b 100644 --- a/dabench/data/_barotropic.py +++ b/dabench/data/_barotropic.py @@ -36,6 +36,7 @@ class Barotropic(_data.Data): See https://pyqg.readthedocs.io Notes: + DEPRECATED Uses default attribute values from pyqg.BTModel: https://pyqg.readthedocs.io/en/latest/api.html#pyqg.BTModel Those values originally come from Mcwilliams 1984: @@ -207,8 +208,8 @@ def __advance__(self,): """Advances the QG model according to set attributes Returns: - qs (array_like): absolute potential vorticity (relative potential - vorticity + background vorticity). + Array of absolute potential vorticity (relative potential + vorticity + background vorticity). """ qs = [] for _ in self.m.run_with_snapshots(tsnapstart=0, tsnapint=self.m.dt): diff --git a/dabench/data/_data.py b/dabench/data/_data.py index 2777ead..fb92159 100644 --- a/dabench/data/_data.py +++ b/dabench/data/_data.py @@ -114,8 +114,8 @@ def generate(self, convergence tolerance, etc.). Returns: - Xarray Dataset of output vector and (if return_tlm=True) - Xarray DataArray of TLMs corresponding to the system trajectory. + Xarray Dataset of output vector, and if return_tlm=True + Xarray DataArray of TLMs corresponding to the system trajectory. """ # Check that n_steps or t_final is supplied @@ -279,7 +279,7 @@ def calc_lyapunov_exponents_series( Returns: Lyapunov exponents for all timesteps, array of size - (total_time/rescale_time - 1, system_dim) + (total_time/rescale_time - 1, system_dim) """ # Set total_time diff --git a/dabench/data/_pyqg.py b/dabench/data/_pyqg.py index d8f90c6..05dc9d2 100644 --- a/dabench/data/_pyqg.py +++ b/dabench/data/_pyqg.py @@ -31,6 +31,7 @@ class PyQG(_data.Data): See https://pyqg.readthedocs.io Notes: + DEPRECATED Uses default attribute values from pyqg.QGModel: https://pyqg.readthedocs.io/en/latest/api.html#pyqg.QGModel @@ -191,8 +192,8 @@ def __advance__(self,): """Advances the QG model according to set attributes Returns: - qs (array_like): absolute potential vorticity (relative potential - vorticity + background vorticity). + Array of absolute potential vorticity (relative potential + vorticity + background vorticity). """ qs = [] for _ in self.m.run_with_snapshots(tsnapstart=0, tsnapint=self.m.dt): diff --git a/dabench/data/_qgs.py b/dabench/data/_qgs.py index 08cad68..9b8bf86 100644 --- a/dabench/data/_qgs.py +++ b/dabench/data/_qgs.py @@ -124,12 +124,11 @@ def rhs(self, """Vector field (tendencies) of qgs system Args: - x: State vector, shape: (system_dim) + x: State vector of size (system_dim) t: times vector. Required as argument slot for some numerical integrators but unused. Returns: Vector field of qgs - """ dx = self.f(t, x) @@ -143,13 +142,12 @@ def Jacobian(self, """Jacobian of the qgs system Args: - x: State vector, shape: (system_dim) + x: State vector of size (system_dim) t: times vector. Required as argument slot for some numerical integrators but unused. Returns: - J: Jacobian matrix, shape: (system_dim, system_dim) - + Jacobian matrix of size (system_dim, system_dim) """ J = self.Df(t, x) @@ -187,8 +185,8 @@ def generate(self, convergence tolerance, etc.). Returns: - Xarray Dataset of output vector and (if return_tlm=True) - Xarray DataArray of TLMs corresponding to the system trajectory. + Xarray Dataset of output vector, and if return_tlm=True + Xarray DataArray of TLMs corresponding to the system trajectory. """ # Check that n_steps or t_final is supplied @@ -294,7 +292,7 @@ def rhs_aux(self, t: Array of times with size (time_dim) Returns: - State vector [size: (system_dim,)] + State vector of size (system_dim,) """ # Compute M dxdt = self.rhs(x[:self.system_dim], t) @@ -351,7 +349,7 @@ def calc_lyapunov_exponents_series( Returns: Lyapunov exponents for all timesteps, array of size - (total_time/rescale_time - 1, system_dim) + (total_time/rescale_time - 1, system_dim) """ # Set total_time if total_time is None: