Skip to content

feat(arviz): add arviz import and export#32

Draft
davecwright3 wants to merge 1 commit intonanograv:mainfrom
davecwright3:feat/arviz-core
Draft

feat(arviz): add arviz import and export#32
davecwright3 wants to merge 1 commit intonanograv:mainfrom
davecwright3:feat/arviz-core

Conversation

@davecwright3
Copy link

@davecwright3 davecwright3 commented Oct 3, 2024

This PR implements a very basic import and export scheme with ArviZ.

Checks for new file types

If it finds netcdf ".nc" or zarr ".zarr" files, it tries to read them with ArviZ.

la_forge/la_forge/core.py

Lines 118 to 126 in 6d23069

# Check if it's in a common arviz format
elif os.path.isfile(chaindir + '/chain.nc') or os.path.isfile(chaindir + '/chain.zarr'):
self.chainpath = chaindir + '/chain.nc' if os.path.isfile(chaindir + '/chain.nc') else chaindir + '/chain.zarr'
extension = self.chainpath.split(".")[-1]
try:
inf_data = az.from_netcdf(self.chainpath) if extension=="nc" else az.from_zarr(self.chainpath)
except:
msg = f"{self.chainpath} is not a valid ArviZ InferenceData object."
raise ValueError(msg)

Converts ArviZ data and metadata into a single "chain" for La Forge to consume

It filters out any parameters that do not correspond to samples in the chain. This means that the number of parameters always equals the number of columns in the chain. It therefore skips the check later on that would add extra PTMCMC parameters, such as lnpost.

In this implementation, you need to have already named your variables in the ArviZ InferenceData object to their desired final names. I would recommend naming them to their usual PTMCMC values for backwards compatibility.

la_forge/la_forge/core.py

Lines 127 to 129 in 6d23069

stacked = az.extract(inf_data) # combines chains
self.chain = stacked.to_array().to_numpy().T # ArviZ uses dimension 1 for samples, we want it to be 0
self.params = [param for param in stacked.variables if param not in ['sample', 'chain', 'draw']]

Adds an arviz cached_property to the Core class

Assuming you have a Core named my_core, calling my_core.arviz will return an ArviZ InferenceData object populated with the Core's data.

la_forge/la_forge/core.py

Lines 687 to 713 in 6d23069

@cached_property
def arviz(self) -> az.InferenceData:
"""Create an arviz.InferenceData object from a Core."""
# Easiest to make a dataframe first
df = pd.DataFrame(data=self.chain, columns=self.params)
# ArviZ wants to see `chain` and `draw` dimensions
df["chain"] = 0
df["draw"] = np.arange(len(df), dtype=int)
df = df.set_index(["chain", "draw"])
# Make an xarray `Dataset` to give ArviZ
xdata = xr.Dataset.from_dataframe(df)
# Store some metadata
xdata.attrs.update(
source="la_forge_core",
created_at=datetime.datetime.now(datetime.timezone.utc)
.replace(microsecond=0)
.isoformat(),
)
# Make the ArviZ object
dataset = az.InferenceData(posterior=xdata)
return dataset

Self contained example

import arviz as az
from la_forge.core import Core
from pathlib import Path

chain_dir = Path("test_chain/")
chain_dir.mkdir(parents=True, exist_ok=True)

inf_data = az.load_arviz_data("regression1d")
inf_data.to_netcdf(chain_dir/"chain.nc")

az_core = Core(chain_dir.as_posix())
>>> print(az_core.chain)
[[ 1.5665029  -1.33202579  1.8831507  -1.34775906  1.21947951]
 [ 1.80178445 -1.18480163  1.8831507  -1.34775906  1.10626457]
 [ 1.84332941 -1.22758633  1.8831507  -1.34775906  1.0078494 ]
 ...
 [ 1.68224315 -1.18911818  1.8831507  -1.34775906  0.91507343]
 [ 2.01824417 -1.34356813  1.8831507  -1.34775906  1.14519821]
 [ 1.94768056 -1.37682251  1.8831507  -1.34775906  0.99698405]]

>>> print(az_core.params)
['slope', 'intercept', 'true_slope', 'true_intercept', 'eps']
>>> print(az_core.arviz)
Inference data with groups:
	> posterior

>>> print(az_core.arviz.posterior)
<xarray.Dataset> Size: 96kB
Dimensions:         (chain: 1, draw: 2000)
Coordinates:
  * chain           (chain) int64 8B 0
  * draw            (draw) int64 16kB 0 1 2 3 4 5 ... 1995 1996 1997 1998 1999
Data variables:
    slope           (chain, draw) float64 16kB 1.567 1.802 1.843 ... 2.018 1.948
    intercept       (chain, draw) float64 16kB -1.332 -1.185 ... -1.344 -1.377
    true_slope      (chain, draw) float64 16kB 1.883 1.883 1.883 ... 1.883 1.883
    true_intercept  (chain, draw) float64 16kB -1.348 -1.348 ... -1.348 -1.348
    eps             (chain, draw) float64 16kB 1.219 1.106 1.008 ... 1.145 0.997
Attributes:
    source:      la_forge_core
    created_at:  2024-10-04T00:00:57+00:00

@jeremy-baier jeremy-baier mentioned this pull request Oct 21, 2024
13 tasks
@davecwright3
Copy link
Author

Tests are failing because the minimum arviz version I specified is too high for these python versions. I'll lower the arviz version bound.

@davecwright3
Copy link
Author

Also I should raise a warning if an inference data object is imported that doesn't have all of the usual PTMCMC fields defined. This is just so users are aware some methods may fail that depend on those fields existing

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant