diff --git a/docs/examples/plot_5_emcee_arviz_numpyro.py b/docs/examples/plot_5_emcee_arviz_numpyro.py index 0e70d3be..ecbd5ebd 100644 --- a/docs/examples/plot_5_emcee_arviz_numpyro.py +++ b/docs/examples/plot_5_emcee_arviz_numpyro.py @@ -1,5 +1,5 @@ """ -# Using external samples easily +# Using external samples `emcee`, `arviz`, and `numpyro` are all popular MCMC packages. ChainConsumer provides class methods to turn results from these packages into chains efficiently. diff --git a/docs/examples/plot_7_multimodal_chains.py b/docs/examples/plot_7_multimodal_chains.py new file mode 100644 index 00000000..c2e7971a --- /dev/null +++ b/docs/examples/plot_7_multimodal_chains.py @@ -0,0 +1,79 @@ +""" +# Multimodal distributions + +`ChainConsumer` can handle cases where the distributions of your chains are multimodal. +""" + +import numpy as np +import pandas as pd + +from chainconsumer import Chain, ChainConsumer +from chainconsumer.statistics import SummaryStatistic + +# %% +# First, let's build some dummy data + +rng = np.random.default_rng(42) +size = 60_000 + +eta = rng.normal(loc=0.0, scale=0.8, size=size) + +phi = np.asarray( + [rng.gamma(shape=2.5, scale=0.4, size=size // 2) - 3.0, 3.0 - rng.gamma(shape=5.0, scale=0.35, size=(size // 2))] +).flatten() + +rng.shuffle(phi) + +df = pd.DataFrame({"eta": eta, "phi": phi}) + +# %% +# To build a multimodal chain, you simply have to pass `multimodal=True` when building the chain. To work, it requires +# you to specify `SummaryStatistic.HDI` as the summary statistic. + +chain_multimodal = Chain( + samples=df.copy(), + name="posterior-multimodal", + statistics=SummaryStatistic.HDI, + multimodal=True, # <- Here +) + +# %% +# Now, if you add this `Chain` to a plotter, it will try to look for sub-intervals and display them. + +cc = ChainConsumer() +cc.add_chain(chain_multimodal) +fig = cc.plotter.plot() + +# %% +# Let's compare with what would happen if you don't use a multimodal chain. We use the same data as before but don't +# tell `ChainConsumer` that we expect the chains to be multimodal. + +chain_unimodal = Chain(samples=df.copy(), name="posterior-unimodal", statistics=SummaryStatistic.HDI, multimodal=False) + +cc.add_chain(chain_unimodal) +fig = cc.plotter.plot() + +# %% +# Let's try with even more modes. + +eta = np.asarray( + [ + rng.normal(loc=-3, scale=0.8, size=size // 3), + rng.normal(loc=0.0, scale=0.8, size=size // 3), + rng.normal(loc=+3, scale=0.8, size=size // 3), + ] +).flatten() + + +rng.shuffle(eta) + +df = pd.DataFrame({"eta": eta, "phi": phi}) + +chain_multimodal = Chain( + samples=df.copy(), name="posterior-multimodal", statistics=SummaryStatistic.HDI, multimodal=True +) + +cc = ChainConsumer() +cc.add_chain(chain_multimodal) +fig = cc.plotter.plot() +fig.tight_layout() diff --git a/docs/resources/generate_stats.py b/docs/resources/generate_stats.py new file mode 100644 index 00000000..b3ceb941 --- /dev/null +++ b/docs/resources/generate_stats.py @@ -0,0 +1,81 @@ +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from matplotlib import rc +from scipy.stats import gamma + +from chainconsumer import Chain, ChainConsumer +from chainconsumer.statistics import SummaryStatistic + +# Activate latex text rendering +rc("font", family="serif", serif=["Computer Modern Roman"], size=13) +rc("text", usetex=True) +matplotlib.rcParams["text.latex.preamble"] = r"\usepackage{amsmath}" + +x = np.linspace(0, 5, 100) + +loc = 4 +scale = 0.45 + +fig, axs = plt.subplots(nrows=2, ncols=1, sharex=True, height_ratios=[0.5, 0.5], figsize=(5, 5)) +axs[0].plot(x, gamma.pdf(x, a=loc, scale=scale), color="black") +axs[1].plot(x, gamma.cdf(x, a=loc, scale=scale), color="black") + + +axs[1].set_xlabel("$x$") +axs[0].set_ylabel("$P(x)$") +axs[1].set_ylabel("$C(x)$") +axs[0].set_xlim(0, 5.0) +axs[0].set_ylim(0, 0.6) +axs[1].set_ylim(0, 1) + +samples = pd.DataFrame.from_dict({"gamma": gamma.rvs(size=10_000_000, a=loc, scale=scale)}) + +summary_list = [ + (SummaryStatistic.MAX, "MAX"), + (SummaryStatistic.CUMULATIVE, "CUMULATIVE"), + (SummaryStatistic.MEAN, "MEAN"), + (SummaryStatistic.HDI, "HDI"), +] + +chains = [] + +for summary, name in summary_list: + chains.append(Chain(samples=samples, statistics=summary, name=name)) + +cc = ChainConsumer() + +summary_result = cc.analysis.get_summary(chains=chains, columns=["gamma"]) + +for (_summary, name), color, linestyle, marker_style in zip( + summary_list, + ["r", "g", "b", "y"], + [":", "--", "-", "-."], + ["o", "^", "s", "*"], + strict=False, +): + bound = summary_result[name]["gamma"] + + x_min, x_mid, x_max = bound.lower, bound.center, bound.upper + + axs[0].scatter(x_mid, gamma.pdf(x_mid, a=loc, scale=scale), label=name, zorder=10, color=color, marker=marker_style) + axs[1].scatter(x_mid, gamma.cdf(x_mid, a=loc, scale=scale), zorder=10, color=color, marker=marker_style) + + axs[0].vlines( + x=x_min, ymin=0, ymax=gamma.pdf(x_min, a=loc, scale=scale), color=color, linestyle=linestyle, alpha=0.5 + ) + axs[0].vlines( + x=x_max, ymin=0, ymax=gamma.pdf(x_max, a=loc, scale=scale), color=color, linestyle=linestyle, alpha=0.5 + ) + + axs[1].hlines( + xmin=0, xmax=x_min, y=gamma.cdf(x_min, a=loc, scale=scale), color=color, linestyle=linestyle, alpha=0.5 + ) + axs[1].hlines( + xmin=0, xmax=x_max, y=gamma.cdf(x_max, a=loc, scale=scale), color=color, linestyle=linestyle, alpha=0.5 + ) + +axs[0].legend(fontsize=8) +plt.tight_layout() +plt.savefig("stats.png", bbox_inches="tight") diff --git a/docs/resources/stats.png b/docs/resources/stats.png index b7e501c9..6c2cb938 100644 Binary files a/docs/resources/stats.png and b/docs/resources/stats.png differ diff --git a/docs/usage.md b/docs/usage.md index e0fd145d..20299c2c 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -12,7 +12,7 @@ In general, this is the flow: ## Statistics -When summarising chains, ChainConsumer offers several different methods. The below image shows the upper and lower bounds and central points for the "MEAN", "CUMULATIVE", and "MAX" methods respectively. The "MAX_CENTRAL" method is the blue central value and the red bounds. +When summarising chains, ChainConsumer offers several different methods. The below image shows the upper and lower bounds and central points for the `MAX`, `CUMULATIVE`, `MEAN` and `HDI` methods respectively, with their associated bounds. ![](resources/stats.png) diff --git a/src/chainconsumer/analysis.py b/src/chainconsumer/analysis.py index c078cc50..464e0672 100644 --- a/src/chainconsumer/analysis.py +++ b/src/chainconsumer/analysis.py @@ -1,7 +1,8 @@ from __future__ import annotations import logging -from collections.abc import Callable +import warnings +from collections.abc import Callable, Sequence from pathlib import Path import numpy as np @@ -9,6 +10,7 @@ from scipy.integrate import simpson as simps from scipy.interpolate import interp1d from scipy.ndimage import gaussian_filter +from scipy.optimize import root_scalar from .base import BetterBase from .chain import Chain, ChainName, ColumnName, MaxPosterior, Named2DMatrix @@ -17,6 +19,33 @@ from .statistics import SummaryStatistic +def _mask_to_intervals( + x: np.ndarray, + mask: np.ndarray, +) -> list[tuple[float, float]]: + """ + Turn a mask indexed on x to a list of intervals + """ + if mask.size == 0: + return [] + + change = np.diff(mask.astype(int)) + starts = np.where(change == 1)[0] + 1 # False -> True + ends = np.where(change == -1)[0] # True -> False + + # If we start inside an interval, prepend 0 + if mask[0]: + starts = np.concatenate(([0], starts)) + + # If we end inside an interval, append last index + if mask[-1]: + ends = np.concatenate((ends, [len(mask) - 1])) + + intervals = [(float(x[s]), float(x[e])) for s, e in zip(starts, ends, strict=True) if x[e] > x[s]] + + return intervals + + class Bound(BetterBase): lower: float | None = Field(default=None) center: float | None = Field(default=None) @@ -53,6 +82,7 @@ def __init__(self, parent: ChainConsumer): SummaryStatistic.MEAN: self.get_parameter_summary_mean, SummaryStatistic.CUMULATIVE: self.get_parameter_summary_cumulative, SummaryStatistic.MAX_CENTRAL: self.get_parameter_summary_max_central, + SummaryStatistic.HDI: self.get_parameter_summary_hdi, } def get_latex_table( @@ -163,11 +193,11 @@ def get_summary( """Gets a summary of the marginalised parameter distributions. Args: - parameters (list[str], optional): A list of parameters which to generate summaries for. + columns (list[str], optional): A list of parameters which to generate summaries for. chains (dict[str, Chain] | list[str], optional): A list of chains to generate summaries for. Returns: - dict[ChainName, dict[ColumnName, Bound]]: A map from chain name to column name to bound. + dict[ChainName, dict[ColumnName, Bound | list[Bound]]]: A map from chain name to column name to bound. """ results = {} if chains is None: @@ -183,6 +213,22 @@ def get_summary( continue summary = self.get_parameter_summary(chain, p) res[p] = summary + + if chain.multimodal: + intervals = self.get_parameter_hdi_intervals(chain, p) + # If there is a single interval, we skip + if len(intervals) >= 2: + multimodal_bounds = self.get_parameter_multimodal_bounds( + chain, + p, + intervals=intervals, + ) + if multimodal_bounds is not None: + res[p] = multimodal_bounds + continue + + res[p] = summary + results[chain.name] = res return results @@ -276,7 +322,12 @@ def get_covariance_table( return self._get_2d_latex_table(covariance, caption, label) def _get_smoothed_histogram( - self, chain: Chain, column: ColumnName, pad: bool = False + self, + chain: Chain, + column: ColumnName, + pad: bool = False, + *, + use_kde: bool | None = None, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: data = chain.get_data(column) if chain.grid: @@ -292,7 +343,10 @@ def _get_smoothed_histogram( if chain.smooth_value: hist = gaussian_filter(hist, chain.smooth_value, mode="reflect") - if chain.kde: + if use_kde is None: + use_kde = bool(chain.kde) + + if use_kde: kde_xs = np.linspace(edge_centers[0], edge_centers[-1], max(200, int(bins.max()))) factor = chain.kde if isinstance(chain.kde, int | float) else 1.0 ys = MegKDE(data.to_numpy(), chain.weights, factor=factor).evaluate(kde_xs) @@ -325,27 +379,69 @@ def _get_2d_latex_table(self, named_matrix: Named2DMatrix, caption: str, label: table += hline_text return latex_table % (column_def, table) - def get_parameter_text(self, bound: Bound, wrap: bool = False): - """Generates LaTeX appropriate text from marginalised parameter bounds. - - Parameters - ---------- - lower : float - The lower bound on the parameter - maximum : float - The value of the parameter with maximum probability - upper : float - The upper bound on the parameter - wrap : bool - Wrap output text in dollar signs for LaTeX - - Returns - ------- - str - The formatted text given the parameter bounds + def get_parameter_text( + self, + bound: Bound | Sequence[Bound], + wrap: bool = False, + *, + label: str | None = None, + ) -> str: + """Format marginal parameter bounds for display. + + Args: + bound: + The bound (or list of bounds) to format. + wrap: + Wrap each formatted expression in LaTeX dollar signs. + label: + Optional parameter label to prepend. For multimodal results the + label is placed on its own line. + + Returns: + The formatted string. Returns an empty string when the input contains + no finite limits. """ + + if bound is None: + return "" + + # Fallback to single bound behavior if there is only one mode identified + if isinstance(bound, Sequence) and len(bound) < 2: + bound = bound[0] + + if isinstance(bound, Sequence) and not isinstance(bound, Bound): + bounds = [b for b in bound if isinstance(b, Bound) and not b.all_none] + if not bounds: + return "" + + lines: list[str] = [] + if label: + lines.append(f"${label}$" if wrap else label) + + for index, sub_bound in enumerate(bounds, start=1): + entry = Analysis._format_single_bound(sub_bound, use_pm=False) + if not entry: + continue + if wrap: + entry = f"${entry}$" + lines.append(f"I{index}: {entry}") + + return "\n".join(lines) + if bound.lower is None or bound.upper is None or bound.center is None: return "" + + text = self._format_single_bound(bound, use_pm=True) + + if label: + text = f"{label} = {text}" + + if wrap: + return f"${text}$" + return text + + @staticmethod + def _format_single_bound(bound: Bound, *, use_pm: bool) -> str: upper_error = bound.upper - bound.center lower_error = bound.center - bound.lower if upper_error != 0 and lower_error != 0: @@ -389,14 +485,12 @@ def get_parameter_text(self, bound: Bound, wrap: bool = False): fmt = "%0.0f" upper_error_text = fmt % upper_error lower_error_text = fmt % lower_error - if upper_error_text == lower_error_text: + if use_pm and upper_error_text == lower_error_text: text = r"{}\pm {}".format(fmt, "%s") % (maximum, lower_error_text) else: text = r"{}^{{+{}}}_{{-{}}}".format(fmt, "%s", "%s") % (maximum, upper_error_text, lower_error_text) if factor != 0: text = r"\left( %s \right) \times 10^{%d}" % (text, -factor) - if wrap: - text = f"${text}$" return text def get_parameter_summary_mean(self, chain: Chain, column: ColumnName) -> Bound | None: @@ -412,6 +506,144 @@ def get_parameter_summary_cumulative(self, chain: Chain, column: ColumnName) -> bounds = interp1d(cs, xs)(vals) return Bound(lower=bounds[0], center=bounds[1], upper=bounds[2]) + def get_parameter_summary_hdi(self, chain: Chain, column: ColumnName) -> Bound: + data = chain.get_data(column).to_numpy() + n_samples = data.size + + if n_samples <= 512: # Arbitrary low sample warning + warnings.warn( + ( + f"Only {n_samples} samples available to compute an HDI for column '{column}' " + f"in chain '{chain.name}'. Results may be unreliable; consider enabling KDE or " + "providing more samples." + ), + UserWarning, + stacklevel=2, + ) + + xs, _, cs = self._get_smoothed_histogram(chain, column, pad=True) + + cdf_points = np.concatenate(([0.0], cs)) + x_points = np.concatenate(([xs[0]], xs)) + + eps = 1e-12 + best_width = float("inf") + best_lower = float(x_points[0]) + best_upper = float(x_points[-1]) + best_start_mass = 0.0 + best_end_mass = 1.0 + + for start_idx, start_mass in enumerate(cdf_points[:-1]): + required = start_mass + chain.summary_area + if required > 1.0 + eps: + break + + # Smallest index with cdf_points[end_idx] >= required + end_idx = np.searchsorted(cdf_points, required, side="left") + + # Ensure at least one point is in the interval + if end_idx <= start_idx: + end_idx = start_idx + 1 + if end_idx >= cdf_points.size: + break + + # If still slightly under target, move one step right if possible + if cdf_points[end_idx] - start_mass < chain.summary_area - eps and end_idx + 1 < cdf_points.size: + end_idx += 1 + + lower = float(x_points[start_idx]) + upper = float(x_points[end_idx]) + width = upper - lower + if width <= eps: + continue + + if width < best_width - eps: + best_width = width + best_lower = lower + best_upper = upper + best_start_mass = float(start_mass) + best_end_mass = float(cdf_points[end_idx]) + + interval_mass = best_end_mass - best_start_mass + + if interval_mass <= eps: + center = 0.5 * (best_lower + best_upper) + + else: + center_mass = best_start_mass + 0.5 * interval_mass + center = float(np.interp(center_mass, cdf_points, x_points, left=best_lower, right=best_upper)) + + return Bound(lower=best_lower, center=center, upper=best_upper) + + def get_parameter_hdi_intervals(self, chain: Chain, column: ColumnName) -> list[tuple[float, float]]: + """Return highest-density intervals for a marginal distribution. + + Multimodal chains yield one interval per disjoint density band, whereas unimodal chains + return a single contiguous interval. + """ + summary = self.get_parameter_summary_hdi(chain, column) + default_interval = [(summary.lower, summary.upper)] + xs, ys, _ = self._get_smoothed_histogram(chain, column, pad=True) + + # We look for the threshold that is the root of this function + def mass_diff(threshold, density, xs, target): + mask = density >= threshold + mass_above_threshold = float(simps(np.where(mask, density, 0.0), x=xs)) + return mass_above_threshold - target + + area = simps(ys, x=xs) + density = ys / area + + sol = root_scalar( + mass_diff, + bracket=(0.0, float(np.max(density))), + args=(density, xs, chain.summary_area), + method="bisect", + xtol=5e-4, + ) + + threshold = sol.root + mask = density >= threshold + + intervals = _mask_to_intervals(xs, mask) + + return intervals if intervals else default_interval + + def get_parameter_multimodal_bounds( + self, + chain: Chain, + column: ColumnName, + intervals: list[tuple[float, float]], + ) -> list[Bound]: + """ + Convert multimodal HDI bands into `Bound` instances. + """ + + xs, ys, _ = self._get_smoothed_histogram( + chain, + column, + pad=True, + ) + + lower_limit, upper_limit = float(xs.min()), float(xs.max()) + + bounds = [] + + for lower_raw, upper_raw in intervals: + lower, upper = max(lower_raw, lower_limit), min(upper_raw, upper_limit) + mask = (xs >= lower) & (xs <= upper) + + if np.any(mask): + idx = int(np.argmax(ys[mask])) + center = float(xs[mask][idx]) + + else: + center = float(0.5 * (lower + upper)) + + bounds.append(Bound(lower=float(lower), center=center, upper=float(upper))) + + return bounds + def get_parameter_summary_max(self, chain: Chain, column: ColumnName) -> Bound | None: xs, ys, cs = self._get_smoothed_histogram(chain, column) n_pad = 1000 diff --git a/src/chainconsumer/chain.py b/src/chainconsumer/chain.py index b770aed5..6f560c49 100644 --- a/src/chainconsumer/chain.py +++ b/src/chainconsumer/chain.py @@ -48,7 +48,7 @@ class ChainConfig(BetterBase): if you have two chains, you probably want them to be different colors. """ - statistics: SummaryStatistic = Field(default=SummaryStatistic.MAX, description="The summary statistic to use") + statistics: SummaryStatistic | None = Field(default=None, description="The summary statistic to use") summary_area: float = Field(default=0.6827, ge=0, le=1.0, description="The area to use for summary statistics") sigmas: list[float] = Field(default=[0, 1, 2], description="The sigmas to use for summary statistics") color: ColorInput | None = Field(default=None, description="The color of the chain") # type: ignore @@ -59,6 +59,7 @@ class ChainConfig(BetterBase): shade_alpha: float = Field(default=0.5, description="The alpha of the shading") shade_gradient: float = Field(default=1.0, description="The contrast between contour levels") bar_shade: bool = Field(default=True, description="Whether to shade marginalised distributions") + multimodal: bool = Field(default=False, description="Mark the chain as multimodal to enable HDI band splitting.") bins: int | None = Field(default=None, description="The number of bins to use for histograms.") kde: int | float | bool = Field(default=False, description="The bandwidth for KDEs") smooth: int | None = Field( @@ -281,6 +282,18 @@ def _validate_model(self) -> Chain: if self.num_free_params is not None: assert np.isfinite(self.num_free_params), "num_free_params is not finite" + if self.statistics is None: + if self.multimodal: + self.statistics = SummaryStatistic.HDI + else: + self.statistics = SummaryStatistic.MAX + + if self.multimodal and self.statistics is not SummaryStatistic.HDI: + raise ValueError( + f"Chain {self.name} is marked as multimodal but uses {self.statistics.value}; " + "set statistics=SummaryStatistic.HDI." + ) + return self def get_data(self, column: str) -> pd.Series[float]: diff --git a/src/chainconsumer/plotter.py b/src/chainconsumer/plotter.py index b0f0a78a..f4718651 100644 --- a/src/chainconsumer/plotter.py +++ b/src/chainconsumer/plotter.py @@ -1,5 +1,8 @@ +from __future__ import annotations + from enum import Enum from pathlib import Path +from typing import TYPE_CHECKING import matplotlib import matplotlib.pyplot as plt @@ -24,6 +27,9 @@ from .plotting import add_watermark, plot_surface from .plotting.config import PlotConfig +if TYPE_CHECKING: + from .chainconsumer import ChainConsumer + class PlottingBase(BetterBase): chains: list[Chain] @@ -42,7 +48,7 @@ class FigSize(Enum): @classmethod def get_size( - cls, input: "FigSize | float | int | tuple[float, float]", num_columns: int, has_cax: bool + cls, input: FigSize | float | int | tuple[float, float], num_columns: int, has_cax: bool ) -> tuple[float, float]: if input == FigSize.PAGE: return 10, 10 @@ -106,7 +112,7 @@ def get_artists_from_chains(chains: list[Chain]) -> list[Artist]: class Plotter: - def __init__(self, parent: "ChainConsumer") -> None: + def __init__(self, parent: ChainConsumer) -> None: self.parent: ChainConsumer = parent self._config: PlotConfig | None = None self._default_config = PlotConfig() @@ -201,7 +207,13 @@ def plot( continue do_summary = summarise and p1 not in base.blind - max_hist_val = self._plot_bars(ax, p1, chain, flip=do_flip, summary=do_summary) + max_hist_val = self._plot_bars( + ax, + p1, + chain, + flip=do_flip, + summary=do_summary, + ) if max_val is None or max_hist_val > max_val: max_val = max_hist_val @@ -905,7 +917,12 @@ def _sanitise_chains( return [c for c in final_chains if include_skip or not c.skip] def _plot_bars( - self, ax: Axes, column: str, chain: Chain, flip: bool = False, summary: bool = False + self, + ax: Axes, + column: str, + chain: Chain, + flip: bool = False, + summary: bool = False, ) -> float: # pragma: no cover # Get values from config data = chain.get_data(column) @@ -942,41 +959,62 @@ def _plot_bars( interpolator = interp1d(xs, ys, kind=interp_type) if chain.bar_shade: - fit_values = self.parent.analysis.get_parameter_summary(chain, column) - if fit_values is not None: - lower = fit_values.lower - upper = fit_values.upper - if lower is not None and upper is not None: - lower = max(lower, xs.min()) - upper = min(upper, xs.max()) - x = np.linspace(lower, upper, 1000) # type: ignore + base_bound = self.parent.analysis.get_parameter_summary(chain, column) + + if base_bound is not None and base_bound.lower is not None and base_bound.upper is not None: + if chain.multimodal: + intervals = self.parent.analysis.get_parameter_hdi_intervals(chain, column) + display_bounds = self.parent.analysis.get_parameter_multimodal_bounds( + chain, + column, + intervals, + ) + + # If we get a single interval, fallback to unimodal HDI + if len(intervals) < 2: + intervals = [(base_bound.lower, base_bound.upper)] + + else: + display_bounds = base_bound + intervals = [(display_bounds.lower, display_bounds.upper)] + intervals = [(max(lower, xs.min()), min(upper, xs.max())) for lower, upper in intervals] + + for lower, upper_ in intervals: + x = np.linspace(lower, upper_, 1000) + if flip: ax.fill_betweenx( x, - np.zeros(x.shape), + np.zeros_like(x), interpolator(x), color=chain.color, alpha=0.2, zorder=chain.zorder, ) + else: ax.fill_between( x, - np.zeros(x.shape), + np.zeros_like(x), interpolator(x), color=chain.color, alpha=0.2, zorder=chain.zorder, ) - if summary: - t = self.parent.analysis.get_parameter_text(fit_values) - label = self.config.get_label(column) - if isinstance(column, str): - ax.set_title( - r"${} = {}$".format(label.strip("$"), t), fontsize=self.config.summary_font_size - ) - else: - ax.set_title(rf"${t}$", fontsize=self.config.summary_font_size) + + if summary: + label = self.config.get_label(column) + label_text = label.strip("$") if isinstance(column, str) else None + + title = self.parent.analysis.get_parameter_text( + display_bounds, + wrap=True, + label=label_text, + ) + + if title: + ax.set_title(title, fontsize=self.config.summary_font_size) + return float(ys.max()) def _plot_walk( diff --git a/src/chainconsumer/statistics.py b/src/chainconsumer/statistics.py index 29d4afc2..b278a697 100644 --- a/src/chainconsumer/statistics.py +++ b/src/chainconsumer/statistics.py @@ -18,3 +18,6 @@ class SummaryStatistic(Enum): MEAN = "mean" """As per the cumulative method, except the central value is placed in the midpoint between the upper and lower boundary. Not recommended, but was requested.""" + + HDI = "hdi" + """Use the highest density interval. Finds the narrowest interval covering the requested mass.""" diff --git a/tests/test_analysis.py b/tests/test_analysis.py index 86051e4b..4914a066 100644 --- a/tests/test_analysis.py +++ b/tests/test_analysis.py @@ -1,8 +1,11 @@ +import arviz as az import numpy as np import pandas as pd +import pytest from scipy.stats import skewnorm from chainconsumer import Bound, Chain, ChainConfig, ChainConsumer +from chainconsumer.statistics import SummaryStatistic class TestChain: @@ -12,6 +15,9 @@ class TestChain: data2 = rng.normal(loc=3, scale=1.0, size=n) data_combined = np.vstack((data, data2)).T data_skew = skewnorm.rvs(5, loc=1, scale=1.5, size=n) + data_bimodal = np.concatenate( + [rng.normal(loc=-1.5, scale=0.3, size=n // 2), rng.normal(loc=1.5, scale=0.3, size=n // 2)] + ) chain = Chain(samples=pd.DataFrame(data, columns=["x"]), name="a") chain2 = Chain(samples=pd.DataFrame(data2, columns=["x"]), name="b") @@ -139,6 +145,107 @@ def test_output_format6(self): text = consumer.analysis.get_parameter_text(Bound.from_array(p1), wrap=True) assert text == r"$0.020^{+0.015}_{-0.010}$" + def test_output_multimodal_text(self): + intervals = [ + Bound(lower=-2.0, center=-1.5, upper=-1.0), + Bound(lower=1.0, center=1.4, upper=2.0), + ] + consumer = ChainConsumer() + text = consumer.analysis.get_parameter_text(intervals, wrap=True, label="x") + + assert text.startswith("$x$") + assert text.count("I1:") == 1 and text.count("I2:") == 1 + assert "\\pm" not in text + assert "\n" in text + + def test_summary_multimodal_returns_intervals(self): + rng = np.random.default_rng(42) + samples = np.concatenate( + [ + rng.normal(loc=-1.5, scale=0.2, size=5000), + rng.normal(loc=1.6, scale=0.25, size=5000), + ] + ) + df = pd.DataFrame({"x": samples}) + chain = Chain( + samples=df, + name="bimodal", + statistics=SummaryStatistic.HDI, + kde=True, + multimodal=True, + summary_area=0.5, + ) + consumer = ChainConsumer() + consumer.add_chain(chain) + + summary = consumer.analysis.get_summary() + result = summary["bimodal"]["x"] + + assert isinstance(result, list) + assert len(result) > 1 + assert all(isinstance(bound, Bound) for bound in result) + assert result[0].upper < 0 + assert result[-1].lower > 0 + + def test_hdi_weighted_interval(self): + samples = np.array([0.0, 1.0, 2.0, 2.5, 3.0]) + weights = np.array([0.05, 0.10, 0.50, 0.20, 0.15]) + df = pd.DataFrame({"x": samples, "weight": weights}) + chain = Chain(samples=df, name="weighted", statistics=SummaryStatistic.HDI, summary_area=0.5) + consumer = ChainConsumer() + consumer.add_chain(chain) + + bound = consumer.analysis.get_summary()["weighted"]["x"] + assert bound.lower < bound.center < bound.upper + assert np.isclose(bound.center, 2.0, atol=5e-3) + assert bound.upper - bound.lower < 0.3 + + def test_hdi_warns_when_samples_low_without_kde(self): + df = pd.DataFrame({"x": np.linspace(-1.0, 1.0, 10)}) + chain = Chain(samples=df, name="warn", statistics=SummaryStatistic.HDI, summary_area=0.5) + consumer = ChainConsumer() + consumer.add_chain(chain) + + with pytest.warns(UserWarning, match="Only 10 samples available"): + consumer.analysis.get_parameter_summary_hdi(chain, "x") + + def test_hdi_intervals_single(self): + chain = Chain( + samples=pd.DataFrame(self.data_skew[::20], columns=["x"]), + name="skew", + statistics=SummaryStatistic.HDI, + kde=True, + ) + consumer = ChainConsumer() + consumer.add_chain(chain) + + intervals = consumer.analysis.get_parameter_hdi_intervals(chain, "x") + assert len(intervals) == 1 + lower, upper = intervals[0] + assert lower < upper + + def test_hdi_intervals_multimodal(self): + df = pd.DataFrame(self.data_bimodal[::20], columns=["x"]) + chain_default = Chain(samples=df, name="bimodal", statistics=SummaryStatistic.HDI, kde=True) + consumer = ChainConsumer() + consumer.add_chain(chain_default) + + multimodal_chain = Chain( + samples=df, + name="bimodal_multi", + statistics=SummaryStatistic.HDI, + kde=True, + multimodal=True, + summary_area=chain_default.summary_area, + ) + consumer2 = ChainConsumer() + consumer2.add_chain(multimodal_chain) + + multi_intervals = consumer2.analysis.get_parameter_hdi_intervals(multimodal_chain, "x") + assert len(multi_intervals) == 2 + assert multi_intervals[0][1] < 0.0 + assert multi_intervals[1][0] > 0.0 + def test_output_format7(self): p1 = [None, 2.0e-2, 3.5e-2] consumer = ChainConsumer() @@ -250,6 +357,54 @@ def test_divide_chains_name(self): assert np.all(np.abs(array.center - means[i]) < 1e-1) assert np.abs(consumer.get_chain(name).get_data("x").mean() - means[i]) < 1e-2 + +class TestHDIParityWithArviz: + @staticmethod + def _make_chain( + samples: np.ndarray, + *, + area: float, + multimodal: bool = False, + smooth: int = 0, + kde: int | float | bool = False, + ) -> tuple[ChainConsumer, Chain]: + chain = Chain( + samples=pd.DataFrame({"x": samples}), + name="arviz", + statistics=SummaryStatistic.HDI, + summary_area=area, + smooth=smooth, + kde=kde, + multimodal=multimodal, + ) + consumer = ChainConsumer() + consumer.add_chain(chain) + return consumer, chain + + def test_hdi_matches_arviz_unimodal(self) -> None: + rng = np.random.default_rng(9876) + samples = rng.normal(loc=-0.25, scale=1.3, size=5000) + area = 0.5 + + consumer, chain = self._make_chain(samples, area=area) + bound = consumer.analysis.get_summary()[chain.name]["x"] + assert isinstance(bound, Bound) + + az_hdi = az.stats.hdi(samples, hdi_prob=area) + np.testing.assert_allclose([bound.lower, bound.upper], az_hdi, atol=5e-2) + + def test_hdi_matches_arviz_high_sample_multimodal(self) -> None: + rng = np.random.default_rng(0) + samples = np.concatenate([rng.normal(-2.0, 0.3, size=5000), rng.normal(2.0, 0.25, size=5000)]) + area = 0.5 + + consumer, chain = self._make_chain(samples, area=area, multimodal=True) + cc_intervals = np.array(consumer.analysis.get_parameter_hdi_intervals(chain, "x")) + az_intervals = np.asarray(az.stats.hdi(samples, hdi_prob=area, multimodal=True)) + + assert cc_intervals.shape == az_intervals.shape + np.testing.assert_allclose(cc_intervals, az_intervals, atol=8e-2) + # def test_stats_max_cliff(self): # tolerance = 5e-2 # n = 100000 diff --git a/tests/test_chain.py b/tests/test_chain.py index 6bf81be5..5d2f2187 100644 --- a/tests/test_chain.py +++ b/tests/test_chain.py @@ -4,6 +4,7 @@ from pydantic import ValidationError from chainconsumer.chain import Chain +from chainconsumer.statistics import SummaryStatistic class TestChain: @@ -105,3 +106,13 @@ def test_divide(self): for chain in result: assert chain.walkers == 1 + + def test_default_stat(self): + chain = Chain(samples=self.df, name=self.n) + assert chain.statistics is SummaryStatistic.MAX + + chain = Chain(samples=self.df, name=self.n, multimodal=True) + assert chain.statistics is SummaryStatistic.HDI + + with pytest.raises(ValueError): + Chain(samples=self.df, name=self.n, multimodal=True, statistics=SummaryStatistic.MAX) diff --git a/tests/test_plotter.py b/tests/test_plotter.py index 787b8144..3af88f05 100644 --- a/tests/test_plotter.py +++ b/tests/test_plotter.py @@ -1,8 +1,14 @@ +import matplotlib + +matplotlib.use("Agg") + +import matplotlib.pyplot as plt import numpy as np import pandas as pd from scipy.stats import norm from chainconsumer import Chain, ChainConsumer +from chainconsumer.statistics import SummaryStatistic class TestChain: @@ -69,3 +75,60 @@ def test_plotter_extents6(self): minv, maxv = c.plotter._get_parameter_extents("x", list(c._chains.values())) assert np.isclose(minv, -1, atol=0.01) assert np.isclose(maxv, 1, atol=0.01) + + def test_plotter_multimodal_fill(self): + samples = np.concatenate( + [ + self.rng.normal(loc=-1.5, scale=0.2, size=5000), + self.rng.normal(loc=1.7, scale=0.25, size=5000), + ] + ) + df = pd.DataFrame({"x": samples}) + chain = Chain(samples=df, name="bimodal", statistics=SummaryStatistic.HDI, kde=True, multimodal=True) + consumer = ChainConsumer() + consumer.add_chain(chain) + + fig, ax = plt.subplots() + consumer.plotter._plot_bars(ax, "x", chain) + + intervals_drawn: list[tuple[float, float]] = [] + for collection in ax.collections: + if not collection.get_paths(): + continue + path = collection.get_paths()[0] + xs = path.vertices[:, 0] + xmin, xmax = xs.min(), xs.max() + if xmax - xmin <= 0: + continue + intervals_drawn.append((xmin, xmax)) + + plt.close(fig) + + intervals_drawn.sort() + assert len(intervals_drawn) == 2 + first, second = intervals_drawn + assert first[1] < 0.0 + assert second[0] > 0.0 + + def test_plotter_multimodal_title(self): + samples = np.concatenate( + [ + self.rng.normal(loc=-1.0, scale=0.15, size=4000), + self.rng.normal(loc=1.0, scale=0.15, size=4000), + ] + ) + df = pd.DataFrame({"x": samples}) + chain = Chain(samples=df, name="bimodal", statistics=SummaryStatistic.HDI, kde=True, multimodal=True) + consumer = ChainConsumer() + consumer.add_chain(chain) + + fig, ax = plt.subplots() + consumer.plotter._plot_bars(ax, "x", chain, summary=True) + title = ax.get_title() + plt.close(fig) + + assert "I1:" in title + assert "I2:" in title + assert "\n" in title + assert "\\pm" not in title + assert "^{+" in title and "}_{-" in title diff --git a/tests/test_translators.py b/tests/test_translators.py index 183d5873..b6ec627c 100644 --- a/tests/test_translators.py +++ b/tests/test_translators.py @@ -29,7 +29,9 @@ def model(data): # Running MCMC kernel = NUTS(model) - mcmc = MCMC(kernel, num_warmup=500, num_samples=n_steps, num_chains=n_chains, progress_bar=False) + mcmc = MCMC( + kernel, num_warmup=500, num_samples=n_steps, num_chains=n_chains, progress_bar=False, chain_method="sequential" + ) rng_key = random.PRNGKey(0) mcmc.run(rng_key, data=observed_data)