From 5ecbef27d1618cbe008cfe051cfa5f520aa2d89c Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Wed, 18 Sep 2024 14:30:35 -0600 Subject: [PATCH 01/44] Data().generate() now returns an Xarray 'state vector' --- dabench/data/_data.py | 398 ++++-------------------------------------- 1 file changed, 31 insertions(+), 367 deletions(-) diff --git a/dabench/data/_data.py b/dabench/data/_data.py index 0c375e1..7762611 100644 --- a/dabench/data/_data.py +++ b/dabench/data/_data.py @@ -22,9 +22,6 @@ class Data(): i.e. 1d. random_seed (int): random seed, defaults to 37 delta_t (float): the timestep of the data (assumed uniform) - values (ndarray): 2d array of data (time_dim, system_dim), - set by generate() method - times (ndarray): 1d array of times (time_dim), set by generate() method store_as_jax (bool): Store values as jax array instead of numpy array. Default is False (store as numpy). """ @@ -35,7 +32,6 @@ def __init__(self, original_dim=None, random_seed=37, delta_t=0.01, - values=None, store_as_jax=False, x0=None, **kwargs): @@ -46,10 +42,8 @@ def __init__(self, self.random_seed = random_seed self.delta_t = delta_t self.store_as_jax = store_as_jax - # values and x0 atts are properties to better convert between jax/numpy - self._values = values + # x0 attribute is property to better convert between jax/numpy self._x0 = x0 - self._times = None if original_dim is None: self.original_dim = (system_dim,) @@ -59,69 +53,10 @@ def __init__(self, self._values_gridded = None self._x0_gridded = None - def __getitem__(self, subscript): - if self.values is None: - raise AttributeError('Object does not contain any data values.\n' - 'Run .generate() or .load() and try again') - - if isinstance(subscript, slice): - new_copy = copy.deepcopy(self) - new_copy.values = new_copy.values[ - subscript.start:subscript.stop:subscript.step] - new_copy.times = new_copy.times[ - subscript.start:subscript.stop:subscript.step] - new_copy.time_dim = new_copy.times.shape[0] - return new_copy - else: - new_copy = copy.deepcopy(self) - new_copy.values = new_copy.values[subscript] - new_copy.times = new_copy.times[subscript] - if isinstance(subscript, int): - new_copy.time_dim = 1 - else: - new_copy.time_dim = new_copy.times.shape[0] - return new_copy - - @property - def values(self): - return self._values - - @values.setter - def values(self, vals): - if vals is None: - self._values = None - else: - if self.store_as_jax: - self._values = jnp.asarray(vals) - else: - self._values = np.asarray(vals) - - @values.deleter - def values(self): - del self._values - @property def x0(self): return self._x0 - @property - def times(self): - return self._times - - @times.setter - def times(self, vals): - if vals is None: - self._times = None - else: - if self.store_as_jax: - self._times = jnp.asarray(vals) - else: - self._times = np.asarray(vals) - - @times.deleter - def times(self): - del self._times - @x0.setter def x0(self, x0_vals): if x0_vals is None: @@ -136,13 +71,6 @@ def x0(self, x0_vals): def x0(self): del self._x0 - @property - def values_gridded(self): - if self._values is None: - return None - else: - return self._to_original_dim() - @property def x0_gridded(self): if self._x0 is None: @@ -150,31 +78,9 @@ def x0_gridded(self): else: return self._x0.reshape(self.original_dim) - def _to_original_dim(self): - """Converts 1D representation of system back to original dimensions. - - Returns: - Multidimensional array with shape: - (time_dim, original_dim[0], ..., original_dim[n]) - """ - return jnp.reshape(self.values, (self.time_dim,) + self.original_dim) - - def sample_cells(self, targets): - """Samples values at a list of multidimensional array indices. - - Args: - targets (ndarray): Array of target indices in shape: - (num_of_target_indices, time_dim + original_dim). E.g. - [[0,0], [0,1]] samples the first and second cell values in the - first timestep (in this case original_dim = 1). - """ - tupled_targets = tuple(tuple(targets[:, i]) for - i in range(len(self.original_dim) + 1)) - return self._to_original_dim()[tupled_targets] - def generate(self, n_steps=None, t_final=None, x0=None, M0=None, return_tlm=False, stride=None, **kwargs): - """Generates a dataset and assigns values and times to the data object. + """Generates a dataset and returns xarray state vector. Notes: Either provide n_steps or t_final in order to indicate the length @@ -251,254 +157,40 @@ def generate(self, n_steps=None, t_final=None, x0=None, M0=None, jax_comps=self.store_as_jax, **kwargs) - # The generate method specifically stores data in the object, - # as opposed to the forecast method, which does not. - # Store values and times as part of data object - self.values = y[:, :self.system_dim] - self.times = t - self.time_dim = len(t) + # Convert to JAX if necessary + if self.store_as_jax: + y_out = jnp.array(y[:,:self.system_dim]) + else: + y_out = np.array(y[:,:self.system_dim]) + # Build Xarray object for output + out_vec = xr.Dataset( + {'x': (['time','i'],y_out)}, + coords={'time': t, + 'i':np.arange(self.system_dim)}, + attrs={'store_as_jax':self.store_as_jax, + 'system_dim': self.system_dim, + 'time_dim': self.time_dim + } + ) # Return the data series and associated TLMs if requested if return_tlm: # Reshape M matrix - M = jnp.reshape(y[:, self.system_dim:], - (self.time_dim, - self.system_dim, - self.system_dim) - ) - if self.store_as_jax: - return M - else: - return np.array(M) - - def _import_xarray_ds(self, ds, include_vars=None, exclude_vars=None, - years_select=None, dates_select=None, - lat_sorting=None): - # Convert to numpy background - ds = ds.as_numpy() - - if dates_select is not None: - dates_filter_indices = ds.time.dt.date.isin(dates_select) - # First check to make sure the dates exist in the object - if dates_filter_indices.sum() == 0: - raise ValueError('Dataset does not contain any of the dates' - ' specified in dates_select\n' - 'dates_select = {}\n' - 'NetCDF contains {}'.format( - dates_select, - np.unique(ds.time.dt.date) - ) - ) + M = jnp.reshape(y[:, self.system_dim:], + (self.time_dim, + self.system_dim, + self.system_dim) + ) else: - ds = ds.isel(time=dates_filter_indices) - else: - if years_select is not None: - year_filter_indices = ds.time.dt.year.isin(years_select) - # First check to make sure the years exist in the object - if year_filter_indices.sum() == 0: - raise ValueError('Dataset does not contain any of the ' - 'years specified in years_select\n' - 'years_select = {}\n' - 'NetCDF contains {}'.format( - years_select, - np.unique(ds.time.dt.year) - ) - ) - else: - ds = ds.isel(time=year_filter_indices) - - # Check size before loading - size_gb = ds.nbytes / (1024 ** 3) - if size_gb > 1: - warnings.warn('Trying to load large xarray dataset into memory. \n' - 'Size: {} GB. Operation may take a long time, ' - 'stall, or crash.'.format(size_gb)) - - # Get variable names and shapes - names_list = [] - shapes_list = [] - if exclude_vars is not None: - ds = ds.drop_vars(exclude_vars) - if include_vars is not None: - ds = ds[include_vars] - for data_var in ds.data_vars: - shapes_list.append(ds[data_var].shape) - names_list.append(data_var) - - # Load - ds.load() - - # Get dims - dims = ds.sizes - dims_names = list(ds.sizes) - - # Set times - time_key = None - dims_keys = dims.keys() - if 'time' in dims_keys: - time_key = 'time' - elif 'times' in dims_keys: - time_key = 'times' - elif 'time0' in dims_keys: - time_key = 'time0' - if time_key is not None: - self.times = ds[time_key].values - self.time_dim = self.times.shape[0] - else: - self.times = np.array([0]) - self.time_dim = 1 - - # Find names for key dimensions: lat, lon, level (if it exists) - lat_key = None - lon_key = None - lev_key = None - if 'level' in dims_keys: - lev_key = 'level' - elif 'lev' in dims_keys: - lev_key = 'lev' - if 'latitude' in dims_keys: - lat_key = 'latitude' - elif 'lat' in dims_keys: - lat_key = 'lat' - if 'longitude' in dims_keys: - lon_key = 'longitude' - elif 'lon' in dims_keys: - lon_key = 'lon' - - # Reorder dimensions: time, level, lat, lon, etc. - dim_order = np.array([time_key, lev_key, lat_key, lon_key]) - dim_order = dim_order[dim_order != np.array(None)] - remaining_dims = [d for d in dims_names if d not in dim_order] - full_dim_order = list(dim_order) + remaining_dims - - if len(full_dim_order) > 0: - ds = ds.transpose(*full_dim_order) - - # Orient data vertically - if lat_key is not None: - if lat_sorting is not None: - if lat_sorting == 'ascending': - ds = ds.sortby(lat_key, ascending=True) - elif lat_sorting == 'descending': - ds = ds.sortby(lat_key, ascending=False) - else: - warnings.warn('{} is not a valid value for lat_sorting.\n' - 'Choose one of None, "ascending", or ' - '"descending".\n' - 'Proceeding without sorting.'.format( - lat_sorting) - ) - - # Check if all elements' data shapes are equal - if len(names_list) == 0: - raise ValueError('No valid data_vars were found in dataset.\n' - 'Check your include_vars and exclude_vars args.') - if not shapes_list.count(shapes_list[0]) == len(shapes_list): - # Formatting for showing variable names and shapes - var_shape_warn_list = ['{:<12} {:<15}'.format( - 'Variable', 'Dimensions')] - var_shape_warn_list += ['{:<16} {:<16}'.format( - names_list[i], str(shapes_list[i])) - for i in range(len(shapes_list))] - warnings.warn('data_vars do not all share the same dimensions.\n' - 'Broadcasting variables to same dimensions.\n' - 'To avoid, use include_vars or exclude_vars args.\n' - 'Variable dimensions are:\n' - '{}'.format('\n'.join(var_shape_warn_list)) - ) - - # Gather values and set dimensions - temp_values = np.moveaxis(ds.to_dataarray().values, 0, -1) - self.original_dim = temp_values.shape[1:] - if self.original_dim[-1] == 1 and len(self.original_dim) > 2: - self.original_dim = self.original_dim[:-1] - - self.values = temp_values.reshape( - temp_values.shape[0], -1) - self.var_names = np.array(names_list) - if self.x0 is None: - self.x0 = self.values[0] - self.time_dim = self.values.shape[0] - self.system_dim = self.values.shape[1] - if len(full_dim_order) == 0: - warnings.warn('Unable to find any spatial or level dimensions ' - 'in dataset. Setting original_dim to system_dim: ' - '{}'.format(self.system_dim)) - - def load_netcdf(self, filepath=None, include_vars=None, exclude_vars=None, - years_select=None, dates_select=None, - lat_sorting='descending'): - """Loads values from netCDF file, saves them in values attribute - - Args: - filepath (str): Path to netCDF file to load. If not given, - defaults to loading ERA5 ECMWF SLP data over Japan - from 2018 to 2021. - include_vars (list-like): Data variables to load from NetCDF. If - None (default), loads all variables. Can be used to exclude bad - variables. - exclude_vars (list-like): Data variabes to exclude from NetCDF - loading. If None (default), loads all vars (or only those - specified in include_vars). It's recommended to only specify - include_vars OR exclude_vars (unless you want to do extra - typing). - years_select (list-like): Years to load (ints). If None, loads all - timesteps. - dates_select (list-like): Dates to load. Elements must be - datetime date or datetime objects, depending on type of time - indices in NetCDF. If both years_select and dates_select - are specified, time_stamps overwrites "years" argument. If - None, loads all timesteps. - lat_sorting (str): Orient data by latitude: - descending (default), ascending, or None (uses orientation - from data file). - """ - if filepath is None: - # Use importlib.resources to get the default netCDF from dabench - filepath = resources.files(_suppl_data).joinpath('era5_japan_slp.nc') - with xr.open_dataset(filepath, decode_coords='all') as ds: - self._import_xarray_ds( - ds, include_vars=include_vars, - exclude_vars=exclude_vars, - years_select=years_select, dates_select=dates_select, - lat_sorting=lat_sorting) - - def save_netcdf(self, filename): - """Saves values in values attribute to netCDF file - - Args: - filepath (str): Path to netCDF file to save - """ - - # Set variable names - if not hasattr(self, 'var_names') or self.var_names is None: - var_names = ['var{}'.format(i) for - i in range(self.values.shape[1])] - else: - var_names = self.var_names - - # Set times - if not hasattr(self, 'times') or self.times is None: - times = np.arange(self.values.shape[0]) + M = np.reshape(y[:, self.system_dim:], + (self.time_dim, + self.system_dim, + self.system_dim) + ) + return out_vec, M else: - times = self.times - - # Get values as list: - values_list = [('time', self.values[:, i]) for i in range( - self.values.shape[1])] - - data_dict = dict(zip(var_names, values_list)) - coords_dict = { - 'time': times, - 'system_dim': range(len(var_names)) - } - ds = xr.Dataset( - data_vars=data_dict, - coords=coords_dict - ) - - ds.to_netcdf(filename, mode='w') + return out_vec def rhs_aux(self, x, t): """The auxiliary model used to compute the TLM. @@ -592,8 +284,8 @@ def calc_lyapunov_exponents_series(self, total_time=None, rescale_time=1, # Loop over rescale time periods for i, (t1, t2) in enumerate(zip(times[:-1], times[1:])): - M = self.generate(t_final=t2-t1, x0=x0, M0=M0, return_tlm=True) - x_t2 = self.values[-1] + x, M = self.generate(t_final=t2-t1, x0=x0, M0=M0, return_tlm=True) + x_t2 = x.isel(time=-1).to_array().data.flatten() M_t2 = M[-1] Q, R = jnp.linalg.qr(M_t2) @@ -646,31 +338,3 @@ def calc_lyapunov_exponents_final(self, total_time=None, rescale_time=1, rescale_time=rescale_time, x0=x0, convergence=convergence)[-1] - - def split_train_valid_test(self, train_size, valid_size, test_size): - """Splits data into train, validation, and test sets by time - - Args: - train_size, valid_size, test_size (float or int): Size of sets. - If < 1, represents the fraction of the time series to use. - If > 1, represents the number of timesteps. - - Returns: - (train_obj, valid_obj, test_obj): Data objects - """ - - if 0 < train_size < 1: - train_size = round(train_size*self.time_dim) - if 0 < valid_size < 1: - valid_size = round(valid_size*self.time_dim) - if 0 < test_size < 1: - test_size = round(test_size*self.time_dim) - - # Round up train_size - if train_size + valid_size + test_size < self.time_dim: - train_size = self.time_dim - valid_size - test_size - - train_end = train_size - valid_end = train_size + valid_size - - return self[:train_end], self[train_end:valid_end], self[valid_end:] From f5f1bf48a48d7e71e272e52f911a2ec9bc5bd053 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Wed, 18 Sep 2024 14:32:05 -0600 Subject: [PATCH 02/44] Remove vector class import for now, replacing vector objects with simple xarrays --- dabench/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dabench/__init__.py b/dabench/__init__.py index 8d4cbd2..4bba201 100644 --- a/dabench/__init__.py +++ b/dabench/__init__.py @@ -1 +1 @@ -from . import data, vector, model, observer, obsop, dacycler, _suppl_data +from . import data, model, observer, obsop, dacycler, _suppl_data From 8f604367a1a433a5e809174653731b8eac89af04 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Wed, 18 Sep 2024 15:35:55 -0600 Subject: [PATCH 03/44] Update sqgturb to work with xarray. Returns in real space, not spectral. --- dabench/data/_data.py | 18 +++++++++++++----- dabench/data/sqgturb.py | 8 ++++++-- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/dabench/data/_data.py b/dabench/data/_data.py index 7762611..a9ef6c8 100644 --- a/dabench/data/_data.py +++ b/dabench/data/_data.py @@ -42,6 +42,10 @@ def __init__(self, self.random_seed = random_seed self.delta_t = delta_t self.store_as_jax = store_as_jax + + # Default var and coord names + self.var_names = ['x'] + self.coord_names = ['i'] # x0 attribute is property to better convert between jax/numpy self._x0 = x0 @@ -158,15 +162,19 @@ def generate(self, n_steps=None, t_final=None, x0=None, M0=None, **kwargs) # Convert to JAX if necessary + out_dim = (t.shape[0],) + self.original_dim if self.store_as_jax: - y_out = jnp.array(y[:,:self.system_dim]) + y_out = jnp.array(y[:,:self.system_dim].reshape(out_dim)) else: - y_out = np.array(y[:,:self.system_dim]) + y_out = np.array(y[:,:self.system_dim].reshape(out_dim)) # Build Xarray object for output + coord_dict = dict(zip( + ['time'] + self.coord_names, + [t] + [np.arange(dim) for dim in self.original_dim] + )) out_vec = xr.Dataset( - {'x': (['time','i'],y_out)}, - coords={'time': t, - 'i':np.arange(self.system_dim)}, + {self.var_names[0]: (coord_dict.keys(),y_out)}, + coords=coord_dict, attrs={'store_as_jax':self.store_as_jax, 'system_dim': self.system_dim, 'time_dim': self.time_dim diff --git a/dabench/data/sqgturb.py b/dabench/data/sqgturb.py index 5bd1db5..cccff62 100644 --- a/dabench/data/sqgturb.py +++ b/dabench/data/sqgturb.py @@ -113,6 +113,10 @@ def __init__(self, values=values, times=times, delta_t=delta_t, store_as_jax=store_as_jax, **kwargs) + + self.coord_names = ['level','x','y'] + self.var_names=['pv'] + # Fall back on default if no pv if pv is None: with resources.open_binary( @@ -470,8 +474,8 @@ def integrate(self, f, x0, t_final, delta_t=None, include_x0=True, pvspec, values = jax.lax.scan(self._rk4, pvspec, xs=None, length=n_steps) - # Reshape to (time_dim, system_dim) - values = values.reshape((self.time_dim, -1)) + # Apply reverse fft to + values = self.ifft2(values) # Update internal states self.pvspec = pvspec From 9c33a8ca8cb2394bac99e2d62b16f65af0cbd358 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Fri, 20 Sep 2024 13:49:53 -0600 Subject: [PATCH 04/44] Updated observer with xarray object, but only basic functionality is working --- dabench/observer/_observer.py | 199 +++++++++++++++++++--------------- 1 file changed, 113 insertions(+), 86 deletions(-) diff --git a/dabench/observer/_observer.py b/dabench/observer/_observer.py index 80998e6..ad8f811 100644 --- a/dabench/observer/_observer.py +++ b/dabench/observer/_observer.py @@ -7,8 +7,7 @@ import numpy as np import jax.numpy as jnp - -from dabench.vector import ObsVector +import xarray as xr class Observer(): @@ -72,11 +71,15 @@ class Observer(): """ def __init__(self, - data_obj, + state_vec, random_time_density=1., random_location_density=1., + random_variable_density=1., random_time_count=None, random_location_count=None, + random_variable_count=None, + times=None, + locations=None, time_indices=None, location_indices=None, stationary_observers=True, @@ -87,17 +90,26 @@ def __init__(self, store_as_jax=False, ): - self.data_obj = data_obj - - if time_indices is not None: - time_indices = np.array(time_indices) - self.time_indices = time_indices + self.state_vec = state_vec + self._coord_names = list(self.state_vec.coords.keys()) + self._nontime_coord_names = [coord for coord in self._coord_names + if coord != 'time'] + + # if time_indices is not None: + # time_indices = np.array(time_indices) + # self.time_indices = time_indices + if times is not None: + times = np.array(times) + self.times = times self.random_time_density = random_time_density self.random_time_count = random_time_count - if location_indices is not None: - location_indices = np.array(location_indices) - self.location_indices = location_indices + # if location_indices is not None: + # location_indices = np.array(location_indices) + # self.location_indices = location_indices + # if locations is not None: + # locations = np.array(locations) + self.locations = locations self.random_location_density = random_location_density self.random_location_count = random_location_count self.stationary_observers = stationary_observers @@ -120,12 +132,12 @@ def __init__(self, if isinstance(self.error_bias, (list, np.ndarray, jnp.ndarray)): if len(self.error_bias) == 1: self._error_bias_is_list = False - elif not len(self.error_bias) == self.data_obj.system_dim: + elif not len(self.error_bias) == self.state_vec.system_dim: raise ValueError( "List of error biases has length {}." "Must match either system_dim ({}) or " "number of location indices ({})".format( - len(self.error_bias), self.data_obj.system_dim, + len(self.error_bias), self.state_vec.system_dim, self.location_indices.shape[0])) elif isinstance(self.error_bias, list): if self.store_as_jax: @@ -139,12 +151,12 @@ def __init__(self, if isinstance(self.error_sd, (list, np.ndarray, jnp.ndarray)): if len(self.error_sd) == 1: self._error_sd_is_list = False - elif not len(self.error_sd) == self.data_obj.system_dim: + elif not len(self.error_sd) == self.state_vec.system_dim: raise ValueError( "List of error sds has length {}." "Must match either system_dim ({}) or " "number of location indices ({})".format( - len(self.error_sd), self.data_obj.system_dim, + len(self.error_sd), self.state_vec.system_dim, self.location_indices.shape[0])) elif isinstance(self.error_sd, list): if self.store_as_jax: @@ -159,37 +171,43 @@ def __init__(self, def _generate_time_indices(self, rng): if self.random_time_count is not None: - self.time_indices = np.sort(rng.choice( - self.data_obj.time_dim, + self.times = np.sort(rng.choice( + self.state_vec['time'], size=self.random_time_count, replace=False, shuffle=False)) else: self.time_indices = np.where( rng.binomial(1, p=self.random_time_density, - size=self.data_obj.time_dim + size=self.state_vec.time_dim ).astype('bool') )[0] def _generate_stationary_indices(self, rng): if self.random_location_count is not None: - self.location_indices = rng.choice( - self.data_obj.system_dim, + location_count = self.random_location_count + else: + location_count = np.sum( + rng.binomial(1, + p=self.random_location_density, + size=self.state_vec.system_dim)) + self.locations = { + coord_name: xr.DataArray( + rng.choice( + self.state_vec[coord_name], size=self.random_location_count, replace=False, - shuffle=False) - else: - self.location_indices = np.where( - rng.binomial(1, p=self.random_location_density, - size=self.data_obj.system_dim - ).astype('bool') - )[0] + shuffle=False), + dims=['observations']) + for coord_name in self._nontime_coord_names + } + self.location_dim = location_count def _generate_nonstationary_indices(self, rng): if self.random_location_count is not None: self.location_indices = np.array([ rng.choice( - self.data_obj.system_dim, + self.state_vec.system_dim, size=self.random_location_count, replace=False, shuffle=False) @@ -198,17 +216,17 @@ def _generate_nonstationary_indices(self, rng): self.location_indices = np.array([ np.where( rng.binomial(1, p=self.random_location_density, - size=self.data_obj.system_dim + size=self.state_vec.system_dim ).astype('bool'))[0] for i in range(self.time_indices.shape[0]) ], dtype=object) def _generate_stationary_indices_gridded(self, rng): if self.random_location_count is not None: - arange_list = [np.arange(n) for n in self.data_obj.original_dim] + arange_list = [np.arange(n) for n in self.state_vec.original_dim] ind_possibilities = np.array( np.meshgrid(*arange_list)).T.reshape( - -1, len(self.data_obj.original_dim)) + -1, len(self.state_vec.original_dim)) self.location_indices = rng.choice( ind_possibilities, size=self.random_location_count, @@ -217,16 +235,16 @@ def _generate_stationary_indices_gridded(self, rng): else: self.location_indices = np.array(np.where( rng.binomial(1, p=self.random_location_density, - size=self.data_obj.original_dim + size=self.state_vec.original_dim ).astype('bool') )).T def _generate_nonstationary_indices_gridded(self, rng): if self.random_location_count is not None: - arange_list = [np.arange(n) for n in self.data_obj.original_dim] + arange_list = [np.arange(n) for n in self.state_vec.original_dim] ind_possibilities = np.array( np.meshgrid(*arange_list)).T.reshape( - -1, len(self.data_obj.original_dim)) + -1, len(self.state_vec.original_dim)) self.location_indices = np.array([rng.choice( ind_possibilities, size=self.random_location_count, @@ -236,7 +254,7 @@ def _generate_nonstationary_indices_gridded(self, rng): self.location_indices = np.array([ np.array(np.where( rng.binomial(1, p=self.random_location_density, - size=self.data_obj.original_dim + size=self.state_vec.original_dim ).astype('bool'))).T for i in range(self.time_indices.shape[0]) ], dtype=object) @@ -244,11 +262,11 @@ def _generate_nonstationary_indices_gridded(self, rng): def _sample_stationary(self, errors_vector, sample_in_system_dim): if sample_in_system_dim: values_vector = ( - self.data_obj.values[self.time_indices][ + self.state_vec.values[self.time_indices][ :, self.location_indices] + errors_vector) else: - values_gridded = self.data_obj.values_gridded + values_gridded = self.state_vec.values_gridded values_vector = np.array([ values_gridded[t][tuple(self.location_indices.T)] for t in self.time_indices]) + errors_vector @@ -257,11 +275,11 @@ def _sample_stationary(self, errors_vector, sample_in_system_dim): def _sample_nonstationary(self, errors_vector, sample_in_system_dim): if sample_in_system_dim: values_vector = np.array([ - (self.data_obj.values[self.time_indices[i]] + (self.state_vec.values[self.time_indices[i]] [self.location_indices[i]] + errors_vector[i]) for i in range(self.time_dim)], dtype=object) else: - values_gridded = self.data_obj.values_gridded + values_gridded = self.state_vec.values_gridded values_vector = np.array( [values_gridded[self.time_indices[i]][ tuple(self.location_indices[i].T)] @@ -277,47 +295,46 @@ def observe(self): errors """ - if self.data_obj.values is None: - raise ValueError('Data have not been generated/loaded. Run:\n' - 'self.data_obj.generate() to create data for ' - 'observer') - # Define random num generator rng = np.random.default_rng(self.random_seed) # Set time indices - if self.time_indices is None: + if self.times is None: self._generate_time_indices(rng) - self.time_dim = self.time_indices.shape[0] + self.time_dim = self.times.shape[0] # For stationary observers (default) if self.stationary_observers: # Generate location_indices if not specified - if self.location_indices is None: + if self.locations is None: # Check if data is in spectral or physical space - if (hasattr(self.data_obj, 'is_spectral') and - self.data_obj.is_spectral): + if (hasattr(self.state_vec, 'is_spectral') and + self.state_vec.is_spectral): self._generate_stationary_indices_gridded(rng) else: self._generate_stationary_indices(rng) - # Check that location_indices are in correct dimensions - if self.location_indices.shape[0] == 0: - raise ValueError('location_indices is an empty list') - elif len(self.location_indices.shape) == 1: - sample_in_system_dim = True - elif (self.location_indices.shape[1] == - len(self.data_obj.original_dim)): - sample_in_system_dim = False - else: - raise ValueError('location_indices must be 1D or match\n' - 'len(self.data_obj.original_dim)') + # # Check that location_indices are in correct dimensions + # if self.locations.shape[0] == 0: + # raise ValueError('locations is an empty list') + # elif len(self.locations.shape) == 1: + # sample_in_system_dim = True + # elif (self.locations.shape[1] == + # len(self.state_vec.original_dim)): + # sample_in_system_dim = False + # else: + # raise ValueError('locations must be 1D or match\n' + # 'len(self.state_vec.original_dim)') + + # self.location_dims = tuple([v.shape[0] for k, v in self.locations.items()]) + + self.location_dim = next(iter(self.locations.items()))[1] ['observations'].size # Generate errors - self.location_dim = np.repeat(self.location_indices.shape[0], - self.time_dim) - errors_vec_size = (self.time_dim,) + (self.location_dim[0],) + errors_vec_size = ((self.time_dim,) + + (self.location_dim,) + + (len(self.state_vec.data_vars),)) if self._error_bias_is_list: error_bias = self.error_bias[self.location_indices] else: @@ -330,26 +347,26 @@ def observe(self): scale=error_sd, size=errors_vec_size) - # Clip errors to positive only + # # Clip errors to positive only if self.error_positive_only: errors_vector[errors_vector < 0.] = 0. - # Get values - values_vector = self._sample_stationary( - errors_vector, - sample_in_system_dim) + # # Get values + # values_vector = self._sample_stationary( + # errors_vector, + # sample_in_system_dim) - # Repeat location indices across time_dim for passing to ObsVector - full_loc_indices = np.array( - [self.location_indices] * self.time_dim) + # # Repeat location indices across time_dim for passing to ObsVector + # full_loc_indices = np.array( + # [self.location_indices] * self.time_dim) # If NON-stationary observer else: # Generate location_indices if not specified if self.location_indices is None: # Check if data is in spectral or physical space - if (hasattr(self.data_obj, 'is_spectral') and - self.data_obj.is_spectral): + if (hasattr(self.state_vec, 'is_spectral') and + self.state_vec.is_spectral): self._generate_nonstationary_indices_gridded(rng) else: self._generate_nonstationary_indices(rng) @@ -414,16 +431,26 @@ def observe(self): # For passing to ObsVector full_loc_indices = self.location_indices - return ObsVector(values=values_vector, - times=self.data_obj.times[self.time_indices], - time_indices=self.time_indices, - location_indices=full_loc_indices, - obs_dims=self.location_dim, - num_obs=values_vector.shape[0], - errors=errors_vector, - error_dist='normal', - error_sd=self.error_sd, - error_bias=self.error_bias, - store_as_jax=self.store_as_jax, - stationary_observers=self.stationary_observers - ) + obs_vec = self.state_vec.sel(time=self.times).sel(self.locations) + + obs_vec = obs_vec.assign_coords(variable = list(obs_vec.data_vars)) + obs_vec = obs_vec.assign(errors=(obs_vec.dims, errors_vector)) + + for data_var in obs_vec['variable'].values: + obs_vec[data_var] = obs_vec[data_var] + obs_vec['errors'].sel(variable=data_var) + return obs_vec + + # return self.state_vec.sel(time=self.times).sel(self.locations) + errors_ve + # return ObsVector(values=values_vector, + # times=self.data_obj.times[self.time_indices], + # time_indices=self.time_indices, + # location_indices=full_loc_indices, + # obs_dims=self.location_dim, + # num_obs=values_vector.shape[0], + # errors=errors_vector, + # error_dist='normal', + # error_sd=self.error_sd, + # error_bias=self.error_bias, + # store_as_jax=self.store_as_jax, + # stationary_observers=self.stationary_observers + # ) From 8264485898457f333a82700b5698d35a7c9e1564 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Fri, 20 Sep 2024 16:08:03 -0600 Subject: [PATCH 05/44] 3DVar working with xarray: --- dabench/dacycler/_dacycler.py | 22 ++++++++++------------ dabench/dacycler/_var3d.py | 15 ++++++++------- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/dabench/dacycler/_dacycler.py b/dabench/dacycler/_dacycler.py index 9f24a91..f7a4f3a 100644 --- a/dabench/dacycler/_dacycler.py +++ b/dabench/dacycler/_dacycler.py @@ -88,6 +88,7 @@ def cycle(self, # Number of model steps to run per window steps_per_window = round(analysis_window/self.delta_t) + 1 + print(steps_per_window) # For storing outputs all_output_states = [] @@ -100,11 +101,11 @@ def cycle(self, window_middle = cur_time + _time_offset window_start = window_middle - analysis_window/2 window_end = window_middle + analysis_window/2 - obs_vec_timefilt = obs_vector.filter_times( - window_start, window_end + obs_vec_timefilt = obs_vector.sel( + time=slice(window_start, window_end) ) - if obs_vec_timefilt.values.shape[0] > 0: + if obs_vec_timefilt.sizes['time'] > 0: # 2. Calculate analysis analysis, kh = self._step_cycle(cur_state, obs_vec_timefilt) # 3. Forecast through analysis window @@ -113,18 +114,15 @@ def cycle(self, # 4. Save outputs if return_forecast: # Append forecast to current state, excluding last step - all_output_states.append(forecast_states.values[:-1]) - all_times.append( - np.arange(steps_per_window-1)*self.delta_t + cur_time - ) + print(forecast_states) + all_output_states.append(forecast_states.isel(time=slice(0,steps_per_window-1))) else: - all_output_states.append(analysis.values[np.newaxis]) - all_times.append([cur_time]) + all_output_states.append(analysis) # Starting point for next cycle is last step of forecast - cur_state = forecast_states[-1] + cur_state = forecast_states.isel(time=steps_per_window-1) + print(cur_state) cur_time += analysis_window - return vector.StateVector(values=np.concatenate(all_output_states), - times=np.concatenate(all_times)) + return all_output_states diff --git a/dabench/dacycler/_var3d.py b/dabench/dacycler/_var3d.py index f84fd87..a1a45dd 100644 --- a/dabench/dacycler/_var3d.py +++ b/dabench/dacycler/_var3d.py @@ -62,18 +62,19 @@ def _step_cycle(self, xb, yo, H=None, h=None, R=None, B=None): def _calc_default_H(self, obs_vec): """If H is not provided, creates identity matrix to serve as H""" - H = jnp.zeros((obs_vec.values.flatten().shape[0], self.system_dim)) - H = H.at[jnp.arange(H.shape[0]), obs_vec.location_indices.flatten() + H = jnp.zeros((obs_vec.sizes['observations']*obs_vec.sizes['time'], + self.system_dim)) + H = H.at[jnp.arange(H.shape[0]), np.where(obs_vec.indices.data)[1] ].set(1) return H def _calc_default_R(self, obs_vec): """If R i s not provided, calculates default based on observation error""" - return jnp.identity(obs_vec.values.flatten().shape[0])*obs_vec.error_sd**2 + return jnp.identity( + obs_vec.sizes['observations']*obs_vec.sizes['time'])*obs_vec.error_sd**2 def _calc_default_B(self): """If B is not provided, identity matrix with shape (system_dim, system_dim.""" - return jnp.identity(self.system_dim) def _cycle_general_obsop(self, forecast, obs_vec): @@ -99,8 +100,8 @@ def _cycle_linear_obsop(self, forecast, obs_vec, H=None, R=None, B = self.B # make inputs column vectors - xb = jnp.array([forecast.values.flatten()]).T - yo = jnp.array([obs_vec.values.flatten()]).T + xb = jnp.array([forecast[self.data_vars].to_array().data.flatten()]).T + yo = jnp.array([obs_vec[self.data_vars].to_array().data.flatten()]).T # Set parameters xdim = xb.size # Size or get one of the shape params? @@ -121,7 +122,7 @@ def _cycle_linear_obsop(self, forecast, obs_vec, H=None, R=None, HBHtPlusR_inv = jnp.linalg.inv(H @ BHt + R) KH = BHt @ HBHtPlusR_inv @ H - return vector.StateVector(values=xa.T[0], store_as_jax=True), KH + return forecast.assign(x=(['i'], xa.T[0])), KH def _step_forecast(self, xa, n_steps): """n_steps forward of model forecast""" From b7285511771ea8c511f968c70fe38f609e2750b5 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Mon, 23 Sep 2024 13:25:26 -0600 Subject: [PATCH 06/44] ETKF working with xarray, initial commit --- dabench/dacycler/_etkf.py | 117 +++++++++++++++++++------------------- 1 file changed, 58 insertions(+), 59 deletions(-) diff --git a/dabench/dacycler/_etkf.py b/dabench/dacycler/_etkf.py index ed509e6..9c6c264 100644 --- a/dabench/dacycler/_etkf.py +++ b/dabench/dacycler/_etkf.py @@ -5,6 +5,8 @@ import jax import jax.numpy as jnp from jax.scipy import linalg +import xarray as xr +import xarray_jax as xj from dabench import dacycler, vector import dabench.dacycler._utils as dac_utils @@ -61,20 +63,20 @@ def __init__(self, ensemble=True, B=B, R=R, H=H, h=h) - def _step_cycle(self, xb, yo, obs_time_mask, obs_loc_mask, + def _step_cycle(self, xb, obs_vals, obs_locs, obs_time_mask, obs_loc_mask, H=None, h=None, R=None, B=None): if H is not None or h is None: - vals, kh = self._cycle_obsop( - xb.values, yo.values, yo.location_indices, yo.error_sd, obs_time_mask, + vals = self._cycle_obsop( + xb, obs_vals, obs_locs, obs_time_mask, obs_loc_mask, H, R, B) - return vector.StateVector(values=vals, store_as_jax=True), kh + return vals else: return self._cycle_general_obsop(xb, yo, h, R, B) def _calc_default_H(self, obs_values, obs_loc_indices): H = jnp.zeros((obs_values.flatten().shape[0], self.system_dim)) H = H.at[jnp.arange(H.shape[0]), - obs_loc_indices.flatten() + obs_loc_indices.flatten(), ].set(1) return H @@ -84,7 +86,7 @@ def _calc_default_R(self, obs_values, obs_error_sd): def _calc_default_B(self): return jnp.identity(self.system_dim) - def _cycle_obsop(self, Xbt, obs_values, obs_loc_indices, obs_error_sd, + def _cycle_obsop(self, x0_xarray, obs_values, obs_loc_indices, obs_time_mask, obs_loc_mask, H=None, h=None, R=None, B=None): if H is None and h is None: @@ -97,7 +99,7 @@ def _cycle_obsop(self, Xbt, obs_values, obs_loc_indices, obs_error_sd, H = self.H if R is None: if self.R is None: - R = self._calc_default_R(obs_values, obs_error_sd) + R = self._calc_default_R(obs_values, self.obs_error_sd) else: R = self.R if B is None: @@ -106,14 +108,16 @@ def _cycle_obsop(self, Xbt, obs_values, obs_loc_indices, obs_error_sd, else: B = self.B - nr, nc = Xbt.shape + x0_xarray = x0_xarray.to_xarray() + Xbt = x0_xarray[self._data_vars].to_array().data[0] + nr,nc = Xbt.shape assert nr == self.ensemble_dim, ( 'cycle:: model_forecast must have dimension {}x{}').format( self.ensemble_dim, self.system_dim) # Apply obs masks to H - H = jnp.where(obs_time_mask, H.T, 0).T - H = jnp.where(obs_loc_mask.flatten(), H.T, 0).T + # H = jnp.where(obs_time_mask, H.T, 0).T + # H = jnp.where(obs_loc_mask.flatten(), H.T, 0).T # Analysis cycles over all obs in data_obs Xa = self._compute_analysis(Xb=Xbt.T, @@ -123,20 +127,21 @@ def _cycle_obsop(self, Xbt, obs_values, obs_loc_indices, obs_error_sd, R=R, rho=self.multiplicative_inflation) - return Xa.T, 0 + return x0_xarray.assign(x=(['ensemble','i'], Xa.T)) def _step_forecast(self, xa, n_steps): - data_forecast = [] + ensemble_forecasts = [] + ensemble_inputs = [] for i in range(self.ensemble_dim): - new_vals = self.model_obj.forecast( - vector.StateVector(values=xa.values[i], store_as_jax=True), + cur_inputs, cur_forecast = self.model_obj.forecast( + xa.isel(ensemble=i), n_steps=n_steps - ).values - data_forecast.append(new_vals) + ) + ensemble_inputs.append(cur_inputs) + ensemble_forecasts.append(cur_forecast) - out_vals = jnp.moveaxis(jnp.stack(data_forecast), [0,1,2],[1,0,2]) - return vector.StateVector(values=out_vals, - store_as_jax=True) + return (xr.concat(ensemble_inputs, dim='ensemble'), + xr.concat(ensemble_forecasts, dim='ensemble')) def _apply_obsop(self, Xb, H, h): if H is not None: @@ -215,37 +220,29 @@ def _compute_analysis(self, Xb, y, H, h, R, rho=1.0, Yb=None): return Xa - def _cycle_and_forecast(self, state_obs_tuple, filtered_idx): + def _cycle_and_forecast(self, cur_state, filtered_idx): # 1. Get data - cur_state_vals = state_obs_tuple[0] - obs_vals = state_obs_tuple[1] - obs_times = state_obs_tuple[2] - obs_loc_indices = state_obs_tuple[3] - obs_loc_masks = state_obs_tuple[4] - obs_error_sd = state_obs_tuple[5] + # cur_state_vals = state[self.data_vars].data state_obs_tuple[0] # 1-b. Calculate obs_time_mask and restore filtered_idx to original values - obs_time_mask = jnp.repeat(filtered_idx > 0, obs_loc_indices.shape[1]) + obs_time_mask = filtered_idx > 0 filtered_idx = filtered_idx - 1 # 2. Calculate analysis - new_obs_vals = obs_vals[filtered_idx] - new_obs_loc_indices = obs_loc_indices[filtered_idx] - new_obs_loc_mask = obs_loc_masks[filtered_idx] - analysis, kh = self._step_cycle( - vector.StateVector(values=cur_state_vals, store_as_jax=True), - vector.ObsVector(values=new_obs_vals, - location_indices=new_obs_loc_indices, - error_sd=obs_error_sd, store_as_jax=True), - obs_loc_mask=new_obs_loc_mask, + cur_obs_vals = jnp.array(self._obs_vector[self._observed_vars].to_array().data).at[:, filtered_idx].get() + cur_obs_loc_indices = jnp.array(self._obs_vector.indices.data).at[filtered_idx].get() + cur_obs_loc_mask = jnp.array(self._obs_loc_masks).at[filtered_idx].get().astype(bool) + analysis = self._step_cycle( + cur_state, + cur_obs_vals, + cur_obs_loc_indices, + obs_loc_mask=cur_obs_loc_mask, obs_time_mask=obs_time_mask ) # 3. Forecast next timestep - forecast_states = self._step_forecast(analysis, n_steps=self.steps_per_window) - next_state = forecast_states.values[-1] + next_state, forecast_states = self._step_forecast(analysis, n_steps=self.steps_per_window) - return (next_state, obs_vals, obs_times, obs_loc_indices, - obs_loc_masks, obs_error_sd), forecast_states.values[:-1] + return xj.from_xarray(next_state), forecast_states def cycle(self, input_state, @@ -278,6 +275,11 @@ def cycle(self, vector.StateVector of analyses and times. """ + # These could be different if observer doesn't observe all variables + # For now, making them the same + self._observed_vars = obs_vector['variable'].values + self._data_vars = self._observed_vars + if obs_error_sd is None: obs_error_sd = obs_vector.error_sd self.analysis_window = analysis_window @@ -301,7 +303,7 @@ def cycle(self, # Get the obs vectors for each analysis window all_filtered_idx = dac_utils._get_obs_indices( - obs_times=obs_vector.times, + obs_times=jnp.array(obs_vector.time.values), analysis_times=all_times+_time_offset, start_inclusive=True, end_inclusive=False, @@ -309,34 +311,31 @@ def cycle(self, ) all_filtered_padded = dac_utils._pad_time_indices(all_filtered_idx, add_one=True) - - # Padding observations + self._obs_vector=obs_vector + self.obs_error_sd = obs_error_sd if obs_vector.stationary_observers: - obs_loc_masks = jnp.ones(obs_vector.values.shape, dtype=bool) + self._obs_loc_masks = jnp.ones( + obs_vector[self._observed_vars].to_array().shape, dtype=bool) cur_state, all_values = jax.lax.scan( self._cycle_and_forecast, - (input_state.values, obs_vector.values, obs_vector.times, - obs_vector.location_indices, obs_loc_masks, obs_error_sd), + xj.from_xarray(input_state), all_filtered_padded) else: - obs_vals, obs_locs, obs_loc_masks = dac_utils._pad_obs_locs(obs_vector) + obs_vals, obs_locs, self._obs_loc_masks = dac_utils._pad_obs_locs(obs_vector) cur_state, all_values = jax.lax.scan( self._cycle_and_forecast, (input_state.values, obs_vals, obs_vector.times, obs_locs, obs_loc_masks, obs_error_sd), all_filtered_padded) - + + + all_vals_xr = xr.Dataset( + {var: (('cycle',) + tuple(all_values[var].dims), + all_values[var].data) + for var in all_values.data_vars} + ).rename_dims({'time': 'cycle_timestep'}) if return_forecast: - all_times_forecast = jnp.arange( - 0, - n_cycles*analysis_window, - self.delta_t - ) + start_time - return vector.StateVector(values=jnp.concatenate(all_values), - times=all_times_forecast) + return all_vals_xr else: - return vector.StateVector(values=jnp.vstack([ - forecast[0][jnp.newaxis] for forecast in all_values] - ), - times=all_times) + return all_vals_xr.isel(cycle_timestep=0) \ No newline at end of file From 4ccbeba46020e2c499cdec81c2b8c54987d8c592 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Mon, 23 Sep 2024 13:26:18 -0600 Subject: [PATCH 07/44] Sort observed locations --- dabench/observer/_observer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dabench/observer/_observer.py b/dabench/observer/_observer.py index ad8f811..574ceee 100644 --- a/dabench/observer/_observer.py +++ b/dabench/observer/_observer.py @@ -193,11 +193,11 @@ def _generate_stationary_indices(self, rng): size=self.state_vec.system_dim)) self.locations = { coord_name: xr.DataArray( - rng.choice( + np.sort(rng.choice( self.state_vec[coord_name], size=self.random_location_count, replace=False, - shuffle=False), + shuffle=False)), dims=['observations']) for coord_name in self._nontime_coord_names } @@ -431,6 +431,7 @@ def observe(self): # For passing to ObsVector full_loc_indices = self.location_indices + # loc_indices = xr.where(self.state_vec) obs_vec = self.state_vec.sel(time=self.times).sel(self.locations) obs_vec = obs_vec.assign_coords(variable = list(obs_vec.data_vars)) From e41e60d2091d00035474a099b32eef856ab51479 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Mon, 23 Sep 2024 13:27:24 -0600 Subject: [PATCH 08/44] Data times are stored as base numpy arrays, since xarray coords cannot be jax --- dabench/data/_utils.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/dabench/data/_utils.py b/dabench/data/_utils.py index f119674..41ce004 100644 --- a/dabench/data/_utils.py +++ b/dabench/data/_utils.py @@ -31,10 +31,7 @@ def integrate(function, x0, t_final, delta_t, method='odeint', stride=None, """ if method == 'odeint': # Define timesteps - if jax_comps: - t = jnp.arange(0.0, t_final, delta_t) - else: - t = np.arange(0.0, t_final, delta_t) + t = np.arange(0.0, t_final, delta_t) # If stride is defined, remove timesteps that are not on stride steps if stride is not None: assert stride > 1 and isinstance(stride, int), \ From 4e8b97098c4267ccdcd3985be0332df6c595c71d Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Mon, 23 Sep 2024 13:27:44 -0600 Subject: [PATCH 09/44] Reinserting netcdf utils into data class --- dabench/data/_data.py | 54 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 51 insertions(+), 3 deletions(-) diff --git a/dabench/data/_data.py b/dabench/data/_data.py index a9ef6c8..9950e8f 100644 --- a/dabench/data/_data.py +++ b/dabench/data/_data.py @@ -162,7 +162,8 @@ def generate(self, n_steps=None, t_final=None, x0=None, M0=None, **kwargs) # Convert to JAX if necessary - out_dim = (t.shape[0],) + self.original_dim + self.time_dim = t.shape[0] + out_dim = (self.time_dim,) + self.original_dim if self.store_as_jax: y_out = jnp.array(y[:,:self.system_dim].reshape(out_dim)) else: @@ -176,8 +177,7 @@ def generate(self, n_steps=None, t_final=None, x0=None, M0=None, {self.var_names[0]: (coord_dict.keys(),y_out)}, coords=coord_dict, attrs={'store_as_jax':self.store_as_jax, - 'system_dim': self.system_dim, - 'time_dim': self.time_dim + 'system_dim': self.system_dim } ) @@ -346,3 +346,51 @@ def calc_lyapunov_exponents_final(self, total_time=None, rescale_time=1, rescale_time=rescale_time, x0=x0, convergence=convergence)[-1] + + def load_netcdf(self, filepath=None, include_vars=None, exclude_vars=None, + years_select=None, dates_select=None, + lat_sorting='descending'): + """Loads values from netCDF file, saves them in values attribute + + Args: + filepath (str): Path to netCDF file to load. If not given, + defaults to loading ERA5 ECMWF SLP data over Japan + from 2018 to 2021. + include_vars (list-like): Data variables to load from NetCDF. If + None (default), loads all variables. Can be used to exclude bad + variables. + exclude_vars (list-like): Data variabes to exclude from NetCDF + loading. If None (default), loads all vars (or only those + specified in include_vars). It's recommended to only specify + include_vars OR exclude_vars (unless you want to do extra + typing). + years_select (list-like): Years to load (ints). If None, loads all + timesteps. + dates_select (list-like): Dates to load. Elements must be + datetime date or datetime objects, depending on type of time + indices in NetCDF. If both years_select and dates_select + are specified, time_stamps overwrites "years" argument. If + None, loads all timesteps. + lat_sorting (str): Orient data by latitude: + descending (default), ascending, or None (uses orientation + from data file). + """ + if filepath is None: + # Use importlib.resources to get the default netCDF from dabench + filepath = resources.files(_suppl_data).joinpath('era5_japan_slp.nc') + return xr.open_dataset(filepath, decode_coords='all', engine='scipy').as_numpy() + # self._import_xarray_ds( + # ds, include_vars=include_vars, + # exclude_vars=exclude_vars, + # years_select=years_select, dates_select=dates_select, + # lat_sorting=lat_sorting) + + def save_netcdf(self, ds, filename): + """Saves values in values attribute to netCDF file + + Args: + ds (Xarray Dataset): Xarray dataset + filepath (str): Path to netCDF file to save + """ + + ds.to_netcdf(filename, mode='w') \ No newline at end of file From 99b1b3cd33aa56774f57b5015124b2647b959266 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Mon, 23 Sep 2024 13:27:58 -0600 Subject: [PATCH 10/44] gcp with xarray --- dabench/data/gcp.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dabench/data/gcp.py b/dabench/data/gcp.py index 632d431..03e6a5f 100644 --- a/dabench/data/gcp.py +++ b/dabench/data/gcp.py @@ -112,14 +112,14 @@ def _load_gcp_era5(self): # Subset by lon boundaries ds = ds.sel(longitude=slice(subset_min_lon, subset_max_lon)) - self._import_xarray_ds(ds) + return ds def generate(self): """Alias for _load_gcp_era5""" warnings.warn('GCP.generate() is an alias for the load() method. ' 'Proceeding with downloading ERA5 data from GCP...') - self._load_gcp_era5() + return self._load_gcp_era5() def load(self): """Alias for _load_gcp_era5""" - self._load_gcp_era5() + return self._load_gcp_era5() From 4d2e1959469302ed9dfcb0e28e3270f458eb777e Mon Sep 17 00:00:00 2001 From: Steve Penny Date: Mon, 23 Sep 2024 15:19:59 -0600 Subject: [PATCH 11/44] adding capbility to generate ensemble from era5 data on gcp --- dabench/dasupport/generate_era5_ensemble.py | 217 ++++++++++++++++++++ 1 file changed, 217 insertions(+) create mode 100644 dabench/dasupport/generate_era5_ensemble.py diff --git a/dabench/dasupport/generate_era5_ensemble.py b/dabench/dasupport/generate_era5_ensemble.py new file mode 100644 index 0000000..03cf9d0 --- /dev/null +++ b/dabench/dasupport/generate_era5_ensemble.py @@ -0,0 +1,217 @@ +# Sample a series of initial conditions from era5 in order to generate a test initial ensemble + +import argparse + +# For converting strings into datetime objects +from datetime import datetime, timedelta + +# Interface to Google Cloud Services +import gcsfs +import xarray as xr +from dateutil.relativedelta import relativedelta + +from helpers.constants import ERA5_CONTROL_VARIABLES +from helpers.timing import report_timing + + +#%% Parse arguments +def parse_arguments(): + parser = argparse.ArgumentParser(description="Process command line inputs.") + + # Define the arguments + parser.add_argument( + "--atmosphere_ensemble_s3_key", + type=str, + required=True, + default=None, + help="The s3 path for the ensemble zarr store.", + ) + parser.add_argument( + "--date_format", + type=str, + required=False, + default="%Y%m%dZ%H", + help="Date format. Default: %Y%m%dZ%H", + ) + parser.add_argument( + "--target_date", + type=str, + required=True, + default=None, #datetime.strptime(f"{YEAR}{MONTH}{DAY}Z{HOUR}",'%Y%m%dZ%H'), + help="Initialization date. Default format: %Y%m%dZ%H", + ) + parser.add_argument( + "--ensemble_size", + type=int, + required=True, + default=None, + help="Number of ensemble members", + ) + parser.add_argument( + "--sample_strategy", + type=str, + required=False, + default="consecutive_day", + help="{'multi_year'|'multi_month'|'consecutive_day'}", + ) + parser.add_argument( + "--start_date", + type=str, + required=True, + default=None, #datetime.strptime(f"{YEAR-1}{MONTH}{DAY}Z{HOUR}",'%Y%m%dZ%H'), + help="Date to start backwards count for sample strategy", + ) + parser.add_argument( + "--era5_path", + type=str, + required=False, + default="gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3", + help="Cloud-based source of the ERA5 dataset to access as ensemble members.", + ) + # Parse the arguments + args = parser.parse_args() + return args + + +#%% Define the initial ensemble + + +def define_init_ensemble( + ensemble_size, init_ensemble_start_date, init_ensemble_sample_strategy="multi_year" +): + + if init_ensemble_sample_strategy == "multi_year": + increment = relativedelta(years=1) + elif init_ensemble_sample_strategy == "multi_month": + increment = timedelta(months=1) + elif init_ensemble_sample_strategy == "consecutive_day": + increment = timedelta(days=1) + else: + raise Exception( + f"Not a valid init_ensemble_sampling_strategy = {init_ensemble_sample_strategy}" + ) + + init_ensemble_member_dates = [] + for i in range(ensemble_size): + init_ensemble_member_dates.append(init_ensemble_start_date - i * increment) + + print(f"ensemble member init date list = {init_ensemble_member_dates}") + + return init_ensemble_member_dates + + +def main( + date_format:str="%Y%m%dZ%H", + atmosphere_ensemble_s3_key:str=None, + target_date:datetime=datetime.strptime("19990101Z00","%Y%m%dZ%H"), + sample_strategy:str="consecutive_day", + start_date:datetime=datetime.strptime("19981231Z00","%Y%m%dZ%H"), + ensemble_size:int=4, + era5_path:str="gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3", + ): + + #%% Set up the gcp access to era5 + if era5_path[0:2] == "gs": + gcs = gcsfs.GCSFileSystem(token="anon") + ds_era5 = xr.open_zarr(gcs.get_mapper(era5_path), chunks=None) + else: + raise Exception("Non-GCP source for ERA5 not yet supported. EXITING...") + report_timing(timing_label="build_test_ensemble_era5:: access remote zarr store") + + #%% Reorder the latitudes + # Following: + # https://stackoverflow.com/questions/54677161/xarray-reverse-an-array-along-one-coordinate + # (ECMWF latitudes are often stored N to S instead of - to +) + ds_era5 = ds_era5.isel(latitude=slice(None, None, -1)) + print(ds_era5.latitude) + assert ds_era5.latitude[0] < ds_era5.latitude[-1] + + #%% Determine dates for initial ensemble sampling + init_ensemble_member_dates = define_init_ensemble( + ensemble_size=ensemble_size, + init_ensemble_start_date=start_date, + init_ensemble_sample_strategy=sample_strategy, + ) + + #%% Now sample the selection from era5 and put into new zarr store on s3 + + #%% Sample from era5 + ds_init_ens = ds_era5[ERA5_CONTROL_VARIABLES].sel(time=init_ensemble_member_dates) + report_timing( + timing_label="build_test_ensemble_era5:: select time steps as ensemble members" + ) + print(ds_init_ens) + + #%% Update time to target and add ensemble dimension + ds_init_ens = ds_init_ens.rename_dims(dims_dict={"time": "member"}) + ds_init_ens["member"] = range(ensemble_size) + ds_init_ens = ds_init_ens.drop_vars("time") + report_timing( + timing_label="build_test_ensemble_era5:: add member dimension to replace time" + ) + print(ds_init_ens) + + #%% Select target date from era5 for recentering the ensemble + ds_target = ds_era5[ERA5_CONTROL_VARIABLES].sel(time=target_date) + + #%% Compute the 10m diagnostic wind speed and 10m neutral wind speed + if ('10m_u_component_of_neutral_wind' in ERA5_CONTROL_VARIABLES and + '10m_v_component_of_neutral_wind' in ERA5_CONTROL_VARIABLES): + ds_init_ens['ws10n'] = (ds_init_ens['10m_u_component_of_neutral_wind']**2 + ds_init_ens['10m_v_component_of_neutral_wind']**2)**(0.5) + ds_target['ws10n'] = (ds_target['10m_u_component_of_neutral_wind']**2 + ds_target['10m_v_component_of_neutral_wind']**2)**(0.5) + report_timing( + timing_label="build_test_ensemble_era5:: computing neutral wind speeds at 10m (ws10n)" + ) + if ('10m_u_component_of_wind' in ERA5_CONTROL_VARIABLES and + '10m_v_component_of_wind' in ERA5_CONTROL_VARIABLES): + ds_init_ens['ws10m'] = (ds_init_ens['10m_u_component_of_wind']**2 + ds_init_ens['10m_v_component_of_wind']**2)**(0.5) + ds_target['ws10m'] = (ds_target['10m_u_component_of_wind']**2 + ds_target['10m_v_component_of_wind']**2)**(0.5) + report_timing( + timing_label="build_test_ensemble_era5:: computing diagnostic wind speeds at 10m (ws10m)" + ) + + #%% Recenter ensemble to target date + print(f'build_test_ensemble_era5:: re-centering ensemble with ensemble_size = {ensemble_size} to target_date = {target_date}...') + ds_mean = ds_init_ens.mean(dim="member") + ds_diff = ds_target - ds_mean + ds_init_ens = ds_init_ens + ds_diff + report_timing( + timing_label="build_test_ensemble_era5:: recenter ensemble to target date" + ) + print(ds_init_ens) + + #%% Now add time back on as a singleton dimension + ds_init_ens = ds_init_ens.expand_dims(dim={"time": [target_date]}, axis=0) + report_timing( + timing_label="build_test_ensemble_era5:: add time dimension back on to dataset structure" + ) + print(ds_init_ens) + + #%% Add some checks to make sure dimensions haven't changed + assert ds_era5.sizes['latitude'] == ds_init_ens.sizes['latitude'] + assert ds_era5.sizes['longitude'] == ds_init_ens.sizes['longitude'] + assert ds_era5.sizes['level'] == ds_init_ens.sizes['level'] + + #%% Upload to s3 as zarr + print('Uploading to s3 zarr...') + ds_init_ens.to_zarr(atmosphere_ensemble_s3_key, mode="w") + report_timing( + timing_label="build_test_ensemble_era5:: upload to s3 as a new zarr store" + ) + + +#%% Main access +if __name__ == "__main__": + args = parse_arguments() + + # %% Process input arguments + report_timing(timing_label="build_test_ensemble_era5:: initializing...") + + main( + date_format=args.date_format, + atmosphere_ensemble_s3_key=args.atmosphere_ensemble_s3_key, + target_date=args.target_date, + sample_strategy=args.sample_strategy, + start_date=args.start_date, + ensemble_size=args.ensemble_size, + era5_path=args.e From 3a1506fbaf287b02d4ed75ad45e8784b6c8c9ed2 Mon Sep 17 00:00:00 2001 From: Steve Penny Date: Mon, 23 Sep 2024 15:27:52 -0600 Subject: [PATCH 12/44] added era5 var selection to main file --- dabench/dasupport/generate_era5_ensemble.py | 32 ++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/dabench/dasupport/generate_era5_ensemble.py b/dabench/dasupport/generate_era5_ensemble.py index 03cf9d0..50dcb6b 100644 --- a/dabench/dasupport/generate_era5_ensemble.py +++ b/dabench/dasupport/generate_era5_ensemble.py @@ -10,9 +10,39 @@ import xarray as xr from dateutil.relativedelta import relativedelta -from helpers.constants import ERA5_CONTROL_VARIABLES from helpers.timing import report_timing +# Selected vars for ERA5 ensemble +# This will reduce the number of model fields processed and stored in the ensemble +# A number of these fields are used, for example, by the Google Research NeuralGCM, +# while additional variables are added to support DA of surface satellite observations. +ERA5_CONTROL_VARIABLES = [ + 'geopotential', + 'temperature', + 'specific_humidity', + 'u_component_of_wind', + 'v_component_of_wind', + 'specific_cloud_ice_water_content', + 'specific_cloud_liquid_water_content', + 'surface_pressure', + 'sea_surface_temperature', + 'sea_ice_cover', + # additional variables for DA support: 10m wind speed, u/v neutral winds at 10m + # (wind speed is precomputed upon ensemble file generation) + '10m_u_component_of_wind', + '10m_v_component_of_wind', + '10m_u_component_of_neutral_wind', + '10m_v_component_of_neutral_wind', + 'significant_height_of_combined_wind_waves_and_swell', + 'mean_wave_direction', + 'mean_wave_period', + 'geopotential_at_surface' +] +# From ECMWF docs (for wave parameters): +# https://codes.ecmwf.int/grib/param-db/140229 +# https://codes.ecmwf.int/grib/param-db/140230 +# https://codes.ecmwf.int/grib/param-db/140232 + #%% Parse arguments def parse_arguments(): From 9eddcf1ad5d686ec7d3a393183a8085b9754398f Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Mon, 23 Sep 2024 16:35:26 -0600 Subject: [PATCH 13/44] ETKF can handle irregular obs now, but doesn't have proper indices info --- dabench/dacycler/_etkf.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/dabench/dacycler/_etkf.py b/dabench/dacycler/_etkf.py index 9c6c264..aa0cf15 100644 --- a/dabench/dacycler/_etkf.py +++ b/dabench/dacycler/_etkf.py @@ -316,17 +316,14 @@ def cycle(self, if obs_vector.stationary_observers: self._obs_loc_masks = jnp.ones( obs_vector[self._observed_vars].to_array().shape, dtype=bool) - cur_state, all_values = jax.lax.scan( - self._cycle_and_forecast, - xj.from_xarray(input_state), - all_filtered_padded) else: - obs_vals, obs_locs, self._obs_loc_masks = dac_utils._pad_obs_locs(obs_vector) - cur_state, all_values = jax.lax.scan( - self._cycle_and_forecast, - (input_state.values, obs_vals, obs_vector.times, - obs_locs, obs_loc_masks, obs_error_sd), - all_filtered_padded) + self._obs_loc_masks = ~np.isnan( + obs_vector[self._observed_vars].to_array().data)[0] + self._obs_vector=self._obs_vector.fillna(0) + cur_state, all_values = jax.lax.scan( + self._cycle_and_forecast, + xj.from_xarray(input_state), + all_filtered_padded) all_vals_xr = xr.Dataset( From bffcfbc4567037d8c34f05146b2d7715c3f23fa5 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Mon, 23 Sep 2024 18:12:33 -0600 Subject: [PATCH 14/44] Use system_index variable for H --- dabench/dacycler/_etkf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dabench/dacycler/_etkf.py b/dabench/dacycler/_etkf.py index aa0cf15..13901d3 100644 --- a/dabench/dacycler/_etkf.py +++ b/dabench/dacycler/_etkf.py @@ -230,7 +230,7 @@ def _cycle_and_forecast(self, cur_state, filtered_idx): # 2. Calculate analysis cur_obs_vals = jnp.array(self._obs_vector[self._observed_vars].to_array().data).at[:, filtered_idx].get() - cur_obs_loc_indices = jnp.array(self._obs_vector.indices.data).at[filtered_idx].get() + cur_obs_loc_indices = jnp.array(self._obs_vector.system_index.data).at[:, filtered_idx].get() cur_obs_loc_mask = jnp.array(self._obs_loc_masks).at[filtered_idx].get().astype(bool) analysis = self._step_cycle( cur_state, From 0cbd940d2a3706b9add5fb4ca0dfe4c1095a275e Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Mon, 23 Sep 2024 18:13:29 -0600 Subject: [PATCH 15/44] Observer adds integer indices for easier H calculation --- dabench/observer/_observer.py | 321 ++++++++++------------------------ 1 file changed, 93 insertions(+), 228 deletions(-) diff --git a/dabench/observer/_observer.py b/dabench/observer/_observer.py index 574ceee..a0d92c4 100644 --- a/dabench/observer/_observer.py +++ b/dabench/observer/_observer.py @@ -74,14 +74,10 @@ def __init__(self, state_vec, random_time_density=1., random_location_density=1., - random_variable_density=1., random_time_count=None, random_location_count=None, - random_variable_count=None, times=None, locations=None, - time_indices=None, - location_indices=None, stationary_observers=True, error_bias=0., error_sd=0., @@ -94,21 +90,28 @@ def __init__(self, self._coord_names = list(self.state_vec.coords.keys()) self._nontime_coord_names = [coord for coord in self._coord_names if coord != 'time'] + self.state_vec = self.state_vec.assign_coords( + {'variable': self.state_vec.data_vars} + # 'variable_index': np.arange(len(self.state_vec.data_vars))} + ) + self.state_vec = self.state_vec.assign_coords( + {'{}_index'.format(coord): (coord, np.arange(self.state_vec.sizes[coord])) + for coord in self._coord_names} + ) + self.state_vec = self.state_vec.assign( + {'system_index': (['variable'] + ['time'] + self._nontime_coord_names, + np.tile(np.arange(self.state_vec.system_dim), self.state_vec.sizes['time']).reshape( + self.state_vec.to_array().shape + ))} + ) - # if time_indices is not None: - # time_indices = np.array(time_indices) - # self.time_indices = time_indices if times is not None: times = np.array(times) self.times = times + self.random_time_density = random_time_density self.random_time_count = random_time_count - # if location_indices is not None: - # location_indices = np.array(location_indices) - # self.location_indices = location_indices - # if locations is not None: - # locations = np.array(locations) self.locations = locations self.random_location_density = random_location_density self.random_location_count = random_location_count @@ -129,6 +132,7 @@ def __init__(self, self.error_bias = error_bias self.error_sd = error_sd + if isinstance(self.error_bias, (list, np.ndarray, jnp.ndarray)): if len(self.error_bias) == 1: self._error_bias_is_list = False @@ -169,7 +173,7 @@ def __init__(self, self.error_positive_only = error_positive_only - def _generate_time_indices(self, rng): + def _generate_times(self, rng): if self.random_time_count is not None: self.times = np.sort(rng.choice( self.state_vec['time'], @@ -183,7 +187,7 @@ def _generate_time_indices(self, rng): ).astype('bool') )[0] - def _generate_stationary_indices(self, rng): + def _generate_stationary_locs(self, rng): if self.random_location_count is not None: location_count = self.random_location_count else: @@ -195,7 +199,7 @@ def _generate_stationary_indices(self, rng): coord_name: xr.DataArray( np.sort(rng.choice( self.state_vec[coord_name], - size=self.random_location_count, + size=location_count, replace=False, shuffle=False)), dims=['observations']) @@ -203,89 +207,35 @@ def _generate_stationary_indices(self, rng): } self.location_dim = location_count - def _generate_nonstationary_indices(self, rng): + def _generate_nonstationary_locs(self, rng): + """Generate different locations for each observation time""" if self.random_location_count is not None: - self.location_indices = np.array([ - rng.choice( - self.state_vec.system_dim, - size=self.random_location_count, - replace=False, - shuffle=False) - for i in range(self.time_indices.shape[0])]) + self._location_counts = np.repeat( + self.random_location_count, self.times.shape[0] + ) else: - self.location_indices = np.array([ - np.where( - rng.binomial(1, p=self.random_location_density, - size=self.state_vec.system_dim - ).astype('bool'))[0] - for i in range(self.time_indices.shape[0]) - ], dtype=object) - - def _generate_stationary_indices_gridded(self, rng): - if self.random_location_count is not None: - arange_list = [np.arange(n) for n in self.state_vec.original_dim] - ind_possibilities = np.array( - np.meshgrid(*arange_list)).T.reshape( - -1, len(self.state_vec.original_dim)) - self.location_indices = rng.choice( - ind_possibilities, - size=self.random_location_count, + # An unequal amount of locations per time + self._location_counts = [np.sum( + rng.binomial(1, + p=self.random_location_density, + size=self.state_vec.system_dim) + ) + for i in range(self.times.shape[0])] + + self.locations = [{ + coord_name: xr.DataArray( + np.sort(rng.choice( + self.state_vec[coord_name], + size=lc, replace=False, shuffle=False) - else: - self.location_indices = np.array(np.where( - rng.binomial(1, p=self.random_location_density, - size=self.state_vec.original_dim - ).astype('bool') - )).T + ), + dims=['observations']) + for coord_name in self._nontime_coord_names + } + for lc in self._location_counts] - def _generate_nonstationary_indices_gridded(self, rng): - if self.random_location_count is not None: - arange_list = [np.arange(n) for n in self.state_vec.original_dim] - ind_possibilities = np.array( - np.meshgrid(*arange_list)).T.reshape( - -1, len(self.state_vec.original_dim)) - self.location_indices = np.array([rng.choice( - ind_possibilities, - size=self.random_location_count, - replace=False, - shuffle=False) for i in range(self.time_indices.shape[0])]) - else: - self.location_indices = np.array([ - np.array(np.where( - rng.binomial(1, p=self.random_location_density, - size=self.state_vec.original_dim - ).astype('bool'))).T - for i in range(self.time_indices.shape[0]) - ], dtype=object) - - def _sample_stationary(self, errors_vector, sample_in_system_dim): - if sample_in_system_dim: - values_vector = ( - self.state_vec.values[self.time_indices][ - :, self.location_indices] - + errors_vector) - else: - values_gridded = self.state_vec.values_gridded - values_vector = np.array([ - values_gridded[t][tuple(self.location_indices.T)] - for t in self.time_indices]) + errors_vector - return values_vector - - def _sample_nonstationary(self, errors_vector, sample_in_system_dim): - if sample_in_system_dim: - values_vector = np.array([ - (self.state_vec.values[self.time_indices[i]] - [self.location_indices[i]] + errors_vector[i]) - for i in range(self.time_dim)], dtype=object) - else: - values_gridded = self.state_vec.values_gridded - values_vector = np.array( - [values_gridded[self.time_indices[i]][ - tuple(self.location_indices[i].T)] - + errors_vector[i] for i in range(self.time_dim)], - dtype=object) - return values_vector + self.location_dim = np.max(self._location_counts) def observe(self): """Generate observations. @@ -300,158 +250,73 @@ def observe(self): # Set time indices if self.times is None: - self._generate_time_indices(rng) + self._generate_times(rng) self.time_dim = self.times.shape[0] # For stationary observers (default) if self.stationary_observers: - # Generate location_indices if not specified + # Generate locations if not specified if self.locations is None: - # Check if data is in spectral or physical space - if (hasattr(self.state_vec, 'is_spectral') and - self.state_vec.is_spectral): - self._generate_stationary_indices_gridded(rng) - else: - self._generate_stationary_indices(rng) - - # # Check that location_indices are in correct dimensions - # if self.locations.shape[0] == 0: - # raise ValueError('locations is an empty list') - # elif len(self.locations.shape) == 1: - # sample_in_system_dim = True - # elif (self.locations.shape[1] == - # len(self.state_vec.original_dim)): - # sample_in_system_dim = False - # else: - # raise ValueError('locations must be 1D or match\n' - # 'len(self.state_vec.original_dim)') - - # self.location_dims = tuple([v.shape[0] for k, v in self.locations.items()]) - - self.location_dim = next(iter(self.locations.items()))[1] ['observations'].size - - # Generate errors - errors_vec_size = ((self.time_dim,) - + (self.location_dim,) - + (len(self.state_vec.data_vars),)) - if self._error_bias_is_list: - error_bias = self.error_bias[self.location_indices] - else: - error_bias = self.error_bias - if self._error_sd_is_list: - error_sd = self.error_sd[self.location_indices] + self._generate_stationary_locs(rng) else: - error_sd = self.error_sd - errors_vector = rng.normal(loc=error_bias, - scale=error_sd, - size=errors_vec_size) + self.location_dim = next(iter(self.locations.items()))[1]['observations'].size - # # Clip errors to positive only - if self.error_positive_only: - errors_vector[errors_vector < 0.] = 0. - # # Get values - # values_vector = self._sample_stationary( - # errors_vector, - # sample_in_system_dim) - - # # Repeat location indices across time_dim for passing to ObsVector - # full_loc_indices = np.array( - # [self.location_indices] * self.time_dim) + # Sample + obs_vec = self.state_vec.sel(time=self.times).sel(self.locations) # If NON-stationary observer else: # Generate location_indices if not specified - if self.location_indices is None: - # Check if data is in spectral or physical space - if (hasattr(self.state_vec, 'is_spectral') and - self.state_vec.is_spectral): - self._generate_nonstationary_indices_gridded(rng) - else: - self._generate_nonstationary_indices(rng) - - # Check that location_indices are in correct dimensions - if self.location_indices.shape[0] == 0: - raise ValueError('location_indices is an empty list') - elif len(self.location_indices[0].shape) == 1: - sample_in_system_dim = True - elif (self.location_indices[0].shape[1] == - len(self.data_obj.original_dim)): - sample_in_system_dim = False - else: - raise ValueError('With stationary_observers=False,' - 'location_indices must be 1D array of arrays,' - ' with each element being 1D or matching\n' - 'self.data_obj.original_dim') - self.location_dim = np.array([a.shape[0] for a in - self.location_indices]) - - # Generate errors - if self._error_bias_is_list: - if self._error_sd_is_list: - errors_vector = np.array([ - rng.normal( - loc=self.error_bias[ld], - scale=self.error_sd[ld], - size=ld) - for ld in self.location_dim], dtype=object) - else: - errors_vector = np.array([ - rng.normal( - loc=self.error_bias[ld], - scale=self.error_sd, - size=ld) - for ld in self.location_dim], dtype=object) - else: - if self._error_sd_is_list: - errors_vector = np.array([ - rng.normal( - loc=self.error_bias, - scale=self.error_sd[ld], - size=ld) - for ld in self.location_dim], dtype=object) - else: - errors_vector = np.array([ - rng.normal( - loc=self.error_bias, - scale=self.error_sd, - size=ld) - for ld in self.location_dim], dtype=object) - - if self.error_positive_only: - errors_vector = np.array([ - np.maximum(e, 0.) for e in errors_vector]) - - # Get values from generator - values_vector = self._sample_nonstationary( - errors_vector, - sample_in_system_dim) + if self.locations is None: + self._generate_nonstationary_locs(rng) + + # If there's an unequal number of obs, will pad + pad_widths = self.location_dim - np.array(self._location_counts) + + # Sample + obs_vec = xr.concat([ + # Select by time + self.state_vec.sel( + time=t + # Select locations + ).sel( + self.locations[i] + # Pad observations to max number + ).pad( + observations=(0, pad_widths[i]) + ) + for i, t in enumerate(self.times)], + dim='time') + + # Generate errors + errors_vec_size = ((self.time_dim,) + + (self.location_dim,) + + (obs_vec.sizes['variable'],)) + if self._error_bias_is_list: + error_bias = self.error_bias[self.location_indices] + else: + error_bias = self.error_bias + if self._error_sd_is_list: + error_sd = self.error_sd[self.location_indices] + else: + error_sd = self.error_sd + errors_vector = rng.normal(loc=error_bias, + scale=error_sd, + size=errors_vec_size) - # For passing to ObsVector - full_loc_indices = self.location_indices + # Clip errors to positive only + if self.error_positive_only: + errors_vector[errors_vector < 0.] = 0. # loc_indices = xr.where(self.state_vec) - obs_vec = self.state_vec.sel(time=self.times).sel(self.locations) - - obs_vec = obs_vec.assign_coords(variable = list(obs_vec.data_vars)) + # obs_vec = obs_vec.assign_coords(variable = list(obs_vec.data_vars)) + # print(errors_vector.shape) + # print(obs_vec) + print(obs_vec.dims) obs_vec = obs_vec.assign(errors=(obs_vec.dims, errors_vector)) for data_var in obs_vec['variable'].values: obs_vec[data_var] = obs_vec[data_var] + obs_vec['errors'].sel(variable=data_var) - return obs_vec - - # return self.state_vec.sel(time=self.times).sel(self.locations) + errors_ve - # return ObsVector(values=values_vector, - # times=self.data_obj.times[self.time_indices], - # time_indices=self.time_indices, - # location_indices=full_loc_indices, - # obs_dims=self.location_dim, - # num_obs=values_vector.shape[0], - # errors=errors_vector, - # error_dist='normal', - # error_sd=self.error_sd, - # error_bias=self.error_bias, - # store_as_jax=self.store_as_jax, - # stationary_observers=self.stationary_observers - # ) + return obs_vec \ No newline at end of file From e023257616fc4ec8f9ba6a4db58c31a5a3919200 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Wed, 25 Sep 2024 10:26:47 -0600 Subject: [PATCH 16/44] Allow datavars to not match observed vars --- dabench/dacycler/_etkf.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dabench/dacycler/_etkf.py b/dabench/dacycler/_etkf.py index 13901d3..300e18b 100644 --- a/dabench/dacycler/_etkf.py +++ b/dabench/dacycler/_etkf.py @@ -252,7 +252,8 @@ def cycle(self, obs_error_sd=None, analysis_window=0.2, analysis_time_in_window=None, - return_forecast=False): + return_forecast=False + ): """Perform DA cycle repeatedly, including analysis and forecast Args: @@ -278,7 +279,7 @@ def cycle(self, # These could be different if observer doesn't observe all variables # For now, making them the same self._observed_vars = obs_vector['variable'].values - self._data_vars = self._observed_vars + self._data_vars = list(input_state.data_vars) if obs_error_sd is None: obs_error_sd = obs_vector.error_sd From 6e9de1ab52cad260de4e54631251c129a4e34a1d Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Wed, 25 Sep 2024 11:38:01 -0600 Subject: [PATCH 17/44] Updated system index and remove sort from the observer --- dabench/observer/_observer.py | 43 ++++++++++++++++++++++------------- 1 file changed, 27 insertions(+), 16 deletions(-) diff --git a/dabench/observer/_observer.py b/dabench/observer/_observer.py index a0d92c4..950dba5 100644 --- a/dabench/observer/_observer.py +++ b/dabench/observer/_observer.py @@ -98,11 +98,19 @@ def __init__(self, {'{}_index'.format(coord): (coord, np.arange(self.state_vec.sizes[coord])) for coord in self._coord_names} ) + # The system_index corresponds to the points location in a flattened + # array (i.e. state_vec[state_vec.data_vars].to_array().data.flatten()) self.state_vec = self.state_vec.assign( - {'system_index': (['variable'] + ['time'] + self._nontime_coord_names, - np.tile(np.arange(self.state_vec.system_dim), self.state_vec.sizes['time']).reshape( - self.state_vec.to_array().shape - ))} + {'system_index': ( + ['variable'] + ['time'] + self._nontime_coord_names, + np.tile( + np.arange(self.state_vec.system_dim).reshape( + self.state_vec.sizes['variable'], -1 + ), + self.state_vec.sizes['time'] + ).reshape(self.state_vec.to_array().shape) + ) + } ) if times is not None: @@ -197,11 +205,11 @@ def _generate_stationary_locs(self, rng): size=self.state_vec.system_dim)) self.locations = { coord_name: xr.DataArray( - np.sort(rng.choice( + rng.choice( self.state_vec[coord_name], size=location_count, replace=False, - shuffle=False)), + shuffle=False), dims=['observations']) for coord_name in self._nontime_coord_names } @@ -224,13 +232,12 @@ def _generate_nonstationary_locs(self, rng): self.locations = [{ coord_name: xr.DataArray( - np.sort(rng.choice( + rng.choice( self.state_vec[coord_name], size=lc, replace=False, - shuffle=False) - ), - dims=['observations']) + shuffle=False), + dims=['observations']) for coord_name in self._nontime_coord_names } for lc in self._location_counts] @@ -306,17 +313,21 @@ def observe(self): scale=error_sd, size=errors_vec_size) + # Include flag for whether observations are stationary or not + obs_vec = obs_vec.assign_attrs( + stationary_observers=self.stationary_observers) + # Clip errors to positive only if self.error_positive_only: errors_vector[errors_vector < 0.] = 0. - # loc_indices = xr.where(self.state_vec) - # obs_vec = obs_vec.assign_coords(variable = list(obs_vec.data_vars)) - # print(errors_vector.shape) - # print(obs_vec) - print(obs_vec.dims) + # Save errors and apply them to observations obs_vec = obs_vec.assign(errors=(obs_vec.dims, errors_vector)) - for data_var in obs_vec['variable'].values: obs_vec[data_var] = obs_vec[data_var] + obs_vec['errors'].sel(variable=data_var) + + # Transpose system_index to ensure consistency with flattened data + obs_vec['system_index'] = obs_vec['system_index'].transpose('variable','time','observations').fillna( + 0).astype(int) + return obs_vec \ No newline at end of file From 2f87b454f657add2339a641289259fcb938779ff Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Wed, 25 Sep 2024 18:10:17 -0600 Subject: [PATCH 18/44] Making some dacycler methods part of parent class for ease of maintenance --- dabench/dacycler/_dacycler.py | 121 ++++++++++++++++++++++------------ dabench/dacycler/_etkf.py | 117 ++------------------------------ 2 files changed, 85 insertions(+), 153 deletions(-) diff --git a/dabench/dacycler/_dacycler.py b/dabench/dacycler/_dacycler.py index f7a4f3a..44d4207 100644 --- a/dabench/dacycler/_dacycler.py +++ b/dabench/dacycler/_dacycler.py @@ -1,8 +1,12 @@ """Base class for Data Assimilation Cycler object (DACycler)""" -from dabench import vector import numpy as np +import jax.numpy as jnp +import jax +import xarray as xr +import xarray_jax as xj +import dabench.dacycler._utils as dac_utils class DACycler(): """Base class for DACycler object @@ -49,14 +53,31 @@ def __init__(self, self.delta_t = delta_t self.model_obj = model_obj + + def _step_forecast(self, xa, n_steps=1): + """Perform forecast using model object""" + return self.model_obj.forecast(xa, n_steps=n_steps) + + def _step_cycle(self, xb, obs_vals, obs_locs, obs_time_mask, obs_loc_mask, + H=None, h=None, R=None, B=None, **kwargs): + if H is not None or h is None: + vals = self._cycle_obsop( + xb, obs_vals, obs_locs, obs_time_mask, + obs_loc_mask, H, R, B, **kwargs) + return vals + else: + return self._cycle_general_obsop(xb, yo, h, R, B) + def cycle(self, input_state, start_time, obs_vector, n_cycles, - analysis_window, + obs_error_sd=None, + analysis_window=0.2, analysis_time_in_window=None, - return_forecast=False): + return_forecast=False + ): """Perform DA cycle repeatedly, including analysis and forecast Args: @@ -79,50 +100,64 @@ def cycle(self, vector.StateVector of analyses and times. """ + # These could be different if observer doesn't observe all variables + # For now, making them the same + self._observed_vars = obs_vector['variable'].values + self._data_vars = list(input_state.data_vars) + + if obs_error_sd is None: + obs_error_sd = obs_vector.error_sd + self.analysis_window = analysis_window + # If don't specify analysis_time_in_window, is assumed to be middle if analysis_time_in_window is None: analysis_time_in_window = analysis_window/2 + # Steps per window + 1 to include start + self.steps_per_window = round(analysis_window/self.delta_t) + 1 + self._model_timesteps = jnp.arange(self.steps_per_window)*self.delta_t + # Time offset from middle of time window, for gathering observations _time_offset = (analysis_window/2) - analysis_time_in_window - # Number of model steps to run per window - steps_per_window = round(analysis_window/self.delta_t) + 1 - print(steps_per_window) - - # For storing outputs - all_output_states = [] - all_times = [] - cur_time = start_time - cur_state = input_state - - for i in range(n_cycles): - # 1. Filter observations to inside analysis window - window_middle = cur_time + _time_offset - window_start = window_middle - analysis_window/2 - window_end = window_middle + analysis_window/2 - obs_vec_timefilt = obs_vector.sel( - time=slice(window_start, window_end) - ) - - if obs_vec_timefilt.sizes['time'] > 0: - # 2. Calculate analysis - analysis, kh = self._step_cycle(cur_state, obs_vec_timefilt) - # 3. Forecast through analysis window - forecast_states = self._step_forecast(analysis, - n_steps=steps_per_window) - # 4. Save outputs - if return_forecast: - # Append forecast to current state, excluding last step - print(forecast_states) - all_output_states.append(forecast_states.isel(time=slice(0,steps_per_window-1))) - else: - all_output_states.append(analysis) - - # Starting point for next cycle is last step of forecast - cur_state = forecast_states.isel(time=steps_per_window-1) - print(cur_state) - cur_time += analysis_window - - return all_output_states - + # Set up for jax.lax.scan, which is very fast + all_times = dac_utils._get_all_times( + start_time, + analysis_window, + n_cycles) + + # Get the obs vectors for each analysis window + all_filtered_idx = dac_utils._get_obs_indices( + obs_times=jnp.array(obs_vector.time.values), + analysis_times=all_times+_time_offset, + start_inclusive=True, + end_inclusive=False, + analysis_window=analysis_window + ) + input_state = input_state.assign(_cur_time=start_time) + + all_filtered_padded = dac_utils._pad_time_indices(all_filtered_idx, add_one=True) + self._obs_vector=obs_vector + self.obs_error_sd = obs_error_sd + if obs_vector.stationary_observers: + self._obs_loc_masks = jnp.ones( + obs_vector[self._observed_vars].to_array().shape, dtype=bool) + else: + self._obs_loc_masks = ~np.isnan( + obs_vector[self._observed_vars].to_array().data)[0] + self._obs_vector=self._obs_vector.fillna(0) + cur_state, all_values = jax.lax.scan( + self._cycle_and_forecast, + xj.from_xarray(input_state), + all_filtered_padded) + + all_vals_xr = xr.Dataset( + {var: (('cycle',) + tuple(all_values[var].dims), + all_values[var].data) + for var in all_values.data_vars} + ).rename_dims({'time': 'cycle_timestep'}) + + if return_forecast: + return all_vals_xr + else: + return all_vals_xr.isel(cycle_timestep=0) \ No newline at end of file diff --git a/dabench/dacycler/_etkf.py b/dabench/dacycler/_etkf.py index 300e18b..b069322 100644 --- a/dabench/dacycler/_etkf.py +++ b/dabench/dacycler/_etkf.py @@ -8,8 +8,7 @@ import xarray as xr import xarray_jax as xj -from dabench import dacycler, vector -import dabench.dacycler._utils as dac_utils +from dabench import dacycler class ETKF(dacycler.DACycler): @@ -63,15 +62,6 @@ def __init__(self, ensemble=True, B=B, R=R, H=H, h=h) - def _step_cycle(self, xb, obs_vals, obs_locs, obs_time_mask, obs_loc_mask, - H=None, h=None, R=None, B=None): - if H is not None or h is None: - vals = self._cycle_obsop( - xb, obs_vals, obs_locs, obs_time_mask, - obs_loc_mask, H, R, B) - return vals - else: - return self._cycle_general_obsop(xb, yo, h, R, B) def _calc_default_H(self, obs_values, obs_loc_indices): H = jnp.zeros((obs_values.flatten().shape[0], self.system_dim)) @@ -108,7 +98,6 @@ def _cycle_obsop(self, x0_xarray, obs_values, obs_loc_indices, else: B = self.B - x0_xarray = x0_xarray.to_xarray() Xbt = x0_xarray[self._data_vars].to_array().data[0] nr,nc = Xbt.shape assert nr == self.ensemble_dim, ( @@ -116,8 +105,8 @@ def _cycle_obsop(self, x0_xarray, obs_values, obs_loc_indices, self.ensemble_dim, self.system_dim) # Apply obs masks to H - # H = jnp.where(obs_time_mask, H.T, 0).T - # H = jnp.where(obs_loc_mask.flatten(), H.T, 0).T + # H = jnp.where(obs_time_mask.flatten(), H.T, 0).T + H = jnp.where(obs_loc_mask.flatten(), H.T, 0).T # Analysis cycles over all obs in data_obs Xa = self._compute_analysis(Xb=Xbt.T, @@ -222,9 +211,10 @@ def _compute_analysis(self, Xb, y, H, h, R, rho=1.0, Yb=None): def _cycle_and_forecast(self, cur_state, filtered_idx): # 1. Get data - # cur_state_vals = state[self.data_vars].data state_obs_tuple[0] - # 1-b. Calculate obs_time_mask and restore filtered_idx to original values + cur_state = cur_state.to_xarray() + cur_time = cur_state['_cur_time'].data + cur_state = cur_state.drop_vars(['_cur_time']) obs_time_mask = filtered_idx > 0 filtered_idx = filtered_idx - 1 @@ -241,99 +231,6 @@ def _cycle_and_forecast(self, cur_state, filtered_idx): ) # 3. Forecast next timestep next_state, forecast_states = self._step_forecast(analysis, n_steps=self.steps_per_window) + next_state = next_state.assign(_cur_time = cur_time + self.analysis_window) return xj.from_xarray(next_state), forecast_states - - def cycle(self, - input_state, - start_time, - obs_vector, - n_cycles, - obs_error_sd=None, - analysis_window=0.2, - analysis_time_in_window=None, - return_forecast=False - ): - """Perform DA cycle repeatedly, including analysis and forecast - - Args: - input_state (vector.StateVector): Input state. - start_time (float or datetime-like): Starting time. - obs_vector (vector.ObsVector): Observations vector. - n_cycles (int): Number of analysis cycles to run, each of length - analysis_window. - analysis_window (float): Time window from which to gather - observations for DA Cycle. - analysis_time_in_window (float): Where within analysis_window - to perform analysis. For example, 0.0 is the start of the - window. Default is None, which selects the middle of the - window. - return_forecast (bool): If True, returns forecast at each model - timestep. If False, returns only analyses, one per analysis - cycle. Default is False. - - Returns: - vector.StateVector of analyses and times. - """ - - # These could be different if observer doesn't observe all variables - # For now, making them the same - self._observed_vars = obs_vector['variable'].values - self._data_vars = list(input_state.data_vars) - - if obs_error_sd is None: - obs_error_sd = obs_vector.error_sd - self.analysis_window = analysis_window - - # If don't specify analysis_time_in_window, is assumed to be middle - if analysis_time_in_window is None: - analysis_time_in_window = analysis_window/2 - - # Steps per window + 1 to include start - self.steps_per_window = round(analysis_window/self.delta_t) + 1 - - # Time offset from middle of time window, for gathering observations - _time_offset = (analysis_window/2) - analysis_time_in_window - - # Set up for jax.lax.scan, which is very fast - all_times = dac_utils._get_all_times( - start_time, - analysis_window, - n_cycles) - - - # Get the obs vectors for each analysis window - all_filtered_idx = dac_utils._get_obs_indices( - obs_times=jnp.array(obs_vector.time.values), - analysis_times=all_times+_time_offset, - start_inclusive=True, - end_inclusive=False, - analysis_window=analysis_window - ) - - all_filtered_padded = dac_utils._pad_time_indices(all_filtered_idx, add_one=True) - self._obs_vector=obs_vector - self.obs_error_sd = obs_error_sd - if obs_vector.stationary_observers: - self._obs_loc_masks = jnp.ones( - obs_vector[self._observed_vars].to_array().shape, dtype=bool) - else: - self._obs_loc_masks = ~np.isnan( - obs_vector[self._observed_vars].to_array().data)[0] - self._obs_vector=self._obs_vector.fillna(0) - cur_state, all_values = jax.lax.scan( - self._cycle_and_forecast, - xj.from_xarray(input_state), - all_filtered_padded) - - - all_vals_xr = xr.Dataset( - {var: (('cycle',) + tuple(all_values[var].dims), - all_values[var].data) - for var in all_values.data_vars} - ).rename_dims({'time': 'cycle_timestep'}) - - if return_forecast: - return all_vals_xr - else: - return all_vals_xr.isel(cycle_timestep=0) \ No newline at end of file From ccf7e0316dab310f2a0817fb8a019cef95042876 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Wed, 25 Sep 2024 18:10:33 -0600 Subject: [PATCH 19/44] Initial version of xarray var4dBP, needs testing and cleaning --- dabench/dacycler/_var4d_backprop.py | 269 +++++++++------------------- 1 file changed, 83 insertions(+), 186 deletions(-) diff --git a/dabench/dacycler/_var4d_backprop.py b/dabench/dacycler/_var4d_backprop.py index 161f66a..4a72a6e 100644 --- a/dabench/dacycler/_var4d_backprop.py +++ b/dabench/dacycler/_var4d_backprop.py @@ -11,6 +11,8 @@ import jax import optax from functools import partial +import xarray as xr +import xarray_jax as xj from dabench import dacycler, vector import dabench.dacycler._utils as dac_utils @@ -79,6 +81,9 @@ def __init__(self, self.learning_rate = learning_rate self.lr_decay = lr_decay self.steps_per_window = steps_per_window + # if obs_window_indices is None: + # self.obs_window_indices + # else: self.obs_window_indices = obs_window_indices self.loss_growth_limit = loss_growth_limit @@ -135,12 +140,12 @@ def _make_loss(self, xb0, obs_vals, Hs, Binv, Rinv, @jax.jit def loss_4dvarcost(x0): # Get initial departure - db0 = (x0.ravel() - xb0.ravel()) + db0 = (x0.to_array().data.ravel() - xb0.to_array().data.ravel()) # Make new prediction - pred_x = self.step_forecast( - vector.StateVector(values=x0, store_as_jax=True), - n_steps).values + # NOTE: [1] selects the full forecast + pred_x = self._step_forecast( + x0, n_steps)[1]['x'].data # Calculate observation term of J_0 obs_term = 0 @@ -173,8 +178,10 @@ def _make_backprop_epoch(self, loss_func, optimizer, hessian_inv): @jax.jit def _backprop_epoch(epoch_state_tuple, i): x0, init_loss, opt_state = epoch_state_tuple + x0 = x0.to_xarray() loss_val, dx0 = loss_value_grad(x0) - dx0_hess = hessian_inv @ dx0 + x0_array = x0.to_stacked_array('system', []) + dx0_hess = hessian_inv @ dx0.to_stacked_array('system',[]).data init_loss = jax.lax.cond( i == 0, lambda: loss_val, @@ -186,17 +193,30 @@ def _backprop_epoch(epoch_state_tuple, i): lambda: loss_val) updates, opt_state = optimizer.update(dx0_hess, opt_state) - x0_new = optax.apply_updates(x0, updates) - - return (x0_new, init_loss, opt_state), loss_val + # x0_new = x0.assign(x=x0['x']-dx0_hess*optimizer.) + # updated_vals = optax.apply_updates( + # x0.to_stacked_array('system',[]).data, updates) + x0_array.data = optax.apply_updates( + x0_array.data, updates) + # print(updates[0]) + # updated_vals = jax.tree.map( + # lambda p, u: jnp.asarray(jnp.array(p)+jnp.array(u)).astype(jnp.asarray(p).dtype), + # x0['x'].data, updates + # ) + # x0_new = x0.assign(x=((x0.dims), updated_vals)) + # x0_array.data = updated_vals + x0_new = x0_array.to_unstacked_dataset('system').assign_attrs( + x0.attrs + ) + # return (xj.from_xarray(x0_new), init_loss, opt_state), loss_val + return (xj.from_xarray(x0_new), init_loss, opt_state), loss_val return _backprop_epoch - def _cycle_obsop(self, x0, obs_values, obs_loc_indices, obs_error_sd, + def _cycle_obsop(self, x0_xarray, obs_values, obs_loc_indices, obs_time_mask, obs_loc_mask, - H=None, h=None, R=None, B=None, obs_window_indices=None, - n_steps=1): + H=None, h=None, R=None, B=None, obs_window_indices=None): if H is None and h is None: if self.H is None: if self.h is None: @@ -221,7 +241,7 @@ def _cycle_obsop(self, x0, obs_values, obs_loc_indices, obs_error_sd, if R is None: if self.R is None: - R = self._calc_default_R(obs_values, obs_error_sd) + R = self._calc_default_R(obs_values, self.obs_error_sd) else: R = self.R @@ -231,6 +251,8 @@ def _cycle_obsop(self, x0, obs_values, obs_loc_indices, obs_error_sd, else: B = self.B + # x0 = x0_xarray + Rinv = jscipy.linalg.inv(R) Binv = jscipy.linalg.inv(B) @@ -240,205 +262,80 @@ def _cycle_obsop(self, x0, obs_values, obs_loc_indices, obs_error_sd, Binv + Hs.at[0].get().T @ Rinv @ Hs.at[0].get()) loss_func = self._make_loss( - x0, + x0_xarray, obs_values, Hs, Binv, Rinv, obs_window_indices, obs_time_mask, - n_steps=n_steps) + n_steps=self.steps_per_window) lr = optax.exponential_decay( self.learning_rate, 1, self.lr_decay) optimizer = optax.sgd(lr) - opt_state = optimizer.init(x0) + opt_state = optimizer.init(x0_xarray.to_stacked_array('system',[]).data) # Make initial forecast and calculate loss backprop_epoch_func = self._make_backprop_epoch(loss_func, optimizer, hessian_inv) epoch_state_tuple, loss_vals = jax.lax.scan( - backprop_epoch_func, init=(x0, 0., opt_state), + backprop_epoch_func, init=(xj.from_xarray(x0_xarray), 0., opt_state), xs=jnp.arange(self.num_iters)) - x0, init_loss, opt_state = epoch_state_tuple - - xa = self.step_forecast( - vector.StateVector(values=x0, store_as_jax=True), - n_steps=n_steps) - - return xa, loss_vals - - def step_cycle(self, xb, yo, obs_time_mask, obs_loc_mask, - obs_window_indices, H=None, h=None, R=None, B=None, - n_steps=1): - """Perform one step of DA Cycle""" - if H is not None or h is None: - return self._cycle_obsop( - xb.values, yo.values, yo.location_indices, yo.error_sd, - obs_time_mask=obs_time_mask, obs_loc_mask=obs_loc_mask, - H=H, R=R, B=B, - obs_window_indices=obs_window_indices, n_steps=n_steps) - else: - return self._cycle_obsop( - xb, yo, h, R, B, obs_window_indices=obs_window_indices, - n_steps=n_steps) - - def step_forecast(self, xa, n_steps=1): - """Perform forecast using model object""" - if 'n_steps' in inspect.getfullargspec(self.model_obj.forecast).args: - return self.model_obj.forecast(xa, n_steps=n_steps) - else: - if n_steps == 1: - return self.model_obj.forecast(xa) - else: - out = [xa] - xi = xa - for s in range(n_steps): - xi = self.model.forecast(xi) - out.append(xi) - return vector.StateVector(jnp.vstack(xi), store_as_jax=True) - - def _cycle_and_forecast(self, cur_state_vals_time_tuple, filtered_idx): - cur_state_vals, cur_time = cur_state_vals_time_tuple - obs_error_sd = self._obs_error_sd - - # Calculate obs_time_mask and restore filtered_idx to original values + x0_new = epoch_state_tuple[0].to_xarray() + + return x0_new + + def _cycle_and_forecast(self, cur_state, filtered_idx): + # 1. Get data + # 1-b. Calculate obs_time_mask and restore filtered_idx to original values + cur_state = cur_state.to_xarray() + cur_time = cur_state['_cur_time'].data + cur_state = cur_state.drop_vars(['_cur_time']) obs_time_mask = filtered_idx > 0 filtered_idx = filtered_idx - 1 - cur_obs_vals = jnp.array(self._obs_vector.values).at[filtered_idx].get() - cur_obs_loc_indices = jnp.array(self._obs_vector.location_indices).at[filtered_idx].get() - cur_obs_times = jnp.array(self._obs_vector.times).at[filtered_idx].get() - cur_obs_loc_mask = jnp.array(self._obs_loc_masks).at[filtered_idx].get().astype(bool) + # cur_obs_vals = jnp.array(self._obs_vector[self._observed_vars].to_array().data).at[0, filtered_idx].get() + cur_obs_vals = jnp.array(self._obs_vector[self._observed_vars].to_stacked_array('system',['time']).data).at[filtered_idx].get() + cur_obs_times = jnp.array(self._obs_vector.time.data).at[filtered_idx].get() + # NOTE: .at[0] selects the first "variable". If there are multiple variables, not sure how we could tweak this + # cur_obs_loc_indices = jnp.array(self._obs_vector.system_index.data).at[:, filtered_idx].get().flatten() + # cur_obs_loc_indices = jnp.array(self._obs_vector.system_index.data).at[0,filtered_idx].get() + cur_obs_loc_indices = jnp.array(self._obs_vector.system_index.data).at[:, filtered_idx].get().reshape(filtered_idx.shape[0], -1) + # cur_obs_loc_mask = jnp.array(self._obs_loc_masks).at[0, filtered_idx].get().astype(bool) + cur_obs_loc_mask = jnp.array(self._obs_loc_masks).at[:, filtered_idx].get().astype(bool).reshape(filtered_idx.shape[0], -1) # Calculate obs window indices: closest model timesteps that match obs - obs_window_indices = jax.lax.cond( - self.obs_window_indices is None, - lambda: jnp.array([ + obs_window_indices =jnp.array([ jnp.argmin( jnp.abs(obs_time - (cur_time + self._model_timesteps)) ) for obs_time in cur_obs_times - ]), - lambda: jnp.array(self.obs_window_indices) - ) - - analysis, loss_vals = self.step_cycle( - vector.StateVector(values=cur_state_vals, store_as_jax=True), - vector.ObsVector(values=cur_obs_vals, - location_indices=cur_obs_loc_indices, - error_sd=obs_error_sd, - store_as_jax=True), - obs_time_mask=obs_time_mask, + ]) + # obs_window_indices = jax.lax.cond( + # self.obs_window_indices is None, + # lambda: jnp.array([ + # jnp.argmin( + # jnp.abs(obs_time - (cur_time + self._model_timesteps)) + # ) for obs_time in cur_obs_times + # ], jnp.float64), + # lambda: jnp.array(self.obs_window_indices) + # ) + + # 2. Calculate analysis + analysis = self._step_cycle( + cur_state, + cur_obs_vals, + cur_obs_loc_indices, obs_loc_mask=cur_obs_loc_mask, - n_steps=self.steps_per_window, - obs_window_indices=obs_window_indices) - new_time = cur_time + self.analysis_window - - return (analysis.values[-1], new_time), (analysis.values[:-1], loss_vals) - - def cycle(self, - input_state, - start_time, - obs_vector, - obs_error_sd, - n_cycles, - analysis_window, - analysis_time_in_window=0, - return_forecast=False): - """Perform DA cycle repeatedly, including analysis and forecast - - Args: - input_state (vector.StateVector): Input state. - start_time (float or datetime-like): Starting time. - obs_vector (vector.ObsVector): Observations vector. - obs_error_sd (float): Standard deviation of observation error. - Typically not known, so provide a best-guess. - n_cycles (int): Number of analysis cycles to run, each of length - analysis_window. - analysis_window (float): Length of time window from which to gather - observations for each DA Cycle, in model time units. - analysis_time_in_window (float): At what time within analysis_window - to perform analysis. For example, 0.0 is the start of the - window. Default is 0, the start of the window. - return_forecast (bool): If True, returns forecast at each model - timestep. If False, returns only analyses, one per analysis - cycle. Default is False. - - Returns: - vector.StateVector of analyses and times. - """ - if (not obs_vector.stationary_observers and - (self.H is not None or self.h is not None)): - warnings.warn( - "Provided obs vector has nonstationary observers. When" - " providing a custom obs operator (H/h), the Var4DBackprop" - "DA cycler may not function properly. If you encounter " - "errors, try again with an observer where" - "stationary_observers=True or without specifying H or h (a " - "default H matrix will be used to map observations to system " - "space)." - ) - self.analysis_window = analysis_window - - # If don't specify analysis_time_in_window, is assumed to be middle - if analysis_time_in_window is None: - analysis_time_in_window = self.analysis_window/2 - - # Time offset from middle of time window, for gathering observations - _time_offset = (analysis_window/2) - analysis_time_in_window - - # Set up for jax.lax.scan, which is very fast - all_times = dac_utils._get_all_times(start_time, analysis_window, - n_cycles) - - if self.steps_per_window is None: - self.steps_per_window = round(analysis_window/self.delta_t) + 1 - self._model_timesteps = jnp.arange(self.steps_per_window)*self.delta_t - - # Get the obs vectors for each analysis window - all_filtered_idx = dac_utils._get_obs_indices( - obs_times=obs_vector.times, - analysis_times=all_times+_time_offset, - start_inclusive=True, - end_inclusive=True, - analysis_window=analysis_window - ) - - all_filtered_padded = dac_utils._pad_time_indices(all_filtered_idx) - - self._obs_vector = obs_vector - self._obs_error_sd = obs_error_sd - - # Padding observations - if obs_vector.stationary_observers: - self._obs_loc_masks = jnp.ones(obs_vector.values.shape, dtype=bool) - else: - obs_vals, obs_locs, obs_loc_masks = dac_utils._pad_obs_locs( - obs_vector) - self._obs_vector.values = obs_vals - self._obs_vector.location_indices = obs_locs - self._obs_loc_masks = jnp.array(obs_loc_masks) - - cur_state, all_results = jax.lax.scan( - self._cycle_and_forecast, - init=(input_state.values, start_time), - xs=all_filtered_padded) - self.loss_values = all_results[1] - all_values = all_results[0] - - if return_forecast: - all_times_forecast = jnp.arange( - 0, - n_cycles*analysis_window, - self.delta_t - ) + start_time - return vector.StateVector(values=jnp.concatenate(all_values), - times=all_times_forecast) - else: - return vector.StateVector(values=jnp.vstack([ - forecast[0] for forecast in all_values] - ), - times=all_times) + obs_time_mask=obs_time_mask, + obs_window_indices=obs_window_indices + ) + + # 3. Forecast forward + next_state, forecast_states = self._step_forecast(analysis, n_steps=self.steps_per_window) + next_state = next_state.assign(_cur_time = cur_time + self.analysis_window) + + return xj.from_xarray(next_state), forecast_states \ No newline at end of file From 3c09ef36e3d13bd6f6fb1172e5090601daf21ae0 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Wed, 25 Sep 2024 18:12:57 -0600 Subject: [PATCH 20/44] Cleaned var4dbp using xarray --- dabench/dacycler/_var4d_backprop.py | 30 +---------------------------- 1 file changed, 1 insertion(+), 29 deletions(-) diff --git a/dabench/dacycler/_var4d_backprop.py b/dabench/dacycler/_var4d_backprop.py index 4a72a6e..38d51eb 100644 --- a/dabench/dacycler/_var4d_backprop.py +++ b/dabench/dacycler/_var4d_backprop.py @@ -143,7 +143,7 @@ def loss_4dvarcost(x0): db0 = (x0.to_array().data.ravel() - xb0.to_array().data.ravel()) # Make new prediction - # NOTE: [1] selects the full forecast + # NOTE: [1] selects the full forecast instead of last timestep only pred_x = self._step_forecast( x0, n_steps)[1]['x'].data @@ -193,22 +193,11 @@ def _backprop_epoch(epoch_state_tuple, i): lambda: loss_val) updates, opt_state = optimizer.update(dx0_hess, opt_state) - # x0_new = x0.assign(x=x0['x']-dx0_hess*optimizer.) - # updated_vals = optax.apply_updates( - # x0.to_stacked_array('system',[]).data, updates) x0_array.data = optax.apply_updates( x0_array.data, updates) - # print(updates[0]) - # updated_vals = jax.tree.map( - # lambda p, u: jnp.asarray(jnp.array(p)+jnp.array(u)).astype(jnp.asarray(p).dtype), - # x0['x'].data, updates - # ) - # x0_new = x0.assign(x=((x0.dims), updated_vals)) - # x0_array.data = updated_vals x0_new = x0_array.to_unstacked_dataset('system').assign_attrs( x0.attrs ) - # return (xj.from_xarray(x0_new), init_loss, opt_state), loss_val return (xj.from_xarray(x0_new), init_loss, opt_state), loss_val return _backprop_epoch @@ -251,9 +240,6 @@ def _cycle_obsop(self, x0_xarray, obs_values, obs_loc_indices, else: B = self.B - # x0 = x0_xarray - - Rinv = jscipy.linalg.inv(R) Binv = jscipy.linalg.inv(B) @@ -298,14 +284,9 @@ def _cycle_and_forecast(self, cur_state, filtered_idx): obs_time_mask = filtered_idx > 0 filtered_idx = filtered_idx - 1 - # cur_obs_vals = jnp.array(self._obs_vector[self._observed_vars].to_array().data).at[0, filtered_idx].get() cur_obs_vals = jnp.array(self._obs_vector[self._observed_vars].to_stacked_array('system',['time']).data).at[filtered_idx].get() cur_obs_times = jnp.array(self._obs_vector.time.data).at[filtered_idx].get() - # NOTE: .at[0] selects the first "variable". If there are multiple variables, not sure how we could tweak this - # cur_obs_loc_indices = jnp.array(self._obs_vector.system_index.data).at[:, filtered_idx].get().flatten() - # cur_obs_loc_indices = jnp.array(self._obs_vector.system_index.data).at[0,filtered_idx].get() cur_obs_loc_indices = jnp.array(self._obs_vector.system_index.data).at[:, filtered_idx].get().reshape(filtered_idx.shape[0], -1) - # cur_obs_loc_mask = jnp.array(self._obs_loc_masks).at[0, filtered_idx].get().astype(bool) cur_obs_loc_mask = jnp.array(self._obs_loc_masks).at[:, filtered_idx].get().astype(bool).reshape(filtered_idx.shape[0], -1) # Calculate obs window indices: closest model timesteps that match obs @@ -314,15 +295,6 @@ def _cycle_and_forecast(self, cur_state, filtered_idx): jnp.abs(obs_time - (cur_time + self._model_timesteps)) ) for obs_time in cur_obs_times ]) - # obs_window_indices = jax.lax.cond( - # self.obs_window_indices is None, - # lambda: jnp.array([ - # jnp.argmin( - # jnp.abs(obs_time - (cur_time + self._model_timesteps)) - # ) for obs_time in cur_obs_times - # ], jnp.float64), - # lambda: jnp.array(self.obs_window_indices) - # ) # 2. Calculate analysis analysis = self._step_cycle( From cbc308fb25d81a0722191b250539893d4e65f9d7 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Fri, 27 Sep 2024 10:36:44 -0600 Subject: [PATCH 21/44] All dacyclers working with xarray, but possibly some accuracy issues with 4dvar and 4dvarBP --- dabench/dacycler/_dacycler.py | 102 ++++++++- dabench/dacycler/_etkf.py | 176 ++++++--------- dabench/dacycler/_var3d.py | 65 ++---- dabench/dacycler/_var4d.py | 334 +++++++--------------------- dabench/dacycler/_var4d_backprop.py | 52 +---- 5 files changed, 256 insertions(+), 473 deletions(-) diff --git a/dabench/dacycler/_dacycler.py b/dabench/dacycler/_dacycler.py index 44d4207..3bcd589 100644 --- a/dabench/dacycler/_dacycler.py +++ b/dabench/dacycler/_dacycler.py @@ -54,6 +54,20 @@ def __init__(self, self.model_obj = model_obj + def _calc_default_H(self, obs_values, obs_loc_indices): + H = jnp.zeros((obs_values.flatten().shape[0], self.system_dim)) + H = H.at[jnp.arange(H.shape[0]), + obs_loc_indices.flatten(), + ].set(1) + return H + + def _calc_default_R(self, obs_values, obs_error_sd): + return jnp.identity(obs_values.flatten().shape[0])*(obs_error_sd**2) + + def _calc_default_B(self): + """If B is not provided, identity matrix with shape (system_dim, system_dim.""" + return jnp.identity(self.system_dim) + def _step_forecast(self, xa, n_steps=1): """Perform forecast using model object""" return self.model_obj.forecast(xa, n_steps=n_steps) @@ -66,7 +80,76 @@ def _step_cycle(self, xb, obs_vals, obs_locs, obs_time_mask, obs_loc_mask, obs_loc_mask, H, R, B, **kwargs) return vals else: - return self._cycle_general_obsop(xb, yo, h, R, B) + raise ValueError( + 'Only linear obs operators (H) are supported right now.') + vals = self._cycle_general_obsop( + xb, obs_vals, obs_locs, obs_time_mask, + obs_loc_mask, h, R, B, **kwargs) + return vals + + def _cycle_and_forecast(self, cur_state, filtered_idx): + # 1. Get data + # 1-b. Calculate obs_time_mask and restore filtered_idx to original values + cur_state = cur_state.to_xarray() + cur_time = cur_state['_cur_time'].data + cur_state = cur_state.drop_vars(['_cur_time']) + obs_time_mask = filtered_idx > 0 + filtered_idx = filtered_idx - 1 + + # 2. Calculate analysis + cur_obs_vals = jnp.array(self._obs_vector[self._observed_vars].to_array().data).at[:, filtered_idx].get() + cur_obs_loc_indices = jnp.array(self._obs_vector.system_index.data).at[:, filtered_idx].get() + cur_obs_loc_mask = jnp.array(self._obs_loc_masks).at[:, filtered_idx].get().astype(bool) + cur_obs_time_mask = jnp.repeat(obs_time_mask, cur_obs_vals.shape[-1]) + analysis = self._step_cycle( + cur_state, + cur_obs_vals, + cur_obs_loc_indices, + obs_loc_mask=cur_obs_loc_mask, + obs_time_mask=cur_obs_time_mask + ) + # 3. Forecast next timestep + next_state, forecast_states = self._step_forecast(analysis, n_steps=self.steps_per_window) + next_state = next_state.assign(_cur_time = cur_time + self.analysis_window) + + return xj.from_xarray(next_state), forecast_states + + def _cycle_and_forecast_4d(self, cur_state, filtered_idx): + # 1. Get data + # 1-b. Calculate obs_time_mask and restore filtered_idx to original values + cur_state = cur_state.to_xarray() + cur_time = cur_state['_cur_time'].data + cur_state = cur_state.drop_vars(['_cur_time']) + obs_time_mask = filtered_idx > 0 + filtered_idx = filtered_idx - 1 + + cur_obs_vals = jnp.array(self._obs_vector[self._observed_vars].to_stacked_array('system',['time']).data).at[filtered_idx].get() + cur_obs_times = jnp.array(self._obs_vector.time.data).at[filtered_idx].get() + cur_obs_loc_indices = jnp.array(self._obs_vector.system_index.data).at[:, filtered_idx].get().reshape(filtered_idx.shape[0], -1) + cur_obs_loc_mask = jnp.array(self._obs_loc_masks).at[:, filtered_idx].get().astype(bool).reshape(filtered_idx.shape[0], -1) + + # Calculate obs window indices: closest model timesteps that match obs + obs_window_indices =jnp.array([ + jnp.argmin( + jnp.abs(obs_time - (cur_time + self._model_timesteps)) + ) for obs_time in cur_obs_times + ]) + + # 2. Calculate analysis + analysis = self._step_cycle( + cur_state, + cur_obs_vals, + cur_obs_loc_indices, + obs_loc_mask=cur_obs_loc_mask, + obs_time_mask=obs_time_mask, + obs_window_indices=obs_window_indices + ) + + # 3. Forecast forward + next_state, forecast_states = self._step_forecast(analysis, n_steps=self.steps_per_window) + next_state = next_state.assign(_cur_time = cur_time + self.analysis_window) + + return xj.from_xarray(next_state), forecast_states def cycle(self, input_state, @@ -144,12 +227,19 @@ def cycle(self, obs_vector[self._observed_vars].to_array().shape, dtype=bool) else: self._obs_loc_masks = ~np.isnan( - obs_vector[self._observed_vars].to_array().data)[0] + obs_vector[self._observed_vars].to_array().data) self._obs_vector=self._obs_vector.fillna(0) - cur_state, all_values = jax.lax.scan( - self._cycle_and_forecast, - xj.from_xarray(input_state), - all_filtered_padded) + + if self.in_4d: + cur_state, all_values = jax.lax.scan( + self._cycle_and_forecast_4d, + xj.from_xarray(input_state), + all_filtered_padded) + else: + cur_state, all_values = jax.lax.scan( + self._cycle_and_forecast, + xj.from_xarray(input_state), + all_filtered_padded) all_vals_xr = xr.Dataset( {var: (('cycle',) + tuple(all_values[var].dims), diff --git a/dabench/dacycler/_etkf.py b/dabench/dacycler/_etkf.py index b069322..c2347b5 100644 --- a/dabench/dacycler/_etkf.py +++ b/dabench/dacycler/_etkf.py @@ -62,63 +62,8 @@ def __init__(self, ensemble=True, B=B, R=R, H=H, h=h) - - def _calc_default_H(self, obs_values, obs_loc_indices): - H = jnp.zeros((obs_values.flatten().shape[0], self.system_dim)) - H = H.at[jnp.arange(H.shape[0]), - obs_loc_indices.flatten(), - ].set(1) - return H - - def _calc_default_R(self, obs_values, obs_error_sd): - return jnp.identity(obs_values.flatten().shape[0])*(obs_error_sd**2) - - def _calc_default_B(self): - return jnp.identity(self.system_dim) - - def _cycle_obsop(self, x0_xarray, obs_values, obs_loc_indices, - obs_time_mask, obs_loc_mask, - H=None, h=None, R=None, B=None): - if H is None and h is None: - if self.H is None: - if self.h is None: - H = self._calc_default_H(obs_values, obs_loc_indices) - else: - h = self.h - else: - H = self.H - if R is None: - if self.R is None: - R = self._calc_default_R(obs_values, self.obs_error_sd) - else: - R = self.R - if B is None: - if self.B is None: - B = self._calc_default_B() - else: - B = self.B - - Xbt = x0_xarray[self._data_vars].to_array().data[0] - nr,nc = Xbt.shape - assert nr == self.ensemble_dim, ( - 'cycle:: model_forecast must have dimension {}x{}').format( - self.ensemble_dim, self.system_dim) - - # Apply obs masks to H - # H = jnp.where(obs_time_mask.flatten(), H.T, 0).T - H = jnp.where(obs_loc_mask.flatten(), H.T, 0).T - - # Analysis cycles over all obs in data_obs - Xa = self._compute_analysis(Xb=Xbt.T, - y=obs_values, - H=H, - h=h, - R=R, - rho=self.multiplicative_inflation) - - return x0_xarray.assign(x=(['ensemble','i'], Xa.T)) - def _step_forecast(self, xa, n_steps): + """Ensemble method needs a slightly different _step_forecast method""" ensemble_forecasts = [] ensemble_inputs = [] for i in range(self.ensemble_dim): @@ -132,19 +77,19 @@ def _step_forecast(self, xa, n_steps): return (xr.concat(ensemble_inputs, dim='ensemble'), xr.concat(ensemble_forecasts, dim='ensemble')) - def _apply_obsop(self, Xb, H, h): + def _apply_obsop(self, xb, H, h): if H is not None: - Yb = H @ Xb + yb = H @ xb else: - Yb = h(Xb) + yb = h(xb) - return Yb + return yb - def _compute_analysis(self, Xb, y, H, h, R, rho=1.0, Yb=None): + def _compute_analysis(self, xb, y, H, h, R, rho=1.0, yb=None): """ETKF analysis algorithm Args: - Xb (ndarray): Forecast/background ensemble with shape + xb (ndarray): Forecast/background ensemble with shape (system_dim, ensemble_dim). y (ndarray): Observation array with shape (observation_dim,) H (ndarray): Observation operator with shape (observation_dim, @@ -155,10 +100,10 @@ def _compute_analysis(self, Xb, y, H, h, R, rho=1.0, Yb=None): (i.e. no inflation) Returns: - Xa (ndarray): Analysis ensemble [size: (system_dim, ensemble_dim)] + xa (ndarray): Analysis ensemble [size: (system_dim, ensemble_dim)] """ # Number of state variables, ensemble members and observations - system_dim, ensemble_dim = Xb.shape + system_dim, ensemble_dim = xb.shape observation_dim = y.shape[0] # Auxiliary matrices that will ease the computations @@ -166,30 +111,25 @@ def _compute_analysis(self, Xb, y, H, h, R, rho=1.0, Yb=None): I = jnp.identity(ensemble_dim) # The ensemble is inflated (rho=1.0 is no inflation) - Xb_pert = Xb @ (I-U) - Xb = Xb_pert + Xb @ U - - # Ensemble Transform Kalman Filter - # Initialize the ensemble in observation space - if Yb is None: - Yb = jnp.empty((observation_dim, ensemble_dim)) + xb_pert = xb @ (I-U) + xb = xb_pert + xb @ U - # Map every ensemble member into observation space - Yb = self._apply_obsop(Xb, H, h) + # Map every ensemble member into observation space + yb = self._apply_obsop(xb, H, h) # Get ensemble means and perturbations - xb_bar = jnp.mean(Xb, axis=1) - Xb_pert = Xb @ (I-U) + xb_bar = jnp.mean(xb, axis=1) + xb_pert = xb @ (I-U) - yb_bar = jnp.mean(Yb, axis=1) - Yb_pert = Yb @ (I-U) + yb_bar = jnp.mean(yb, axis=1) + yb_pert = yb @ (I-U) # Compute the analysis if len(R) > 0: Rinv = jnp.linalg.pinv(R, rtol=1e-15) Pa_ens = jnp.linalg.pinv((ensemble_dim-1)/rho*I - + Yb_pert.T @ Rinv @ Yb_pert, + + yb_pert.T @ Rinv @ yb_pert, rtol=1e-15) Wa = linalg.sqrtm((ensemble_dim-1) * Pa_ens) Wa = Wa.real @@ -198,39 +138,55 @@ def _compute_analysis(self, Xb, y, H, h, R, rho=1.0, Yb=None): Pa_ens = jnp.zeros((ensemble_dim, ensemble_dim), dtype=R.dtype) Wa = jnp.zeros((ensemble_dim, ensemble_dim), dtype=R.dtype) - wa = Pa_ens @ Yb_pert.T @ Rinv @ (y.flatten()-yb_bar) + wa = Pa_ens @ yb_pert.T @ Rinv @ (y.flatten()-yb_bar) - Xa_pert = Xb_pert @ Wa + xa_pert = xb_pert @ Wa - xa_bar = xb_bar + jnp.ravel(Xb_pert @ wa) + xa_bar = xb_bar + jnp.ravel(xb_pert @ wa) v = jnp.ones((1, ensemble_dim)) - Xa = Xa_pert + xa_bar[:, None] @ v - - return Xa - - def _cycle_and_forecast(self, cur_state, filtered_idx): - # 1. Get data - # 1-b. Calculate obs_time_mask and restore filtered_idx to original values - cur_state = cur_state.to_xarray() - cur_time = cur_state['_cur_time'].data - cur_state = cur_state.drop_vars(['_cur_time']) - obs_time_mask = filtered_idx > 0 - filtered_idx = filtered_idx - 1 - - # 2. Calculate analysis - cur_obs_vals = jnp.array(self._obs_vector[self._observed_vars].to_array().data).at[:, filtered_idx].get() - cur_obs_loc_indices = jnp.array(self._obs_vector.system_index.data).at[:, filtered_idx].get() - cur_obs_loc_mask = jnp.array(self._obs_loc_masks).at[filtered_idx].get().astype(bool) - analysis = self._step_cycle( - cur_state, - cur_obs_vals, - cur_obs_loc_indices, - obs_loc_mask=cur_obs_loc_mask, - obs_time_mask=obs_time_mask - ) - # 3. Forecast next timestep - next_state, forecast_states = self._step_forecast(analysis, n_steps=self.steps_per_window) - next_state = next_state.assign(_cur_time = cur_time + self.analysis_window) - - return xj.from_xarray(next_state), forecast_states + xa = xa_pert + xa_bar[:, None] @ v + + return xa + + def _cycle_obsop(self, x0_xarray, obs_values, obs_loc_indices, + obs_time_mask, obs_loc_mask, + H=None, h=None, R=None, B=None): + if H is None and h is None: + if self.H is None: + if self.h is None: + H = self._calc_default_H(obs_values, obs_loc_indices) + else: + h = self.h + else: + H = self.H + if R is None: + if self.R is None: + R = self._calc_default_R(obs_values, self.obs_error_sd) + else: + R = self.R + if B is None: + if self.B is None: + B = self._calc_default_B() + else: + B = self.B + + xb = x0_xarray.to_stacked_array('system',['ensemble']).data.T + n_sys, n_ens = xb.shape + assert n_ens == self.ensemble_dim, ( + 'cycle:: model_forecast must have dimension {}x{}').format( + self.ensemble_dim, self.system_dim) + + # Apply obs masks to H + H = jnp.where(obs_time_mask.flatten(), H.T, 0).T + H = jnp.where(obs_loc_mask.flatten(), H.T, 0).T + + # Analysis cycles over all obs in data_obs + xa = self._compute_analysis(xb=xb, + y=obs_values, + H=H, + h=h, + R=R, + rho=self.multiplicative_inflation) + + return x0_xarray.assign(x=(['ensemble','i'], xa.T)) \ No newline at end of file diff --git a/dabench/dacycler/_var3d.py b/dabench/dacycler/_var3d.py index a1a45dd..eacffe2 100644 --- a/dabench/dacycler/_var3d.py +++ b/dabench/dacycler/_var3d.py @@ -48,49 +48,21 @@ def __init__(self, ensemble=False, B=B, R=R, H=H, h=h) - def _step_cycle(self, xb, yo, H=None, h=None, R=None, B=None): - """Perform one step of DA Cycle - - Returns: - vector.StateVector containing analysis results - - """ - if H is not None or h is None: - return self._cycle_linear_obsop(xb, yo, H, R, B) - else: - return self._cycle_general_obsop(xb, yo, h, R, B) - - def _calc_default_H(self, obs_vec): - """If H is not provided, creates identity matrix to serve as H""" - H = jnp.zeros((obs_vec.sizes['observations']*obs_vec.sizes['time'], - self.system_dim)) - H = H.at[jnp.arange(H.shape[0]), np.where(obs_vec.indices.data)[1] - ].set(1) - return H - - def _calc_default_R(self, obs_vec): - """If R i s not provided, calculates default based on observation error""" - return jnp.identity( - obs_vec.sizes['observations']*obs_vec.sizes['time'])*obs_vec.error_sd**2 - - def _calc_default_B(self): - """If B is not provided, identity matrix with shape (system_dim, system_dim.""" - return jnp.identity(self.system_dim) - - def _cycle_general_obsop(self, forecast, obs_vec): - return - - def _cycle_linear_obsop(self, forecast, obs_vec, H=None, R=None, - B=None): + def _cycle_obsop(self, x0_xarray, obs_values, obs_loc_indices, + obs_time_mask, obs_loc_mask, + H=None, h=None, R=None, B=None): """When obsop (H) is linear""" - if H is None: + if H is None and h is None: if self.H is None: - H = self._calc_default_H(obs_vec) + if self.h is None: + H = self._calc_default_H(obs_values, obs_loc_indices) + else: + h = self.h else: H = self.H if R is None: if self.R is None: - R = self._calc_default_R(obs_vec) + R = self._calc_default_R(obs_values, self.obs_error_sd) else: R = self.R if B is None: @@ -98,10 +70,15 @@ def _cycle_linear_obsop(self, forecast, obs_vec, H=None, R=None, B = self._calc_default_B() else: B = self.B + print(H) # make inputs column vectors - xb = jnp.array([forecast[self.data_vars].to_array().data.flatten()]).T - yo = jnp.array([obs_vec[self.data_vars].to_array().data.flatten()]).T + xb = x0_xarray.to_stacked_array('system',[]).data.flatten() + yo = obs_values.flatten().T + + # Apply masks to H + H = jnp.where(obs_time_mask.flatten(), H.T, 0).T + H = jnp.where(obs_loc_mask.flatten(), H.T, 0).T # Set parameters xdim = xb.size # Size or get one of the shape params? @@ -118,12 +95,4 @@ def _cycle_linear_obsop(self, forecast, obs_vec, H=None, R=None, xa, ierr = jscipy.sparse.linalg.cg(A, b1, x0=xb, tol=1e-05, maxiter=1000) - # Compute KH: - HBHtPlusR_inv = jnp.linalg.inv(H @ BHt + R) - KH = BHt @ HBHtPlusR_inv @ H - - return forecast.assign(x=(['i'], xa.T[0])), KH - - def _step_forecast(self, xa, n_steps): - """n_steps forward of model forecast""" - return self.model_obj.forecast(xa, n_steps=n_steps) + return x0_xarray.assign(x=(x0_xarray.dims, xa.T)) \ No newline at end of file diff --git a/dabench/dacycler/_var4d.py b/dabench/dacycler/_var4d.py index 7dd6cee..20f505e 100644 --- a/dabench/dacycler/_var4d.py +++ b/dabench/dacycler/_var4d.py @@ -11,6 +11,8 @@ from jax.scipy.sparse.linalg import bicgstab from copy import deepcopy from functools import partial +import xarray as xr +import xarray_jax as xj from dabench import dacycler, vector import dabench.dacycler._utils as dac_utils @@ -92,128 +94,6 @@ def _calc_default_H(self, obs_loc_indices): ].set(1) return Hs - - def _calc_default_R(self, obs_values, obs_error_sd): - return jnp.identity(obs_values[0].shape[0])*(obs_error_sd**2) - - def _calc_default_B(self): - return jnp.identity(self.system_dim) - - def _make_outerloop_4d(self, xb0, Hs, B, Rinv, - obs_values, obs_window_indices, obs_time_mask, - n_steps): - - def _outerloop_4d(x0, _): - # Get TLM and current forecast trajectory - # Based on current best guess for x0 - M, x = self.model_obj.compute_tlm( - n_steps=n_steps, - state_vec=vector.StateVector(values=x0, - store_as_jax=True) - ) - - # 4D-Var inner loop - x0 = self._innerloop_4d(self.system_dim, - x, xb0, obs_values, - Hs, B, Rinv, M, - obs_window_indices, - obs_time_mask) - - return x0, x0 - - return _outerloop_4d - - def _cycle_obsop(self, xb0, obs_values, obs_loc_indices, obs_error_sd, - obs_window_indices, obs_time_mask, obs_loc_mask, - H=None, h=None, R=None, B=None, - n_steps=1): - if H is None and h is None: - if self.H is None: - if self.h is None: - H = self._calc_default_H(obs_loc_indices) - # Apply obs loc mask - # NOTE: nonstationary observer case runs MUCH slower. Not sure why - # Ideally, this conditional would not be necessary, but this is a - # workaround to prevent slowing down stationary observer case. - Hs = jax.lax.cond( - self._obs_vector.stationary_observers, - lambda: H, - lambda: (obs_loc_mask[:, :, jnp.newaxis] * H)) - else: - h = self.h - else: - # Assumes self.H is for a single timestep - H = self.H[jnp.newaxis] - Hs = jax.lax.cond( - self._obs_vector.stationary_observers, - lambda: jnp.repeat(H, obs_values.shape[0], axis=0), - lambda: (obs_loc_mask[:, :, jnp.newaxis] * H)) - - if R is None: - if self.R is None: - R = self._calc_default_R(obs_values, obs_error_sd) - else: - R = self.R - if B is None: - if self.B is None: - B = self._calc_default_B() - else: - B = self.B - - # Static Variables - Rinv = jscipy.linalg.inv(R) - - # Best guess for x0 starts as background - x0 = deepcopy(xb0) - - outerloop_4d_func = self._make_outerloop_4d( - xb0, Hs, B, Rinv, obs_values, obs_window_indices, - obs_time_mask, n_steps) - - x0, all_x0s = jax.lax.scan(outerloop_4d_func, init=x0, - xs=None, length=self.n_outer_loops) - - # forecast - x = self.step_forecast( - n_steps=n_steps, - x0=vector.StateVector(values=x0, store_as_jax=True) - ).values - - return x - - def step_cycle(self, x0, yo, obs_time_mask, obs_loc_mask, - obs_window_indices, H=None, h=None, R=None, B=None, - n_steps=1): - """Perform one step of DA Cycle""" - if H is not None or h is None: - return self._cycle_obsop( - x0.values, yo.values, yo.location_indices, yo.error_sd, - obs_loc_mask=obs_loc_mask, obs_time_mask=obs_time_mask, - obs_window_indices=obs_window_indices, - H=H, R=R, B=B, - n_steps=n_steps) - else: - return self._cycle_obsop( - x0.values, yo.values, yo.location_indices, yo.error_sd, h=h, - R=R, B=B, obs_window_indices=obs_window_indices, - n_steps=n_steps) - - def step_forecast(self, x0, n_steps=1): - """Perform forecast using model object""" - if 'n_steps' in inspect.getfullargspec(self.model_obj.forecast).args: - return self.model_obj.forecast(x0, n_steps=n_steps) - else: - if n_steps == 1: - return self.model_obj.forecast(x0) - else: - out = [x0] - xi = x0 - for s in range(n_steps): - xi = self.model.forecast(xi) - out.append(xi) - return vector.StateVector(jnp.vstack(xi), store_as_jax=True) - - def _calc_J_term(self, H, M, Rinv, y, x): # The Jb Term (A) HM = H @ M @@ -223,12 +103,12 @@ def _calc_J_term(self, H, M, Rinv, y, x): D = (y - (H @ x)) return MtHtRinv @ HM, MtHtRinv @ D[:, None] - @partial(jax.jit, static_argnums=[0, 1]) def _innerloop_4d(self, system_dim, x, xb0, obs_vals, Hs, B, Rinv, M, obs_window_indices, obs_time_mask): """4DVar innerloop""" - x0_last = x[0] + x0_last = x.isel(time=0) + x = x.to_stacked_array('system',['time']) # Set up Variables SumMtHtRinvHM = jnp.zeros_like(B) # A input @@ -238,15 +118,15 @@ def _innerloop_4d(self, system_dim, x, xb0, obs_vals, Hs, B, Rinv, M, for i, j in enumerate(obs_window_indices): Jb, Jo = jax.lax.cond( obs_time_mask.at[i].get(mode='fill', fill_value=0), - lambda: self._calc_J_term(Hs.at[i].get(mode='clip'), M[j], - Rinv, obs_vals[i], x[j]), + lambda: self._calc_J_term(Hs.at[i].get(mode='clip'), M.data[j], + Rinv, obs_vals[i], x.data[j]), lambda: (jnp.zeros_like(SumMtHtRinvHM), jnp.zeros_like(SumMtHtRinvD)) ) SumMtHtRinvHM += Jb SumMtHtRinvD += Jo # Compute initial departure - db0 = xb0 - x0_last + db0 = (xb0 - x0_last).to_stacked_array('system',[]).data # Solve Ax=b for the initial perturbation dx0 = self._solve(db0, SumMtHtRinvHM, SumMtHtRinvD, B) @@ -256,6 +136,30 @@ def _innerloop_4d(self, system_dim, x, xb0, obs_vals, Hs, B, Rinv, M, return x0_new + def _make_outerloop_4d(self, xb0, Hs, B, Rinv, + obs_values, obs_window_indices, obs_time_mask, + n_steps): + + def _outerloop_4d(x0, _): + # Get TLM and current forecast trajectory + # Based on current best guess for x0 + x0 = x0.to_xarray() + x, M = self.model_obj.compute_tlm( + n_steps=n_steps, + state_vec=x0 + ) + + # 4D-Var inner loop + x0 = self._innerloop_4d(self.system_dim, + x, xb0, obs_values, + Hs, B, Rinv, M, + obs_window_indices, + obs_time_mask) + + return xj.from_xarray(x0.drop_vars('time')), x0 + + return _outerloop_4d + @partial(jax.jit, static_argnums=0) def _solve(self, db0, SumMtHtRinvHM, SumMtHtRinvD, B): """Solve the 4D-Var linear optimization @@ -286,142 +190,54 @@ def _solve(self, db0, SumMtHtRinvHM, SumMtHtRinvD, B): return dx0 - def _cycle_and_forecast(self, cur_state_vals_time_tuple, filtered_idx): - cur_state_vals, cur_time = cur_state_vals_time_tuple - obs_error_sd = self._obs_error_sd - - # Calculate obs_time_mask and restore filtered_idx to original values - obs_time_mask = filtered_idx > 0 - filtered_idx = filtered_idx - 1 - - cur_obs_vals = jnp.array(self._obs_vector.values).at[filtered_idx].get() - cur_obs_loc_indices = jnp.array(self._obs_vector.location_indices).at[filtered_idx].get() - cur_obs_times = jnp.array(self._obs_vector.times).at[filtered_idx].get() - cur_obs_loc_mask = jnp.array(self._obs_loc_masks).at[filtered_idx].get().astype(bool) - - # Calculate obs window indices: closest model timesteps that match obs - obs_window_indices = jax.lax.cond( - self.obs_window_indices is None, - lambda: jnp.array([ - jnp.argmin( - jnp.abs(obs_time - (cur_time + self._model_timesteps)) - ) for obs_time in cur_obs_times - ]), - lambda: jnp.array(self.obs_window_indices) - ) - - analysis = self.step_cycle( - vector.StateVector(values=cur_state_vals, store_as_jax=True), - vector.ObsVector(values=cur_obs_vals, - location_indices=cur_obs_loc_indices, - error_sd=obs_error_sd, - store_as_jax=True), - obs_time_mask=obs_time_mask, - obs_loc_mask=cur_obs_loc_mask, - n_steps=self.steps_per_window, - obs_window_indices=obs_window_indices) - new_time = cur_time + self.analysis_window - - return (analysis[-1], new_time), analysis[:-1] - - def cycle(self, - input_state, - start_time, - obs_vector, - obs_error_sd, - n_cycles, - analysis_window, - analysis_time_in_window=0, - return_forecast=False): - """Perform DA cycle repeatedly, including analysis and forecast - - Args: - input_state (vector.StateVector): Input state. - start_time (float or datetime-like): Starting time. - obs_vector (vector.ObsVector): Observations vector. - obs_error_sd (float): Standard deviation of observation error. - Typically not known, so provide a best-guess. - n_cycles (int): Number of analysis cycles to run, each of length - analysis_window. - analysis_window (float): Length of time window from which to gather - observations for each DA Cycle, in model time units. - analysis_time_in_window (float): At what time within analysis_window - to perform analysis. For example, 0.0 is the start of the - window. Default is 0, the start of the window. - return_forecast (bool): If True, returns forecast at each model - timestep. If False, returns only analyses, one per analysis - cycle. Default is False. - - Returns: - vector.StateVector of analyses and times. - """ - if (not obs_vector.stationary_observers and - (self.H is not None or self.h is not None)): - warnings.warn( - "Provided obs vector has nonstationary observers. When" - " providing a custom obs operator (H/h), the Var4DBackprop" - "DA cycler may not function properly. If you encounter " - "errors, try again with an observer where" - "stationary_observers=True or without specifying H or h (a " - "default H matrix will be used to map observations to system " - "space)." - ) - self.analysis_window = analysis_window - - # If don't specify analysis_time_in_window, is assumed to be middle - if analysis_time_in_window is None: - analysis_time_in_window = self.analysis_window/2 + def _cycle_obsop(self, xb0, obs_values, obs_loc_indices, + obs_time_mask, obs_loc_mask, + H=None, h=None, R=None, B=None, obs_window_indices=None): + if H is None and h is None: + if self.H is None: + if self.h is None: + H = self._calc_default_H(obs_loc_indices) + # Apply obs loc mask + # NOTE: nonstationary observer case runs MUCH slower. Not sure why + # Ideally, this conditional would not be necessary, but this is a + # workaround to prevent slowing down stationary observer case. + Hs = jax.lax.cond( + self._obs_vector.stationary_observers, + lambda: H, + lambda: (obs_loc_mask[:, :, jnp.newaxis] * H)) + else: + h = self.h + else: + # Assumes self.H is for a single timestep + H = self.H[jnp.newaxis] + Hs = jax.lax.cond( + self._obs_vector.stationary_observers, + lambda: jnp.repeat(H, obs_values.shape[0], axis=0), + lambda: (obs_loc_mask[:, :, jnp.newaxis] * H)) - # Time offset from middle of time window, for gathering observations - _time_offset = (analysis_window/2) - analysis_time_in_window + if R is None: + if self.R is None: + R = self._calc_default_R(obs_values, self.obs_error_sd) + else: + R = self.R - # Set up for jax.lax.scan, which is very fast - all_times = dac_utils._get_all_times(start_time, analysis_window, - n_cycles) + if B is None: + if self.B is None: + B = self._calc_default_B() + else: + B = self.B - if self.steps_per_window is None: - self.steps_per_window = round(analysis_window/self.delta_t) + 1 - self._model_timesteps = jnp.arange(self.steps_per_window)*self.delta_t + # Static Variables + Rinv = jscipy.linalg.inv(R) - # Get the obs vectors for each analysis window - all_filtered_idx = dac_utils._get_obs_indices( - obs_times=obs_vector.times, - analysis_times=all_times+_time_offset, - start_inclusive=True, - end_inclusive=True, - analysis_window=analysis_window - ) + # Best guess for x0 starts as background + x0_new = deepcopy(xb0) - all_filtered_padded = dac_utils._pad_time_indices(all_filtered_idx) + outerloop_4d_func = self._make_outerloop_4d( + xb0, Hs, B, Rinv, obs_values, obs_window_indices, + obs_time_mask, self.steps_per_window) - self._obs_vector = obs_vector - self._obs_error_sd = obs_error_sd + x0_new, all_x0s = jax.lax.scan(outerloop_4d_func, init=xj.from_xarray(x0_new), + xs=None, length=self.n_outer_loops) - # Padding observations - if obs_vector.stationary_observers: - self._obs_loc_masks = jnp.ones(obs_vector.values.shape, dtype=bool) - else: - obs_vals, obs_locs, obs_loc_masks = dac_utils._pad_obs_locs( - obs_vector) - self._obs_vector.values = obs_vals - self._obs_vector.location_indices = obs_locs - self._obs_loc_masks = jnp.array(obs_loc_masks) - - cur_state, all_values = jax.lax.scan( - self._cycle_and_forecast, - init=(input_state.values, start_time), - xs=all_filtered_padded) - - if return_forecast: - all_times_forecast = jnp.arange( - 0, - n_cycles*analysis_window, - self.delta_t - ) + start_time - return vector.StateVector(values=jnp.concatenate(all_values), - times=all_times_forecast) - else: - return vector.StateVector(values=jnp.vstack([ - forecast[0] for forecast in all_values] - ), - times=all_times) + return x0_new.to_xarray() \ No newline at end of file diff --git a/dabench/dacycler/_var4d_backprop.py b/dabench/dacycler/_var4d_backprop.py index 38d51eb..6cd3a1c 100644 --- a/dabench/dacycler/_var4d_backprop.py +++ b/dabench/dacycler/_var4d_backprop.py @@ -81,9 +81,6 @@ def __init__(self, self.learning_rate = learning_rate self.lr_decay = lr_decay self.steps_per_window = steps_per_window - # if obs_window_indices is None: - # self.obs_window_indices - # else: self.obs_window_indices = obs_window_indices self.loss_growth_limit = loss_growth_limit @@ -98,7 +95,6 @@ def __init__(self, ensemble=False, B=B, R=R, H=H, h=h) - def _calc_default_H(self, obs_loc_indices): Hs = jnp.zeros((obs_loc_indices.shape[0], obs_loc_indices.shape[1], self.system_dim), @@ -109,12 +105,6 @@ def _calc_default_H(self, obs_loc_indices): return Hs - def _calc_default_R(self, obs_values, obs_error_sd): - return jnp.identity(obs_values[0].shape[0])*(obs_error_sd**2) - - def _calc_default_B(self): - return jnp.identity(self.system_dim) - def _raise_nan_error(self): raise ValueError('Loss value is nan, exiting optimization') @@ -145,7 +135,7 @@ def loss_4dvarcost(x0): # Make new prediction # NOTE: [1] selects the full forecast instead of last timestep only pred_x = self._step_forecast( - x0, n_steps)[1]['x'].data + x0, n_steps)[1].to_stacked_array('system',['time']).data # Calculate observation term of J_0 obs_term = 0 @@ -202,7 +192,6 @@ def _backprop_epoch(epoch_state_tuple, i): return _backprop_epoch - def _cycle_obsop(self, x0_xarray, obs_values, obs_loc_indices, obs_time_mask, obs_loc_mask, H=None, h=None, R=None, B=None, obs_window_indices=None): @@ -273,41 +262,4 @@ def _cycle_obsop(self, x0_xarray, obs_values, obs_loc_indices, x0_new = epoch_state_tuple[0].to_xarray() - return x0_new - - def _cycle_and_forecast(self, cur_state, filtered_idx): - # 1. Get data - # 1-b. Calculate obs_time_mask and restore filtered_idx to original values - cur_state = cur_state.to_xarray() - cur_time = cur_state['_cur_time'].data - cur_state = cur_state.drop_vars(['_cur_time']) - obs_time_mask = filtered_idx > 0 - filtered_idx = filtered_idx - 1 - - cur_obs_vals = jnp.array(self._obs_vector[self._observed_vars].to_stacked_array('system',['time']).data).at[filtered_idx].get() - cur_obs_times = jnp.array(self._obs_vector.time.data).at[filtered_idx].get() - cur_obs_loc_indices = jnp.array(self._obs_vector.system_index.data).at[:, filtered_idx].get().reshape(filtered_idx.shape[0], -1) - cur_obs_loc_mask = jnp.array(self._obs_loc_masks).at[:, filtered_idx].get().astype(bool).reshape(filtered_idx.shape[0], -1) - - # Calculate obs window indices: closest model timesteps that match obs - obs_window_indices =jnp.array([ - jnp.argmin( - jnp.abs(obs_time - (cur_time + self._model_timesteps)) - ) for obs_time in cur_obs_times - ]) - - # 2. Calculate analysis - analysis = self._step_cycle( - cur_state, - cur_obs_vals, - cur_obs_loc_indices, - obs_loc_mask=cur_obs_loc_mask, - obs_time_mask=obs_time_mask, - obs_window_indices=obs_window_indices - ) - - # 3. Forecast forward - next_state, forecast_states = self._step_forecast(analysis, n_steps=self.steps_per_window) - next_state = next_state.assign(_cur_time = cur_time + self.analysis_window) - - return xj.from_xarray(next_state), forecast_states \ No newline at end of file + return x0_new \ No newline at end of file From 1690992901de1559701fec7e94b8eb6321110ea7 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Fri, 27 Sep 2024 10:37:28 -0600 Subject: [PATCH 22/44] State Vec has delta_t attribute and M is provided as xarray --- dabench/data/_data.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/dabench/data/_data.py b/dabench/data/_data.py index 9950e8f..5ffaa88 100644 --- a/dabench/data/_data.py +++ b/dabench/data/_data.py @@ -177,7 +177,8 @@ def generate(self, n_steps=None, t_final=None, x0=None, M0=None, {self.var_names[0]: (coord_dict.keys(),y_out)}, coords=coord_dict, attrs={'store_as_jax':self.store_as_jax, - 'system_dim': self.system_dim + 'system_dim': self.system_dim, + 'delta_t': self.delta_t } ) @@ -196,6 +197,9 @@ def generate(self, n_steps=None, t_final=None, x0=None, M0=None, self.system_dim, self.system_dim) ) + M = xr.DataArray( + M, dims=('time','system_0','system_n') + ) return out_vec, M else: return out_vec From 8b81d55b8884e8afbd4881c56a882ed1bfe8fdcb Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Fri, 27 Sep 2024 10:40:14 -0600 Subject: [PATCH 23/44] Working RC Model with xarray --- dabench/model/_rc.py | 38 ++++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/dabench/model/_rc.py b/dabench/model/_rc.py index e3636a6..e842dd7 100644 --- a/dabench/model/_rc.py +++ b/dabench/model/_rc.py @@ -7,6 +7,7 @@ from scipy import sparse, stats, linalg import numpy as np import jax.numpy as jnp +import xarray as xr from dabench import vector, model @@ -167,7 +168,7 @@ def weights_init(self): self.states = None self.Adense = A.asformat('array') if self.sparse_adj_matrix else A - def generate(self, u, A=None, Win=None, r0=None, save_states=False): + def generate(self, state_vec, A=None, Win=None, r0=None): """generate reservoir time series from input signal u Args: u (array_like): (time_dimension, system_dimension), input signal to @@ -183,6 +184,7 @@ def generate(self, u, A=None, Win=None, r0=None, save_states=False): Returns: r (array_like): (time_dim, reservoir_dim), reservoir state """ + u = state_vec.to_stacked_array('system',['time']).data r = np.zeros((u.shape[0], self.reservoir_dim)) if r0 is not None: @@ -194,11 +196,10 @@ def generate(self, u, A=None, Win=None, r0=None, save_states=False): for t in range(0, u.shape[0]): r[t, :] = self.update(r[t - 1], u[t - 1, :], A, Win) - if save_states: - self.states = r - self.s_last = r[-1] - else: - return r + return xr.Dataset( + {'r': (('time', 'reservoir'), r)}, + coords={'time':state_vec.time} + ) def update(self, r, u, A=None, Win=None): """Update reservoir state with input signal and previous state @@ -379,10 +380,9 @@ def train(self, data_obj, update_Wout=True): Wout (array_like): Trained output weight matrix """ - if self.states is None: - self.generate(data_obj.values, save_states=True) - r = self.states[:, :] - u = data_obj.values[:, :] + r = self.generate(data_obj)['r'].data + # u = data_obj.to_array().transpose(..., 'variable').data.reshape(data_obj.sizes['time'], -1) + u = data_obj.to_array().stack(system=['variable','i']).data self.Wout = self._compute_Wout(r, u, update_Wout=update_Wout, u=u.T) def _compute_Wout(self, rt, y, update_Wout=True, u=None): @@ -480,21 +480,23 @@ def _linsolve_pinv(self, X, Y, beta=None): def forecast(self, state_vec, n_steps=1): if n_steps == 1: - new_vals = self.update(state_vec.values, - self.readout(state_vec.values)) - new_vec = vector.StateVector(values=new_vals, store_as_jax=True) - + new_vals = self.update(state_vec['r'].data, + self.readout(state_vec['r'].data)) + new_vec = xr.Dataset( + {'r':(('time','reservoir'), new_vals)} + ) else: - r = state_vec.values + r = state_vec['r'].data r_full = jnp.zeros((n_steps, self.reservoir_dim)) for i in range(n_steps): r_full = r_full.at[i].set(r) if i < n_steps-1: r = self.update(r, self.readout(r)) - new_vec = vector.StateVector(values=r_full, store_as_jax=True) - - return new_vec + new_vec = xr.Dataset( + {'r':(('time','reservoir'), r_full)} + ) + return new_vec.isel(time=-1), new_vec.drop_isel(time=-1) def save_weights(self, pkl_path): """Save RC reservoir weights as pkl file. From e3cd32043280c6d5ff588c29e205f86d958533ec Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Fri, 27 Sep 2024 10:40:56 -0600 Subject: [PATCH 24/44] Fixed generator extra step rounding error --- dabench/data/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dabench/data/_utils.py b/dabench/data/_utils.py index 41ce004..6034237 100644 --- a/dabench/data/_utils.py +++ b/dabench/data/_utils.py @@ -31,7 +31,7 @@ def integrate(function, x0, t_final, delta_t, method='odeint', stride=None, """ if method == 'odeint': # Define timesteps - t = np.arange(0.0, t_final, delta_t) + t = np.arange(0.0, t_final - delta_t/2, delta_t) # If stride is defined, remove timesteps that are not on stride steps if stride is not None: assert stride > 1 and isinstance(stride, int), \ From 19572e3cc49f3cfab9f9156f3ba14c0f27d9b9cc Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Fri, 27 Sep 2024 10:41:21 -0600 Subject: [PATCH 25/44] Observer can accept random_time_density now --- dabench/observer/_observer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dabench/observer/_observer.py b/dabench/observer/_observer.py index 950dba5..afefded 100644 --- a/dabench/observer/_observer.py +++ b/dabench/observer/_observer.py @@ -189,11 +189,11 @@ def _generate_times(self, rng): replace=False, shuffle=False)) else: - self.time_indices = np.where( + self.times = self.state_vec.time[np.where( rng.binomial(1, p=self.random_time_density, - size=self.state_vec.time_dim + size=self.state_vec.sizes['time'] ).astype('bool') - )[0] + )[0]] def _generate_stationary_locs(self, rng): if self.random_location_count is not None: From e5bfe0e6a2a38e47ca1312b2a6f8cf1156f49c7a Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Fri, 27 Sep 2024 10:50:22 -0600 Subject: [PATCH 26/44] Add permissible xarray jax to pyproject toml --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 77dfabe..d6976c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,8 @@ dependencies = [ "scipy", "optax", "xarray", - "cftime" + "cftime", + "xarray_jax@git+git+https://github.com/kysolvik/xarray_jax_permissible.git" ] [project.optional-dependencies] From 789019aed233c5b9df4f08bb03f88a71bc8988ad Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Fri, 27 Sep 2024 10:56:12 -0600 Subject: [PATCH 27/44] Remove all vector module imports --- dabench/dacycler/_etkf.py | 2 +- dabench/dacycler/_var3d.py | 7 +++---- dabench/dacycler/_var4d.py | 4 ++-- dabench/dacycler/_var4d_backprop.py | 4 ++-- 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/dabench/dacycler/_etkf.py b/dabench/dacycler/_etkf.py index c2347b5..b5514f5 100644 --- a/dabench/dacycler/_etkf.py +++ b/dabench/dacycler/_etkf.py @@ -189,4 +189,4 @@ def _cycle_obsop(self, x0_xarray, obs_values, obs_loc_indices, R=R, rho=self.multiplicative_inflation) - return x0_xarray.assign(x=(['ensemble','i'], xa.T)) \ No newline at end of file + return x0_xarray.assign(x=(['ensemble','i'], xa.T)) diff --git a/dabench/dacycler/_var3d.py b/dabench/dacycler/_var3d.py index eacffe2..2511902 100644 --- a/dabench/dacycler/_var3d.py +++ b/dabench/dacycler/_var3d.py @@ -4,7 +4,7 @@ import jax.numpy as jnp import jax.scipy as jscipy -from dabench import dacycler, vector +from dabench import dacycler class Var3D(dacycler.DACycler): @@ -72,9 +72,8 @@ def _cycle_obsop(self, x0_xarray, obs_values, obs_loc_indices, B = self.B print(H) - # make inputs column vectors xb = x0_xarray.to_stacked_array('system',[]).data.flatten() - yo = obs_values.flatten().T + yo = obs_values.flatten() # Apply masks to H H = jnp.where(obs_time_mask.flatten(), H.T, 0).T @@ -95,4 +94,4 @@ def _cycle_obsop(self, x0_xarray, obs_values, obs_loc_indices, xa, ierr = jscipy.sparse.linalg.cg(A, b1, x0=xb, tol=1e-05, maxiter=1000) - return x0_xarray.assign(x=(x0_xarray.dims, xa.T)) \ No newline at end of file + return x0_xarray.assign(x=(x0_xarray.dims, xa.T)) diff --git a/dabench/dacycler/_var4d.py b/dabench/dacycler/_var4d.py index 20f505e..934f716 100644 --- a/dabench/dacycler/_var4d.py +++ b/dabench/dacycler/_var4d.py @@ -14,7 +14,7 @@ import xarray as xr import xarray_jax as xj -from dabench import dacycler, vector +from dabench import dacycler import dabench.dacycler._utils as dac_utils @@ -240,4 +240,4 @@ def _cycle_obsop(self, xb0, obs_values, obs_loc_indices, x0_new, all_x0s = jax.lax.scan(outerloop_4d_func, init=xj.from_xarray(x0_new), xs=None, length=self.n_outer_loops) - return x0_new.to_xarray() \ No newline at end of file + return x0_new.to_xarray() diff --git a/dabench/dacycler/_var4d_backprop.py b/dabench/dacycler/_var4d_backprop.py index 6cd3a1c..6a04a06 100644 --- a/dabench/dacycler/_var4d_backprop.py +++ b/dabench/dacycler/_var4d_backprop.py @@ -14,7 +14,7 @@ import xarray as xr import xarray_jax as xj -from dabench import dacycler, vector +from dabench import dacycler import dabench.dacycler._utils as dac_utils @@ -262,4 +262,4 @@ def _cycle_obsop(self, x0_xarray, obs_values, obs_loc_indices, x0_new = epoch_state_tuple[0].to_xarray() - return x0_new \ No newline at end of file + return x0_new From f0f982db3cb16c947ab492f878c6401e31225416 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Fri, 27 Sep 2024 11:48:53 -0600 Subject: [PATCH 28/44] Rename i to index for toy data generators --- dabench/data/_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dabench/data/_data.py b/dabench/data/_data.py index 5ffaa88..94e3cd6 100644 --- a/dabench/data/_data.py +++ b/dabench/data/_data.py @@ -45,7 +45,7 @@ def __init__(self, # Default var and coord names self.var_names = ['x'] - self.coord_names = ['i'] + self.coord_names = ['index'] # x0 attribute is property to better convert between jax/numpy self._x0 = x0 From e8bd8979cfce7df39d361492ad181e8c75963b7b Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Fri, 27 Sep 2024 11:50:47 -0600 Subject: [PATCH 29/44] GCP system dim now specified --- dabench/data/gcp.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/dabench/data/gcp.py b/dabench/data/gcp.py index 03e6a5f..8d0e9b7 100644 --- a/dabench/data/gcp.py +++ b/dabench/data/gcp.py @@ -112,6 +112,9 @@ def _load_gcp_era5(self): # Subset by lon boundaries ds = ds.sel(longitude=slice(subset_min_lon, subset_max_lon)) + # Assign system dimension + ds = ds.to_stacked_array('system',['time']).sizes['system'] + return ds def generate(self): From 17cab8f8f6b84c055d3708e5f2b4de849fa6fca2 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Fri, 27 Sep 2024 11:51:25 -0600 Subject: [PATCH 30/44] Observer can accept list of error_sds, and now samples with replacement when there is more than 1 dimension to sample along --- dabench/observer/_observer.py | 46 ++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/dabench/observer/_observer.py b/dabench/observer/_observer.py index afefded..5e79b61 100644 --- a/dabench/observer/_observer.py +++ b/dabench/observer/_observer.py @@ -94,10 +94,6 @@ def __init__(self, {'variable': self.state_vec.data_vars} # 'variable_index': np.arange(len(self.state_vec.data_vars))} ) - self.state_vec = self.state_vec.assign_coords( - {'{}_index'.format(coord): (coord, np.arange(self.state_vec.sizes[coord])) - for coord in self._coord_names} - ) # The system_index corresponds to the points location in a flattened # array (i.e. state_vec[state_vec.data_vars].to_array().data.flatten()) self.state_vec = self.state_vec.assign( @@ -147,10 +143,8 @@ def __init__(self, elif not len(self.error_bias) == self.state_vec.system_dim: raise ValueError( "List of error biases has length {}." - "Must match either system_dim ({}) or " - "number of location indices ({})".format( - len(self.error_bias), self.state_vec.system_dim, - self.location_indices.shape[0])) + "Must match system_dim ({}) or ".format( + len(self.error_bias), self.state_vec.system_dim)) elif isinstance(self.error_bias, list): if self.store_as_jax: self.error_bias = jnp.array(self.error_bias) @@ -166,10 +160,8 @@ def __init__(self, elif not len(self.error_sd) == self.state_vec.system_dim: raise ValueError( "List of error sds has length {}." - "Must match either system_dim ({}) or " - "number of location indices ({})".format( - len(self.error_sd), self.state_vec.system_dim, - self.location_indices.shape[0])) + "Must match system_dim ({})".format( + len(self.error_sd), self.state_vec.system_dim)) elif isinstance(self.error_sd, list): if self.store_as_jax: self.error_sd = jnp.array(self.error_sd) @@ -203,12 +195,16 @@ def _generate_stationary_locs(self, rng): rng.binomial(1, p=self.random_location_density, size=self.state_vec.system_dim)) + if len(self._nontime_coord_names) > 1: + sample_w_replace=True + else: + sample_w_replace=False self.locations = { coord_name: xr.DataArray( rng.choice( self.state_vec[coord_name], size=location_count, - replace=False, + replace=sample_w_replace, shuffle=False), dims=['observations']) for coord_name in self._nontime_coord_names @@ -230,12 +226,17 @@ def _generate_nonstationary_locs(self, rng): ) for i in range(self.times.shape[0])] + if len(self._nontime_coord_names) > 1: + sample_w_replace=True + else: + sample_w_replace=False + self.locations = [{ coord_name: xr.DataArray( rng.choice( self.state_vec[coord_name], size=lc, - replace=False, + replace=sample_w_replace, shuffle=False), dims=['observations']) for coord_name in self._nontime_coord_names @@ -297,16 +298,24 @@ def observe(self): for i, t in enumerate(self.times)], dim='time') + # Transpose system_index to ensure consistency with flattened data + obs_vec['system_index'] = obs_vec['system_index'].transpose('variable','time','observations').fillna( + 0).astype(int) + # Generate errors errors_vec_size = ((self.time_dim,) + (self.location_dim,) + (obs_vec.sizes['variable'],)) + errors_vec_size = ((obs_vec.sizes['variable'],) + + (self.time_dim,) + + (self.location_dim,)) + if self._error_bias_is_list: - error_bias = self.error_bias[self.location_indices] + error_bias = self.error_bias[obs_vec['system_index'].data] else: error_bias = self.error_bias if self._error_sd_is_list: - error_sd = self.error_sd[self.location_indices] + error_sd = self.error_sd[obs_vec['system_index'].data] else: error_sd = self.error_sd errors_vector = rng.normal(loc=error_bias, @@ -322,12 +331,9 @@ def observe(self): errors_vector[errors_vector < 0.] = 0. # Save errors and apply them to observations - obs_vec = obs_vec.assign(errors=(obs_vec.dims, errors_vector)) + obs_vec = obs_vec.assign(errors=(obs_vec['system_index'].dims, errors_vector)) for data_var in obs_vec['variable'].values: obs_vec[data_var] = obs_vec[data_var] + obs_vec['errors'].sel(variable=data_var) - # Transpose system_index to ensure consistency with flattened data - obs_vec['system_index'] = obs_vec['system_index'].transpose('variable','time','observations').fillna( - 0).astype(int) return obs_vec \ No newline at end of file From 43c9d2b1713b1bef03dcf991acd044a2c2457886 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Fri, 27 Sep 2024 14:41:08 -0600 Subject: [PATCH 31/44] Fixed issue with missing time offset for 4dvar and 4dvarBP --- dabench/dacycler/_dacycler.py | 19 ++++++++++++++----- dabench/dacycler/_var4d.py | 4 +++- dabench/dacycler/_var4d_backprop.py | 11 +++++++---- 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/dabench/dacycler/_dacycler.py b/dabench/dacycler/_dacycler.py index 3bcd589..7389998 100644 --- a/dabench/dacycler/_dacycler.py +++ b/dabench/dacycler/_dacycler.py @@ -41,6 +41,7 @@ def __init__(self, R=None, H=None, h=None, + analysis_time_in_window=None ): self.h = h @@ -52,6 +53,7 @@ def __init__(self, self.system_dim = system_dim self.delta_t = delta_t self.model_obj = model_obj + self.analysis_time_in_window = analysis_time_in_window def _calc_default_H(self, obs_values, obs_loc_indices): @@ -190,11 +192,14 @@ def cycle(self, if obs_error_sd is None: obs_error_sd = obs_vector.error_sd + self.analysis_window = analysis_window # If don't specify analysis_time_in_window, is assumed to be middle - if analysis_time_in_window is None: - analysis_time_in_window = analysis_window/2 + if self.analysis_time_in_window is None and analysis_time_in_window is None: + analysis_time_in_window = self.analysis_window/2 + else: + analysis_time_in_window = self.analysis_time_in_window # Steps per window + 1 to include start self.steps_per_window = round(analysis_window/self.delta_t) + 1 @@ -209,12 +214,16 @@ def cycle(self, analysis_window, n_cycles) + + if self.steps_per_window is None: + self.steps_per_window = round(analysis_window/self.delta_t) + 1 + self._model_timesteps = jnp.arange(self.steps_per_window)*self.delta_t # Get the obs vectors for each analysis window all_filtered_idx = dac_utils._get_obs_indices( obs_times=jnp.array(obs_vector.time.values), analysis_times=all_times+_time_offset, start_inclusive=True, - end_inclusive=False, + end_inclusive=self.in_4d, analysis_window=analysis_window ) input_state = input_state.assign(_cur_time=start_time) @@ -248,6 +257,6 @@ def cycle(self, ).rename_dims({'time': 'cycle_timestep'}) if return_forecast: - return all_vals_xr + return all_vals_xr.drop_isel(cycle_timestep=-1) else: - return all_vals_xr.isel(cycle_timestep=0) \ No newline at end of file + return all_vals_xr.isel(cycle_timestep=0) diff --git a/dabench/dacycler/_var4d.py b/dabench/dacycler/_var4d.py index 934f716..96535b7 100644 --- a/dabench/dacycler/_var4d.py +++ b/dabench/dacycler/_var4d.py @@ -66,6 +66,7 @@ def __init__(self, n_outer_loops=1, steps_per_window=1, obs_window_indices=None, + analysis_time_in_window=0, **kwargs ): @@ -83,7 +84,8 @@ def __init__(self, model_obj=model_obj, in_4d=True, ensemble=False, - B=B, R=R, H=H, h=h) + B=B, R=R, H=H, h=h, + analysis_time_in_window=analysis_time_in_window) def _calc_default_H(self, obs_loc_indices): Hs = jnp.zeros((obs_loc_indices.shape[0], obs_loc_indices.shape[1], diff --git a/dabench/dacycler/_var4d_backprop.py b/dabench/dacycler/_var4d_backprop.py index 6a04a06..80cfcfc 100644 --- a/dabench/dacycler/_var4d_backprop.py +++ b/dabench/dacycler/_var4d_backprop.py @@ -74,6 +74,7 @@ def __init__(self, steps_per_window=None, obs_window_indices=None, loss_growth_limit=10, + analysis_time_in_window=0, **kwargs ): @@ -93,7 +94,8 @@ def __init__(self, model_obj=model_obj, in_4d=True, ensemble=False, - B=B, R=R, H=H, h=h) + B=B, R=R, H=H, h=h, + analysis_time_in_window=analysis_time_in_window) def _calc_default_H(self, obs_loc_indices): Hs = jnp.zeros((obs_loc_indices.shape[0], obs_loc_indices.shape[1], @@ -115,7 +117,7 @@ def _callback_raise_error(self, error_method, loss_val): jax.debug.callback(error_method) return loss_val - @partial(jax.jit, static_argnums=[0]) + # @partial(jax.jit, static_argnums=[0]) def _calc_obs_term(self, pred_x, obs_vals, Ht, Rinv): pred_obs = pred_x @ Ht resid = pred_obs.ravel() - obs_vals.ravel() @@ -127,7 +129,7 @@ def _make_loss(self, xb0, obs_vals, Hs, Binv, Rinv, obs_time_mask, n_steps): """Define loss function based on 4dvar cost""" - @jax.jit + # @jax.jit def loss_4dvarcost(x0): # Get initial departure db0 = (x0.to_array().data.ravel() - xb0.to_array().data.ravel()) @@ -165,7 +167,7 @@ def _make_backprop_epoch(self, loss_func, optimizer, hessian_inv): loss_value_grad = value_and_grad(loss_func, argnums=0) - @jax.jit + # @jax.jit def _backprop_epoch(epoch_state_tuple, i): x0, init_loss, opt_state = epoch_state_tuple x0 = x0.to_xarray() @@ -256,6 +258,7 @@ def _cycle_obsop(self, x0_xarray, obs_values, obs_loc_indices, # Make initial forecast and calculate loss backprop_epoch_func = self._make_backprop_epoch(loss_func, optimizer, hessian_inv) + # epoch_state_tuple, loss_vals = backprop_epoch_func((xj.from_xarray(x0_xarray), 0., opt_state),0) epoch_state_tuple, loss_vals = jax.lax.scan( backprop_epoch_func, init=(xj.from_xarray(x0_xarray), 0., opt_state), xs=jnp.arange(self.num_iters)) From 5a916f5cd026ded72bd051d57a88a0ff782d35cd Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Mon, 30 Sep 2024 10:00:02 -0600 Subject: [PATCH 32/44] Reassign coords to match input state within dacycler --- dabench/dacycler/_dacycler.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dabench/dacycler/_dacycler.py b/dabench/dacycler/_dacycler.py index 7389998..5f7d06b 100644 --- a/dabench/dacycler/_dacycler.py +++ b/dabench/dacycler/_dacycler.py @@ -112,7 +112,10 @@ def _cycle_and_forecast(self, cur_state, filtered_idx): ) # 3. Forecast next timestep next_state, forecast_states = self._step_forecast(analysis, n_steps=self.steps_per_window) - next_state = next_state.assign(_cur_time = cur_time + self.analysis_window) + next_state = next_state.assign( + _cur_time = cur_time + self.analysis_window + ).assign_coords( + cur_state.coords) return xj.from_xarray(next_state), forecast_states From 6a2e44e82d0e03b8a4463758a0b3e67dca28104f Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Mon, 30 Sep 2024 10:00:31 -0600 Subject: [PATCH 33/44] Apply coord reassing for 4dvar cycler too --- dabench/dacycler/_dacycler.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dabench/dacycler/_dacycler.py b/dabench/dacycler/_dacycler.py index 5f7d06b..6afc605 100644 --- a/dabench/dacycler/_dacycler.py +++ b/dabench/dacycler/_dacycler.py @@ -152,7 +152,10 @@ def _cycle_and_forecast_4d(self, cur_state, filtered_idx): # 3. Forecast forward next_state, forecast_states = self._step_forecast(analysis, n_steps=self.steps_per_window) - next_state = next_state.assign(_cur_time = cur_time + self.analysis_window) + next_state = next_state.assign( + _cur_time = cur_time + self.analysis_window + ).assign_coords( + cur_state.coords) return xj.from_xarray(next_state), forecast_states From 2758b823f3d83614beaaf9e62477ad77fc1bf9b9 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Mon, 30 Sep 2024 10:01:01 -0600 Subject: [PATCH 34/44] Remove unnecessary print from 3dvar --- dabench/dacycler/_var3d.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dabench/dacycler/_var3d.py b/dabench/dacycler/_var3d.py index 2511902..271b969 100644 --- a/dabench/dacycler/_var3d.py +++ b/dabench/dacycler/_var3d.py @@ -70,7 +70,6 @@ def _cycle_obsop(self, x0_xarray, obs_values, obs_loc_indices, B = self._calc_default_B() else: B = self.B - print(H) xb = x0_xarray.to_stacked_array('system',[]).data.flatten() yo = obs_values.flatten() From a5157680e1babf8e6cee69ebd3a319f5ff74c6b0 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Mon, 30 Sep 2024 10:45:01 -0600 Subject: [PATCH 35/44] Reassign coords for outer loop carry instead of dropping time --- dabench/dacycler/_var4d.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dabench/dacycler/_var4d.py b/dabench/dacycler/_var4d.py index 96535b7..2a5242f 100644 --- a/dabench/dacycler/_var4d.py +++ b/dabench/dacycler/_var4d.py @@ -152,13 +152,13 @@ def _outerloop_4d(x0, _): ) # 4D-Var inner loop - x0 = self._innerloop_4d(self.system_dim, + x0_new = self._innerloop_4d(self.system_dim, x, xb0, obs_values, Hs, B, Rinv, M, obs_window_indices, obs_time_mask) - return xj.from_xarray(x0.drop_vars('time')), x0 + return xj.from_xarray(x0_new.assign_coords(x0.coords)), x0 return _outerloop_4d From 96721df2b4522ef181746ba05425da343b7590d0 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Mon, 30 Sep 2024 11:02:20 -0600 Subject: [PATCH 36/44] Updated gcp to properly assign system_dim --- dabench/data/gcp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dabench/data/gcp.py b/dabench/data/gcp.py index 8d0e9b7..e479989 100644 --- a/dabench/data/gcp.py +++ b/dabench/data/gcp.py @@ -113,7 +113,7 @@ def _load_gcp_era5(self): ds = ds.sel(longitude=slice(subset_min_lon, subset_max_lon)) # Assign system dimension - ds = ds.to_stacked_array('system',['time']).sizes['system'] + ds = ds.assign_coords(system_dim=ds.to_stacked_array('system',['time']).sizes['system']) return ds From 2c7686de90948d8870790a7724dd4d35aa354a8a Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Mon, 30 Sep 2024 11:19:02 -0600 Subject: [PATCH 37/44] Assign system_dim as attr, not coord --- dabench/data/gcp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dabench/data/gcp.py b/dabench/data/gcp.py index e479989..3e8727c 100644 --- a/dabench/data/gcp.py +++ b/dabench/data/gcp.py @@ -113,7 +113,7 @@ def _load_gcp_era5(self): ds = ds.sel(longitude=slice(subset_min_lon, subset_max_lon)) # Assign system dimension - ds = ds.assign_coords(system_dim=ds.to_stacked_array('system',['time']).sizes['system']) + ds = ds.assign_attrs(system_dim=ds.to_stacked_array('system',['time']).sizes['system']) return ds From 974613cabdc73588186cc1bf9062b236fb2c65e5 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Mon, 30 Sep 2024 11:54:14 -0600 Subject: [PATCH 38/44] Fix typo for xarray_jax git repo --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d6976c4..6714bc2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "optax", "xarray", "cftime", - "xarray_jax@git+git+https://github.com/kysolvik/xarray_jax_permissible.git" + "xarray_jax@git+https://github.com/kysolvik/xarray_jax_permissible.git" ] [project.optional-dependencies] From 3fdfabd100d7ec6cc426e9a5c72c2b0b624527a9 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Tue, 1 Oct 2024 19:48:46 -0600 Subject: [PATCH 39/44] XArray accessors for helper methods --- dabench/data/__init__.py | 1 + dabench/data/_xarray_accessor.py | 62 ++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+) create mode 100644 dabench/data/_xarray_accessor.py diff --git a/dabench/data/__init__.py b/dabench/data/__init__.py index 5c9509e..11e367e 100644 --- a/dabench/data/__init__.py +++ b/dabench/data/__init__.py @@ -9,6 +9,7 @@ from .barotropic import Barotropic from .enso_indices import ENSOIndices from .qgs import QGS +from ._xarray_accessor import DABenchDatasetAccessor, DABenchDataArrayAccessor __all__ = [ 'Data', diff --git a/dabench/data/_xarray_accessor.py b/dabench/data/_xarray_accessor.py new file mode 100644 index 0000000..205cc6d --- /dev/null +++ b/dabench/data/_xarray_accessor.py @@ -0,0 +1,62 @@ +import xarray as xr +import numpy as np +import warnings + + +def _check_split_lengths(xr_obj, split_lengths): + total_length = np.sum(split_lengths) + xr_timedim = xr_obj.sizes['time'] + if xr_timedim < total_length: + warnings.warn("Specified split lengths ({}) exceed \n" + "Xarray object's time dimension ({}).".format( + split_lengths, xr_timedim + )) + elif xr_timedim > total_length: + warnings.warn("Specified split lengths ({}) are shorter than " + "Xarray object's time dimension ({}).".format( + split_lengths, xr_timedim + )) + + +@xr.register_dataset_accessor('dabench') +class DABenchDatasetAccessor: + """Helper methods for manipulating xarray Datasets""" + def __init__(self, xarray_obj): + self._obj = xarray_obj + + def to_system(self): + if 'time' in self._obj.coords: + remaining_dim = ['time'] + else: + remaining_dim = [] + return self._obj.to_stacked_array('system', remaining_dim) + + def split_train_val_test(self, split_lengths): + _check_split_lengths(self._obj, split_lengths) + out_ds = [] + start_i = 0 + for sl in split_lengths: + end_i = start_i + sl + out_ds.append(self._obj.isel(time=slice(start_i, end_i))) + return tuple(out_ds) + + +@xr.register_dataarray_accessor('dabench') +class DABenchDataArrayAccessor: + """Helper methods for manipulating xarray DataArrays""" + def __init__(self, xarray_obj): + self._obj = xarray_obj + + def to_gridded(self): + return self._obj.to_unstacked_dataset('system') + + def split_train_val_test(self, split_lengths): + _check_split_lengths(self._obj, split_lengths) + out_ds = [] + start_i = 0 + for sl in split_lengths: + end_i = start_i + sl + out_ds.append(self._obj.isel(time=slice(start_i, end_i))) + return tuple(out_ds) + + From 93aff5cc3f6da23561d88b450d006d4970d67750 Mon Sep 17 00:00:00 2001 From: Steve Penny Date: Wed, 2 Oct 2024 19:54:37 -0600 Subject: [PATCH 40/44] bugfix --- dabench/dasupport/generate_era5_ensemble.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dabench/dasupport/generate_era5_ensemble.py b/dabench/dasupport/generate_era5_ensemble.py index 50dcb6b..bdefccc 100644 --- a/dabench/dasupport/generate_era5_ensemble.py +++ b/dabench/dasupport/generate_era5_ensemble.py @@ -244,4 +244,5 @@ def main( sample_strategy=args.sample_strategy, start_date=args.start_date, ensemble_size=args.ensemble_size, - era5_path=args.e + era5_path=args.era5_path, + ) From becfa0f460521035b2d1e7ab9bb349903f3152fe Mon Sep 17 00:00:00 2001 From: Steve Penny Date: Thu, 3 Oct 2024 01:41:42 -0600 Subject: [PATCH 41/44] add support for setting up era5 initial ensemble --- dabench/__init__.py | 2 +- dabench/dasupport/__init__.py | 5 + dabench/dasupport/__pycache__ | 0 dabench/dasupport/generate_era5_ensemble.py | 47 ++--- dabench/metrics/_ensemble.py | 191 ++++++++++++++++++++ dabench/utils/__init__.py | 5 + dabench/utils/timing.py | 62 +++++++ 7 files changed, 290 insertions(+), 22 deletions(-) create mode 100644 dabench/dasupport/__init__.py create mode 100644 dabench/dasupport/__pycache__ create mode 100644 dabench/metrics/_ensemble.py create mode 100644 dabench/utils/__init__.py create mode 100644 dabench/utils/timing.py diff --git a/dabench/__init__.py b/dabench/__init__.py index 8d4cbd2..011c448 100644 --- a/dabench/__init__.py +++ b/dabench/__init__.py @@ -1 +1 @@ -from . import data, vector, model, observer, obsop, dacycler, _suppl_data +from . import data, vector, model, observer, obsop, dacycler, dasupport, _suppl_data, utils diff --git a/dabench/dasupport/__init__.py b/dabench/dasupport/__init__.py new file mode 100644 index 0000000..c234c37 --- /dev/null +++ b/dabench/dasupport/__init__.py @@ -0,0 +1,5 @@ +from .generate_era5_ensemble import GenEra5Ens + +__all__ = [ + 'GenEra5Ens', + ] diff --git a/dabench/dasupport/__pycache__ b/dabench/dasupport/__pycache__ new file mode 100644 index 0000000..e69de29 diff --git a/dabench/dasupport/generate_era5_ensemble.py b/dabench/dasupport/generate_era5_ensemble.py index bdefccc..b7d9876 100644 --- a/dabench/dasupport/generate_era5_ensemble.py +++ b/dabench/dasupport/generate_era5_ensemble.py @@ -10,7 +10,7 @@ import xarray as xr from dateutil.relativedelta import relativedelta -from helpers.timing import report_timing +from ..utils.timing import report_timing # Selected vars for ERA5 ensemble # This will reduce the number of model fields processed and stored in the ensemble @@ -106,7 +106,7 @@ def parse_arguments(): #%% Define the initial ensemble -def define_init_ensemble( +def _define_init_ensemble( ensemble_size, init_ensemble_start_date, init_ensemble_sample_strategy="multi_year" ): @@ -130,7 +130,7 @@ def define_init_ensemble( return init_ensemble_member_dates -def main( +def GenEra5Ens( date_format:str="%Y%m%dZ%H", atmosphere_ensemble_s3_key:str=None, target_date:datetime=datetime.strptime("19990101Z00","%Y%m%dZ%H"), @@ -138,6 +138,7 @@ def main( start_date:datetime=datetime.strptime("19981231Z00","%Y%m%dZ%H"), ensemble_size:int=4, era5_path:str="gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3", + verbose:bool=False, ): #%% Set up the gcp access to era5 @@ -146,7 +147,7 @@ def main( ds_era5 = xr.open_zarr(gcs.get_mapper(era5_path), chunks=None) else: raise Exception("Non-GCP source for ERA5 not yet supported. EXITING...") - report_timing(timing_label="build_test_ensemble_era5:: access remote zarr store") + report_timing(timing_label="GenEra5Ens:: access remote zarr store") #%% Reorder the latitudes # Following: @@ -157,7 +158,7 @@ def main( assert ds_era5.latitude[0] < ds_era5.latitude[-1] #%% Determine dates for initial ensemble sampling - init_ensemble_member_dates = define_init_ensemble( + init_ensemble_member_dates = _define_init_ensemble( ensemble_size=ensemble_size, init_ensemble_start_date=start_date, init_ensemble_sample_strategy=sample_strategy, @@ -168,18 +169,20 @@ def main( #%% Sample from era5 ds_init_ens = ds_era5[ERA5_CONTROL_VARIABLES].sel(time=init_ensemble_member_dates) report_timing( - timing_label="build_test_ensemble_era5:: select time steps as ensemble members" + timing_label="GenEra5Ens:: select time steps as ensemble members" ) - print(ds_init_ens) + if verbose: + print(ds_init_ens) #%% Update time to target and add ensemble dimension ds_init_ens = ds_init_ens.rename_dims(dims_dict={"time": "member"}) ds_init_ens["member"] = range(ensemble_size) ds_init_ens = ds_init_ens.drop_vars("time") report_timing( - timing_label="build_test_ensemble_era5:: add member dimension to replace time" + timing_label="GenEra5Ens:: add member dimension to replace time" ) - print(ds_init_ens) + if verbose: + print(ds_init_ens) #%% Select target date from era5 for recentering the ensemble ds_target = ds_era5[ERA5_CONTROL_VARIABLES].sel(time=target_date) @@ -190,43 +193,45 @@ def main( ds_init_ens['ws10n'] = (ds_init_ens['10m_u_component_of_neutral_wind']**2 + ds_init_ens['10m_v_component_of_neutral_wind']**2)**(0.5) ds_target['ws10n'] = (ds_target['10m_u_component_of_neutral_wind']**2 + ds_target['10m_v_component_of_neutral_wind']**2)**(0.5) report_timing( - timing_label="build_test_ensemble_era5:: computing neutral wind speeds at 10m (ws10n)" + timing_label="GenEra5Ens:: computing neutral wind speeds at 10m (ws10n)" ) if ('10m_u_component_of_wind' in ERA5_CONTROL_VARIABLES and '10m_v_component_of_wind' in ERA5_CONTROL_VARIABLES): ds_init_ens['ws10m'] = (ds_init_ens['10m_u_component_of_wind']**2 + ds_init_ens['10m_v_component_of_wind']**2)**(0.5) ds_target['ws10m'] = (ds_target['10m_u_component_of_wind']**2 + ds_target['10m_v_component_of_wind']**2)**(0.5) report_timing( - timing_label="build_test_ensemble_era5:: computing diagnostic wind speeds at 10m (ws10m)" + timing_label="GenEra5Ens:: computing diagnostic wind speeds at 10m (ws10m)" ) #%% Recenter ensemble to target date - print(f'build_test_ensemble_era5:: re-centering ensemble with ensemble_size = {ensemble_size} to target_date = {target_date}...') + print(f'GenEra5Ens:: re-centering ensemble with ensemble_size = {ensemble_size} to target_date = {target_date}...') ds_mean = ds_init_ens.mean(dim="member") ds_diff = ds_target - ds_mean ds_init_ens = ds_init_ens + ds_diff report_timing( - timing_label="build_test_ensemble_era5:: recenter ensemble to target date" + timing_label="GenEra5Ens:: recenter ensemble to target date" ) - print(ds_init_ens) + if verbose: + print(ds_init_ens) #%% Now add time back on as a singleton dimension ds_init_ens = ds_init_ens.expand_dims(dim={"time": [target_date]}, axis=0) report_timing( - timing_label="build_test_ensemble_era5:: add time dimension back on to dataset structure" + timing_label="GenEra5Ens:: add time dimension back on to dataset structure" ) - print(ds_init_ens) + if verbose: + print(ds_init_ens) #%% Add some checks to make sure dimensions haven't changed assert ds_era5.sizes['latitude'] == ds_init_ens.sizes['latitude'] assert ds_era5.sizes['longitude'] == ds_init_ens.sizes['longitude'] assert ds_era5.sizes['level'] == ds_init_ens.sizes['level'] - #%% Upload to s3 as zarr - print('Uploading to s3 zarr...') + #%% Store to zarr (locally or on e.g. AWS s3) + print('Storing as zarr...') ds_init_ens.to_zarr(atmosphere_ensemble_s3_key, mode="w") report_timing( - timing_label="build_test_ensemble_era5:: upload to s3 as a new zarr store" + timing_label="GenEra5Ens:: upload to s3 as a new zarr store" ) @@ -235,9 +240,9 @@ def main( args = parse_arguments() # %% Process input arguments - report_timing(timing_label="build_test_ensemble_era5:: initializing...") + report_timing(timing_label="GenEra5Ens:: initializing...") - main( + GenEra5Ens( date_format=args.date_format, atmosphere_ensemble_s3_key=args.atmosphere_ensemble_s3_key, target_date=args.target_date, diff --git a/dabench/metrics/_ensemble.py b/dabench/metrics/_ensemble.py new file mode 100644 index 0000000..763b366 --- /dev/null +++ b/dabench/metrics/_ensemble.py @@ -0,0 +1,191 @@ +"""Ensemble forecast metrics""" + +import jax.numpy as jnp +from dabench.metrics import _utils + + +__all__ = [ + 'rank_histogram', + 'crps_ensemble', + ] + + +def rank_histogram(observations, forecasts, dim=None, member_dim="member"): + """JAX array implementation of Rank Histogram + + Description: + (from https://www.cawcr.gov.au/projects/verification/#Methods_for_EPS) + + Answers the question: How well does the ensemble spread of the forecast represent the true variability (uncertainty) of the observations? + + Also known as a "Talagrand diagram", this method checks where the verifying observation usually falls with respect to the ensemble forecast data, which is arranged in increasing order at each grid point. In an ensemble with perfect spread, each member represents an equally likely scenario, so the observation is equally likely to fall between any two members. + + To construct a rank histogram, do the following: + 1. At every observation (or analysis) point rank the N ensemble members from lowest to highest. This represents N+1 possible bins that the observation could fit into, including the two extremes + 2. Identify which bin the observation falls into at each point + 3. Tally over many observations to create a histogram of rank. + + Interpretation: + Flat - ensemble spread about right to represent forecast uncertainty + U-shaped - ensemble spread too small, many observations falling outside the extremes of the ensemble + Dome-shaped - ensemble spread too large, most observations falling near the center of the ensemble + Asymmetric - ensemble contains bias + + Note: A flat rank histogram does not necessarily indicate a good forecast, it only measures whether the observed probability distribution is well represented by the ensemble. + + Args: + predictions (ndarray): Array of predictions + targets (ndarray): Array of targets to compare against. Shape must + be broadcastable to shape of predictions. + + Returns: + [UPDATE] Float, Pearson's R correlation coefficient. + """ + + # RMSD = sqrt( 1/(N+1) * sum(Sk - M/(N+1)^2) ) + + # See: https://github.com/xarray-contrib/xskillscore/blob/64f17fdd1816b64b9e13c3f2febb9800a7e6ed0c/xskillscore/core/probabilistic.py#L830C20-L830C76 + + def _rank_first(x, y): + """Concatenates x and y and returns the rank of the + first element along the last axes""" + xy = jnp.concatenate((x[..., jnp.newaxis], y), axis=-1) + return bn.nanrankdata(xy, axis=-1)[..., 0] + + if dim is not None: + if len(dim) == 0: + raise ValueError( + "At least one dimension must be supplied to compute rank histogram over" + ) + if member_dim in dim: + raise ValueError(f'"{member_dim}" cannot be specified as an input to dim') + + ranks = xr.apply_ufunc( + _rank_first, + observations, + forecasts, + input_core_dims=[[], [member_dim]], + dask="parallelized", + output_dtypes=[int], + ) + + bin_edges = jnp.arange(0.5, len(forecasts[member_dim]) + 2) + return histogram(ranks, bins=[bin_edges], bin_names=["rank"], dim=dim, bin_dim_suffix="") + + +def crps_ensemble(observations, forecasts, axis=-1): + """JAX array implementation of Continuous Ranked Probability Score + + (From: https://confluence.ecmwf.int/display/FUG/Section+12.B+Statistical+Concepts+-+Probabilistic+Data#:~:text=The%20Continuous%20Ranked%20Probability%20Score,the%20forecast%20is%20wholly%20inaccurate.) + + A generalisation of Ranked Probability Score (RPS) is the Continuous Rank Probability Score (CRPSS) where the thresholds are continuous rather than discrete (see Nurmi, 2003; Jollife and Stephenson, 2003; Wilks, 2006). The Continuous Ranked Probability Score (CRPS) is a measure of how good forecasts are in matching observed outcomes. Where: + + CRPS = 0 the forecast is wholly accurate; + CRPS = 1 the forecast is wholly inaccurate. + CRPS is calculated by comparing the Cumulative Distribution Functions (CDF) for the forecast against a reference dataset (observations, or analyses, or climatology) over a given period. + + Args: + predictions (ndarray): Array of predictions + targets (ndarray): Array of targets to compare against. Shape must + be broadcastable to shape of predictions. + + Returns: + [UPDATE] Float, Mean Squared Error + """ + + # Integral from -inf to inf: (1/M) * sum[ S [P_j(x) - H(x - x_oj)]^2 dx ] + # where Pj, H, and x_oj are the predicted cumulative distribution for case j, the Heaviside step function, + # and the observed value, respectively. + # (see: https://www.ecmwf.int/sites/default/files/elibrary/2007/10729-ensemble-forecasting.pdf) + # with M independent cases (e.g. different dates) + + # See: https://github.com/properscoring/properscoring/blob/a465b5578d4b661e662933e84fa7673a70e75e94/properscoring/_crps.py#L244 + + # Manage input quality + observations = jnp.asarray(observations) + forecasts = jnp.asarray(forecasts) + + if axis != -1: + # Move the axis to the end + forecasts = jnp.rollaxis(forecasts, axis, start=forecasts.ndim) + + if observations.shape not in [forecasts.shape, forecasts.shape[:-1]]: + raise ValueError('observations and forecasts must have matching ' + 'shapes or matching shapes except along `axis=%s`' + % axis) + + if observations.shape == forecasts.shape: + if weights is not None: + raise ValueError('cannot supply weights unless you also supply ' + 'an ensemble forecast') + return abs(observations - forecasts) + + # Sort forecast members by target quantity + idx = jnp.argsort(forecasts, axis=-1) + forecasts = forecasts[idx] + weights = jnp.ones_like(forecasts) + + return _crps_ensemble_vectorized(observation, forecasts, weights, result) + +# @guvectorize(["void(float64[:], float64[:], float64[:], float64[:])"], +# "(),(n),(n)->()", nopython=True) + + @partial(jnp.vectorize, signature='(),(n),(n)->()') + def _crps_ensemble_vectorized(observation, forecasts, weights, result): + # beware: forecasts are assumed sorted in NumPy's sort order + + # add asserts here: + + # we index the 0th element to get the scalar value from this 0d array: + # http://numba.pydata.org/numba-doc/0.18.2/user/vectorize.html#the-guvectorize-decorator + obs = observation[0] + + if jnp.isnan(obs): + result[0] = jnp.nan + return + + total_weight = 0.0 + for n, weight in enumerate(weights): + if jnp.isnan(forecasts[n]): + # NumPy sorts NaN to the end + break + if not weight >= 0: + # this catches NaN weights + result[0] = jnp.nan + return + total_weight += weight + + obs_cdf = 0 + forecast_cdf = 0 + prev_forecast = 0 + integral = 0 + + for n, forecast in enumerate(forecasts): + if jnp.isnan(forecast): + # NumPy sorts NaN to the end + if n == 0: + integral = jnp.nan + # reset for the sake of the conditional below + forecast = prev_forecast + break + + if obs_cdf == 0 and obs < forecast: + integral += (obs - prev_forecast) * forecast_cdf ** 2 + integral += (forecast - obs) * (forecast_cdf - 1) ** 2 + obs_cdf = 1 + else: + integral += ((forecast - prev_forecast) + * (forecast_cdf - obs_cdf) ** 2) + + forecast_cdf += weights[n] / total_weight + prev_forecast = forecast + + if obs_cdf == 0: + # forecast can be undefined here if the loop body is never executed + # (because forecasts have size 0), but don't worry about that because + # we want to raise an error in that case, anyways + integral += obs - forecast + + result[0] = integral + + diff --git a/dabench/utils/__init__.py b/dabench/utils/__init__.py new file mode 100644 index 0000000..b4600fe --- /dev/null +++ b/dabench/utils/__init__.py @@ -0,0 +1,5 @@ +from .timing import report_timing + +__all__ = [ + 'report_timing', + ] diff --git a/dabench/utils/timing.py b/dabench/utils/timing.py new file mode 100644 index 0000000..a7f0fc9 --- /dev/null +++ b/dabench/utils/timing.py @@ -0,0 +1,62 @@ +import datetime +import time + + +def report_timing(timing_label=""): + + if not hasattr(report_timing, "timing_start_time"): + report_timing.timing_start_process_time = time.process_time() + report_timing.timing_start_time = time.time() + report_timing.last_process_time = report_timing.timing_start_process_time + report_timing.last_time = report_timing.timing_start_time + return + + print(f"\n< === {timing_label} ===") + + # Print the current time, helpful for tracking long runs + now = datetime.datetime.now() + print(f"Current datetime is: {now}") + + print(" === ") + + # get process execution time + timing_end_process_time = time.process_time() + seconds = timing_end_process_time - report_timing.last_process_time + minutes = seconds / 60.0 + print(f"CPU Execution time of this step: {seconds} seconds or {minutes} minutes.") + seconds = timing_end_process_time - report_timing.timing_start_process_time + minutes = seconds / 60.0 + print(f"CPU Execution time so far: {seconds} seconds or {minutes} minutes.") + + print(" === ") + + # get wall clock time + timing_end_time = time.time() + seconds = timing_end_time - report_timing.last_time + minutes = seconds / 60.0 + print( + f"Wall Clock Execution time of this step: {seconds} seconds or {minutes} minutes." + ) + seconds = timing_end_time - report_timing.timing_start_time + minutes = seconds / 60.0 + print(f"Wall Clock Execution time so far: {seconds} seconds or {minutes} minutes.") + + print(f" === {timing_label} === >\n") + + # Set up to get estimate of time between calls + report_timing.last_process_time = timing_end_process_time + report_timing.last_time = timing_end_time + + +def _test(): + report_timing(timing_label="initializing...") + time.sleep(3) + report_timing(timing_label="3 second sleep.") + time.sleep(10) + report_timing(timing_label="10 second sleep.") + + +# %% Main access +if __name__ == "__main__": + # main(sys.argv) + _test() From 25bb39bbbaff3ce3da2ac03399afd12162222377 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Thu, 3 Oct 2024 13:40:30 -0600 Subject: [PATCH 42/44] Update xarray accessor methods: ds.dab.flatten() and da.dab.unflatten() --- dabench/data/_xarray_accessor.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dabench/data/_xarray_accessor.py b/dabench/data/_xarray_accessor.py index 205cc6d..faa344c 100644 --- a/dabench/data/_xarray_accessor.py +++ b/dabench/data/_xarray_accessor.py @@ -18,13 +18,13 @@ def _check_split_lengths(xr_obj, split_lengths): )) -@xr.register_dataset_accessor('dabench') +@xr.register_dataset_accessor('dab') class DABenchDatasetAccessor: """Helper methods for manipulating xarray Datasets""" def __init__(self, xarray_obj): self._obj = xarray_obj - def to_system(self): + def flatten(self): if 'time' in self._obj.coords: remaining_dim = ['time'] else: @@ -41,13 +41,13 @@ def split_train_val_test(self, split_lengths): return tuple(out_ds) -@xr.register_dataarray_accessor('dabench') +@xr.register_dataarray_accessor('dab') class DABenchDataArrayAccessor: """Helper methods for manipulating xarray DataArrays""" def __init__(self, xarray_obj): self._obj = xarray_obj - def to_gridded(self): + def unflatten(self): return self._obj.to_unstacked_dataset('system') def split_train_val_test(self, split_lengths): From 0d998735910ef237ccbca5fb29bd8867b52b54c4 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Thu, 3 Oct 2024 16:00:52 -0600 Subject: [PATCH 43/44] NeuralGCM model with configuration YAML --- dabench/model/_neuralgcm.py | 611 ++++++++++++++++++++++++++++++++++ dabench/model/_neuralgcm.yaml | 40 +++ 2 files changed, 651 insertions(+) create mode 100644 dabench/model/_neuralgcm.py create mode 100644 dabench/model/_neuralgcm.yaml diff --git a/dabench/model/_neuralgcm.py b/dabench/model/_neuralgcm.py new file mode 100644 index 0000000..081d721 --- /dev/null +++ b/dabench/model/_neuralgcm.py @@ -0,0 +1,611 @@ +#!/usr/bin/env python + +# Author: +# Stephen G. Penny +# 7/30/24 - 8/9/24 +# Adapted from: +# https://neuralgcm.readthedocs.io/en/latest/inference_demo.html + +# ============================================== +# dabench interface: +from dabench import vector, model + +# ============================================== +# Required for neuralGCM: +import jax +jax.config.update('jax_enable_x64', False) +import numpy as np +import pickle +import xarray as xr + +# For timing the run +import time + +# For managing time stamps +from datetime import datetime, timedelta + +# Dynamical core tools +from dinosaur import horizontal_interpolation +from dinosaur import spherical_harmonic +from dinosaur import xarray_utils + +# Full model with NN and dycore +import neuralgcm + +# Interface to Google Cloud Services +import gcsfs + +# For reading input yaml file +import yaml + +# For plotting +import matplotlib.pyplot as plt +# ============================================== + +# For type checking +from typing import Any, Callable, Mapping, MutableMapping, Sequence, TypeVar +DatasetOrDataArray = TypeVar( + 'DatasetOrDataArray', xr.Dataset, xr.DataArray +) + +class NeuralGCM(model.Model): + + def __init__(self, + system_dim=None, + time_dim=None, + delta_t=None, + model_obj=None, + params=None, + infile=None): + super().__init__(system_dim=None, + time_dim=None, + delta_t=None, + model_obj=None) + + # Infile override, if provided + if infile is not None: + params = self.load_config(infile) + + # ----------------------- + # Set up input defaults + # ----------------------- + + # Initialize gcs token + self.gcs = gcsfs.GCSFileSystem(token='anon') + + # Load the model + self.forcing_type = params.get("forcing_type", "deterministic") + self.atm_res = params.get("atm_res", "1_4") + self.model_name = f'neural_gcm_dynamic_forcing_{self.forcing_type}_{self.atm_res}_deg.pkl' + self.gcs_key = params.get("gcs_key", f'gs://gresearch/neuralgcm/04_30_2024/{self.model_name}') + + # Load ics + self.era5_path = params.get("era5_path", 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3') + self.ics_path = params.get("ics_path", self.era5_path) + self.bcs_path = params.get("bcs_path", self.era5_path) + self.forecast_hours = params.get("forecast_hours",4*24) + self.forecast_delta = timedelta(hours=self.forecast_hours) + self.data_stride = params.get("data_stride",24) + + self.start_time = params.get("start_time",'2020-02-14') + self.datetime_starttime = datetime.strptime(self.start_time, "%Y-%m-%d") # %H:%M:%S') + end_time = self.datetime_starttime + self.forecast_delta + self.end_time = end_time.strftime("%Y-%m-%d") #, %H:%M:%S") + + # Regrid the ics and bcs + self.interpolation_method = params.get("interpolation_method","conservative") + + # Timing options + self.inner_steps = params.get("inner_steps",6) # save model outputs once every 6 hours + self.outer_steps = params.get("outer_steps",4) # 4*24 // inner_steps, # total of 4 days + self.timedelta = np.timedelta64(1, 'h') * self.inner_steps + self.times = np.arange(self.outer_steps + 1)*self.timedelta + + # Pre-proceessing support + input_variables = {'u': 'u_component_of_wind', + 'v': 'v_component_of_wind', + 't': 'temperature', + 'q': 'specific_humidity', + 'zh': 'geopotential', + 'clwc': 'specific_cloud_liquid_water_content', + 'ciwc': 'specific_cloud_ice_water_content', + 'sst': 'sea_surface_temperature', + 'ciconc': 'sea_ice_cover', + } + + self.input_variables = params.get("input_variables", input_variables) + + # Needed to match up observations and forecasting for da cycler + + self.data_var_order = ['specific_cloud_liquid_water_content', 'specific_cloud_ice_water_content', + 'temperature', 'geopotential', 'specific_humidity', + 'u_component_of_wind', 'v_component_of_wind'] + + # May be useful to pull directly on these levels: + # https://www.ecmwf.int/en/forecasts/datasets/set-i#I-i-b + input_levels = np.array([ 1, 2, 3, 5, 7, 10, 20, 30, 50, 70, 100, 125, \ + 150, 175, 200, 225, 250, 300, 350, 400, 450, 500, 550, 600, \ + 650, 700, 750, 775, 800, 825, 850, 875, 900, 925, 950, 975, \ + 1000]) + + self.input_levels = params.get("input_levels", input_levels) + + # Run forecast + self.use_sfc_forecast = params.get("use_sfc_forecast",False) + self.random_seed = params.get("random_seed",42) + + # Plotting + self.plot_hpa_level=params.get("plot_hpa_level",850) + self.plot_x=params.get("plot_x",'longitude') + self.plot_y=params.get("plot_y",'latitude') + self.plot_row=params.get("plot_row",'time') + self.plot_col=params.get("plot_col",'model') + self.plot_robust=params.get("plot_robust",True) + self.plot_aspect=params.get("plot_aspect",2) + self.plot_size=params.get("plot_size",2) + self.plot_show=params.get("plot_show",True) + + attrs = vars(self) + print(', '.join("%s: %s" % item for item in attrs.items())) + + + def load_config(self, infile): + with open(infile, 'r') as f: + params = yaml.load(f, Loader=yaml.SafeLoader) + + params['model_name'] = f"neural_gcm_dynamic_forcing_{params['forcing_type']}_{params['atm_res']}_deg.pkl" + params['gcs_key'] = f"gs://gresearch/neuralgcm/04_30_2024/{params['model_name']}" + params['outer_steps'] = int(params['forecast_hours']) // int(params['inner_steps']) # Total of n days + print(type(self).__name__, params) + return params + + + def load_model(self): + + forcing_type = self.forcing_type + atm_res = self.atm_res + gcs_key = self.gcs_key + + # ----------------------------- + # Load the model + # ----------------------------- + with self.gcs.open(gcs_key, 'rb') as f: + ckpt = pickle.load(f) + + self._model = neuralgcm.PressureLevelModel.from_checkpoint(ckpt) + + def set_grid_info(self, full_data): + self.latitude_nodes=full_data.sizes['latitude'] + self.longitude_nodes=full_data.sizes['longitude'] + self.latitude_spacing=xarray_utils.infer_latitude_spacing(full_data.latitude) + self.longitude_offset=xarray_utils.infer_longitude_offset(full_data.longitude) + + + def load_ics(self, override=''): + + start_time = self.start_time + end_time = self.end_time + data_stride = self.data_stride + + era5_path = self.era5_path + full_data = xr.open_zarr(self.gcs.get_mapper(era5_path), chunks=None) + self.set_grid_info(full_data) + + print(f'Variables = {[self._model.input_variables]}') + + # TODO: make this only a single time step and add another method + # to load a 'reference' dataset for evaluation. + + # NOTE: The data slice below collects data for the evaluation + # only a single datetime is needed to initialize the forecast model. + # ALSO: Regridding requires the data to be first loaded into memory. + # Because this full dataset is gigantic (100s of TB) we’ll only + # regrid a few time slices: + sliced_data = ( + full_data + [self._model.input_variables] + .sel(time=slice(start_time, end_time, data_stride)) # Select data range from full reanalysis dataset + .compute() # Pull it into memory + ) + + return sliced_data + + + def load_bcs(self, override=''): + + start_time = self.start_time + end_time = self.end_time + data_stride = self.data_stride + + era5_path = self.era5_path + full_data = xr.open_zarr(self.gcs.get_mapper(era5_path), chunks=None) + + print(f'Variables = {[self._model.forcing_variables]}') + + # Regridding requires the data to be first loaded into memory. + # Because this full dataset is gigantic (100s of TB) we’ll only + # regrid a single time point: + sliced_data = ( + full_data + [self._model.forcing_variables] +# [self._model.input_variables + self._model.forcing_variables] +## See: +## https://neuralgcm.readthedocs.io/en/latest/datasets.html#time-shifting + .pipe( + xarray_utils.selective_temporal_shift, + variables=self._model.forcing_variables, + time_shift='24 hours', + ) + .sel(time=slice(start_time, end_time, data_stride)) # Select data range from full reanalysis dataset + .compute() # Pull it into memory + ) + + return sliced_data + + + def get_regridder(self, skipna=True): + + method = self.interpolation_method + + latitude_nodes=self.latitude_nodes + longitude_nodes=self.longitude_nodes + latitude_spacing=self.latitude_spacing + longitude_offset=self.longitude_offset + + print('get_regridder::') + print(f'latitude_nodes = {latitude_nodes}') + print(f'longitude_nodes = {longitude_nodes}') + print(f'latitude_spacing = {latitude_spacing}') + print(f'longitude_offset = {longitude_offset}') + + # Get grid for source dataset + source_grid = spherical_harmonic.Grid( + latitude_nodes=latitude_nodes, + longitude_nodes=longitude_nodes, + latitude_spacing=latitude_spacing, + longitude_offset=longitude_offset, + ) + + # Get grid for neural GCM + target_grid = self._model.data_coords.horizontal + + # ------------------------------------ + # build a Regridder object: + # ------------------------------------ + # Note: Other available regridders include BilinearRegridder and NearestRegridder. + # Note: skipna=True in ConservativeRegridder means grid cells with a mix of NaN/non-NaN + # values should be filled skipping NaN values. This ensures sea surface + # temperature and sea ice cover remains defined in coarse grid cells that + # overlap coastlines. + if method=='conservative': + regridder = horizontal_interpolation.ConservativeRegridder( + source_grid, + target_grid, + skipna=skipna + ) + elif method=='bilinear': + regridder = horizontal_interpolation.BilinearRegridder( + source_grid, + target_grid, + skipna=skipna + ) + elif method=='nearest': + regridder = horizontal_interpolation.NearestRegridder( + source_grid, + target_grid, + skipna=skipna + ) + else: + raise Exception(f'No valid interpolation method provided. method = {method}') + + return regridder + + + def regrid_input(self, data, fill_nans=False): + # Regridding data + # See: https://neuralgcm.readthedocs.io/en/latest/datasets.html + # Preparing a dataset stored on a different horizontal grid for NeuralGCM requires two steps: + + # 1) Horizontal regridding to a Gaussian grid. For processing fine-resolution data conservative + # regridding is most appropriate (and is what we used to train NeuralGCM). + # + # 2) Filling in all missing values (NaN), to ensure all inputs are valid. Forcing fields like + # sea_surface_temperature are only defined over ocean in ERA5, and NeuralGCM’s surface model + # also includes a mask that ignores values over land, but we still need to fill all NaN values + # to them leaking into our model outputs. + # + # Utilities for both of these operations are packaged as part of Dinosaur. + + # build a Regridder object: + print ('regrid_input:: self.get_regridder...') + regridder = self.get_regridder() + + # Perform regridding operation + print ('regrid_input:: xarray_utils.regrid...') + eval_data = xarray_utils.regrid(data, regridder) + + # Fill in fields like SST that may be NaN over land + if (fill_nans): + eval_data = xarray_utils.fill_nan_with_nearest(eval_data) + + return eval_data + + def forecast(self, state_vec, n_steps): + # Template forecast method to interface with DA + input_modelstate = self._model.inputs_from_xarray(state_vec) + encoded = self._model.encode(input_modelstate, self.input_forcings_t0) + final_state, predictions = self._model.unroll( + encoded, + self.sfc_forcing_forecast, + steps=n_steps, + timedelta=self.timedelta, + start_with_input=True + ) + preds_xarray = self._model.data_to_xarray( + predictions, + times=self._model.sim_time_to_datetime64(predictions['sim_time']) + ) + return preds_xarray + + def postprocess_helper(self, out_state, forcings): + decoded = self._model.decode(out_state, forcings) + return self._model.data_to_xarray( + decoded, + times=decoded['sim_time'] + ) + + def run_forecast(self, ics_data, bcs_data): + # NOTE: the ICs and BCs are extracted from the 'eval_data'. + # It is preferable to input these separately, so that + # the eval dataset can change, and since the IC and BC + # may have different time dimensions + + use_sfc_forecast = self.use_sfc_forecast + + # Get the initial conditions + inputs = self._model.inputs_from_xarray(ics_data.isel(time=0)) + + # Get initial surface boundary conditions + input_forcings_t0 = self._model.forcings_from_xarray(bcs_data.isel(time=0)) + rng_key = jax.random.key(self.random_seed) # optional for deterministic models + + # Set up combined ICs and BCs + initial_state = self._model.encode(inputs, input_forcings_t0, rng_key) + + # Get forecast surface boundary conditions. Either: + # (a) use persistence for forcing variables (SST and sea ice cover), or + # (b) use a forecast of SBCs (sst and sea ice) + if not use_sfc_forecast: + # Use a persistence forecast instead + self.sfc_forcing_forecast = self._model.forcings_from_xarray(bcs_data.head(time=1)) + #NOTE: ".head(time=1)" gets the first time step and keeps the time dimension, + # unlike ".isel(time=0)" which collapses the time dimension + else: + self.sfc_forcing_forecast = self._model.forcings_from_xarray(bcs_data) + + # make forecast + # see: https://neuralgcm.readthedocs.io/en/latest/trained_models.html#advancing-in-time + print('run_forecast:: steps = {self.outer_steps}') + print('run_forecast:: timedelta = {self.timedelta}') + final_state, predictions = self.forecast(state_vec=initial_state, n_steps=self.outer_steps) + predictions_ds = self._model.data_to_xarray(predictions, times=self.times) + + return predictions_ds + + + def plot_results(self, eval_data, predictions_ds): + + inner_steps = self.inner_steps + data_stride = self.data_stride + outer_steps = self.outer_steps + + plot_x = self.plot_x + plot_y = self.plot_y + plot_row = self.plot_row + plot_col = self.plot_col + plot_robust = self.plot_robust + plot_aspect = self.plot_aspect + plot_size = self.plot_size + plot_hpa_level = self.plot_hpa_level + plot_show = self.plot_show + + #STEVE: get times from predictions_ds + times = predictions_ds['time'].values + + # Selecting ERA5 targets from exactly the same time slice + target_trajectory = self._model.inputs_from_xarray( + eval_data + .thin(time=(inner_steps // data_stride)) + .isel(time=slice(outer_steps)) + ) + target_data_ds = self._model.data_to_xarray(target_trajectory, times=times) + + combined_ds = xr.concat([target_data_ds, predictions_ds], 'model') + combined_ds.coords['model'] = ['ERA5', 'NeuralGCM'] + + # Visualize ERA5 vs NeuralGCM trajectories + combined_ds.specific_humidity.sel(level=plot_hpa_level).plot( + x=plot_x, y=plot_y, row=plot_row, col=plot_col, robust=plot_robust, aspect=plot_aspect, size=plot_size + ); + filename = 'plot_specific_humidity' + plt.savefig(filename) + if (plot_show): + plt.show() + + combined_ds.u_component_of_wind.sel(level=plot_hpa_level).plot( + x=plot_x, y=plot_y, row=plot_row, col=plot_col, robust=plot_robust, aspect=plot_aspect, size=plot_size + ); + filename = 'plot_u_component_of_wind' + plt.savefig(filename) + if (plot_show): + plt.show() + + combined_ds.temperature.sel(level=plot_hpa_level).plot( + x=plot_x, y=plot_y, row=plot_row, col=plot_col, robust=plot_robust, aspect=plot_aspect, size=plot_size + ); + filename = 'plot_temperature' + plt.savefig(filename) + if (plot_show): + plt.show() + + combined_ds.geopotential.sel(level=plot_hpa_level).plot( + x=plot_x, y=plot_y, row=plot_row, col=plot_col, robust=plot_robust, aspect=plot_aspect, size=plot_size + ); + filename = 'plot_geopotential' + plt.savefig(filename) + if (plot_show): + plt.show() + + + def report_timing(self, timing_label=''): + + if not hasattr(self, "timing_start_time"): + self.timing_start_process_time = time.process_time() + self.timing_start_time = time.time() + return + + # get execution time + timing_end_process_time = time.process_time() + seconds = timing_end_process_time - self.timing_start_process_time + minutes = seconds / 60.0 + print(f'< === {timing_label} ===') + print(f'CPU Execution time so far: {minutes} minutes.') + + # get wall clock time + timing_end_time = time.time() + seconds = timing_end_time - self.timing_start_time + minutes = seconds / 60.0 + print(f'Wall Clock Execution time so far: {minutes} minutes.') + print(f' === {timing_label} === >') + + + def prepare_inputs(self): + self.load_model() + ics_sliced = self.load_ics() + bcs_sliced = self.load_bcs() + self.ics_eval = self.regrid_input(data=ics_sliced, fill_nans=False) + self.ics_eval0 = self.ics_eval.head(time=1) + self.bcs_eval = self.regrid_input(data=bcs_sliced, fill_nans=True) + use_sfc_forecast = self.use_sfc_forecast + + var_size = self.ics_eval[self.data_var_order[0]].isel(time=0).size + self.flat_vars_indices = { + 'specific_cloud_liquid_water_content':np.arange(var_size), + 'specific_cloud_ice_water_content':np.arange(var_size, 2*var_size), + 'temperature':np.arange(2*var_size,3*var_size), + 'geopotential':np.arange(3*var_size,4*var_size), + 'specific_humidity':np.arange(4*var_size,5*var_size), + 'u_component_of_wind':np.arange(5*var_size,6*var_size), + 'v_component_of_wind':np.arange(6*var_size,7*var_size) + } + + + # Get the initial conditions + self.inputs = self._model.inputs_from_xarray(self.ics_eval.isel(time=0)) + + # Get initial surface boundary conditions + self.input_forcings_t0 = self._model.forcings_from_xarray(self.bcs_eval.isel(time=0)) + rng_key = jax.random.key(self.random_seed) # optional for deterministic models + + # Set up combined ICs and BCs + self.initial_state = self._model.encode( + self.inputs, self.input_forcings_t0, rng_key) + + # Get forecast surface boundary conditions. Either: + # (a) use persistence for forcing variables (SST and sea ice cover), or + # (b) use a forecast of SBCs (sst and sea ice) + if not use_sfc_forecast: + # Use a persistence forecast instead + self.sfc_forcing_forecast = self._model.forcings_from_xarray(self.bcs_eval.head(time=1)) + #NOTE: ".head(time=1)" gets the first time step and keeps the time dimension, + # unlike ".isel(time=0)" which collapses the time dimension + else: + self.sfc_forcing_forecast = self._model.forcings_from_xarray(self.bcs_eval) + + + def full_sequence(self): + + # initialize the start time + self.report_timing() + + # Load the model and model weights from a file + print('Loading model...') + self.load_model() + self.report_timing("load model") + + # -------------------------- + # Load ICs and BCs + # -------------------------- + + # Get the ECMWF initial conditions and boundary conditions + print('Getting ECMWF ICs (this may take about 5-7 minutes)...') + ics_sliced = self.load_ics() + self.report_timing("get ics") + + # Get the ECMWF initial conditions and boundary conditions + print('Getting ECMWF BCs (should be < 1 min)...') + bcs_sliced = self.load_bcs() + self.report_timing("get bcs") + + # -------------------------- + # Regrid ICs and BCs + # -------------------------- + + # Regrid the ECMWF ICs to the model input grid + print('Regridding ECMWF ICs (should be < 1 min)...') + ics_eval = self.regrid_input(data=ics_sliced, fill_nans=False) + self.report_timing("regrid ics") + era5_eval = ics_eval + + + # Regrid the ECMWF BCs to the model input grid + print('Regridding ECMWF BCs (should be < 1 min)...') + bcs_eval = self.regrid_input(data=bcs_sliced, fill_nans=True) + self.report_timing("regrid bcs") + + # -------------------------- + # Run the forecast and store intermediate steps + # -------------------------- + + print('Running forecast (this may take a while, e.g. about 15-minutes per simulation day)...') + predictions_ds = self.run_forecast(ics_eval,bcs_eval) #eval_data) + print('Finished forecast!') + self.report_timing("run forecast") + + # Final timing + self.report_timing("Final") + + # -------------------------- + # Plot the results + # -------------------------- + + print('Plotting results...') + self.plot_results(era5_eval,predictions_ds) + self.ics = ics_eval + self.bcs = bcs_eval + self.prediction_ds = predictions_ds + + +if __name__ == "__main__": + + # Get input parameters: + params = None + infile = '_neuralgcm.yaml' + + # Create model class + model = NeuralGCM(params=params, infile=infile) + + # Print key input params: + print(f'demo_start_time = {model.start_time}') + print(f'demo_end_time = {model.end_time}') + print(f'data_inner_steps = {model.data_stride}') + print(f'inner_steps = {model.inner_steps}') + print(f'outer_steps = {model.outer_steps}') + print(f'timedelta = {model.timedelta}') + print(f'datetime_starttime = {model.datetime_starttime}') + print(f'forecast_delta = {model.forecast_delta}') + print(f'times = {model.times}') +# input("Press Enter to continue...") + + # Run all + model.full_sequence() diff --git a/dabench/model/_neuralgcm.yaml b/dabench/model/_neuralgcm.yaml new file mode 100644 index 0000000..9f91c03 --- /dev/null +++ b/dabench/model/_neuralgcm.yaml @@ -0,0 +1,40 @@ +# Model setup +forcing_type: "deterministic" +atm_res: "1_4" +#Interal Format: +# model_name: f'neural_gcm_dynamic_forcing_{self.forcing_type}_{self.atm_res}_deg.pkl' +# gcs_key: f'gs://gresearch/neuralgcm/04_30_2024/{self.model_name}' +# Options: +#@param ['neural_gcm_dynamic_forcing_deterministic_0_7_deg.pkl', + # 'neural_gcm_dynamic_forcing_deterministic_1_4_deg.pkl', + # 'neural_gcm_dynamic_forcing_deterministic_2_8_deg.pkl', + # 'neural_gcm_dynamic_forcing_stochastic_1_4_deg.pkl'] {type: "string"} + +# Load ics +era5_path: 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3' +ics_path: 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3' +bcs_path: 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3' +start_time: '2023-09-05' +#forecast_hours: 96 +forecast_hours: 24 +data_stride: 24 + +# Regridding +interpolation_method: "conservative" + +# Run forecast +inner_steps: 24 +outer_steps: 4 +random_seed: 42 +use_sfc_persistence: True + +# Plotting +plot_hpa_level: 850 +plot_x: 'longitude' +plot_y: 'latitude' +plot_row: 'time' +plot_col: 'model' +plot_robust: True +plot_aspect: 2 +plot_size: 2 +plot_show: True From 2623ce7a92b14b13efdfeb2a7c4f7a60058499c1 Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Thu, 3 Oct 2024 16:17:31 -0600 Subject: [PATCH 44/44] Neuralgcm forecast returns last step and full forecast tuple --- dabench/model/_neuralgcm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dabench/model/_neuralgcm.py b/dabench/model/_neuralgcm.py index 081d721..2b408e8 100644 --- a/dabench/model/_neuralgcm.py +++ b/dabench/model/_neuralgcm.py @@ -345,7 +345,7 @@ def forecast(self, state_vec, n_steps): predictions, times=self._model.sim_time_to_datetime64(predictions['sim_time']) ) - return preds_xarray + return preds_xarray.isel(time=-1), preds_xarray def postprocess_helper(self, out_state, forcings): decoded = self._model.decode(out_state, forcings)