From 152af4c7c9c7c0ca0940755916aeee48b0a4b83f Mon Sep 17 00:00:00 2001 From: = Date: Mon, 23 Jun 2025 15:05:06 -0400 Subject: [PATCH 1/4] draft of MCLMC --- mclmc.py | 207 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 207 insertions(+) create mode 100644 mclmc.py diff --git a/mclmc.py b/mclmc.py new file mode 100644 index 000000000..80bd54b81 --- /dev/null +++ b/mclmc.py @@ -0,0 +1,207 @@ + + +import argparse +from collections import namedtuple +import os + +import matplotlib.pyplot as plt + +import jax +import jax.numpy as jnp +from jax import random + + +import numpyro +import numpyro.distributions as dist +from numpyro.infer import MCMC +from numpyro.infer.mcmc import MCMCKernel +import blackjax +from numpyro.infer.util import initialize_model +from blackjax.util import pytree_size +from blackjax.mcmc.integrators import ( + IntegratorState) + + +FullState = namedtuple("FullState", ["position", "momentum", "logdensity", "logdensity_grad", "rng_key"]) + +class MCLMC(MCMCKernel): + """ + Microcanonical Langevin Monte Carlo (MCLMC) kernel. + + :param model: Python callable containing Pyro primitives. + :param step_size: Initial step size for the Langevin dynamics. + :param num_steps: Number of steps to take in each MCMC iteration. + :param integrator_type: Type of integrator to use (e.g. "mclachlan"). + :param diagonal_preconditioning: Whether to use diagonal preconditioning. + :param num_tuning_steps: Number of tuning steps to use. + :param desired_energy_var: Desired energy variance for tuning. + """ + + + + def __init__( + self, + model=None, + desired_energy_var=5e-4, + diagonal_preconditioning=True, + ): + if model is None: + raise ValueError("Model must be specified for MCLMC") + self._model = model + self._diagonal_preconditioning = diagonal_preconditioning + self._desired_energy_var = desired_energy_var + self._init_fn = None + self._sample_fn = None + self._postprocess_fn = None + + @property + def model(self): + return self._model + + @property + def sample_field(self): + return "position" + + def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): + """ + Initialize the MCLMC kernel. + + :param rng_key: Random number generator key + :param num_warmup: Number of warmup steps + :param init_params: Initial parameters + :param model_args: Model arguments + :param model_kwargs: Model keyword arguments + :return: Initial state + """ + + init_model_key, init_state_key, run_key, rng_key_tune = jax.random.split(rng_key, 4) + + init_params, potential_fn_gen, _, _ = initialize_model( + init_model_key, + self._model, + model_args=(), + dynamic_args=True, + ) + + logdensity_fn = lambda position: -potential_fn_gen()(position) + initial_position = init_params.z + self.logdensity_fn = logdensity_fn + + sampler_state = blackjax.mcmc.mclmc.init( + position=initial_position, + logdensity_fn=self.logdensity_fn, + rng_key=init_state_key, + ) + + kernel = lambda inverse_mass_matrix: blackjax.mcmc.mclmc.build_kernel( + logdensity_fn=logdensity_fn, + integrator=blackjax.mcmc.integrators.isokinetic_mclachlan, + inverse_mass_matrix=inverse_mass_matrix, + ) + + self.dim = pytree_size(initial_position) + + # num_steps is a dummy param here + ( + blackjax_state_after_tuning, + blackjax_mclmc_sampler_params, + num_tuning_integrator_steps, + ) = blackjax.mclmc_find_L_and_step_size( + mclmc_kernel=kernel, + num_steps=100, + state=sampler_state, + rng_key=rng_key_tune, + diagonal_preconditioning=True, + frac_tune3=num_warmup / (3 * 100), + frac_tune2=num_warmup / (3 * 100), + frac_tune1=num_warmup / (3 * 100), + desired_energy_var=5e-4 + ) + + self.adapt_state = blackjax_mclmc_sampler_params + + return FullState(blackjax_state_after_tuning.position, blackjax_state_after_tuning.momentum, blackjax_state_after_tuning.logdensity, blackjax_state_after_tuning.logdensity_grad, run_key) + + + def sample(self, state, model_args, model_kwargs): + """ + Run MCLMC from the given state and return the resulting state. + + :param state: Current state + :param model_args: Model arguments + :param model_kwargs: Model keyword arguments + :return: Next state after running MCLMC + """ + + mclmc_state = IntegratorState(state.position, state.momentum, state.logdensity, state.logdensity_grad) + rng_key, rng_key_sample = jax.random.split(state.rng_key, 2) + + kernel = blackjax.mcmc.mclmc.build_kernel( + logdensity_fn=self.logdensity_fn, + integrator=blackjax.mcmc.integrators.isokinetic_mclachlan, + inverse_mass_matrix=self.adapt_state.inverse_mass_matrix, + ) + + new_state, info = kernel( + rng_key=rng_key_sample, + state=mclmc_state, + step_size=self.adapt_state.step_size, + L=self.adapt_state.L + ) + + return FullState(new_state.position, new_state.momentum, new_state.logdensity, new_state.logdensity_grad, rng_key) + +if __name__ == "__main__": + + def gaussian_2d_model(): + """ + A simple 2D Gaussian model with mean [0, 0] and covariance [[1, 0.5], [0.5, 1]]. + """ + x = numpyro.sample("x", dist.Normal(0.0, 1.0)) + y = numpyro.sample("y", dist.Normal(0.0, 1.0)) + numpyro.sample("obs", dist.Normal(x + y, 0.5), obs=jnp.array([0.0])) + return x + y + + + def run_inference(model, args, rng_key): + """ + Run MCMC inference on the given model. + + :param model: The model to run inference on + :param args: Command line arguments + :param rng_key: Random number generator key + :return: MCMC object + """ + kernel = MCLMC( + model=model, + diagonal_preconditioning=True, + desired_energy_var=5e-4, + ) + + mcmc = MCMC( + kernel, + num_warmup=1000, + num_samples=1000, + num_chains=1, + progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True, + ) + + mcmc.run(rng_key) + mcmc.print_summary(exclude_deterministic=False) + + samples = mcmc.get_samples() + plt.figure(figsize=(8, 8)) + plt.scatter(samples['x'], samples['y'], alpha=0.5) + plt.xlabel('x') + plt.ylabel('y') + plt.title('MCLMC samples from 2D Gaussian') + plt.grid(True) + plt.savefig('mclmc_samples.png') + plt.close() + + return mcmc + + + rng_key = random.PRNGKey(0) + mcmc = run_inference(gaussian_2d_model, args=None, rng_key=rng_key) + From a1e7e0ba510c52791cab450355d2fdc3540f0d3d Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Tue, 20 Jan 2026 23:27:14 +0100 Subject: [PATCH 2/4] feat: Add Microcanonical Langevin Monte Carlo (MCLMC) kernel Add MCLMC inference algorithm as a new MCMCKernel that wraps blackjax's MCLMC implementation. This provides an alternative gradient-based MCMC method to NUTS/HMC. Features: - MCLMC kernel with automatic step size and trajectory length tuning - Optional blackjax dependency with informative error message - postprocess_fn for constrained/unconstrained transformations - Diagnostics string for progress bar - Comprehensive test suite References: - Microcanonical Hamiltonian Monte Carlo (arXiv:2212.08549) --- mclmc.py | 207 ------------------------------------ numpyro/infer/mclmc.py | 219 +++++++++++++++++++++++++++++++++++++++ test/infer/test_mclmc.py | 155 +++++++++++++++++++++++++++ 3 files changed, 374 insertions(+), 207 deletions(-) delete mode 100644 mclmc.py create mode 100644 numpyro/infer/mclmc.py create mode 100644 test/infer/test_mclmc.py diff --git a/mclmc.py b/mclmc.py deleted file mode 100644 index 80bd54b81..000000000 --- a/mclmc.py +++ /dev/null @@ -1,207 +0,0 @@ - - -import argparse -from collections import namedtuple -import os - -import matplotlib.pyplot as plt - -import jax -import jax.numpy as jnp -from jax import random - - -import numpyro -import numpyro.distributions as dist -from numpyro.infer import MCMC -from numpyro.infer.mcmc import MCMCKernel -import blackjax -from numpyro.infer.util import initialize_model -from blackjax.util import pytree_size -from blackjax.mcmc.integrators import ( - IntegratorState) - - -FullState = namedtuple("FullState", ["position", "momentum", "logdensity", "logdensity_grad", "rng_key"]) - -class MCLMC(MCMCKernel): - """ - Microcanonical Langevin Monte Carlo (MCLMC) kernel. - - :param model: Python callable containing Pyro primitives. - :param step_size: Initial step size for the Langevin dynamics. - :param num_steps: Number of steps to take in each MCMC iteration. - :param integrator_type: Type of integrator to use (e.g. "mclachlan"). - :param diagonal_preconditioning: Whether to use diagonal preconditioning. - :param num_tuning_steps: Number of tuning steps to use. - :param desired_energy_var: Desired energy variance for tuning. - """ - - - - def __init__( - self, - model=None, - desired_energy_var=5e-4, - diagonal_preconditioning=True, - ): - if model is None: - raise ValueError("Model must be specified for MCLMC") - self._model = model - self._diagonal_preconditioning = diagonal_preconditioning - self._desired_energy_var = desired_energy_var - self._init_fn = None - self._sample_fn = None - self._postprocess_fn = None - - @property - def model(self): - return self._model - - @property - def sample_field(self): - return "position" - - def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): - """ - Initialize the MCLMC kernel. - - :param rng_key: Random number generator key - :param num_warmup: Number of warmup steps - :param init_params: Initial parameters - :param model_args: Model arguments - :param model_kwargs: Model keyword arguments - :return: Initial state - """ - - init_model_key, init_state_key, run_key, rng_key_tune = jax.random.split(rng_key, 4) - - init_params, potential_fn_gen, _, _ = initialize_model( - init_model_key, - self._model, - model_args=(), - dynamic_args=True, - ) - - logdensity_fn = lambda position: -potential_fn_gen()(position) - initial_position = init_params.z - self.logdensity_fn = logdensity_fn - - sampler_state = blackjax.mcmc.mclmc.init( - position=initial_position, - logdensity_fn=self.logdensity_fn, - rng_key=init_state_key, - ) - - kernel = lambda inverse_mass_matrix: blackjax.mcmc.mclmc.build_kernel( - logdensity_fn=logdensity_fn, - integrator=blackjax.mcmc.integrators.isokinetic_mclachlan, - inverse_mass_matrix=inverse_mass_matrix, - ) - - self.dim = pytree_size(initial_position) - - # num_steps is a dummy param here - ( - blackjax_state_after_tuning, - blackjax_mclmc_sampler_params, - num_tuning_integrator_steps, - ) = blackjax.mclmc_find_L_and_step_size( - mclmc_kernel=kernel, - num_steps=100, - state=sampler_state, - rng_key=rng_key_tune, - diagonal_preconditioning=True, - frac_tune3=num_warmup / (3 * 100), - frac_tune2=num_warmup / (3 * 100), - frac_tune1=num_warmup / (3 * 100), - desired_energy_var=5e-4 - ) - - self.adapt_state = blackjax_mclmc_sampler_params - - return FullState(blackjax_state_after_tuning.position, blackjax_state_after_tuning.momentum, blackjax_state_after_tuning.logdensity, blackjax_state_after_tuning.logdensity_grad, run_key) - - - def sample(self, state, model_args, model_kwargs): - """ - Run MCLMC from the given state and return the resulting state. - - :param state: Current state - :param model_args: Model arguments - :param model_kwargs: Model keyword arguments - :return: Next state after running MCLMC - """ - - mclmc_state = IntegratorState(state.position, state.momentum, state.logdensity, state.logdensity_grad) - rng_key, rng_key_sample = jax.random.split(state.rng_key, 2) - - kernel = blackjax.mcmc.mclmc.build_kernel( - logdensity_fn=self.logdensity_fn, - integrator=blackjax.mcmc.integrators.isokinetic_mclachlan, - inverse_mass_matrix=self.adapt_state.inverse_mass_matrix, - ) - - new_state, info = kernel( - rng_key=rng_key_sample, - state=mclmc_state, - step_size=self.adapt_state.step_size, - L=self.adapt_state.L - ) - - return FullState(new_state.position, new_state.momentum, new_state.logdensity, new_state.logdensity_grad, rng_key) - -if __name__ == "__main__": - - def gaussian_2d_model(): - """ - A simple 2D Gaussian model with mean [0, 0] and covariance [[1, 0.5], [0.5, 1]]. - """ - x = numpyro.sample("x", dist.Normal(0.0, 1.0)) - y = numpyro.sample("y", dist.Normal(0.0, 1.0)) - numpyro.sample("obs", dist.Normal(x + y, 0.5), obs=jnp.array([0.0])) - return x + y - - - def run_inference(model, args, rng_key): - """ - Run MCMC inference on the given model. - - :param model: The model to run inference on - :param args: Command line arguments - :param rng_key: Random number generator key - :return: MCMC object - """ - kernel = MCLMC( - model=model, - diagonal_preconditioning=True, - desired_energy_var=5e-4, - ) - - mcmc = MCMC( - kernel, - num_warmup=1000, - num_samples=1000, - num_chains=1, - progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True, - ) - - mcmc.run(rng_key) - mcmc.print_summary(exclude_deterministic=False) - - samples = mcmc.get_samples() - plt.figure(figsize=(8, 8)) - plt.scatter(samples['x'], samples['y'], alpha=0.5) - plt.xlabel('x') - plt.ylabel('y') - plt.title('MCLMC samples from 2D Gaussian') - plt.grid(True) - plt.savefig('mclmc_samples.png') - plt.close() - - return mcmc - - - rng_key = random.PRNGKey(0) - mcmc = run_inference(gaussian_2d_model, args=None, rng_key=rng_key) - diff --git a/numpyro/infer/mclmc.py b/numpyro/infer/mclmc.py new file mode 100644 index 000000000..36c1371c1 --- /dev/null +++ b/numpyro/infer/mclmc.py @@ -0,0 +1,219 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from collections import namedtuple + +import jax + +from numpyro.infer.mcmc import MCMCKernel +from numpyro.infer.util import initialize_model +from numpyro.util import identity + +try: + import blackjax + from blackjax.mcmc.integrators import IntegratorState + from blackjax.util import pytree_size + + _BLACKJAX_AVAILABLE = True +except ImportError: + _BLACKJAX_AVAILABLE = False + blackjax = None + IntegratorState = None + pytree_size = None + +FullState = namedtuple( + "FullState", ["position", "momentum", "logdensity", "logdensity_grad", "rng_key"] +) + + +class MCLMC(MCMCKernel): + """ + Microcanonical Langevin Monte Carlo (MCLMC) kernel. + + MCLMC is a gradient-based MCMC algorithm that uses Hamiltonian dynamics + on an extended state space. It requires the `blackjax` package. + + **References:** + + 1. *Microcanonical Hamiltonian Monte Carlo*, + Jakob Robnik, G. Bruno De Luca, Eva Silverstein, Uroš Seljak + https://arxiv.org/abs/2212.08549 + + .. note:: The model must have at least 2 latent dimensions for MCLMC to work + (this is a limitation of the blackjax implementation). + + :param model: Python callable containing Pyro :mod:`~numpyro.primitives`. + :param float desired_energy_var: Target energy variance for step size and + trajectory length tuning. Smaller values lead to more conservative + step sizes. Defaults to 5e-4. + :param bool diagonal_preconditioning: Whether to use diagonal preconditioning + for the mass matrix. Defaults to True. + """ + + def __init__( + self, + model=None, + desired_energy_var=5e-4, + diagonal_preconditioning=True, + ): + if not _BLACKJAX_AVAILABLE: + raise ImportError( + "MCLMC requires the 'blackjax' package. " + "Please install it with: pip install blackjax" + ) + if model is None: + raise ValueError("Model must be specified for MCLMC") + self._model = model + self._diagonal_preconditioning = diagonal_preconditioning + self._desired_energy_var = desired_energy_var + self._init_fn = None + self._sample_fn = None + self._postprocess_fn = None + + @property + def model(self): + return self._model + + @property + def sample_field(self): + return "position" + + @property + def default_fields(self): + return (self.sample_field,) + + def get_diagnostics_str(self, state): + """ + Return a diagnostics string for the progress bar. + """ + return "step_size={:.2e}, L={:.2e}".format( + self.adapt_state.step_size, self.adapt_state.L + ) + + def postprocess_fn(self, args, kwargs): + """ + Get a function that transforms unconstrained values at sample sites to values + constrained to the site's support, in addition to returning deterministic + sites in the model. + + :param args: Arguments to the model. + :param kwargs: Keyword arguments to the model. + """ + if self._postprocess_fn is None: + return identity + return self._postprocess_fn(*args, **kwargs) + + def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): + """ + Initialize the MCLMC kernel. + + :param rng_key: Random number generator key + :param num_warmup: Number of warmup steps + :param init_params: Initial parameters + :param model_args: Model arguments + :param model_kwargs: Model keyword arguments + :return: Initial state + """ + + init_model_key, init_state_key, run_key, rng_key_tune = jax.random.split( + rng_key, 4 + ) + + init_params, potential_fn_gen, postprocess_fn, _ = initialize_model( + init_model_key, + self._model, + model_args=model_args, + model_kwargs=model_kwargs, + dynamic_args=True, + ) + self._postprocess_fn = postprocess_fn + + def logdensity_fn(position): + return -potential_fn_gen(*model_args, **model_kwargs)(position) + + initial_position = init_params.z + self.logdensity_fn = logdensity_fn + + sampler_state = blackjax.mcmc.mclmc.init( + position=initial_position, + logdensity_fn=self.logdensity_fn, + rng_key=init_state_key, + ) + + def kernel(inverse_mass_matrix): + return blackjax.mcmc.mclmc.build_kernel( + logdensity_fn=logdensity_fn, + integrator=blackjax.mcmc.integrators.isokinetic_mclachlan, + inverse_mass_matrix=inverse_mass_matrix, + ) + + self.dim = pytree_size(initial_position) + + # num_steps is a dummy param here (used for tuning fractions) + num_tuning_steps = 100 + ( + blackjax_state_after_tuning, + blackjax_mclmc_sampler_params, + _, + ) = blackjax.mclmc_find_L_and_step_size( + mclmc_kernel=kernel, + num_steps=num_tuning_steps, + state=sampler_state, + rng_key=rng_key_tune, + diagonal_preconditioning=self._diagonal_preconditioning, + frac_tune3=num_warmup / (3 * num_tuning_steps), + frac_tune2=num_warmup / (3 * num_tuning_steps), + frac_tune1=num_warmup / (3 * num_tuning_steps), + desired_energy_var=self._desired_energy_var, + ) + + self.adapt_state = blackjax_mclmc_sampler_params + + return FullState( + blackjax_state_after_tuning.position, + blackjax_state_after_tuning.momentum, + blackjax_state_after_tuning.logdensity, + blackjax_state_after_tuning.logdensity_grad, + run_key, + ) + + def sample(self, state, model_args, model_kwargs): + """ + Run MCLMC from the given state and return the resulting state. + + :param state: Current state + :param model_args: Model arguments + :param model_kwargs: Model keyword arguments + :return: Next state after running MCLMC + """ + + mclmc_state = IntegratorState( + state.position, state.momentum, state.logdensity, state.logdensity_grad + ) + rng_key, rng_key_sample = jax.random.split(state.rng_key, 2) + + kernel = blackjax.mcmc.mclmc.build_kernel( + logdensity_fn=self.logdensity_fn, + integrator=blackjax.mcmc.integrators.isokinetic_mclachlan, + inverse_mass_matrix=self.adapt_state.inverse_mass_matrix, + ) + + new_state, info = kernel( + rng_key=rng_key_sample, + state=mclmc_state, + step_size=self.adapt_state.step_size, + L=self.adapt_state.L, + ) + + return FullState( + new_state.position, + new_state.momentum, + new_state.logdensity, + new_state.logdensity_grad, + rng_key, + ) + + def __getstate__(self): + state = self.__dict__.copy() + state["_postprocess_fn"] = None + return state diff --git a/test/infer/test_mclmc.py b/test/infer/test_mclmc.py new file mode 100644 index 000000000..b3fd1af9d --- /dev/null +++ b/test/infer/test_mclmc.py @@ -0,0 +1,155 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from numpy.testing import assert_allclose +import pytest + +from jax import random +import jax.numpy as jnp + +import numpyro +import numpyro.distributions as dist +from numpyro.infer import MCMC +from numpyro.infer.mclmc import MCLMC + + +def test_mclmc_model_required(): + """Test that ValueError is raised when model is None.""" + with pytest.raises(ValueError, match="Model must be specified"): + MCLMC(model=None) + + +def test_mclmc_blackjax_not_installed(monkeypatch): + """Test that ImportError is raised with informative message when blackjax is not installed.""" + import numpyro.infer.mclmc as mclmc_module + + # Temporarily set _BLACKJAX_AVAILABLE to False + monkeypatch.setattr(mclmc_module, "_BLACKJAX_AVAILABLE", False) + + def dummy_model(): + numpyro.sample("x", dist.Normal(0, 1)) + + with pytest.raises(ImportError, match="MCLMC requires the 'blackjax' package"): + MCLMC(model=dummy_model) + + +def test_mclmc_normal(): + """Test MCLMC with a 2D normal distribution. + + Note: MCLMC requires at least 2 dimensions (blackjax limitation). + """ + true_mean = jnp.array([1.0, 2.0]) + true_std = jnp.array([0.5, 1.0]) + num_warmup, num_samples = 1000, 2000 + + def model(): + numpyro.sample("x", dist.Normal(true_mean, true_std).to_event(1)) + + kernel = MCLMC(model=model) + mcmc = MCMC( + kernel, + num_warmup=num_warmup, + num_samples=num_samples, + num_chains=1, + progress_bar=False, + ) + mcmc.run(random.PRNGKey(0)) + samples = mcmc.get_samples() + + assert "x" in samples + assert samples["x"].shape == (num_samples, 2) + assert_allclose(jnp.mean(samples["x"], axis=0), true_mean, atol=0.1) + assert_allclose(jnp.std(samples["x"], axis=0), true_std, atol=0.2) + + +def test_mclmc_gaussian_2d(): + """Test MCLMC with a 2D Gaussian model with observation.""" + num_warmup, num_samples = 1000, 1000 + + def model(): + x = numpyro.sample("x", dist.Normal(0.0, 1.0)) + y = numpyro.sample("y", dist.Normal(0.0, 1.0)) + numpyro.sample("obs", dist.Normal(x + y, 0.5), obs=jnp.array(0.0)) + + kernel = MCLMC( + model=model, + diagonal_preconditioning=True, + desired_energy_var=5e-4, + ) + mcmc = MCMC( + kernel, + num_warmup=num_warmup, + num_samples=num_samples, + num_chains=1, + progress_bar=False, + ) + mcmc.run(random.PRNGKey(0)) + samples = mcmc.get_samples() + + assert "x" in samples + assert "y" in samples + assert samples["x"].shape == (num_samples,) + assert samples["y"].shape == (num_samples,) + # With obs=0, x+y should be close to 0, so means should be near 0 + assert_allclose(jnp.mean(samples["x"]) + jnp.mean(samples["y"]), 0.0, atol=0.2) + + +def test_mclmc_logistic_regression(): + """Test MCLMC with a logistic regression model. + + Note: MCLMC currently doesn't pass model_args, so we use a closure pattern. + """ + N, dim = 1000, 3 + num_warmup, num_samples = 1000, 2000 + + key1, key2, key3 = random.split(random.PRNGKey(0), 3) + data = random.normal(key1, (N, dim)) + true_coefs = jnp.arange(1.0, dim + 1.0) + logits = jnp.sum(true_coefs * data, axis=-1) + labels = dist.Bernoulli(logits=logits).sample(key2) + + # Use closure pattern since MCLMC doesn't pass model_args + def model(): + coefs = numpyro.sample("coefs", dist.Normal(jnp.zeros(dim), jnp.ones(dim))) + logits = jnp.sum(coefs * data, axis=-1) + numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels) + + kernel = MCLMC(model=model) + mcmc = MCMC( + kernel, + num_warmup=num_warmup, + num_samples=num_samples, + num_chains=1, + progress_bar=False, + ) + mcmc.run(key3) + samples = mcmc.get_samples() + + assert "coefs" in samples + assert samples["coefs"].shape == (num_samples, dim) + assert_allclose(jnp.mean(samples["coefs"], 0), true_coefs, atol=0.5) + + +def test_mclmc_sample_shape(): + """Test that MCLMC produces samples with expected shapes.""" + num_warmup, num_samples = 500, 500 + + def model(): + numpyro.sample("a", dist.Normal(0, 1)) + numpyro.sample("b", dist.Normal(0, 1).expand([3])) + numpyro.sample("c", dist.Normal(0, 1).expand([2, 4])) + + kernel = MCLMC(model=model) + mcmc = MCMC( + kernel, + num_warmup=num_warmup, + num_samples=num_samples, + num_chains=1, + progress_bar=False, + ) + mcmc.run(random.PRNGKey(0)) + samples = mcmc.get_samples() + + assert samples["a"].shape == (num_samples,) + assert samples["b"].shape == (num_samples, 3) + assert samples["c"].shape == (num_samples, 2, 4) From 59d7f417fbe1cf62a7a178475a9a5e454e4d8022 Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Tue, 20 Jan 2026 23:36:18 +0100 Subject: [PATCH 3/4] coauthor Co-authored-by: reubenharry From af9f00101eeeec13803eec95aa324d4c9d6fd66f Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Wed, 21 Jan 2026 00:11:49 +0100 Subject: [PATCH 4/4] add blcakjax to test --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index f0bef9f5a..972a9dad1 100644 --- a/setup.py +++ b/setup.py @@ -57,6 +57,7 @@ "scikit-learn", "scipy>=1.9", "ty>=0.0.4", + "blackjax>=1.3", ], "dev": [ "dm-haiku>=0.0.14",