diff --git a/dabench/__init__.py b/dabench/__init__.py index 4bba201..e19f0c9 100644 --- a/dabench/__init__.py +++ b/dabench/__init__.py @@ -1 +1,2 @@ +"""DataAssimBench""" from . import data, model, observer, obsop, dacycler, _suppl_data diff --git a/dabench/dacycler/__init__.py b/dabench/dacycler/__init__.py index eca5762..2478491 100644 --- a/dabench/dacycler/__init__.py +++ b/dabench/dacycler/__init__.py @@ -1,3 +1,5 @@ +"""Data Assimilation cyclers""" + from ._dacycler import DACycler from ._var3d import Var3D from ._etkf import ETKF diff --git a/dabench/dacycler/_dacycler.py b/dabench/dacycler/_dacycler.py index 34ac8f8..feb3f2e 100644 --- a/dabench/dacycler/_dacycler.py +++ b/dabench/dacycler/_dacycler.py @@ -16,16 +16,12 @@ XarrayDatasetLike = xr.Dataset | xj.XjDataset class DACycler(): - """Base class for DACycler object + """Base for all DACyclers - Attributes: + Args: system_dim: System dimension delta_t: The timestep of the model (assumed uniform) model_obj: Forecast model object. - in_4d: True for 4D data assimilation techniques (e.g. 4DVar). - Default is False. - ensemble: True for ensemble-based data assimilation techniques - (ETKF). Default is False B: Initial / static background error covariance. Shape: (system_dim, system_dim). If not provided, will be calculated automatically. @@ -37,13 +33,13 @@ class DACycler(): h: Optional observation operator as function. More flexible (allows for more complex observation operator). Default is None. """ + _in_4d: bool = False + _uses_ensemble: bool = False def __init__(self, system_dim: int, delta_t: float, model_obj: Model, - in_4d: bool = False, - ensemble: bool = False, B: ArrayLike | None = None, R: ArrayLike | None = None, H: ArrayLike | None = None, @@ -54,8 +50,6 @@ def __init__(self, self.H = H self.R = R self.B = B - self.in_4d = in_4d - self.ensemble = ensemble self.system_dim = system_dim self.delta_t = delta_t self.model_obj = model_obj @@ -230,7 +224,7 @@ def cycle(self, # If don't specify analysis_time_in_window, is assumed to be middle if analysis_time_in_window is None: - if self.in_4d: + if self._in_4d: analysis_time_in_window = 0 else: analysis_time_in_window = self.analysis_window/2 @@ -257,7 +251,7 @@ def cycle(self, obs_times=jnp.array(obs_vector.time.values), analysis_times=all_times+_time_offset, start_inclusive=True, - end_inclusive=self.in_4d, + end_inclusive=self._in_4d, analysis_window=analysis_window ) input_state = input_state.assign(_cur_time=start_time) @@ -273,7 +267,7 @@ def cycle(self, obs_vector[self._observed_vars].to_array().data) self._obs_vector=self._obs_vector.fillna(0) - if self.in_4d: + if self._in_4d: cur_state, all_values = jax.lax.scan( self._cycle_and_forecast_4d, xj.from_xarray(input_state), diff --git a/dabench/dacycler/_etkf.py b/dabench/dacycler/_etkf.py index f90a67c..7ad8aa9 100644 --- a/dabench/dacycler/_etkf.py +++ b/dabench/dacycler/_etkf.py @@ -17,9 +17,9 @@ XarrayDatasetLike = xr.Dataset | xj.XjDataset class ETKF(dacycler.DACycler): - """Class for building ETKF DA Cycler + """Ensemble transform Kalman filter DA Cycler - Attributes: + Args: system_dim: System dimension. delta_t: The timestep of the model (assumed uniform) model_obj: Forecast model object. @@ -38,6 +38,8 @@ class ETKF(dacycler.DACycler): multiplicative_inflation: Scaling factor by which to multiply ensemble deviation. Default is 1.0 (no inflation). """ + _in_4d: bool = False + _uses_ensemble: bool = True def __init__(self, system_dim: int, @@ -57,8 +59,6 @@ def __init__(self, super().__init__(system_dim=system_dim, delta_t=delta_t, model_obj=model_obj, - in_4d=False, - ensemble=True, B=B, R=R, H=H, h=h) def _step_forecast(self, diff --git a/dabench/dacycler/_var3d.py b/dabench/dacycler/_var3d.py index 533b85c..c5ff024 100644 --- a/dabench/dacycler/_var3d.py +++ b/dabench/dacycler/_var3d.py @@ -16,9 +16,9 @@ XarrayDatasetLike = xr.Dataset | xj.XjDataset class Var3D(dacycler.DACycler): - """Class for building 3DVar DA Cycler + """3D-Var DA Cycler - Attributes: + Args: system_dim: System dimension. delta_t: The timestep of the model (assumed uniform) model_obj: Forecast model object. @@ -33,6 +33,8 @@ class Var3D(dacycler.DACycler): h: Optional observation operator as function. More flexible (allows for more complex observation operator). Default is None. """ + _in_4d: bool = False + _uses_ensemble: bool = False def __init__(self, system_dim: int, @@ -47,8 +49,6 @@ def __init__(self, super().__init__(system_dim=system_dim, delta_t=delta_t, model_obj=model_obj, - in_4d=False, - ensemble=False, B=B, R=R, H=H, h=h) def _cycle_obsop(self, diff --git a/dabench/dacycler/_var4d.py b/dabench/dacycler/_var4d.py index 52439b2..1b2974e 100644 --- a/dabench/dacycler/_var4d.py +++ b/dabench/dacycler/_var4d.py @@ -25,16 +25,12 @@ XarrayDatasetLike = xr.Dataset | xj.XjDataset class Var4D(dacycler.DACycler): - """Class for building 4D DA Cycler + """4D-Var DA Cycler - Attributes: + Args: system_dim: System dimension. delta_t: The timestep of the model (assumed uniform) model_obj: Forecast model object. - in_4d: True for 4D data assimilation techniques (e.g. 4DVar). - Always True for Var4D. - ensemble: True for ensemble-based data assimilation techniques - (ETKF). Always False for Var4D. B: Initial / static background error covariance. Shape: (system_dim, system_dim). If not provided, will be calculated automatically. @@ -59,6 +55,8 @@ class Var4D(dacycler.DACycler): [0, 1, 2, 3, 4, 5]. If None (default), will calculate automatically. """ + _in_4d: bool = True + _uses_ensemble: bool = False def __init__(self, system_dim: int, @@ -87,8 +85,6 @@ def __init__(self, super().__init__(system_dim=system_dim, delta_t=delta_t, model_obj=model_obj, - in_4d=True, - ensemble=False, B=B, R=R, H=H, h=h) def _calc_default_H(self, diff --git a/dabench/dacycler/_var4d_backprop.py b/dabench/dacycler/_var4d_backprop.py index 25e4274..944f980 100644 --- a/dabench/dacycler/_var4d_backprop.py +++ b/dabench/dacycler/_var4d_backprop.py @@ -25,16 +25,12 @@ ScheduleState = Any class Var4DBackprop(dacycler.DACycler): - """Class for building Backpropagation 4D DA Cycler + """Backpropagation 4D-Var DA Cycler - Attributes: + Args: system_dim: System dimension. delta_t: The timestep of the model (assumed uniform) model_obj: Forecast model object. - in_4d: True for 4D data assimilation techniques (e.g. 4DVar). - Always True for Var4DBackprop. - ensemble: True for ensemble-based data assimilation techniques - (ETKF). Always False for Var4DBackprop. B: Initial / static background error covariance. Shape: (system_dim, system_dim). If not provided, will be calculated automatically. @@ -65,6 +61,8 @@ class Var4DBackprop(dacycler.DACycler): return an error. This prevents it from hanging indefinitely when loss grows exponentionally. Default is 10. """ + _in_4d: bool = True + _uses_ensemble: bool = False def __init__(self, system_dim: int, @@ -97,8 +95,6 @@ def __init__(self, super().__init__(system_dim=system_dim, delta_t=delta_t, model_obj=model_obj, - in_4d=True, - ensemble=False, B=B, R=R, H=H, h=h) def _calc_default_H(self, diff --git a/dabench/data/__init__.py b/dabench/data/__init__.py index 11e367e..02bb34a 100644 --- a/dabench/data/__init__.py +++ b/dabench/data/__init__.py @@ -1,14 +1,15 @@ +"""Data generators and downloaders""" from ._data import Data -from .lorenz63 import Lorenz63 -from .lorenz96 import Lorenz96 -from .sqgturb import SQGTurb -from .gcp import GCP -from .pyqg import PyQG -from .pyqg_jax import PyQGJax -from .barotropic import Barotropic -from .enso_indices import ENSOIndices -from .qgs import QGS +from ._lorenz63 import Lorenz63 +from ._lorenz96 import Lorenz96 +from ._sqgturb import SQGTurb +from ._gcp import GCP +from ._pyqg import PyQG +from ._pyqg_jax import PyQGJax +from ._barotropic import Barotropic +from ._enso_indices import ENSOIndices +from ._qgs import QGS from ._xarray_accessor import DABenchDatasetAccessor, DABenchDataArrayAccessor __all__ = [ diff --git a/dabench/data/barotropic.py b/dabench/data/_barotropic.py similarity index 96% rename from dabench/data/barotropic.py rename to dabench/data/_barotropic.py index d742978..bbc090b 100644 --- a/dabench/data/barotropic.py +++ b/dabench/data/_barotropic.py @@ -30,12 +30,13 @@ class Barotropic(_data.Data): - """ Class to set up barotropic model + """Barotropic model data generator based on pyqg - The data class is a wrapper of a "optional" pyqg package. + This data class is a wrapper of a "optional" pyqg package. See https://pyqg.readthedocs.io Notes: + DEPRECATED Uses default attribute values from pyqg.BTModel: https://pyqg.readthedocs.io/en/latest/api.html#pyqg.BTModel Those values originally come from Mcwilliams 1984: @@ -43,7 +44,7 @@ class Barotropic(_data.Data): vortices in turbulent flow. Journal of Fluid Mechanics, 146, pp 21-43 doi:10.1017/S0022112084001750. - Attributes: + Args: system_dim: system dimension beta: Gradient of coriolis parameter. Units: meters^-1 * seconds^-1. Default is 0. @@ -207,8 +208,8 @@ def __advance__(self,): """Advances the QG model according to set attributes Returns: - qs (array_like): absolute potential vorticity (relative potential - vorticity + background vorticity). + Array of absolute potential vorticity (relative potential + vorticity + background vorticity). """ qs = [] for _ in self.m.run_with_snapshots(tsnapstart=0, tsnapint=self.m.dt): diff --git a/dabench/data/_data.py b/dabench/data/_data.py index e1c16a3..fb92159 100644 --- a/dabench/data/_data.py +++ b/dabench/data/_data.py @@ -16,11 +16,10 @@ ArrayLike = np.ndarray | jax.Array class Data(): - """Generic class for data generator objects. + """Base for all data generator objects. - Attributes: + Args: system_dim: system dimension - time_dim: total time steps original_dim: dimensions in original space, e.g. could be 3x3 for a 2d system with system_dim = 9. Defaults to (system_dim), i.e. 1d. @@ -32,7 +31,6 @@ class Data(): def __init__(self, system_dim: int = 3, - time_dim: int = 1, original_dim: tuple[int, ...] | None = None, random_seed: int = 37, delta_t: float = 0.01, @@ -42,7 +40,6 @@ def __init__(self, """Initializes the base data object""" self.system_dim = system_dim - self.time_dim = time_dim self.random_seed = random_seed self.delta_t = delta_t self.store_as_jax = store_as_jax @@ -98,8 +95,7 @@ def generate(self, Notes: Either provide n_steps or t_final in order to indicate the length - of the forecast. These are used to set the values, times, and - time_dim attributes. + of the forecast. Args: n_steps: Number of timesteps. One of n_steps OR @@ -118,8 +114,8 @@ def generate(self, convergence tolerance, etc.). Returns: - Xarray Dataset of output vector and (if return_tlm=True) - Xarray DataArray of TLMs corresponding to the system trajectory. + Xarray Dataset of output vector, and if return_tlm=True + Xarray DataArray of TLMs corresponding to the system trajectory. """ # Check that n_steps or t_final is supplied @@ -172,8 +168,8 @@ def generate(self, **kwargs) # Convert to JAX if necessary - self.time_dim = t.shape[0] - out_dim = (self.time_dim,) + self.original_dim + time_dim = t.shape[0] + out_dim = (time_dim,) + self.original_dim if self.store_as_jax: y_out = jnp.array(y[:,:self.system_dim].reshape(out_dim)) else: @@ -197,13 +193,13 @@ def generate(self, # Reshape M matrix if self.store_as_jax: M = jnp.reshape(y[:, self.system_dim:], - (self.time_dim, + (time_dim, self.system_dim, self.system_dim) ) else: M = np.reshape(y[:, self.system_dim:], - (self.time_dim, + (time_dim, self.system_dim, self.system_dim) ) @@ -283,7 +279,7 @@ def calc_lyapunov_exponents_series( Returns: Lyapunov exponents for all timesteps, array of size - (total_time/rescale_time - 1, system_dim) + (total_time/rescale_time - 1, system_dim) """ # Set total_time diff --git a/dabench/data/enso_indices.py b/dabench/data/_enso_indices.py similarity index 98% rename from dabench/data/enso_indices.py rename to dabench/data/_enso_indices.py index edfb02a..cda4453 100644 --- a/dabench/data/enso_indices.py +++ b/dabench/data/_enso_indices.py @@ -14,14 +14,13 @@ class ENSOIndices(_data.Data): - """Class to get ENSO indices from CPC website + """Gets ENSO indices from CPC website Notes: Source: https://www.cpc.ncep.noaa.gov/data/indices/ - Attributes: + Args: system_dim: system dimension - time_dim: total time steps store_as_jax: Store values as jax array instead of numpy array. Default is False (store as numpy). file_dict: Lists of files to get. Dict keys are type of data: @@ -58,7 +57,6 @@ def __init__(self, file_dict: dict | None = None, var_types: dict | None = None, system_dim: int | None = None, - time_dim: int | None = None, store_as_jax: bool = False, **kwargs): @@ -66,7 +64,7 @@ def __init__(self, self.file_dict = file_dict self.var_types = var_types - super().__init__(system_dim=system_dim, time_dim=time_dim, + super().__init__(system_dim=system_dim, values=None, delta_t=None, **kwargs, store_as_jax=store_as_jax) diff --git a/dabench/data/gcp.py b/dabench/data/_gcp.py similarity index 98% rename from dabench/data/gcp.py rename to dabench/data/_gcp.py index 36508c3..8b964cb 100644 --- a/dabench/data/gcp.py +++ b/dabench/data/_gcp.py @@ -18,13 +18,13 @@ class GCP(_data.Data): - """Class for loading ERA5 data from Google Cloud Platform + """Loads ERA5 data from Google Cloud Platform Notes: Source: https://cloud.google.com/storage/docs/public-datasets/era5 Data is hourly - Attributes: + Args: variables: Names of ERA5 variables to load. For description of variables, see: https://github.com/google-research/arco-era5?tab=readme-ov-file#full_37-1h-0p25deg-chunk-1zarr-v3 diff --git a/dabench/data/lorenz63.py b/dabench/data/_lorenz63.py similarity index 94% rename from dabench/data/lorenz63.py rename to dabench/data/_lorenz63.py index 48c1423..ee15a26 100644 --- a/dabench/data/lorenz63.py +++ b/dabench/data/_lorenz63.py @@ -13,9 +13,9 @@ ArrayLike = np.ndarray | jax.Array class Lorenz63(_data.Data): - """ Class to set up Lorenz 63 model data + """Lorenz 63 model data generator. - Attributes: + Args: sigma: Lorenz 63 param. Default is 10., the original value used in Lorenz, 1963. https://doi.org/10.1175/1520-0469(1963)020<0130:DNF>2.0.CO;2 @@ -30,7 +30,6 @@ class Lorenz63(_data.Data): and initial conditions [0., 1., 0.], a spinup which replicates the simulation described in Lorenz, 1963. system_dim: system dimension. Must be 3 for Lorenz63. - time_dim: total time steps store_as_jax: Store values as jax array instead of numpy array. Default is False (store as numpy). """ @@ -42,7 +41,6 @@ def __init__(self, delta_t: float = 0.01, x0: ArrayLike | None = jnp.array([-10.0, -15.0, 21.3]), system_dim: int = 3, - time_dim: int | None = None, values: ArrayLike | None = None, store_as_jax: bool = False, **kwargs): @@ -57,7 +55,7 @@ def __init__(self, print('Assigning system_dim to 3.') system_dim = 3 - super().__init__(system_dim=system_dim, time_dim=time_dim, + super().__init__(system_dim=system_dim, values=values, delta_t=delta_t, store_as_jax=store_as_jax, **kwargs) diff --git a/dabench/data/lorenz96.py b/dabench/data/_lorenz96.py similarity index 97% rename from dabench/data/lorenz96.py rename to dabench/data/_lorenz96.py index 991b21e..4b89b16 100644 --- a/dabench/data/lorenz96.py +++ b/dabench/data/_lorenz96.py @@ -13,13 +13,13 @@ class Lorenz96(_data.Data): - """Class to set up Lorenz 96 model data. + """Lorenz 96 model data generator. Notes: Default values come from Lorenz, 1996: eapsweb.mit.edu/sites/default/files/Predicability_a_Problem_2006.pdf - Attributes: + Args: forcing_term: Forcing constant for Lorenz96, prevents energy from decaying to 0. Default is 8.0. x0: Initial state vector, array of floats of size @@ -33,7 +33,6 @@ class Lorenz96(_data.Data): which is set to 0.01. system_dim: System dimension, must be between 4 and 40. Default is 36. - time_dim: Total time steps delta_t: Length of one time step. Default is 0.05 from Lorenz, 1996, but on modern computers 0.01 is often used. store_as_jax: Store values as jax array instead of numpy array. @@ -45,13 +44,12 @@ def __init__(self, delta_t: float = 0.05, x0: ArrayLike | None = None, system_dim: int = 36, - time_dim: int | None = None, values: ArrayLike | None = None, store_as_jax: bool = False, **kwargs): """Initialize Lorenz96 object, subclass of Base""" - super().__init__(system_dim=system_dim, time_dim=time_dim, + super().__init__(system_dim=system_dim, values=values, delta_t=delta_t, store_as_jax=store_as_jax, **kwargs) diff --git a/dabench/data/pyqg.py b/dabench/data/_pyqg.py similarity index 96% rename from dabench/data/pyqg.py rename to dabench/data/_pyqg.py index ea28864..05dc9d2 100644 --- a/dabench/data/pyqg.py +++ b/dabench/data/_pyqg.py @@ -25,16 +25,17 @@ class PyQG(_data.Data): - """ Class to set up quasi-geotropic model + """PyQG quasi-geotropic model data generator. The PyQG class is simply a wrapper of a "optional" pyqg package. See https://pyqg.readthedocs.io Notes: + DEPRECATED Uses default attribute values from pyqg.QGModel: https://pyqg.readthedocs.io/en/latest/api.html#pyqg.QGModel - Attributes: + Args: beta (float): Gradient of coriolis parameter. Units: meters^-1 * seconds^-1 rek (float): Linear drag in lower layer. Units: seconds^-1 @@ -47,7 +48,7 @@ class PyQG(_data.Data): ny (int): Number of grid points in the y direction (default: nx). L (float): Domain length in x direction. Units: meters. W (float): Domain width in y direction. Units: meters (default: L). - filterfac (float): amplitdue of the spectral spherical filter + filterfac (float): amplitude of the spectral spherical filter (originally 18.4, later changed to 23.6). delta_t (float): Numerical timestep. Units: seconds. twrite (int): Interval for cfl writeout. Units: number of timesteps. @@ -191,8 +192,8 @@ def __advance__(self,): """Advances the QG model according to set attributes Returns: - qs (array_like): absolute potential vorticity (relative potential - vorticity + background vorticity). + Array of absolute potential vorticity (relative potential + vorticity + background vorticity). """ qs = [] for _ in self.m.run_with_snapshots(tsnapstart=0, tsnapint=self.m.dt): diff --git a/dabench/data/pyqg_jax.py b/dabench/data/_pyqg_jax.py similarity index 96% rename from dabench/data/pyqg_jax.py rename to dabench/data/_pyqg_jax.py index 812250a..e26282d 100644 --- a/dabench/data/pyqg_jax.py +++ b/dabench/data/_pyqg_jax.py @@ -32,7 +32,7 @@ class PyQGJax(_data.Data): - """Class to set up quasi-geotropic model + """PyQGJax quasi-geotropic model data generator. The PyQGJax class is simply a wrapper of the "optional" pyqg-jax package. See https://pyqg-jax.readthedocs.io @@ -41,7 +41,7 @@ class PyQGJax(_data.Data): Uses default attribute values from pyqg_jax.QGModel: https://pyqg.readthedocs.io/en/latest/api.html#pyqg.QGModel - Attributes: + Args: beta: Gradient of coriolis parameter. Units: meters^-1 * seconds^-1 rd: Deformation radius. Units: meters. @@ -72,7 +72,6 @@ def __init__(self, ny: int | None = None, delta_t: float = 7200, random_seed: int = 37, - time_dim: int | None = None, store_as_jax: bool = False, **kwargs): """ Initialize PyQGJax QGModel object, subclass of Base @@ -110,7 +109,7 @@ def __init__(self, jax.random.PRNGKey(0) ) super().__init__(system_dim=system_dim, original_dim=original_dim, - time_dim=time_dim, delta_t=delta_t, + delta_t=delta_t, store_as_jax=store_as_jax, x0=x0, **kwargs) @@ -157,8 +156,7 @@ def generate(self, Notes: Either provide n_steps or t_final in order to indicate the length - of the forecast. These are used to set the values, times, and - time_dim attributes. + of the forecast. Args: n_steps: Number of timesteps. One of n_steps OR diff --git a/dabench/data/qgs.py b/dabench/data/_qgs.py similarity index 93% rename from dabench/data/qgs.py rename to dabench/data/_qgs.py index f256612..9b8bf86 100644 --- a/dabench/data/qgs.py +++ b/dabench/data/_qgs.py @@ -33,12 +33,12 @@ class QGS(_data.Data): - """ Class to set up QGS quasi-geostrophic model + """QGS quasi-geostrophic model data generator. The QGS class is simply a wrapper of an *optional* qgs package. See https://qgs.readthedocs.io/ - Attributes: + Args: model_params: qgs parameter object. See: https://qgs.readthedocs.io/en/latest/files/technical/configuration.html#qgs.params.params.QgParams If None, will use defaults specified by: @@ -53,7 +53,6 @@ def __init__(self, x0: ArrayLike | None = None, delta_t: ArrayLike | None = 0.1, system_dim: int | None = None, - time_dim: int | None = None, store_as_jax: bool = False, random_seed: int = 37, **kwargs): @@ -86,7 +85,7 @@ def __init__(self, if x0 is None: x0 = self._rng.random(system_dim)*0.001 - super().__init__(system_dim=system_dim, time_dim=time_dim, + super().__init__(system_dim=system_dim, delta_t=delta_t, store_as_jax=store_as_jax, x0=x0, **kwargs) @@ -124,13 +123,12 @@ def rhs(self, ) -> np.ndarray: """Vector field (tendencies) of qgs system - Arg: - x: State vector, shape: (system_dim) + Args: + x: State vector of size (system_dim) t: times vector. Required as argument slot for some numerical integrators but unused. Returns: - dx: vector field of qgs - + Vector field of qgs """ dx = self.f(t, x) @@ -139,18 +137,17 @@ def rhs(self, def Jacobian(self, x: ArrayLike, - t: float | None = 0 + t: float | None = 0 ) -> np.ndarray: """Jacobian of the qgs system - Arg: - x: State vector, shape: (system_dim) + Args: + x: State vector of size (system_dim) t: times vector. Required as argument slot for some numerical integrators but unused. Returns: - J (ndarray): Jacobian matrix, shape: (system_dim, system_dim) - + Jacobian matrix of size (system_dim, system_dim) """ J = self.Df(t, x) @@ -169,8 +166,7 @@ def generate(self, Notes: Either provide n_steps or t_final in order to indicate the length - of the forecast. These are used to set the values, times, and - time_dim attributes. + of the forecast. Args: n_steps (int): Number of timesteps. One of n_steps OR @@ -189,8 +185,8 @@ def generate(self, convergence tolerance, etc.). Returns: - Xarray Dataset of output vector and (if return_tlm=True) - Xarray DataArray of TLMs corresponding to the system trajectory. + Xarray Dataset of output vector, and if return_tlm=True + Xarray DataArray of TLMs corresponding to the system trajectory. """ # Check that n_steps or t_final is supplied @@ -243,8 +239,8 @@ def generate(self, **kwargs) # Convert to JAX if necessary - self.time_dim = t.shape[0] - out_dim = (self.time_dim,) + self.original_dim + time_dim = t.shape[0] + out_dim = (time_dim,) + self.original_dim if self.store_as_jax: y_out = jnp.array(y[:,:self.system_dim].reshape(out_dim)) else: @@ -268,13 +264,13 @@ def generate(self, # Reshape M matrix if self.store_as_jax: M = jnp.reshape(y[:, self.system_dim:], - (self.time_dim, + (time_dim, self.system_dim, self.system_dim) ) else: M = np.reshape(y[:, self.system_dim:], - (self.time_dim, + (time_dim, self.system_dim, self.system_dim) ) @@ -296,7 +292,7 @@ def rhs_aux(self, t: Array of times with size (time_dim) Returns: - dxaux (ndarray): State vector [size: (system_dim,)] + State vector of size (system_dim,) """ # Compute M dxdt = self.rhs(x[:self.system_dim], t) @@ -353,7 +349,7 @@ def calc_lyapunov_exponents_series( Returns: Lyapunov exponents for all timesteps, array of size - (total_time/rescale_time - 1, system_dim) + (total_time/rescale_time - 1, system_dim) """ # Set total_time if total_time is None: diff --git a/dabench/data/sqgturb.py b/dabench/data/_sqgturb.py similarity index 97% rename from dabench/data/sqgturb.py rename to dabench/data/_sqgturb.py index 0cc2738..f0ca83e 100644 --- a/dabench/data/sqgturb.py +++ b/dabench/data/_sqgturb.py @@ -50,16 +50,15 @@ class SQGTurb(_data.Data): - """Class to set up SQGTurb model and manage data. + """SQGTurb model data generator. - Attributes: + Args: pv: Potential vorticity array. If None (default), loads data from 57600 step spinup with initial conditions taken from Jeff Whitaker's original implementation: https://github.com/jswhit/sqgturb. 57600 steps matches the "nature run" spin up in that repository. system_dim: The dimension of the system state - time_dim: The dimension of the timeseries (not used) delta_t: model time step (seconds) x0: Initial state, array of floats of size (system_dim). @@ -499,7 +498,6 @@ def integrate(self, if include_x0: n_steps = n_steps + 1 - self.time_dim = n_steps times = t + jnp.arange(n_steps)*delta_t # Integrate in spectral spacestep_n @@ -548,13 +546,4 @@ def rhs(self, # save wind field self.u = -psiy self.v = psix - return dpvspecdt - - def _to_original_dim(self) -> np.ndarray: - """Going back to 2D is a bit trickier for sqgturb""" - gridded_vals = np.zeros((self.time_dim, self.Nv, self.Nx, self.Nx)) - - for t in np.arange(self.time_dim): - gridded_vals[t] = self.map1dto2d_ifft2(self.values[t]) - - return gridded_vals + return dpvspecdt \ No newline at end of file diff --git a/dabench/model/__init__.py b/dabench/model/__init__.py index f05d128..15591a2 100644 --- a/dabench/model/__init__.py +++ b/dabench/model/__init__.py @@ -1,3 +1,4 @@ +"""Model classes""" from ._model import Model from ._rc import RCModel diff --git a/dabench/model/_model.py b/dabench/model/_model.py index c319ce2..75802ba 100644 --- a/dabench/model/_model.py +++ b/dabench/model/_model.py @@ -8,23 +8,20 @@ import xarray as xr class Model(): - """Base class for Model object + """Base for Model objects - Attributes: - system_dim (int): system dimension - time_dim (int): total time steps - delta_t (float): the timestep of the model (assumed uniform) - model_obj (obj): underlying model object, e.g. pytorch neural network. + Args: + system_dim: system dimension + delta_t: the timestep of the model (assumed uniform) + model_obj: underlying model object, e.g. pytorch neural network. """ def __init__(self, system_dim: int | None = None, - time_dim: int | None = None, delta_t: int | None = None, - model_obj: int | None = None + model_obj: Any | None = None ): self.system_dim = system_dim - self.time_dim = time_dim self.delta_t = delta_t self.model_obj = model_obj diff --git a/dabench/model/_rc.py b/dabench/model/_rc.py index e842dd7..87a3dbf 100644 --- a/dabench/model/_rc.py +++ b/dabench/model/_rc.py @@ -15,9 +15,9 @@ class RCModel(model.Model): - """Class for a simple Reservoir Computing data-driven model + """A simple Reservoir Computing data-driven model - Attributes: + Args: system_dim (int): Dimension of reservoir output. input_dim (int): Dimension of reservoir input signal. reservoir_dim (int): Dimension of reservoir state. Default: 512. @@ -170,6 +170,7 @@ def weights_init(self): def generate(self, state_vec, A=None, Win=None, r0=None): """generate reservoir time series from input signal u + Args: u (array_like): (time_dimension, system_dimension), input signal to reservoir @@ -182,7 +183,7 @@ def generate(self, state_vec, A=None, Win=None, r0=None): If False, returns states. Default: False. Returns: - r (array_like): (time_dim, reservoir_dim), reservoir state + Reservoirs state, size (time_dim, reservoir_dim) """ u = state_vec.to_stacked_array('system',['time']).data r = np.zeros((u.shape[0], self.reservoir_dim)) @@ -214,7 +215,7 @@ def update(self, r, u, A=None, Win=None): reservoir input weight matrix. If None, uses self.Win. Default is None Returns: - q (array_like): (reservoir_dim,) Reservoir state at next time step + Reservoir state at next time step, of size (reservoir_dim,) """ if A is None: @@ -248,8 +249,7 @@ def predict(self, state_vec, delta_t, initial_index=0, n_steps=100, r0 (array_like, optional): initial reservoir state Returns: - dataobj_pred (vector.StateVector): StateVector object covering - prediction period + Data object covering prediction period """ # Recompute the initial reservoir spinup to get reservoir states @@ -287,6 +287,7 @@ def predict(self, state_vec, delta_t, initial_index=0, n_steps=100, def readout(self, rt, Wout=None, utm1=None): """use Wout to map reservoir state to output + Args: rt (array_like): 1D or 2D with dims: (Nr,) or (Ntime, Nr) reservoir state, either passed as single time snapshot, @@ -294,9 +295,11 @@ def readout(self, rt, Wout=None, utm1=None): utm1 (array_like): 1D or 2D with dims: (Nu,) or (Ntime, Nu) u(t-1) for r(t), only used if readout_method = 'biased', then Wout*[1, u(t-1), r(t)]=u(t) + Returns: - vt (array_like): 1D or 2D with dims: (Nout,) or (Ntime, Nout) - depending on shape of input array + 1D or 2D array with dims(Nout,) or (Ntime, Nout) + depending on shape of input array + Todo: generalize similar to DiffRC """ @@ -346,8 +349,9 @@ def _predict_backend(self, n_samples, s_last, u_last, delta_t, Default is None. Wout (array_like, optional): Rutput weight matrix. If None, uses self.Wout. Default is None. + Returns: - y (Data): data object with predicted signal from reservoir + Data object with predicted signal from reservoir """ s = jnp.zeros((n_samples, self.reservoir_dim)) @@ -397,8 +401,8 @@ def _compute_Wout(self, rt, y, update_Wout=True, u=None): initialize it by rewriting the ybar and sbar matrices Returns: - Wout (array_like): 2D with dims (output_dim, reservoir_dim), - this is also stored within the object + Wout array, 2D with dims (output_dim, reservoir_dim), + this is also stored within the object Sets Attributes: ybar (array_like): y.T @ st, st is rt with readout_method accounted @@ -451,8 +455,8 @@ def _compute_Wout(self, rt, y, update_Wout=True, u=None): return self.Wout def _linsolve(self, X, Y, beta=None, **kwargs): - '''Linear solver wrapper - Solve for A in Y = AX + '''Linear solver wrapper for A in Y = AX + Args: X (matrix) : independent variable Y (matrix) : dependent variable @@ -464,11 +468,13 @@ def _linsolve(self, X, Y, beta=None, **kwargs): def _linsolve_pinv(self, X, Y, beta=None): """Solve for A in Y = AX, assuming X and Y are known. + Args: X : independent variable, square matrix Y : dependent variable, square matrix + Returns: - A : Solution matrix, rectangular matrix + Solution matrix, rectangular matrix """ if beta is not None: Xinv = linalg.pinv(X+beta*np.eye(X.shape[0])) diff --git a/dabench/observer/__init__.py b/dabench/observer/__init__.py index 63df24f..16b4d98 100644 --- a/dabench/observer/__init__.py +++ b/dabench/observer/__init__.py @@ -1,3 +1,4 @@ +"""Observer module""" from ._observer import Observer __all__ = [ diff --git a/dabench/observer/_observer.py b/dabench/observer/_observer.py index cd52096..decb844 100644 --- a/dabench/observer/_observer.py +++ b/dabench/observer/_observer.py @@ -17,10 +17,10 @@ class Observer(): - """Base class for Observer objects + """Flexibly samples observations from generated data - Attributes: - data_obj: Data generator/loader object from which + Args: + state_vec: Data generator/loader object from which to gather observations. random_location_density: Fraction of locations in system_dim to randomly select for observing, must be value @@ -74,6 +74,16 @@ class Observer(): store_as_jax: Store values as jax array instead of numpy array. Default is False (store as numpy). + Attributes: + locations (ArrayLike): Location indices for making + observations. In system_dim (1D) or original dim + (>1D) of self.state_vec. + location_dim (int): Number of locations sampled from (max + in a single time step, if non-stationary observers). + times (ArrayLike): Time indices to gather observations + from. + time_dim (int): Number of times sampled from. + """ def __init__(self, @@ -265,7 +275,7 @@ def observe(self) -> xr.Dataset: Returns: ObsVector containing observation values, times, locations, and - errors + errors """ # Define random num generator @@ -353,4 +363,4 @@ def observe(self) -> xr.Dataset: obs_vec[data_var] = obs_vec[data_var] + obs_vec['errors'].sel(variable=data_var) - return obs_vec \ No newline at end of file + return obs_vec diff --git a/docs/conf.py b/docs/conf.py index 86a22e8..ecaaa1f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -13,9 +13,23 @@ 'sphinx.ext.duration', 'sphinx.ext.doctest', 'sphinx.ext.intersphinx', - 'autoapi.extension' + 'autoapi.extension', + 'sphinx.ext.autodoc.typehints', + 'sphinx.ext.napoleon' ] autoapi_dirs = ['../dabench'] +# Important: Because we're not including "undoc-members", +# you need to include a docstring on *everything* you want documented. +# Including in __init__.py for submodules. +autoapi_options = ['members', 'show-module-summary', + 'special-members', 'imported-members'] + +autodoc_typehints = 'description' +autoapi_member_order = 'groupwise' +autoapi_add_toctree_entry = True +autoapi_own_page_level = 'module' +napoleon_numpy_docstring = False +napoleon_google_docstring = True intersphinx_mapping = { 'python': ('https://docs.python.org/3/', None), diff --git a/tests/dacycler_base_test.py b/tests/dacycler_base_test.py index ba33419..b51036a 100644 --- a/tests/dacycler_base_test.py +++ b/tests/dacycler_base_test.py @@ -9,12 +9,11 @@ def test_dacycler_init(): params = {'system_dim': 6, 'delta_t': 0.5, - 'ensemble': True, 'model_obj':dab.model.RCModel(6, 10)} test_dac = dab.dacycler.DACycler(**params) assert test_dac.system_dim == 6 assert test_dac.delta_t == 0.5 - assert test_dac.ensemble - assert not test_dac.in_4d + assert not test_dac._uses_ensemble + assert not test_dac._in_4d