From a631b9c71f572df9503ee52c70161a820d8ea9a2 Mon Sep 17 00:00:00 2001 From: Gino Carrillo Date: Thu, 20 Jul 2023 16:28:33 -0700 Subject: [PATCH 01/35] 2d b-splines --- examples/basis_spline_example.py | 1 + gwinferno/interpolation.py | 3 +- gwinferno/models/bsplines/joint.py | 38 +++++++++++++++++- gwinferno/models/bsplines/separable.py | 55 +++++++++++++++++++++++++- 4 files changed, 93 insertions(+), 4 deletions(-) diff --git a/examples/basis_spline_example.py b/examples/basis_spline_example.py index b5413dda..000d6d30 100755 --- a/examples/basis_spline_example.py +++ b/examples/basis_spline_example.py @@ -189,6 +189,7 @@ def get_weights(z, prior): p_ct1ct2 = tilt_model(len(z.shape), tilt_cs) p_z = z_model(z, lamb, z_cs) wts = p_m1q * p_a1a2 * p_ct1ct2 * p_z / prior + return jnp.where(jnp.isnan(wts) | jnp.isinf(wts), 0, wts) peweights = get_weights(pedict["redshift"], pedict["prior"]) diff --git a/gwinferno/interpolation.py b/gwinferno/interpolation.py index 9c221acb..d0cd088a 100644 --- a/gwinferno/interpolation.py +++ b/gwinferno/interpolation.py @@ -515,13 +515,14 @@ def bases(self, xs, ys): Args: xs (array_like): input values to evaluate the X basis spline at - xs (array_like): input values to evaluate the Y basis spline at + ys (array_like): input values to evaluate the Y basis spline at Returns: array_like: the design matrix evaluated at xs. shape (xdf, ydf, *xs.shape) """ self.x_bases = self.x_interpolator.bases(xs) self.y_bases = self.y_interpolator.bases(ys) + #Check that out is a proper tensor product out = jnp.array([[self.x_bases[i] * self.y_bases[j] for i in range(self.xdf)] for j in range(self.ydf)]).reshape( self.xdf, self.ydf, *xs.shape ) diff --git a/gwinferno/models/bsplines/joint.py b/gwinferno/models/bsplines/joint.py index a7549ba4..b209d309 100644 --- a/gwinferno/models/bsplines/joint.py +++ b/gwinferno/models/bsplines/joint.py @@ -4,7 +4,7 @@ import jax.numpy as jnp -from ...interpolation import RectBivariateBasisSpline +from ...interpolation import RectBivariateBasisSpline, BSpline class Base2DBSplineModel(object): @@ -16,8 +16,12 @@ def __init__( yy, xx_inj, yy_inj, + xorder = 3, + yorder = 3, xrange=(0, 1), yrange=(0, 1), + xbasis = BSpline, + ybasis = BSpline, basis=RectBivariateBasisSpline, **kwargs, ): @@ -25,7 +29,7 @@ def __init__( self.yknots = ynknots self.xmin, self.xmax = xrange self.ymin, self.ymax = yrange - self.interpolator = basis(xnknots, ynknots, xrange=xrange, yrange=yrange, **kwargs) + self.interpolator = basis(xnknots, ynknots, xrange=xrange, yrange=yrange, xbasis=xbasis, ybasis=ybasis, kx=xorder, ky=yorder, **kwargs) self.pe_design_matrix = jnp.array(self.interpolator.bases(xx, yy)) self.inj_design_matrix = jnp.array(self.interpolator.bases(xx_inj, yy_inj)) self.funcs = [self.inj_pdf, self.pe_pdf] @@ -65,3 +69,33 @@ def __init__( yrange=(0, 1), **kwargs, ) +class BSplineJointMassRedshift(Base2DBSplineModel): + def __init__( + nknots_m, + nknots_z, + m1, + z, + m1_inj, + z_inj, + mmin=3., + mmax=100., + order_m=3, + order_z=3, + basis_m=BSpline, + basis_z=BSpline, + **kwargs, + ): + super().__init__( + nknots_m, + nknots_z, + m1, + z, + m1_inj, + z_inj, + xorder = order_m, + yorder = order_z, + xrange = (mmin, mmax), + yrange = (0, 2), + xbasis = basis_m, + ybasis = basis_z, + ) \ No newline at end of file diff --git a/gwinferno/models/bsplines/separable.py b/gwinferno/models/bsplines/separable.py index 6b0d2575..0bd3c7f7 100644 --- a/gwinferno/models/bsplines/separable.py +++ b/gwinferno/models/bsplines/separable.py @@ -439,7 +439,7 @@ def __init__( **kwargs, ) self.ratio_model = BSplineRatio( - n_splines_m, + n_splines_q, q, q_inj, qmin=m2min / mmax, @@ -676,3 +676,56 @@ def __call__(self, ecoefs, pcoefs, pe_samples=True): p_chieff = self.chi_eff_model(ecoefs, pe_samples=pe_samples) p_chip = self.chi_p_model(pcoefs, pe_samples=pe_samples) return p_chieff * p_chip + +class BSplineJointMassRedshiftBSplineRatio(object): + def __init__( + self, + nknots_m, + nknots_z, + nknots_q, + m1, + m1_inj, + q, + q_inj, + z, + z_inj, + order_m=3, + order_q=3, + order_z=3, + m1min=3.0, + m2min=3.0, + mmax=100.0, + basis_m=BSpline, + basis_q=BSpline, + basis_z=BSpline, + **kwargs, + ): + self.primary_model = BSplineJointMassRedshift( + nknots_m, + nknots_z, + m1, + z, + m1_inj, + z_inj, + mmin=m1min, + mmax=mmax, + order_m=order_m, + order_z=order_z, + basis_m=basis_m, + basis_z=basis_z, + **kwargs, + ) + self.ratio_model = BSplineRatio( + nknots_q, + q, + q_inj, + qmin=m2min / mmax, + knots=knots_q, + order=order_q, + prefix=prefix_q, + basis=basis_q, + **kwargs, + ) + + def __call__(self, ndim, mcoefs, qcoefs): + return self.ratio_model(ndim, qcoefs) * self.primary_model(ndim, mcoefs) From 0c671740d7e0e24feff1f7975a197317f2e00874 Mon Sep 17 00:00:00 2001 From: Jaxen Godfrey Date: Tue, 11 Jul 2023 14:07:42 -0700 Subject: [PATCH 02/35] bug fixes --- gwinferno/interpolation.py | 16 +++++---------- gwinferno/postprocess/calculate_ppds.py | 26 ++++++++++++------------- 2 files changed, 18 insertions(+), 24 deletions(-) diff --git a/gwinferno/interpolation.py b/gwinferno/interpolation.py index d0cd088a..71ce40e9 100644 --- a/gwinferno/interpolation.py +++ b/gwinferno/interpolation.py @@ -96,18 +96,12 @@ def __init__( self.xrange = xrange if knots is None: if interior_knots is None: - interior_knots = np.linspace(*(0, 1), n_df - k + 2) - if proper: - dx = interior_knots[1] - interior_knots[0] - knots = (xrange[1] - xrange[0]) * jnp.linspace(-dx * (k - 1), 1 + dx * (k - 1), len(interior_knots) + (k - 1) * 2) - - else: - knots = np.append( - np.append(np.array([xrange[0]] * (k - 1)), interior_knots), - np.array([xrange[1]] * (k - 1)), - ) + interior_knots = np.linspace(xrange[0], xrange[1], n_df - k + 2) + dx = interior_knots[1] - interior_knots[0] + knots = jnp.linspace(xrange[0] - dx * (k - 1), xrange[1] + dx * (k - 1), len(interior_knots) + (k - 1) * 2) + self.knots = knots - self.interior_knots = (xrange[1] - xrange[0]) * interior_knots + self.interior_knots = interior_knots assert len(self.knots) == self.N + self.order self.normalize = normalize diff --git a/gwinferno/postprocess/calculate_ppds.py b/gwinferno/postprocess/calculate_ppds.py index 5843a80b..e29c6d98 100644 --- a/gwinferno/postprocess/calculate_ppds.py +++ b/gwinferno/postprocess/calculate_ppds.py @@ -93,13 +93,13 @@ def calc_rz(cs, la, r): def calculate_iid_spin_bspline_ppds(coefs, model, nknots, rate=None, xmin=0, xmax=1, k=4, ngrid=500, pop_frac=None, pop_num=None, **model_kwargs): xs = np.linspace(xmin, xmax, ngrid) - pdf = model(nknots, xs, xs, xs, xs, order=k - 1, **model_kwargs) + pdf = model(nknots, xs, xs, xs, xs, degree=k - 1, **model_kwargs) pdfs = np.zeros((coefs.shape[0], len(xs))) if rate is None: rate = jnp.ones(coefs.shape[0]) def calc_pdf(cs, r): - return pdf.primary_model(1, cs) # * r + return pdf.primary_model(cs) #* r calc_pdf = jit(calc_pdf) _ = calc_pdf(coefs[0], rate[0]) @@ -116,14 +116,14 @@ def calc_pdf(cs, r): def calculate_ind_spin_bspline_ppds(coefs, scoefs, model, nknots, rate=None, xmin=0, xmax=1, k=4, ngrid=750, **model_kwargs): xs = jnp.linspace(xmin, xmax, ngrid) - pdf = model(nknots, xs, xs, xs, xs, order=k - 1, **model_kwargs) + pdf = model(nknots, xs, xs, xs, xs, degree=k - 1, **model_kwargs) ppdfs = np.zeros((coefs.shape[0], len(xs))) spdfs = np.zeros((coefs.shape[0], len(xs))) if rate is None: rate = jnp.ones(coefs.shape[0]) def calc_pdf(pcs, scs, r): - return pdf.primary_model(1, pcs), pdf.secondary_model(1, scs) # * r + return pdf.primary_model(pcs), pdf.secondary_model(scs) # * r calc_pdf = jit(calc_pdf) _, _ = calc_pdf(coefs[0], scoefs[0], rate[0]) @@ -135,13 +135,13 @@ def calc_pdf(pcs, scs, r): def calculate_chieff_bspline_ppds(coefs, model, nknots, rate=None, xmin=-1, xmax=1, k=4, ngrid=750, **model_kwargs): xs = jnp.linspace(xmin, xmax, ngrid) - pdf = model(nknots, xs, xs, order=k - 1, **model_kwargs) + pdf = model(nknots, xs, xs, degree=k - 1, **model_kwargs) pdfs = np.zeros((coefs.shape[0], len(xs))) if rate is None: rate = jnp.ones(coefs.shape[0]) def calc_pdf(cs, r): - return pdf(1, cs) * r + return pdf(cs) * r calc_pdf = jit(calc_pdf) _ = calc_pdf(coefs[0], rate[0]) @@ -164,7 +164,7 @@ def calculate_m1q_bspline_ppds( rate = jnp.ones(mcoefs.shape[0]) def calc_pdf(mcs, qcs, r, pop_frac): - p_mq = mass_pdf(2, mcs, qcs) + p_mq = mass_pdf(mcs, qcs) p_mq = jnp.where(jnp.less(mm, m1mmin) | jnp.less(mm * qq, mmin), 0, p_mq) p_m = jnp.trapz(p_mq, qs, axis=0) p_q = jnp.trapz(p_mq, ms, axis=1) @@ -189,7 +189,7 @@ def calculate_m1q_bspline_weights( rate = 1 def calc_pdf(mcs, qcs, r): - p_mq = mass_pdf(2, mcs, qcs) + p_mq = mass_pdf(mcs, qcs) p_mq = jnp.where(jnp.less(m, m1mmin) | jnp.less(m * q, mmin), 0, p_mq) return r * p_mq @@ -200,11 +200,11 @@ def calc_pdf(mcs, qcs, r): def calculate_iid_spin_bspline_weights(xs, coefs, model, nknots, rate=1, xmin=0, xmax=1, k=4, ngrid=500, pop_frac=None, pop_num=None, **model_kwargs): - pdf = model(nknots, xs, xs, xs, xs, order=k - 1, **model_kwargs) + pdf = model(nknots, xs, xs, xs, xs, degree=k - 1, **model_kwargs) pdfs = np.zeros((coefs.shape[0], len(xs))) def calc_pdf(cs, r): - return pdf.primary_model(1, cs) # * r + return pdf.primary_model(cs) # * r calc_pdf = jit(calc_pdf) # loop through hyperposterior samples @@ -219,7 +219,7 @@ def calc_pdf(cs, r): def calculate_m1_bspline_q_powerlaw_ppds( - mcoefs, mass_model, nknots, rate=None, mmin=3.0, m1mmin=3.0, mmax=100.0, pop_frac=1, basis=LogXLogYBSpline, **model_kwargs + mcoefs, mass_model, nknots, rate=None, mmin=3.0, m1mmin=3.0, mmax=100.0, pop_frac=1, pop_num=2, basis=LogXLogYBSpline, **model_kwargs ): ms = np.linspace(m1mmin, mmax, 800) qs = np.linspace(mmin / mmax, 1, 800) @@ -247,7 +247,7 @@ def calc_pdf(mcs, r, pop_frac, beta): calc_pdf = jit(calc_pdf) # loop through hyperposterior samples for ii in trange(mcoefs.shape[0]): - mpdfs[ii], qpdfs[ii] = calc_pdf(mcoefs[ii], rate[ii], pop_frac[ii][2], model_kwargs["beta"][ii]) + mpdfs[ii], qpdfs[ii] = calc_pdf(mcoefs[ii], rate[ii], pop_frac[ii][pop_num], model_kwargs["beta"][ii]) return mpdfs, qpdfs, ms, qs @@ -273,7 +273,7 @@ def calculate_m1m2_bspline_ppds( rate = jnp.ones(mcoefs.shape[0]) def calc_pdf(mcs1, r, pop_frac, beta): - p_m1m2 = mass_pdf(2, mcs1, beta=beta) + p_m1m2 = mass_pdf(mcs1, beta=beta) p_m1m2 = jnp.where(jnp.less(mm1, mmin) | jnp.less(mm2, mmin), 0, p_m1m2) p_m1 = jnp.trapz(p_m1m2, ms2, axis=0) p_m2 = jnp.trapz(p_m1m2, ms1, axis=1) From 44490e6e268c5b0431b7ca2e01408cee1f945247 Mon Sep 17 00:00:00 2001 From: Jaxen Godfrey Date: Tue, 11 Jul 2023 14:08:03 -0700 Subject: [PATCH 03/35] add xarray data util function --- gwinferno/utils.py | 193 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 193 insertions(+) create mode 100644 gwinferno/utils.py diff --git a/gwinferno/utils.py b/gwinferno/utils.py new file mode 100644 index 00000000..bc9b391e --- /dev/null +++ b/gwinferno/utils.py @@ -0,0 +1,193 @@ +import xarray as xr + + +from typing import ( + Any, + List, + Iterator, + Mapping, + Tuple, + Union, + Optional, + overload, +) + + +def _compressible_dtype(dtype): + """Check basic dtypes for automatic compression.""" + if dtype.kind == "V": + return all(_compressible_dtype(item) for item, _ in dtype.fields.values()) + return dtype.kind in {"b", "i", "u", "f", "c", "S"} + + +class DataSet(Mapping[str, xr.Dataset]): + """Adapted from Arviz InferenceData object https://python.arviz.org/en/stable/_modules/arviz/data/inference_data.html#InferenceData""" + def __init__( + self, + attrs: Union[None, Mapping[Any, Any]] = None, + **kwargs: Union[xr.Dataset, List[xr.Dataset], Tuple[xr.Dataset, xr.Dataset]], + ) -> None: + + self._groups: List[str] = [] + self._attrs: Union[None, dict] = dict(attrs) if attrs is not None else None + key_list = [key for key in kwargs] + for key in key_list: + dataset = kwargs[key] + setattr(self, key, dataset) + self._groups.append(key) + + def __len__(self) -> int: + """Return the number of groups in this InferenceData object.""" + return len(self._groups) + + def __iter__(self) -> Iterator[str]: + """Iterate over groups in InferenceData object.""" + for group in self._groups: + yield group + + def __getitem__(self, key: str) -> xr.Dataset: + """Get item by key.""" + if key not in self._groups: + raise KeyError(key) + return getattr(self, key) + + def to_netcdf( + self, + filename: str, + compress: bool = True, + groups: Optional[List[str]] = None, + engine: str = "h5netcdf", + ) -> str: + """Write InferenceData to netcdf4 file. + + Parameters + ---------- + filename : str + Location to write to + compress : bool, optional + Whether to compress result. Note this saves disk space, but may make + saving and loading somewhat slower (default: True). + groups : list, optional + Write only these groups to netcdf file. + engine : {"h5netcdf", "netcdf4"}, default "h5netcdf" + Library used to read the netcdf file. + + Returns + ------- + str + Location of netcdf file + """ + mode = "w" # overwrite first, then append + if self._attrs: + xr.Dataset(attrs=self._attrs).to_netcdf(filename, mode=mode, engine=engine) + mode = "a" + + if self._groups: # check's whether a group is present or not. + if groups is None: + groups = self._groups + else: + groups = [group for group in self._groups_all if group in groups] + + for group in groups: + data = getattr(self, group) + kwargs = {"engine": engine} + if compress: + kwargs["encoding"] = { + var_name: {"zlib": True} + for var_name, values in data.variables.items() + if _compressible_dtype(values.dtype) + } + data.to_netcdf(filename, mode=mode, group=group, **kwargs) + data.close() + mode = "a" + elif not self._attrs: # creates a netcdf file for an empty InferenceData object. + if engine == "h5netcdf": + import h5netcdf + + empty_netcdf_file = h5netcdf.File(filename, mode="w") + elif engine == "netcdf4": + import netCDF4 as nc + + empty_netcdf_file = nc.Dataset(filename, mode="w", format="NETCDF4") + empty_netcdf_file.close() + return filename + + @staticmethod + def from_netcdf( + filename, *, engine="h5netcdf", group_kwargs=None, regex=False + ) -> "InferenceData": + """Initialize object from a netcdf file. + + Expects that the file will have groups, each of which can be loaded by xarray. + By default, the datasets of the InferenceData object will be lazily loaded instead + of being loaded into memory. This + behaviour is regulated by the value of ``az.rcParams["data.load"]``. + + Parameters + ---------- + filename : str + location of netcdf file + engine : {"h5netcdf", "netcdf4"}, default "h5netcdf" + Library used to read the netcdf file. + group_kwargs : dict of {str: dict}, optional + Keyword arguments to be passed into each call of :func:`xarray.open_dataset`. + The keys of the higher level should be group names or regex matching group + names, the inner dicts re passed to ``open_dataset`` + This feature is currently experimental. + regex : bool, default False + Specifies where regex search should be used to extend the keyword arguments. + This feature is currently experimental. + + Returns + ------- + InferenceData + """ + groups = {} + attrs = {} + + if engine == "h5netcdf": + import h5netcdf + elif engine == "netcdf4": + import netCDF4 as nc + else: + raise ValueError( + f"Invalid value for engine: {engine}. Valid options are: h5netcdf or netcdf4" + ) + + try: + with h5netcdf.File(filename, mode="r") if engine == "h5netcdf" else nc.Dataset( + filename, mode="r" + ) as data: + data_groups = list(data.groups) + + for group in data_groups: + group_kws = {} + + group_kws = {} + if group_kwargs is not None and regex is False: + group_kws = group_kwargs.get(group, {}) + if group_kwargs is not None and regex is True: + for key, kws in group_kwargs.items(): + if re.search(key, group): + group_kws = kws + group_kws.setdefault("engine", engine) + with xr.open_dataset(filename, group=group, **group_kws) as data: + groups[group] = data + + with xr.open_dataset(filename, engine=engine) as data: + attrs.update(data.load().attrs) + + return DataSet(attrs=attrs, **groups) + except OSError as err: + if err.errno == -101: + raise type(err)( + str(err) + + ( + " while reading a NetCDF file. This is probably an error in HDF5, " + "which happens because your OS does not support HDF5 file locking. See " + "https://stackoverflow.com/questions/49317927/" + "errno-101-netcdf-hdf-error-when-opening-netcdf-file#49317928" + " for a possible solution." + ) + ) from err + raise err \ No newline at end of file From b15abba5257b1431ad6fc786389b566bc3eb835f Mon Sep 17 00:00:00 2001 From: Ben Farr Date: Mon, 17 Jul 2023 13:13:29 -0700 Subject: [PATCH 04/35] fix pre-commits --- gwinferno/postprocess/calculate_ppds.py | 2 +- gwinferno/utils.py | 50 ++++++++++--------------- 2 files changed, 21 insertions(+), 31 deletions(-) diff --git a/gwinferno/postprocess/calculate_ppds.py b/gwinferno/postprocess/calculate_ppds.py index e29c6d98..3121db7f 100644 --- a/gwinferno/postprocess/calculate_ppds.py +++ b/gwinferno/postprocess/calculate_ppds.py @@ -99,7 +99,7 @@ def calculate_iid_spin_bspline_ppds(coefs, model, nknots, rate=None, xmin=0, xma rate = jnp.ones(coefs.shape[0]) def calc_pdf(cs, r): - return pdf.primary_model(cs) #* r + return pdf.primary_model(cs) # * r calc_pdf = jit(calc_pdf) _ = calc_pdf(coefs[0], rate[0]) diff --git a/gwinferno/utils.py b/gwinferno/utils.py index bc9b391e..5b41cff6 100644 --- a/gwinferno/utils.py +++ b/gwinferno/utils.py @@ -1,16 +1,13 @@ -import xarray as xr - +import re +from typing import Any +from typing import Iterator +from typing import List +from typing import Mapping +from typing import Optional +from typing import Tuple +from typing import Union -from typing import ( - Any, - List, - Iterator, - Mapping, - Tuple, - Union, - Optional, - overload, -) +import xarray as xr def _compressible_dtype(dtype): @@ -22,12 +19,13 @@ def _compressible_dtype(dtype): class DataSet(Mapping[str, xr.Dataset]): """Adapted from Arviz InferenceData object https://python.arviz.org/en/stable/_modules/arviz/data/inference_data.html#InferenceData""" + def __init__( self, attrs: Union[None, Mapping[Any, Any]] = None, **kwargs: Union[xr.Dataset, List[xr.Dataset], Tuple[xr.Dataset, xr.Dataset]], ) -> None: - + self._groups: List[str] = [] self._attrs: Union[None, dict] = dict(attrs) if attrs is not None else None key_list = [key for key in kwargs] @@ -39,7 +37,7 @@ def __init__( def __len__(self) -> int: """Return the number of groups in this InferenceData object.""" return len(self._groups) - + def __iter__(self) -> Iterator[str]: """Iterate over groups in InferenceData object.""" for group in self._groups: @@ -50,7 +48,7 @@ def __getitem__(self, key: str) -> xr.Dataset: if key not in self._groups: raise KeyError(key) return getattr(self, key) - + def to_netcdf( self, filename: str, @@ -93,9 +91,7 @@ def to_netcdf( kwargs = {"engine": engine} if compress: kwargs["encoding"] = { - var_name: {"zlib": True} - for var_name, values in data.variables.items() - if _compressible_dtype(values.dtype) + var_name: {"zlib": True} for var_name, values in data.variables.items() if _compressible_dtype(values.dtype) } data.to_netcdf(filename, mode=mode, group=group, **kwargs) data.close() @@ -111,11 +107,9 @@ def to_netcdf( empty_netcdf_file = nc.Dataset(filename, mode="w", format="NETCDF4") empty_netcdf_file.close() return filename - + @staticmethod - def from_netcdf( - filename, *, engine="h5netcdf", group_kwargs=None, regex=False - ) -> "InferenceData": + def from_netcdf(filename, *, engine="h5netcdf", group_kwargs=None, regex=False) -> "DataSet": """Initialize object from a netcdf file. Expects that the file will have groups, each of which can be loaded by xarray. @@ -140,7 +134,7 @@ def from_netcdf( Returns ------- - InferenceData + DataSet """ groups = {} attrs = {} @@ -150,14 +144,10 @@ def from_netcdf( elif engine == "netcdf4": import netCDF4 as nc else: - raise ValueError( - f"Invalid value for engine: {engine}. Valid options are: h5netcdf or netcdf4" - ) + raise ValueError(f"Invalid value for engine: {engine}. Valid options are: h5netcdf or netcdf4") try: - with h5netcdf.File(filename, mode="r") if engine == "h5netcdf" else nc.Dataset( - filename, mode="r" - ) as data: + with h5netcdf.File(filename, mode="r") if engine == "h5netcdf" else nc.Dataset(filename, mode="r") as data: data_groups = list(data.groups) for group in data_groups: @@ -190,4 +180,4 @@ def from_netcdf( " for a possible solution." ) ) from err - raise err \ No newline at end of file + raise err From 006047e04ec7875ba0616ccafea0ddfb72b192ce Mon Sep 17 00:00:00 2001 From: Ben Farr Date: Mon, 17 Jul 2023 13:40:35 -0700 Subject: [PATCH 05/35] fix pre-commit warning about ambiguous flags --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5d5a9cf2..b3f52d61 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: rev: 5.12.0 hooks: - id: isort # sort imports alphabetically and separates import into sections - args: [-w=150, -sl, --gitignore] + args: [-w=150, --sl, --gitignore] - repo: https://github.com/ambv/black rev: 22.6.0 hooks: From eadac6bae4a0faf6f184bf272374d7e08a545bbc Mon Sep 17 00:00:00 2001 From: Ben Farr Date: Mon, 24 Jul 2023 14:35:15 -0700 Subject: [PATCH 06/35] fix method name --- gwinferno/interpolation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gwinferno/interpolation.py b/gwinferno/interpolation.py index 71ce40e9..4fb4a932 100644 --- a/gwinferno/interpolation.py +++ b/gwinferno/interpolation.py @@ -520,7 +520,7 @@ def bases(self, xs, ys): out = jnp.array([[self.x_bases[i] * self.y_bases[j] for i in range(self.xdf)] for j in range(self.ydf)]).reshape( self.xdf, self.ydf, *xs.shape ) - self.reset_bases() + self._reset_bases() return out def _project(self, bases, coefs): From 46b6d1ec5f1045432ec09329b5b89b57bb6395d5 Mon Sep 17 00:00:00 2001 From: Ben Farr Date: Mon, 24 Jul 2023 14:48:00 -0700 Subject: [PATCH 07/35] add vectorized option in comment --- gwinferno/interpolation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gwinferno/interpolation.py b/gwinferno/interpolation.py index 4fb4a932..84da3e03 100644 --- a/gwinferno/interpolation.py +++ b/gwinferno/interpolation.py @@ -516,7 +516,7 @@ def bases(self, xs, ys): """ self.x_bases = self.x_interpolator.bases(xs) self.y_bases = self.y_interpolator.bases(ys) - #Check that out is a proper tensor product + #Check that out is a proper tensor product (could be replaces with x_bases * y_bases[:, jnp.newaxis, ...]) out = jnp.array([[self.x_bases[i] * self.y_bases[j] for i in range(self.xdf)] for j in range(self.ydf)]).reshape( self.xdf, self.ydf, *xs.shape ) From f191925755e00ef3f5b4f44ecc5d0127a382b93f Mon Sep 17 00:00:00 2001 From: Gino Carrillo Date: Mon, 14 Aug 2023 09:43:58 -0700 Subject: [PATCH 08/35] made an edit to call function of Base2DBSplineModel --- gwinferno/models/bsplines/joint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gwinferno/models/bsplines/joint.py b/gwinferno/models/bsplines/joint.py index b209d309..09bec3a0 100644 --- a/gwinferno/models/bsplines/joint.py +++ b/gwinferno/models/bsplines/joint.py @@ -43,8 +43,8 @@ def pe_pdf(self, coefs): def inj_pdf(self, coefs): return self.eval_spline(self.inj_design_matrix, coefs) - def __call__(self, ndim, coefs): - return self.funcs[ndim - 1](coefs) + def __call__(self, coefs, pe_samples = True): + return self.funcs[1](coefs) if pe_samples else self.funcs[0](coefs) class BSplineJointMassRatioChiEffective(Base2DBSplineModel): From a829bd9333067c9e486d1e482ad8901c70427954 Mon Sep 17 00:00:00 2001 From: Gino Carrillo Date: Mon, 14 Aug 2023 09:47:37 -0700 Subject: [PATCH 09/35] 2d b-spline development in RectBivariateBasisSpline class --- gwinferno/interpolation.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/gwinferno/interpolation.py b/gwinferno/interpolation.py index 71ce40e9..3bd792b6 100644 --- a/gwinferno/interpolation.py +++ b/gwinferno/interpolation.py @@ -514,13 +514,12 @@ def bases(self, xs, ys): Returns: array_like: the design matrix evaluated at xs. shape (xdf, ydf, *xs.shape) """ - self.x_bases = self.x_interpolator.bases(xs) + self.x_bases = self.x_interpolator.bases(xs) self.y_bases = self.y_interpolator.bases(ys) - #Check that out is a proper tensor product out = jnp.array([[self.x_bases[i] * self.y_bases[j] for i in range(self.xdf)] for j in range(self.ydf)]).reshape( - self.xdf, self.ydf, *xs.shape - ) - self.reset_bases() + self.xdf, self.ydf, *xs.shape) + self._reset_bases() + return out def _project(self, bases, coefs): @@ -534,6 +533,7 @@ def _project(self, bases, coefs): Returns: array_like: The linear combination of the basis components given the coefficients """ + #NOTE this would be for the logz2dbasis class return jnp.exp(jnp.einsum("ij...,ij->...", bases, coefs)) def project(self, bases, coefs): @@ -547,4 +547,4 @@ def project(self, bases, coefs): Returns: array_like: The linear combination of the basis components given the coefficients """ - return self._project(bases, coefs) * self.norm_2d(coefs) + return self._project(bases, coefs)* self.norm_2d(coefs) From 4684355abe2eba4a5b44cb2c37bef8c2a0e7f6ab Mon Sep 17 00:00:00 2001 From: Gino Carrillo Date: Mon, 14 Aug 2023 11:11:57 -0700 Subject: [PATCH 10/35] changed apply_twod_difference_prior method name --- gwinferno/models/bsplines/smoothing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gwinferno/models/bsplines/smoothing.py b/gwinferno/models/bsplines/smoothing.py index 2c1bfb76..70afb6c3 100644 --- a/gwinferno/models/bsplines/smoothing.py +++ b/gwinferno/models/bsplines/smoothing.py @@ -23,7 +23,7 @@ def apply_difference_prior(coefs, inv_var, degree=1): return -0.5 * inv_var * jnp.dot(delta_c, delta_c.T) -def apply_twod_difference_prior(coefs, inv_var, degree=1): +def apply_2d_difference_prior(coefs, inv_var, degree=1): D = jnp.diff(jnp.eye(len(coefs)), n=degree) delta_c = jnp.dot(coefs, D) return -0.5 * inv_var * jnp.sum(jnp.dot(delta_c, delta_c.T).flatten()) From 1360f7f149c345e93be7c3086f5d8144a6a55561 Mon Sep 17 00:00:00 2001 From: Gino Carrillo Date: Tue, 15 Aug 2023 12:10:36 -0700 Subject: [PATCH 11/35] fixed bug in saving posterior_samples_and_injection_chi_effff.h5 --- gwinferno/preprocess/data_collection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gwinferno/preprocess/data_collection.py b/gwinferno/preprocess/data_collection.py index 7c495a88..6aebe0b8 100644 --- a/gwinferno/preprocess/data_collection.py +++ b/gwinferno/preprocess/data_collection.py @@ -246,7 +246,7 @@ def setup_posterior_samples_and_injections(data_dir, inj_file, param_names=None, injdata, new_pmap = convert_component_spin_injections_to_chieff(injdata, param_map, chip=chi_p) param_map = new_pmap pedata = jnp.array(pedata) - injdata = jnp.array(pedata) + injdata = jnp.array(injdata) if save: mag_data = { "injdata": injdata, From e5efd21c50335c21e5c8d1e948c6c7f0f86de976 Mon Sep 17 00:00:00 2001 From: Gino Carrillo Date: Thu, 31 Aug 2023 16:03:51 -0700 Subject: [PATCH 12/35] Separated the RectBivariateBasisSpline class into a linear and log Z version --- gwinferno/interpolation.py | 55 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 53 insertions(+), 2 deletions(-) diff --git a/gwinferno/interpolation.py b/gwinferno/interpolation.py index 023f11d9..bb60109e 100644 --- a/gwinferno/interpolation.py +++ b/gwinferno/interpolation.py @@ -521,8 +521,8 @@ def _project(self, bases, coefs): Returns: array_like: The linear combination of the basis components given the coefficients """ - #NOTE this would be for the logz2dbasis class - return jnp.exp(jnp.einsum("ij...,ij->...", bases, coefs)) + #NOTE jnp.exp(jnp.einsum("ij...,ij->...", bases, coefs))this would be for the logz2dbasis class + return jnp.einsum("ij...,ij->...", bases, coefs) def project(self, bases, coefs): """ @@ -536,3 +536,54 @@ def project(self, bases, coefs): array_like: The linear combination of the basis components given the coefficients """ return self._project(bases, coefs)* self.norm_2d(coefs) + +class LogZRectBivariateBasisSpline(RectBivariateBasisSpline): + def __init__( + self, + xdf, + ydf, + xrange=(0, 1), + yrange=(0, 1), + kx=4, + ky=4, + xbasis=BSpline, + ybasis=BSpline, + normalize=True, + ): + """ + Class to construct a 2D (bivariate) rectangular basis spline + + Args: + xdf (int): number of degrees of freedom for the spline in the X direction + ydf (int): number of degrees of freedom for the spline in the Y direction + xrange (tuple, optional): domain of X spline. Defaults to (0, 1). + yrange (tuple, optional): domain of Y spline. Defaults to (0, 1). + kx (int, optional): order of the X spline +1, i.e. cubcic splines->k=4. Defaults to 4 (cubic spline). + ky (int, optional): order of the Y spline +1, i.e. cubcic splines->k=4. Defaults to 4 (cubic spline). + xbasis (object, optional): Choice of basis to use for the X spline. Defaults to BSpline. + ybasis (object, optional): Choice of basis to use for the Y spline. Defaults to BSpline. + normalize (bool, optional): flag whether or not to numerically normalize the spline. Defaults to True. + """ + super().__init__( + xdf, + ydf, + xrange=xrange, + yrange=yrange, + kx=kx, + ky=ky, + xbasis=xbasis, + ybasis=ybasis, + normalize=normalize, + ) + def _project(self, bases, coefs): + """ + _project given a design matrix (or bases) and coefficients, project the coefficients onto the spline + + Args: + bases (array_like): The set of basis components or design matrix to project onto + coefs (array_like): coefficients for the basis components + + Returns: + array_like: The linear combination of the basis components given the coefficients + """ + return jnp.exp(jnp.einsum("ij...,ij->...", bases, coefs)) \ No newline at end of file From f1f223d7daa8bea8f986639270f2a0670ae2fbb8 Mon Sep 17 00:00:00 2001 From: Gino Carrillo Date: Mon, 11 Sep 2023 12:04:05 -0700 Subject: [PATCH 13/35] added chieff_range and q_range to BSplineJointMassRatioChiEffective --- gwinferno/models/bsplines/joint.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/gwinferno/models/bsplines/joint.py b/gwinferno/models/bsplines/joint.py index 09bec3a0..d15e5f49 100644 --- a/gwinferno/models/bsplines/joint.py +++ b/gwinferno/models/bsplines/joint.py @@ -4,7 +4,7 @@ import jax.numpy as jnp -from ...interpolation import RectBivariateBasisSpline, BSpline +from ...interpolation import RectBivariateBasisSpline, BSpline, LogZRectBivariateBasisSpline class Base2DBSplineModel(object): @@ -56,6 +56,8 @@ def __init__( q, chieff_inj, q_inj, + chieff_range=(-1,1), + q_range=(0,1), **kwargs, ): super().__init__( @@ -65,8 +67,8 @@ def __init__( yy=q, xx_inj=chieff_inj, yy_inj=q_inj, - xrange=(-1, 1), - yrange=(0, 1), + xrange=chieff_range, + yrange=q_range, **kwargs, ) class BSplineJointMassRedshift(Base2DBSplineModel): From a79283fe6313991ac9cf2117c0913bc62ab7d147 Mon Sep 17 00:00:00 2001 From: Gino Carrillo Date: Tue, 3 Oct 2023 10:57:49 -0700 Subject: [PATCH 14/35] raise n_eff cutoff from Nobs to 4*Nobs --- gwinferno/pipeline/analysis.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gwinferno/pipeline/analysis.py b/gwinferno/pipeline/analysis.py index f35738f8..10a611e4 100644 --- a/gwinferno/pipeline/analysis.py +++ b/gwinferno/pipeline/analysis.py @@ -246,7 +246,7 @@ def hierarchical_likelihood( numpyro.factor( "log_likelihood", jnp.where( - jnp.isnan(log_l) | jnp.less_equal(jnp.exp(jnp.min(logn_effs)), Nobs), + jnp.isnan(log_l) | jnp.less_equal(jnp.exp(jnp.min(logn_effs)), 4*Nobs), jnp.nan_to_num(-jnp.inf), jnp.nan_to_num(log_l), ), @@ -367,7 +367,7 @@ def hierarchical_likelihood_in_log( numpyro.factor( "log_likelihood", jnp.where( - jnp.isnan(log_l) | jnp.less_equal(jnp.exp(jnp.min(logn_effs)), 10), + jnp.isnan(log_l) | jnp.less_equal(jnp.exp(jnp.min(logn_effs)), Nobs), jnp.nan_to_num(-jnp.inf), jnp.nan_to_num(log_l), ), From 127eff23c97053ff2aff0cbb3896fd067f797d8f Mon Sep 17 00:00:00 2001 From: Gino Carrillo Date: Tue, 3 Oct 2023 10:58:22 -0700 Subject: [PATCH 15/35] bug fix 2d spline: norm grids and order --- gwinferno/interpolation.py | 30 ++++++++++++++++++------------ gwinferno/models/bsplines/joint.py | 8 +++++--- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/gwinferno/interpolation.py b/gwinferno/interpolation.py index bb60109e..a4e3f62b 100644 --- a/gwinferno/interpolation.py +++ b/gwinferno/interpolation.py @@ -76,6 +76,7 @@ def __init__( k=4, proper=True, normalize=True, + norm_grid = 1000 ): """ Class to construct a basis spline (with the M-Spline basis) @@ -107,7 +108,7 @@ def __init__( self.normalize = normalize self.basis_vols = np.ones(self.N) if normalize: - self.grid = jnp.linspace(*xrange, 1000) + self.grid = jnp.linspace(*xrange, norm_grid) self.grid_bases = jnp.array(self.bases(self.grid)) self.basis_vols = jnp.array([jnp.trapz(self.grid_bases[i, :], self.grid) for i in range(self.N)]) @@ -240,6 +241,7 @@ def __init__( k=4, proper=True, normalize=False, + **kwargs ): """ Class to construct a basis spline (B-Spline) @@ -263,6 +265,7 @@ def __init__( k=k, proper=proper, normalize=normalize, + **kwargs ) def _bases(self, xs): @@ -440,11 +443,12 @@ def __init__( ydf, xrange=(0, 1), yrange=(0, 1), - kx=4, - ky=4, + xorder=4, + yorder=4, xbasis=BSpline, ybasis=BSpline, normalize=True, + norm_grid=(1000, 1000) ): """ Class to construct a 2D (bivariate) rectangular basis spline @@ -462,14 +466,14 @@ def __init__( """ self.xdf = xdf self.ydf = ydf - self.x_interpolator = xbasis(xdf, xrange=xrange, k=kx, normalize=False) - self.y_interpolator = ybasis(ydf, xrange=yrange, k=ky, normalize=False) + self.x_interpolator = xbasis(xdf, xrange=xrange, k=xorder, normalize=False) + self.y_interpolator = ybasis(ydf, xrange=yrange, k=yorder, normalize=False) self.normalize = normalize self.x_bases = None self.y_bases = None if self.normalize: - self.gridx = jnp.linspace(*xrange, 750) - self.gridy = jnp.linspace(*yrange, 750) + self.gridx = jnp.linspace(*xrange, norm_grid[0]) + self.gridy = jnp.linspace(*yrange, norm_grid[1]) self.gxx, self.gyy = jnp.meshgrid(self.gridx, self.gridy) self.grid_bases = self.bases(self.gxx, self.gyy) @@ -535,7 +539,7 @@ def project(self, bases, coefs): Returns: array_like: The linear combination of the basis components given the coefficients """ - return self._project(bases, coefs)* self.norm_2d(coefs) + return self._project(bases, coefs) * self.norm_2d(coefs) class LogZRectBivariateBasisSpline(RectBivariateBasisSpline): def __init__( @@ -544,11 +548,12 @@ def __init__( ydf, xrange=(0, 1), yrange=(0, 1), - kx=4, - ky=4, + xorder=4, + yorder=4, xbasis=BSpline, ybasis=BSpline, normalize=True, + norm_grid=(1000, 1000) ): """ Class to construct a 2D (bivariate) rectangular basis spline @@ -569,11 +574,12 @@ def __init__( ydf, xrange=xrange, yrange=yrange, - kx=kx, - ky=ky, + xorder=xorder, + yorder=yorder, xbasis=xbasis, ybasis=ybasis, normalize=normalize, + norm_grid = norm_grid ) def _project(self, bases, coefs): """ diff --git a/gwinferno/models/bsplines/joint.py b/gwinferno/models/bsplines/joint.py index d15e5f49..0191362d 100644 --- a/gwinferno/models/bsplines/joint.py +++ b/gwinferno/models/bsplines/joint.py @@ -16,8 +16,6 @@ def __init__( yy, xx_inj, yy_inj, - xorder = 3, - yorder = 3, xrange=(0, 1), yrange=(0, 1), xbasis = BSpline, @@ -29,7 +27,7 @@ def __init__( self.yknots = ynknots self.xmin, self.xmax = xrange self.ymin, self.ymax = yrange - self.interpolator = basis(xnknots, ynknots, xrange=xrange, yrange=yrange, xbasis=xbasis, ybasis=ybasis, kx=xorder, ky=yorder, **kwargs) + self.interpolator = basis(xnknots, ynknots, xrange=xrange, yrange=yrange, xbasis=xbasis, ybasis=ybasis, **kwargs) self.pe_design_matrix = jnp.array(self.interpolator.bases(xx, yy)) self.inj_design_matrix = jnp.array(self.interpolator.bases(xx_inj, yy_inj)) self.funcs = [self.inj_pdf, self.pe_pdf] @@ -58,6 +56,8 @@ def __init__( q_inj, chieff_range=(-1,1), q_range=(0,1), + chi_order = 4, + q_order = 4, **kwargs, ): super().__init__( @@ -69,6 +69,8 @@ def __init__( yy_inj=q_inj, xrange=chieff_range, yrange=q_range, + xorder = chi_order, + yorder = q_order, **kwargs, ) class BSplineJointMassRedshift(Base2DBSplineModel): From 041800404cebca10387264238c85ea3d07d11543 Mon Sep 17 00:00:00 2001 From: Gino Carrillo Date: Tue, 3 Oct 2023 11:00:04 -0700 Subject: [PATCH 16/35] updated gitignore --- .gitignore | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 75bef041..6f455334 100644 --- a/.gitignore +++ b/.gitignore @@ -31,4 +31,6 @@ docs/_build docs/api/** docs/_build/** docs/_build/_sources -docs/_build/_static \ No newline at end of file +docs/_build/_static +run-experimentation +**.h5 \ No newline at end of file From bbba67d524cf8179f97a9df1ed907b38d90f011d Mon Sep 17 00:00:00 2001 From: Gino Carrillo Date: Fri, 6 Oct 2023 14:40:55 -0700 Subject: [PATCH 17/35] updated gitignore --- .gitignore | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 6f455334..8b5c3b87 100644 --- a/.gitignore +++ b/.gitignore @@ -33,4 +33,7 @@ docs/_build/** docs/_build/_sources docs/_build/_static run-experimentation -**.h5 \ No newline at end of file +**.h5 +**.ipynb +/posterior_samples_and_injections_spin_magnitude +**.txt \ No newline at end of file From 4b517c6a1b6ed0059b15e45128a4a3bdd0ba2950 Mon Sep 17 00:00:00 2001 From: Gino Carrillo Date: Fri, 20 Oct 2023 10:35:24 -0700 Subject: [PATCH 18/35] Fixed directory locations for PE samples and injections. --- gwinferno/pipeline/parser.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gwinferno/pipeline/parser.py b/gwinferno/pipeline/parser.py index 75617325..d61d9cec 100644 --- a/gwinferno/pipeline/parser.py +++ b/gwinferno/pipeline/parser.py @@ -120,11 +120,11 @@ def add_mixture_model(self, param, subd): def load_base_parser(): parser = ArgumentParser() - parser.add_argument("--data-dir", type=str, default="/home/bruce.edelman/projects/GWTC3_allevents/") + parser.add_argument("--data-dir", type=str, default="/projects/farr_lab/shared/GWTC3/all_events") parser.add_argument( "--inj-file", type=str, - default="/home/bruce.edelman/projects/GWTC3_allevents/o1o2o3_mixture_injections.hdf5", + default="/projects/farr_lab/shared/GWTC3/o1o2o3_mixture_injections.hdf5", ) parser.add_argument("--outdir", type=str, default="results") parser.add_argument("--mmin", type=float, default=3.0) From 01145601e1a8fdd53dd780df835579bbd866fcb4 Mon Sep 17 00:00:00 2001 From: Gino Carrillo Date: Fri, 20 Oct 2023 10:39:46 -0700 Subject: [PATCH 19/35] Added n_splines attribute to the class BSplineRatio. --- gwinferno/models/bsplines/single.py | 1 + 1 file changed, 1 insertion(+) diff --git a/gwinferno/models/bsplines/single.py b/gwinferno/models/bsplines/single.py index 93c20f04..1ae8304d 100644 --- a/gwinferno/models/bsplines/single.py +++ b/gwinferno/models/bsplines/single.py @@ -314,6 +314,7 @@ def __init__( xrange=(qmin, 1), **kwargs, ) + self.n_splines = n_splines class BSplineMass(Base1DBSplineModel): From 12579327c007730712b4446cf56e9470e6f7e173 Mon Sep 17 00:00:00 2001 From: Gino Carrillo Date: Tue, 24 Oct 2023 10:40:40 -0700 Subject: [PATCH 20/35] Fixed typos --- gwinferno/models/gwpopulation/gwpopulation.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/gwinferno/models/gwpopulation/gwpopulation.py b/gwinferno/models/gwpopulation/gwpopulation.py index bf092b81..43181dd9 100644 --- a/gwinferno/models/gwpopulation/gwpopulation.py +++ b/gwinferno/models/gwpopulation/gwpopulation.py @@ -96,32 +96,32 @@ def __init__(self, z_pe, z_inj): self.zmin = jnp.max(jnp.array([jnp.min(z_pe), jnp.min(z_inj)])) self.zmax = jnp.min(jnp.array([jnp.max(z_pe), jnp.max(z_inj)])) self.zs = jnp.linspace(self.zmin, self.zmax, 1000) - self.dVdc_ = jnp.array(Planck15.differential_comoving_volume(np.array(self.zs)).value * 4.0 * np.pi) - self.dVdcs = [ + self.dVdz_ = jnp.array(Planck15.differential_comoving_volume(np.array(self.zs)).value * 4.0 * np.pi) + self.dVdzs = [ jnp.array(Planck15.differential_comoving_volume(np.array(z_inj)).value * 4.0 * np.pi), jnp.array(Planck15.differential_comoving_volume(np.array(z_pe)).value * 4.0 * np.pi), ] def normalization(self, lamb): - return jnp.trapz(self.prob(self.zs, self.dVdc_, lamb), self.zs) + return jnp.trapz(self.prob(self.zs, self.dVdz_, lamb), self.zs) - def prob(self, z, dVdc, lamb): - return dVdc * jnp.power(1.0 + z, lamb - 1.0) + def prob(self, z, dVdz, lamb): + return dVdz * jnp.power(1.0 + z, lamb - 1.0) def log_prob(self, z, lamb): ndim = len(z.shape) - dVdc = self.dVdcs[ndim - 1] + dVdz = self.dVdzs[ndim - 1] return jnp.where( jnp.less_equal(z, self.zmax), - jnp.log(dVdc) + (lamb - 1.0) * jnp.log(1.0 + z) - jnp.log(self.normalization(lamb)), + jnp.log(dVdz) + (lamb - 1.0) * jnp.log(1.0 + z) - jnp.log(self.normalization(lamb)), jnp.nan_to_num(-jnp.inf), ) def __call__(self, z, lamb): ndim = len(z.shape) - dVdc = self.dVdcs[ndim - 1] + dVdz = self.dVdzs[ndim - 1] return jnp.where( jnp.less_equal(z, self.zmax), - self.prob(z, dVdc, lamb) / self.normalization(lamb), + self.prob(z, dVdz, lamb) / self.normalization(lamb), 0, - ) + ) \ No newline at end of file From f6617c249e733970e44ca28ebd3cc4afcd51424e Mon Sep 17 00:00:00 2001 From: Gino Carrillo Date: Tue, 24 Oct 2023 10:41:51 -0700 Subject: [PATCH 21/35] Deleted extra line spacings and extra line of code --- gwinferno/models/bsplines/single.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/gwinferno/models/bsplines/single.py b/gwinferno/models/bsplines/single.py index 1ae8304d..f4264653 100644 --- a/gwinferno/models/bsplines/single.py +++ b/gwinferno/models/bsplines/single.py @@ -140,7 +140,6 @@ def __init__( **kwargs, ) - class BSplineSpinTilt(Base1DBSplineModel): """Class to construct a cosine tilt (cos(theta)) B-Spline model for a single binary component @@ -314,7 +313,6 @@ def __init__( xrange=(qmin, 1), **kwargs, ) - self.n_splines = n_splines class BSplineMass(Base1DBSplineModel): From c06b9ae8b1a196f82b7f3810175d316e3f64d01b Mon Sep 17 00:00:00 2001 From: Gino Carrillo Date: Tue, 24 Oct 2023 10:42:59 -0700 Subject: [PATCH 22/35] typo fixes --- gwinferno/models/spline_perturbation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gwinferno/models/spline_perturbation.py b/gwinferno/models/spline_perturbation.py index ab713170..d0a5cba3 100644 --- a/gwinferno/models/spline_perturbation.py +++ b/gwinferno/models/spline_perturbation.py @@ -328,7 +328,7 @@ def normalization(self, lamb: float, cs: jnp.ndarray): Returns: _type_: """ - pz = self.dVdc_ * jnp.power(1.0 + self.zs, lamb - 1) + pz = self.dVdz_ * jnp.power(1.0 + self.zs, lamb - 1) pz *= jnp.exp(self.interpolator.project(self.norm_design_matrix, cs)) return jnp.trapz(pz, self.zs) @@ -361,9 +361,9 @@ def __call__(self, z: jnp.ndarray, lamb: float, cs: jnp.ndarray) -> jnp.ndarray: jnp.ndarray: """ ndim = len(z.shape) - dV_cdz = self.dV_cdz[ndim - 1] + dVdz = self.dVdzs[ndim - 1] return jnp.where( jnp.less_equal(z, self.zmax), - self.prob(z, dV_cdz, lamb, cs) / self.normalization(lamb, cs), + self.prob(z, dVdz, lamb, cs) / self.normalization(lamb, cs), 0, ) From 985a051442f8c4e6d5237f38f10a097d8364ca67 Mon Sep 17 00:00:00 2001 From: Gino Carrillo Date: Wed, 25 Oct 2023 11:05:40 -0700 Subject: [PATCH 23/35] deleted line of code to get it working again --- gwinferno/postprocess/calculate_ppds.py | 1 - 1 file changed, 1 deletion(-) diff --git a/gwinferno/postprocess/calculate_ppds.py b/gwinferno/postprocess/calculate_ppds.py index 3121db7f..40d7f0ab 100644 --- a/gwinferno/postprocess/calculate_ppds.py +++ b/gwinferno/postprocess/calculate_ppds.py @@ -171,7 +171,6 @@ def calc_pdf(mcs, qcs, r, pop_frac): return r * p_m * pop_frac / jnp.trapz(p_m, ms), r * p_q * pop_frac / jnp.trapz(p_q, qs) calc_pdf = jit(calc_pdf) - _ = calc_pdf(mcoefs[0], qcoefs[0], rate[0], pop_frac[0][0]) # loop through hyperposterior samples if isinstance(pop_frac, int): for ii in trange(mcoefs.shape[0]): From 4de5df8067a64ad0a3cd4af4851defb6c54881c1 Mon Sep 17 00:00:00 2001 From: Gino Carrillo Date: Fri, 27 Oct 2023 14:38:53 -0700 Subject: [PATCH 24/35] fixed bug --- gwinferno/postprocess/calculate_ppds.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gwinferno/postprocess/calculate_ppds.py b/gwinferno/postprocess/calculate_ppds.py index 40d7f0ab..63f89c36 100644 --- a/gwinferno/postprocess/calculate_ppds.py +++ b/gwinferno/postprocess/calculate_ppds.py @@ -82,7 +82,7 @@ def calculate_powerbspline_rate_of_z_ppds(lamb, z_cs, rate, model): rs = np.zeros((len(lamb), len(zs))) def calc_rz(cs, la, r): - return r * jnp.power(1.0 + zs, la) * jnp.exp(model.interpolator.project(model.norm_design_matrix, (model.nknots, 1), cs)) + return r * jnp.power(1.0 + zs, la) * jnp.exp(model.interpolator.project(model.norm_design_matrix, cs)) calc_rz = jit(calc_rz) _ = calc_rz(z_cs[0], lamb[0], rate[0]) From aba83c3d922ccef804407748a1f8da8a24b94938 Mon Sep 17 00:00:00 2001 From: Gino Carrillo Date: Fri, 27 Oct 2023 14:39:39 -0700 Subject: [PATCH 25/35] testing funcs. Not intended to be a real commit --- gwinferno/models/bsplines/separable.py | 47 ++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/gwinferno/models/bsplines/separable.py b/gwinferno/models/bsplines/separable.py index 0bd3c7f7..a64c3618 100644 --- a/gwinferno/models/bsplines/separable.py +++ b/gwinferno/models/bsplines/separable.py @@ -729,3 +729,50 @@ def __init__( def __call__(self, ndim, mcoefs, qcoefs): return self.ratio_model(ndim, qcoefs) * self.primary_model(ndim, mcoefs) + +#Testing +class BSplineSeperableMassRatioChiEff(object): + def __init__( + self, + nknots_q, + nknots_chieff, + q, + q_inj, + chieff, + chieff_inj, + order_q=4, + order_chieff=4, + qknots = None, + chieffknots = None, + m1min=3.0, + m2min=3.0, + mmax=100.0, + basis_q=BSpline, + basis_chieff=BSpline, + **kwargs, + ): + self.nknots_q = nknots_q + self.nknots_chieff = nknots_chieff + self.ratio_model = BSplineRatio( + nknots_q, + q, + q_inj, + qmin=m2min / mmax, + knots=qknots, + degree=order_q-1, + basis=basis_q, + **kwargs, + ) + + self.chieff_model = BSplineChiEffective( + nknots_chieff, + chieff, + chieff_inj, + knots=chieffknots, + degree=order_chieff-1, + basis=basis_chieff, + **kwargs, + ) + + def __call__(self, chieffcoefs, qcoefs, pe_samples): + return self.ratio_model(qcoefs, pe_samples=pe_samples) * self.chieff_model(chieffcoefs, pe_samples=pe_samples) \ No newline at end of file From c4944dd2d808efb3f2fbe696f3a1ff9449ed1b87 Mon Sep 17 00:00:00 2001 From: Gino Carrillo Date: Tue, 31 Oct 2023 12:16:15 -0700 Subject: [PATCH 26/35] Fixed bugs to get stable sampling --- examples/basis_spline_example.py | 156 ++++++++++++++++++------------- 1 file changed, 89 insertions(+), 67 deletions(-) diff --git a/examples/basis_spline_example.py b/examples/basis_spline_example.py index 000d6d30..5f5d37cd 100755 --- a/examples/basis_spline_example.py +++ b/examples/basis_spline_example.py @@ -39,12 +39,13 @@ def load_parser(): parser.add_argument("--q-knots", type=int, default=30) parser.add_argument("--tilt-knots", type=int, default=25) parser.add_argument("--z-knots", type=int, default=20) + parser.add_argument("--skip-prior", action="store_true", default=True) return parser.parse_args() def setup_mass_BSpline_model(injdata, pedata, pmap, nknots, qknots, mmin=3.0, mmax=100.0): - print(f"Basis Spline model in m1 w/ {nknots} knots logspaced from {mmin} to {mmax}...") - print(f"Basis Spline model in q w/ {qknots} knots linspaced from {mmin/mmax} to 1...") + print(f"Basis Spline model in m1 w/ {nknots} number of bases. Knots are logspaced from {mmin} to {mmax}...") + print(f"Basis Spline model in q w/ {qknots} number of bases. Knots are linspaced from {mmin/mmax} to 1...") model = BSplinePrimaryBSplineRatio( nknots, @@ -152,48 +153,53 @@ def model( Tobs, sample_prior=False, ): - mass_knots = mass_model.primary_model.nknots - q_knots = mass_model.ratio_model.nknots + mass_knots = mass_model.primary_model.n_splines + q_knots = mass_model.ratio_model.n_splines mag_model = spin_models["mag"] tilt_model = spin_models["tilt"] - mag_knots = mag_model.primary_model.nknots - tilt_knots = tilt_model.primary_model.nknots + mag_knots = mag_model.primary_model.n_splines + tilt_knots = tilt_model.primary_model.n_splines z_knots = z_model.nknots mass_cs = numpyro.sample("mass_cs", dist.Normal(0, 6), sample_shape=(mass_knots,)) - mass_tau = numpyro.sample("mass_tau", dist.Uniform(1, 1000)) - numpyro.factor("mass_log_smoothing_prior", apply_difference_prior(mass_cs, mass_tau, degree=2)) + mass_tau_squared = numpyro.sample("mass_tau_squared", dist.TruncatedDistribution(dist.Normal(scale = 0.01), low = 0, high = 1)) + mass_lambda = numpyro.deterministic("mass_lambda", 1/mass_tau_squared) + numpyro.factor("mass_log_smoothing_prior", apply_difference_prior(mass_cs, mass_lambda, degree=2)) q_cs = numpyro.sample("q_cs", dist.Normal(0, 4), sample_shape=(q_knots,)) - q_tau = numpyro.sample("q_tau", dist.Uniform(1, 25)) - numpyro.factor("q_log_smoothing_prior", apply_difference_prior(q_cs, q_tau, degree=2)) + q_tau_squared = numpyro.sample("q_tau_squared", dist.TruncatedDistribution(dist.Normal(scale = 0.1), low = 0, high = 1)) + q_lambda = numpyro.deterministic("q_lambda", 1/q_tau_squared) + numpyro.factor("q_log_smoothing_prior", apply_difference_prior(q_cs, q_lambda, degree=2)) mag_cs = numpyro.sample("mag_cs", dist.Normal(0, 2), sample_shape=(mag_knots,)) - mag_tau = numpyro.sample("mag_tau", dist.Uniform(1, 10)) - numpyro.factor("mag_log_smoothing_prior", apply_difference_prior(mag_cs, mag_tau, degree=2)) + mag_tau_squared = numpyro.sample("mag_tau_squared", dist.TruncatedDistribution(dist.Normal(scale = 0.1), low = 0, high = 1)) + mag_lambda = numpyro.deterministic("mag_lambda", 1/mag_tau_squared) + numpyro.factor("mag_log_smoothing_prior", apply_difference_prior(mag_cs, mag_lambda, degree=2)) tilt_cs = numpyro.sample("tilt_cs", dist.Normal(0, 2), sample_shape=(tilt_knots,)) - tilt_tau = numpyro.sample("tilt_tau", dist.Uniform(1, 10)) - numpyro.factor("tilt_log_smoothing_prior", apply_difference_prior(tilt_cs, tilt_tau, degree=2)) + tilt_tau_squared = numpyro.sample("tilt_tau_squared", dist.TruncatedDistribution(dist.Normal(scale = 0.1), low = 0, high = 1)) + tilt_lambda = numpyro.deterministic("tilt_lambda", 1/tilt_tau_squared) + numpyro.factor("tilt_log_smoothing_prior", apply_difference_prior(tilt_cs, tilt_lambda, degree=2)) lamb = numpyro.sample("lamb", dist.Normal(0, 3)) z_cs = numpyro.sample("z_cs", dist.Normal(), sample_shape=(z_knots,)) - z_tau = numpyro.sample("z_tau", dist.Uniform(1, 5)) - numpyro.factor("z_log_smoothing_prior", apply_difference_prior(z_cs, z_tau, degree=2)) + z_tau_squared = numpyro.sample("z_tau_squared", dist.Uniform(1, 10)) + z_lambda = numpyro.deterministic("z_lambda", 1/z_tau_squared) + numpyro.factor("z_log_smoothing_prior", apply_difference_prior(z_cs, z_lambda, degree=2)) if not sample_prior: - def get_weights(z, prior): - p_m1q = mass_model(len(z.shape), mass_cs, q_cs) - p_a1a2 = mag_model(len(z.shape), mag_cs) - p_ct1ct2 = tilt_model(len(z.shape), tilt_cs) + def get_weights(z, prior, pe_samples = True): + p_m1q = mass_model(mass_cs, q_cs, pe_samples) + p_a1a2 = mag_model(mag_cs, pe_samples) + p_ct1ct2 = tilt_model(tilt_cs, pe_samples) p_z = z_model(z, lamb, z_cs) wts = p_m1q * p_a1a2 * p_ct1ct2 * p_z / prior return jnp.where(jnp.isnan(wts) | jnp.isinf(wts), 0, wts) peweights = get_weights(pedict["redshift"], pedict["prior"]) - injweights = get_weights(injdict["redshift"], injdict["prior"]) + injweights = get_weights(injdict["redshift"], injdict["prior"], pe_samples=False) hierarchical_likelihood( peweights, injweights, @@ -203,7 +209,7 @@ def get_weights(z, prior): surv_hypervolume_fct=z_model.normalization, vtfct_kwargs=dict(lamb=lamb, cs=z_cs), marginalize_selection=False, - min_neff_cut=True, + min_neff_cut=False, posterior_predictive_check=True, pedata=pedict, injdata=injdict, @@ -282,17 +288,17 @@ def main(): "logBFs", "log_l", "mag_cs", - "mag_tau", + "mag_lambda", "mass_cs", - "mass_tau", + "mass_lambda", "q_cs", - "q_tau", + "q_lambda", "rate", "surveyed_hypervolume", "tilt_cs", - "tilt_tau", + "tilt_lambda", "z_cs", - "z_tau", + "z_lambda", ] fig = az.plot_trace(az.from_numpyro(mcmc), var_names=plot_params) plt.savefig(f"{label}_trace_plot.png") @@ -312,6 +318,8 @@ def main(): mmin=args.mmin, m1mmin=args.mmin, mmax=args.mmax, + basis_m=LogXLogYBSpline, + basis_q=LogYBSpline, ) print("calculating mass posterior ppds...") pm1s, pqs, ms, qs = calculate_m1q_bspline_ppds( @@ -323,52 +331,63 @@ def main(): mmin=args.mmin, m1mmin=args.mmin, mmax=args.mmax, + basis_m=LogXLogYBSpline, + basis_q=LogYBSpline, ) - print("calculating mag prior ppds...") - prior_pmags, mags = calculate_iid_spin_bspline_ppds(prior["mag_cs"], BSplineIIDSpinMagnitudes, args.mag_knots, xmin=0, xmax=1) + if not args.skip_prior: + print("calculating mag prior ppds...") + prior_pmags, mags = calculate_iid_spin_bspline_ppds(prior["mag_cs"], BSplineIIDSpinMagnitudes, args.mag_knots, xmin=0, xmax=1, basis=LogYBSpline) print("calculating mag posterior ppds...") - pmags, mags = calculate_iid_spin_bspline_ppds(posterior["mag_cs"], BSplineIIDSpinMagnitudes, args.mag_knots, xmin=0, xmax=1) + pmags, mags = calculate_iid_spin_bspline_ppds(posterior["mag_cs"], BSplineIIDSpinMagnitudes, args.mag_knots, xmin=0, xmax=1, basis=LogYBSpline) - print("calculating tilt prior ppds...") - prior_ptilts, tilts = calculate_iid_spin_bspline_ppds(prior["tilt_cs"], BSplineIIDSpinTilts, args.tilt_knots, xmin=-1, xmax=1) + if not args.skip_prior: + print("calculating tilt prior ppds...") + prior_ptilts, tilts = calculate_iid_spin_bspline_ppds(prior["tilt_cs"], BSplineIIDSpinTilts, args.tilt_knots, xmin=-1, xmax=1, basis=LogYBSpline) print("calculating tilt posterior ppds...") - ptilts, tilts = calculate_iid_spin_bspline_ppds(posterior["tilt_cs"], BSplineIIDSpinTilts, args.tilt_knots, xmin=-1, xmax=1) + ptilts, tilts = calculate_iid_spin_bspline_ppds(posterior["tilt_cs"], BSplineIIDSpinTilts, args.tilt_knots, xmin=-1, xmax=1, basis=LogYBSpline) - print("calculating rate prior ppds...") - prior_Rofz, zs = calculate_powerbspline_rate_of_z_ppds(prior["lamb"], prior["z_cs"], jnp.ones_like(prior["lamb"]), z) + if not args.skip_prior: + print("calculating rate prior ppds...") + prior_Rofz, zs = calculate_powerbspline_rate_of_z_ppds(prior["lamb"], prior["z_cs"], jnp.ones_like(prior["lamb"]), z) print("calculating rate posterior ppds...") Rofz, zs = calculate_powerbspline_rate_of_z_ppds(posterior["lamb"], posterior["z_cs"], posterior["rate"], z) - ppd_dict = { - "dRdm1": pm1s, - "dRdq": pqs, - "m1s": ms, - "qs": qs, - "dRda": pmags, - "mags": mags, - "dRdct": ptilts, - "tilts": tilts, - "Rofz": Rofz, - "zs": zs, - } - dd.io.save(f"{label}_ppds.h5", ppd_dict) - prior_ppd_dict = { - "pm1": prior_pm1s, - "pq": prior_pqs, - "pa": prior_pmags, - "pct": prior_ptilts, - "m1s": ms, - "qs": qs, - "mags": mags, - "tilts": tilts, - "Rofz": prior_Rofz, - "zs": zs, - } - dd.io.save(f"{label}_prior_ppds.h5", prior_ppd_dict) - del ppd_dict, prior_ppd_dict +# Lines (357-383) are commented out due to deepdish errors + # if not args.skip_prior: + # prior_ppd_dict = { + # "pm1": prior_pm1s, + # "pq": prior_pqs, + # "pa": prior_pmags, + # "pct": prior_ptilts, + # "m1s": ms, + # "qs": qs, + # "mags": mags, + # "tilts": tilts, + # "Rofz": prior_Rofz, + # "zs": zs, + # } + + # dd.io.save(f"{label}_prior_ppds.h5", prior_ppd_dict) + # del prior_ppd_dict + + # ppd_dict = { + # "dRdm1": pm1s, + # "dRdq": pqs, + # "m1s": ms, + # "qs": qs, + # "dRda": pmags, + # "mags": mags, + # "dRdct": ptilts, + # "tilts": tilts, + # "Rofz": Rofz, + # "zs": zs, + # } + # dd.io.save(f"{label}_ppds.h5", ppd_dict) + # del ppd_dict print("plotting mass distribution...") + priors = None if args.skip_prior else {"m1": prior_pm1s, "q": prior_pqs} fig = plot_mass_dist( pm1s, pqs, @@ -376,21 +395,24 @@ def main(): qs, mmin=5.0, mmax=args.mmax, - priors={"m1": prior_pm1s, "q": prior_pqs}, + priors=priors, ) plt.savefig(f"{label}_mass_distribution.png") del fig print("plotting spin distributions...") - fig = plot_iid_spin_dist(pmags, ptilts, mags, tilts, priors={"mags": prior_pmags, "tilts": prior_ptilts}) + priors = None if args.skip_prior else {"mags": prior_pmags, "tilts": prior_ptilts} + fig = plot_iid_spin_dist(pmags, ptilts, mags, tilts, priors=priors) plt.savefig(f"{label}_iid_component_spin_distribution.png") del fig print("plotting R(z)...") - fig = plot_rofz(Rofz, zs, prior=prior_Rofz) + prior = None if args.skip_prior else prior_Rofz + fig = plot_rofz(Rofz, zs, prior=prior) plt.savefig(f"{label}_rate_vs_z.png") del fig - fig = plot_rofz(Rofz, zs, logx=True, prior=prior_Rofz) + prior = None if args.skip_prior else prior_Rofz + fig = plot_rofz(Rofz, zs, logx=True, prior=prior) plt.savefig(f"{label}_rate_vs_z_logscale.png") del fig @@ -406,4 +428,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file From 0f0bc70e41aefe680c3985c710528a7455cf468e Mon Sep 17 00:00:00 2001 From: Gino Carrillo Date: Fri, 3 Nov 2023 11:36:51 -0700 Subject: [PATCH 27/35] Uncommented a piece of code that is used in cosmology class --- gwinferno/cosmology.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gwinferno/cosmology.py b/gwinferno/cosmology.py index 16b1c5fd..05d397d6 100644 --- a/gwinferno/cosmology.py +++ b/gwinferno/cosmology.py @@ -48,7 +48,7 @@ def __init__(self, Ho, omega_matter, omega_radiation, omega_lambda, distance_uni self.z = jnp.array([0.0]) self.Dc = jnp.array([0.0]) self.Vc = jnp.array([0.0]) - # self.extend(max_z=2.3, dz=DEFAULT_DZ) + self.extend(max_z=2.3, dz=DEFAULT_DZ) @property def DL(self): From 2807b6db1d292b4e69dded83b5a95c4ef2a7268e Mon Sep 17 00:00:00 2001 From: Gino Carrillo Date: Fri, 3 Nov 2023 11:39:49 -0700 Subject: [PATCH 28/35] Changed neff cut from 4*Nobs back to Nobs --- gwinferno/pipeline/analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gwinferno/pipeline/analysis.py b/gwinferno/pipeline/analysis.py index 4ad146a1..30a2a737 100644 --- a/gwinferno/pipeline/analysis.py +++ b/gwinferno/pipeline/analysis.py @@ -246,7 +246,7 @@ def hierarchical_likelihood( numpyro.factor( "log_likelihood", jnp.where( - jnp.isnan(log_l) | jnp.less_equal(jnp.exp(jnp.min(logn_effs)), 4*Nobs), + jnp.isnan(log_l) | jnp.less_equal(jnp.exp(jnp.min(logn_effs)), Nobs), jnp.nan_to_num(-jnp.inf), jnp.nan_to_num(log_l), ), From adcb82e6b432a4df7a5e4cb3ccc886d76af6fec4 Mon Sep 17 00:00:00 2001 From: Gino Carrillo Date: Thu, 16 Nov 2023 16:27:42 -0800 Subject: [PATCH 29/35] Added another BSpline example script. This model includes a chi_eff B-Spline model with simple power-law model for redshift. --- examples/basis_spline_example_chieff.py | 356 ++++++++++++++++++++++++ 1 file changed, 356 insertions(+) create mode 100755 examples/basis_spline_example_chieff.py diff --git a/examples/basis_spline_example_chieff.py b/examples/basis_spline_example_chieff.py new file mode 100755 index 00000000..9406984b --- /dev/null +++ b/examples/basis_spline_example_chieff.py @@ -0,0 +1,356 @@ +import arviz as az +import deepdish as dd +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpyro +from jax import random +from numpyro import distributions as dist +from numpyro.infer import MCMC +from numpyro.infer import NUTS + +from gwinferno.interpolation import LogXLogYBSpline +from gwinferno.interpolation import LogYBSpline +from gwinferno.models.bsplines.separable import BSplinePrimaryBSplineRatio +from gwinferno.models.bsplines.single import BSplineChiEffective +from gwinferno.models.bsplines.smoothing import apply_difference_prior +from gwinferno.models.gwpopulation.gwpopulation import PowerlawRedshiftModel +from gwinferno.pipeline.analysis import hierarchical_likelihood +from gwinferno.pipeline.parser import load_base_parser +from gwinferno.postprocess.calculate_ppds import calculate_m1q_bspline_ppds +from gwinferno.postprocess.calculate_ppds import calculate_powerlaw_rate_of_z_ppds +from gwinferno.postprocess.calculate_ppds import calculate_chieff_bspline_ppds +from gwinferno.postprocess.plotting import plot_m1_vs_z_ppc +from gwinferno.postprocess.plotting import plot_mass_dist +from gwinferno.postprocess.plotting import plot_rofz +from gwinferno.postprocess.plotting import plot_chieff_dist +from gwinferno.preprocess.data_collection import load_injections +from gwinferno.preprocess.data_collection import load_posterior_samples + +az.style.use("arviz-darkgrid") + + +def load_parser(): + parser = load_base_parser() + parser.add_argument("--mass-knots", type=int, default=100) + parser.add_argument("--mag-knots", type=int, default=30) + parser.add_argument("--q-knots", type=int, default=30) + parser.add_argument("--tilt-knots", type=int, default=25) + parser.add_argument("--z-knots", type=int, default=20) + parser.add_argument("--chieff-nsplines", type=int, default=30) + parser.add_argument("--skip-prior", action="store_true", default=True) + return parser.parse_args() + + +def setup_mass_BSpline_model(injdata, pedata, pmap, nknots, qknots, mmin=3.0, mmax=100.0): + print(f"Basis Spline model in m1 w/ {nknots} number of bases. Knots are logspaced from {mmin} to {mmax}...") + print(f"Basis Spline model in q w/ {qknots} number of bases. Knots are linspaced from {mmin/mmax} to 1...") + + model = BSplinePrimaryBSplineRatio( + nknots, + qknots, + pedata[pmap["mass_1"]], + injdata[pmap["mass_1"]], + pedata[pmap["mass_ratio"]], + injdata[pmap["mass_ratio"]], + m1min=mmin, + m2min=mmin, + mmax=mmax, + basis_m=LogXLogYBSpline, + basis_q=LogYBSpline, + ) + return model + +def setup_chieff_BSpline_model(nsplines, injdata, pedata, pmap): + print(f"Basis spline model in chieff w/ {nsplines} bases. Knots are linearly spaced.") + model = BSplineChiEffective( + n_splines=nsplines, + chieff=pedata[pmap['chi_eff']], + chieff_inj=injdata[pmap['chi_eff']], + basis=LogYBSpline, + ) + return model + + +def setup_redshift_model(injdata, pedata, pmap): + print(f"Powerlaw redshift model set up.") + z_pe = pedata[pmap["redshift"]] + z_inj = injdata[pmap["redshift"]] + model = PowerlawRedshiftModel(z_pe, z_inj) + return model + + +def setup(args): + df = dd.io.load("./saved-pe-and-injs/posterior_samples_and_injections_chi_effective.h5") + pedata = df['pedata'] + injdata = df['injdata'] + param_map = df['param_map'] + param_names = [ + "mass_1", "mass_ratio", "redshift", "chi_eff", "prior" + ] + param_map = {p: i for i, p in enumerate(param_names)} + injdict = {k: injdata[param_map[k]] for k in param_names} + pedict = {k: pedata[param_map[k]] for k in param_names} + nObs = pedata.shape[1] + total_inj = df["total_generated"] + obs_time = df["analysis_time"] + + mass_model = setup_mass_BSpline_model( + injdata, + pedata, + param_map, + args.mass_knots, + args.q_knots, + mmin=args.mmin, + mmax=args.mmax, + ) + z_model = setup_redshift_model(injdata, pedata, param_map) + chieff_model = setup_chieff_BSpline_model(args.chieff_nsplines, injdata, pedata, param_map) + injdict = {k: injdata[param_map[k]] for k in param_names} + pedict = {k: pedata[param_map[k]] for k in param_names} + + print(f"{len(injdict['redshift'])} found injections out of {total_inj} total") + print(f"Observed {nObs} events, each with {pedict['redshift'].shape[1]} samples, over an observing time of {obs_time} yrs") + + return ( + mass_model, + chieff_model, + z_model, + pedict, + injdict, + total_inj, + nObs, + obs_time, + ) + + +def model( + mass_model, + chieff_model, + z_model, + pedict, + injdict, + total_inj, + Nobs, + Tobs, + sample_prior=False, +): + mass_knots = mass_model.primary_model.n_splines + q_knots = mass_model.ratio_model.n_splines + chieff_nsplines = chieff_model.n_splines + + mass_cs = numpyro.sample("mass_cs", dist.Normal(0, 6), sample_shape=(mass_knots,)) + mass_tau_squared = numpyro.sample("mass_tau_squared", dist.TruncatedDistribution(dist.Normal(scale = 0.01), low = 0, high = 1)) + mass_lambda = numpyro.deterministic("mass_lambda", 1/mass_tau_squared) + numpyro.factor("mass_log_smoothing_prior", apply_difference_prior(mass_cs, mass_lambda, degree=2)) + + q_cs = numpyro.sample("q_cs", dist.Normal(0, 4), sample_shape=(q_knots,)) + q_tau_squared = numpyro.sample("q_tau_squared", dist.TruncatedDistribution(dist.Normal(scale = 0.1), low = 0, high = 1)) + q_lambda = numpyro.deterministic("q_lambda", 1/q_tau_squared) + numpyro.factor("q_log_smoothing_prior", apply_difference_prior(q_cs, q_lambda, degree=2)) + + chieff_cs = numpyro.sample("chieff_cs", dist.Normal(0,4), sample_shape=(chieff_nsplines,)) + chieff_tau_squared = numpyro.sample("chieff_tau_squared", dist.TruncatedDistribution(dist.Normal(scale = 0.1), low = 0, high = 1)) + chieff_lambda = numpyro.deterministic("chieff_lambda", 1/chieff_tau_squared) + numpyro.factor("chieff_log_smoothing_prior", apply_difference_prior(chieff_cs, chieff_lambda, degree=2)) + + lamb = numpyro.sample("lamb", dist.Normal(0, 3)) + + if not sample_prior: + + def get_weights(z, prior, pe_samples = True): + p_m1q = mass_model(mass_cs, q_cs, pe_samples) + p_chieff = chieff_model(chieff_cs, pe_samples) + p_z = z_model(z, lamb) + wts = p_m1q * p_chieff * p_z / prior + + return jnp.where(jnp.isnan(wts) | jnp.isinf(wts), 0, wts) + + peweights = get_weights(pedict["redshift"], pedict["prior"]) + injweights = get_weights(injdict["redshift"], injdict["prior"], pe_samples=False) + hierarchical_likelihood( + peweights, + injweights, + total_inj=total_inj, + Nobs=Nobs, + Tobs=Tobs, + vtfct_kwargs=dict(lamb=lamb), + marginalize_selection=False, + min_neff_cut=False, + posterior_predictive_check=True, + pedata=pedict, + injdata=injdict, + param_names=[ + "mass_1", + "mass_ratio", + "redshift", + "chi_eff", + ], + ) + + +def main(): + args = load_parser() + label = f"{args.outdir}/bsplines_{args.chieff_nsplines}chieff_{args.mass_knots}m1_{args.q_knots}q_z" + mass, chieff, z, pedict, injdict, total_inj, nObs, obs_time = setup(args) + if not args.skip_inference: + RNG = random.PRNGKey(0) + MCMC_RNG, PRIOR_RNG, _RNG = random.split(RNG, num=3) + kernel = NUTS(model) + mcmc = MCMC( + kernel, + thinning=args.thinning, + num_warmup=args.warmup, + num_samples=args.samples, + num_chains=args.chains, + ) + print("running mcmc: sampling prior...") + mcmc.run( + PRIOR_RNG, + mass, + chieff, + z, + pedict, + injdict, + float(total_inj), + nObs, + obs_time, + sample_prior=True, + ) + prior = mcmc.get_samples() + dd.io.save(f"{label}_prior_samples.h5", prior) + + kernel = NUTS(model) + mcmc = MCMC( + kernel, + thinning=args.thinning, + num_warmup=args.warmup, + num_samples=args.samples, + num_chains=args.chains, + ) + print("running mcmc: sampling posterior...") + mcmc.run( + MCMC_RNG, + mass, + chieff, + z, + pedict, + injdict, + float(total_inj), + nObs, + obs_time, + sample_prior=False, + ) + mcmc.print_summary() + posterior = mcmc.get_samples() + dd.io.save(f"{label}_posterior_samples.h5", posterior) + plot_params = [ + "detection_efficency", + "lamb", + "log_nEff_inj", + "log_nEffs", + "logBFs", + "log_l", + "chieff_cs", + "chieff_lambda", + "mass_cs", + "mass_lambda", + "q_cs", + "q_lambda", + "rate", + "surveyed_hypervolume", + ] + fig = az.plot_trace(az.from_numpyro(mcmc), var_names=plot_params) + plt.savefig(f"{label}_trace_plot.png") + del fig, mcmc, pedict, injdict, total_inj, obs_time + else: + print(f"loading prior and posterior samples from run with label: {label}...") + prior = dd.io.load(f"{label}_prior_samples.h5") + posterior = dd.io.load(f"{label}_posterior_samples.h5") + + print("calculating mass prior ppds...") + prior_pm1s, prior_pqs, ms, qs = calculate_m1q_bspline_ppds( + prior["mass_cs"], + prior["q_cs"], + BSplinePrimaryBSplineRatio, + args.mass_knots, + args.q_knots, + mmin=args.mmin, + m1mmin=args.mmin, + mmax=args.mmax, + basis_m=LogXLogYBSpline, + basis_q=LogYBSpline, + ) + print("calculating mass posterior ppds...") + pm1s, pqs, ms, qs = calculate_m1q_bspline_ppds( + posterior["mass_cs"], + posterior["q_cs"], + BSplinePrimaryBSplineRatio, + args.mass_knots, + args.q_knots, + mmin=args.mmin, + m1mmin=args.mmin, + mmax=args.mmax, + basis_m=LogXLogYBSpline, + basis_q=LogYBSpline, + ) + + if not args.skip_prior: + print("calculating rate prior ppds...") + prior_Rofz, zs = calculate_powerlaw_rate_of_z_ppds(prior["lamb"], jnp.ones_like(prior["lamb"]), z) + print("calculating rate posterior ppds...") + Rofz, zs = calculate_powerlaw_rate_of_z_ppds(posterior["lamb"], posterior["rate"], z) + + if not args.skip_prior: + print("calculating chieff prior ppds...") + prior_pchieff, xs = calculate_chieff_bspline_ppds( + coefs=prior["chieff_cs"], + model=chieff, + nknots=args.chieff_nsplines, + basis=LogYBSpline, + ) + + print("calculating chieff posterior ppds...") + pchieff, xs = calculate_chieff_bspline_ppds( + coefs=posterior["chieff_cs"], + model=BSplineChiEffective, + nknots=args.chieff_nsplines, + basis=LogYBSpline, + ) + + print("plotting mass distribution...") + priors = None if args.skip_prior else {"m1": prior_pm1s, "q": prior_pqs} + fig = plot_mass_dist( + pm1s, + pqs, + ms, + qs, + mmin=5.0, + mmax=args.mmax, + priors=priors, + ) + plt.savefig(f"{label}_mass_distribution.png") + del fig + + print("plotting chieff distribution...") + prior = None if args.skip_prior else prior_pchieff + fig = plot_chieff_dist(pchieff, xs, prior=prior) + plt.savefig(f"{label}_chieff_distribution.png") + del fig + + print("plotting R(z)...") + prior = None if args.skip_prior else prior_Rofz + fig = plot_rofz(Rofz, zs, prior=prior) + plt.savefig(f"{label}_rate_vs_z.png") + del fig + prior = None if args.skip_prior else prior_Rofz + fig = plot_rofz(Rofz, zs, logx=True, prior=prior) + plt.savefig(f"{label}_rate_vs_z_logscale.png") + del fig + + print("plotting m1/z PPC...") + fig = plot_m1_vs_z_ppc(posterior, nObs, 5.0, args.mmax, z.zmax) + plt.savefig(f"{label}_m1_vs_z_ppc.png") + del fig + +if __name__ == "__main__": + main() \ No newline at end of file From 69a59b0f8620297754874211f08a8a08e09ae8c7 Mon Sep 17 00:00:00 2001 From: Gino Carrillo Date: Thu, 16 Nov 2023 16:34:31 -0800 Subject: [PATCH 30/35] Removed unused functions and added a comment. --- examples/basis_spline_example_chieff.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/basis_spline_example_chieff.py b/examples/basis_spline_example_chieff.py index 9406984b..f56aa02a 100755 --- a/examples/basis_spline_example_chieff.py +++ b/examples/basis_spline_example_chieff.py @@ -23,12 +23,9 @@ from gwinferno.postprocess.plotting import plot_mass_dist from gwinferno.postprocess.plotting import plot_rofz from gwinferno.postprocess.plotting import plot_chieff_dist -from gwinferno.preprocess.data_collection import load_injections -from gwinferno.preprocess.data_collection import load_posterior_samples az.style.use("arviz-darkgrid") - def load_parser(): parser = load_base_parser() parser.add_argument("--mass-knots", type=int, default=100) @@ -80,7 +77,9 @@ def setup_redshift_model(injdata, pedata, pmap): def setup(args): - df = dd.io.load("./saved-pe-and-injs/posterior_samples_and_injections_chi_effective.h5") + #Provide location to PE and injection samples below. + inj_pe_path = "./saved-pe-and-injs/posterior_samples_and_injections_chi_effective.h5" + df = dd.io.load(inj_pe_path) pedata = df['pedata'] injdata = df['injdata'] param_map = df['param_map'] From 1c4d71b40bb38d5451bb435ec6244913136b5c84 Mon Sep 17 00:00:00 2001 From: Gino Carrillo Date: Thu, 16 Nov 2023 16:39:52 -0800 Subject: [PATCH 31/35] Removed unused separable object that was used during developing 2d model. --- gwinferno/models/bsplines/separable.py | 49 +------------------------- 1 file changed, 1 insertion(+), 48 deletions(-) diff --git a/gwinferno/models/bsplines/separable.py b/gwinferno/models/bsplines/separable.py index a64c3618..8fc546e3 100644 --- a/gwinferno/models/bsplines/separable.py +++ b/gwinferno/models/bsplines/separable.py @@ -728,51 +728,4 @@ def __init__( ) def __call__(self, ndim, mcoefs, qcoefs): - return self.ratio_model(ndim, qcoefs) * self.primary_model(ndim, mcoefs) - -#Testing -class BSplineSeperableMassRatioChiEff(object): - def __init__( - self, - nknots_q, - nknots_chieff, - q, - q_inj, - chieff, - chieff_inj, - order_q=4, - order_chieff=4, - qknots = None, - chieffknots = None, - m1min=3.0, - m2min=3.0, - mmax=100.0, - basis_q=BSpline, - basis_chieff=BSpline, - **kwargs, - ): - self.nknots_q = nknots_q - self.nknots_chieff = nknots_chieff - self.ratio_model = BSplineRatio( - nknots_q, - q, - q_inj, - qmin=m2min / mmax, - knots=qknots, - degree=order_q-1, - basis=basis_q, - **kwargs, - ) - - self.chieff_model = BSplineChiEffective( - nknots_chieff, - chieff, - chieff_inj, - knots=chieffknots, - degree=order_chieff-1, - basis=basis_chieff, - **kwargs, - ) - - def __call__(self, chieffcoefs, qcoefs, pe_samples): - return self.ratio_model(qcoefs, pe_samples=pe_samples) * self.chieff_model(chieffcoefs, pe_samples=pe_samples) \ No newline at end of file + return self.ratio_model(ndim, qcoefs) * self.primary_model(ndim, mcoefs) \ No newline at end of file From c6e1d5cddc666fd30914b11ec312a32ee271c71c Mon Sep 17 00:00:00 2001 From: Gino Carrillo Date: Mon, 27 Nov 2023 15:27:48 -0800 Subject: [PATCH 32/35] Removed nan to num jnp in logic for min_neff_cut and added an additional condition --- gwinferno/pipeline/analysis.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gwinferno/pipeline/analysis.py b/gwinferno/pipeline/analysis.py index 30a2a737..bed5e7fd 100644 --- a/gwinferno/pipeline/analysis.py +++ b/gwinferno/pipeline/analysis.py @@ -246,9 +246,9 @@ def hierarchical_likelihood( numpyro.factor( "log_likelihood", jnp.where( - jnp.isnan(log_l) | jnp.less_equal(jnp.exp(jnp.min(logn_effs)), Nobs), - jnp.nan_to_num(-jnp.inf), - jnp.nan_to_num(log_l), + jnp.isnan(log_l) or jnp.isnan(jnp.min(logn_effs)) | jnp.less_equal(jnp.exp(jnp.min(logn_effs)), Nobs), + -jnp.inf, + log_l, ), ) else: From 3320c75b688e54f27ff3c0defe25960a86dfac18 Mon Sep 17 00:00:00 2001 From: Gino Carrillo Date: Tue, 19 Dec 2023 11:37:30 -0800 Subject: [PATCH 33/35] Changed the nEffs cut back to 4*Nobs as done in Cover Your Basis paper. --- gwinferno/pipeline/analysis.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gwinferno/pipeline/analysis.py b/gwinferno/pipeline/analysis.py index bed5e7fd..62fd667c 100644 --- a/gwinferno/pipeline/analysis.py +++ b/gwinferno/pipeline/analysis.py @@ -246,8 +246,8 @@ def hierarchical_likelihood( numpyro.factor( "log_likelihood", jnp.where( - jnp.isnan(log_l) or jnp.isnan(jnp.min(logn_effs)) | jnp.less_equal(jnp.exp(jnp.min(logn_effs)), Nobs), - -jnp.inf, + jnp.isnan(log_l) | jnp.isnan(jnp.min(logn_effs)) | jnp.less_equal(jnp.exp(jnp.min(logn_effs)), 4*Nobs), + -1000, log_l, ), ) From 52cb30adbad933e542a64446be7d39d1751673a3 Mon Sep 17 00:00:00 2001 From: Gino Carrillo Date: Thu, 18 Jan 2024 14:03:31 -0800 Subject: [PATCH 34/35] Changed min_neff_cut to Nobs. Order 10 Nobs should be enough --- gwinferno/pipeline/analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gwinferno/pipeline/analysis.py b/gwinferno/pipeline/analysis.py index 62fd667c..9bc50d99 100644 --- a/gwinferno/pipeline/analysis.py +++ b/gwinferno/pipeline/analysis.py @@ -246,7 +246,7 @@ def hierarchical_likelihood( numpyro.factor( "log_likelihood", jnp.where( - jnp.isnan(log_l) | jnp.isnan(jnp.min(logn_effs)) | jnp.less_equal(jnp.exp(jnp.min(logn_effs)), 4*Nobs), + jnp.isnan(log_l) | jnp.isnan(jnp.min(logn_effs)) | jnp.less_equal(jnp.exp(jnp.min(logn_effs)), Nobs), -1000, log_l, ), From 65f18631e0299c2f21c72aa1f2484294d4524727 Mon Sep 17 00:00:00 2001 From: Gino Carrillo Date: Thu, 18 Jan 2024 14:04:05 -0800 Subject: [PATCH 35/35] Removed an old note and random spacing indents --- gwinferno/interpolation.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/gwinferno/interpolation.py b/gwinferno/interpolation.py index c6eec00f..79561648 100644 --- a/gwinferno/interpolation.py +++ b/gwinferno/interpolation.py @@ -500,7 +500,7 @@ def bases(self, xs, ys): self.x_bases = self.x_interpolator.bases(xs) self.y_bases = self.y_interpolator.bases(ys) out = jnp.array([[self.x_bases[i] * self.y_bases[j] for i in range(self.xdf)] for j in range(self.ydf)]).reshape( - self.xdf, self.ydf, *xs.shape + self.xdf, self.ydf, *xs.shape ) self._reset_bases() @@ -517,7 +517,6 @@ def _project(self, bases, coefs): Returns: array_like: The linear combination of the basis components given the coefficients """ - #NOTE jnp.exp(jnp.einsum("ij...,ij->...", bases, coefs))this would be for the logz2dbasis class return jnp.einsum("ij...,ij->...", bases, coefs) def project(self, bases, coefs):