diff --git a/numpyro/infer/mclmc.py b/numpyro/infer/mclmc.py new file mode 100644 index 000000000..36c1371c1 --- /dev/null +++ b/numpyro/infer/mclmc.py @@ -0,0 +1,219 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from collections import namedtuple + +import jax + +from numpyro.infer.mcmc import MCMCKernel +from numpyro.infer.util import initialize_model +from numpyro.util import identity + +try: + import blackjax + from blackjax.mcmc.integrators import IntegratorState + from blackjax.util import pytree_size + + _BLACKJAX_AVAILABLE = True +except ImportError: + _BLACKJAX_AVAILABLE = False + blackjax = None + IntegratorState = None + pytree_size = None + +FullState = namedtuple( + "FullState", ["position", "momentum", "logdensity", "logdensity_grad", "rng_key"] +) + + +class MCLMC(MCMCKernel): + """ + Microcanonical Langevin Monte Carlo (MCLMC) kernel. + + MCLMC is a gradient-based MCMC algorithm that uses Hamiltonian dynamics + on an extended state space. It requires the `blackjax` package. + + **References:** + + 1. *Microcanonical Hamiltonian Monte Carlo*, + Jakob Robnik, G. Bruno De Luca, Eva Silverstein, Uroš Seljak + https://arxiv.org/abs/2212.08549 + + .. note:: The model must have at least 2 latent dimensions for MCLMC to work + (this is a limitation of the blackjax implementation). + + :param model: Python callable containing Pyro :mod:`~numpyro.primitives`. + :param float desired_energy_var: Target energy variance for step size and + trajectory length tuning. Smaller values lead to more conservative + step sizes. Defaults to 5e-4. + :param bool diagonal_preconditioning: Whether to use diagonal preconditioning + for the mass matrix. Defaults to True. + """ + + def __init__( + self, + model=None, + desired_energy_var=5e-4, + diagonal_preconditioning=True, + ): + if not _BLACKJAX_AVAILABLE: + raise ImportError( + "MCLMC requires the 'blackjax' package. " + "Please install it with: pip install blackjax" + ) + if model is None: + raise ValueError("Model must be specified for MCLMC") + self._model = model + self._diagonal_preconditioning = diagonal_preconditioning + self._desired_energy_var = desired_energy_var + self._init_fn = None + self._sample_fn = None + self._postprocess_fn = None + + @property + def model(self): + return self._model + + @property + def sample_field(self): + return "position" + + @property + def default_fields(self): + return (self.sample_field,) + + def get_diagnostics_str(self, state): + """ + Return a diagnostics string for the progress bar. + """ + return "step_size={:.2e}, L={:.2e}".format( + self.adapt_state.step_size, self.adapt_state.L + ) + + def postprocess_fn(self, args, kwargs): + """ + Get a function that transforms unconstrained values at sample sites to values + constrained to the site's support, in addition to returning deterministic + sites in the model. + + :param args: Arguments to the model. + :param kwargs: Keyword arguments to the model. + """ + if self._postprocess_fn is None: + return identity + return self._postprocess_fn(*args, **kwargs) + + def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): + """ + Initialize the MCLMC kernel. + + :param rng_key: Random number generator key + :param num_warmup: Number of warmup steps + :param init_params: Initial parameters + :param model_args: Model arguments + :param model_kwargs: Model keyword arguments + :return: Initial state + """ + + init_model_key, init_state_key, run_key, rng_key_tune = jax.random.split( + rng_key, 4 + ) + + init_params, potential_fn_gen, postprocess_fn, _ = initialize_model( + init_model_key, + self._model, + model_args=model_args, + model_kwargs=model_kwargs, + dynamic_args=True, + ) + self._postprocess_fn = postprocess_fn + + def logdensity_fn(position): + return -potential_fn_gen(*model_args, **model_kwargs)(position) + + initial_position = init_params.z + self.logdensity_fn = logdensity_fn + + sampler_state = blackjax.mcmc.mclmc.init( + position=initial_position, + logdensity_fn=self.logdensity_fn, + rng_key=init_state_key, + ) + + def kernel(inverse_mass_matrix): + return blackjax.mcmc.mclmc.build_kernel( + logdensity_fn=logdensity_fn, + integrator=blackjax.mcmc.integrators.isokinetic_mclachlan, + inverse_mass_matrix=inverse_mass_matrix, + ) + + self.dim = pytree_size(initial_position) + + # num_steps is a dummy param here (used for tuning fractions) + num_tuning_steps = 100 + ( + blackjax_state_after_tuning, + blackjax_mclmc_sampler_params, + _, + ) = blackjax.mclmc_find_L_and_step_size( + mclmc_kernel=kernel, + num_steps=num_tuning_steps, + state=sampler_state, + rng_key=rng_key_tune, + diagonal_preconditioning=self._diagonal_preconditioning, + frac_tune3=num_warmup / (3 * num_tuning_steps), + frac_tune2=num_warmup / (3 * num_tuning_steps), + frac_tune1=num_warmup / (3 * num_tuning_steps), + desired_energy_var=self._desired_energy_var, + ) + + self.adapt_state = blackjax_mclmc_sampler_params + + return FullState( + blackjax_state_after_tuning.position, + blackjax_state_after_tuning.momentum, + blackjax_state_after_tuning.logdensity, + blackjax_state_after_tuning.logdensity_grad, + run_key, + ) + + def sample(self, state, model_args, model_kwargs): + """ + Run MCLMC from the given state and return the resulting state. + + :param state: Current state + :param model_args: Model arguments + :param model_kwargs: Model keyword arguments + :return: Next state after running MCLMC + """ + + mclmc_state = IntegratorState( + state.position, state.momentum, state.logdensity, state.logdensity_grad + ) + rng_key, rng_key_sample = jax.random.split(state.rng_key, 2) + + kernel = blackjax.mcmc.mclmc.build_kernel( + logdensity_fn=self.logdensity_fn, + integrator=blackjax.mcmc.integrators.isokinetic_mclachlan, + inverse_mass_matrix=self.adapt_state.inverse_mass_matrix, + ) + + new_state, info = kernel( + rng_key=rng_key_sample, + state=mclmc_state, + step_size=self.adapt_state.step_size, + L=self.adapt_state.L, + ) + + return FullState( + new_state.position, + new_state.momentum, + new_state.logdensity, + new_state.logdensity_grad, + rng_key, + ) + + def __getstate__(self): + state = self.__dict__.copy() + state["_postprocess_fn"] = None + return state diff --git a/setup.py b/setup.py index f0bef9f5a..972a9dad1 100644 --- a/setup.py +++ b/setup.py @@ -57,6 +57,7 @@ "scikit-learn", "scipy>=1.9", "ty>=0.0.4", + "blackjax>=1.3", ], "dev": [ "dm-haiku>=0.0.14", diff --git a/test/infer/test_mclmc.py b/test/infer/test_mclmc.py new file mode 100644 index 000000000..b3fd1af9d --- /dev/null +++ b/test/infer/test_mclmc.py @@ -0,0 +1,155 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from numpy.testing import assert_allclose +import pytest + +from jax import random +import jax.numpy as jnp + +import numpyro +import numpyro.distributions as dist +from numpyro.infer import MCMC +from numpyro.infer.mclmc import MCLMC + + +def test_mclmc_model_required(): + """Test that ValueError is raised when model is None.""" + with pytest.raises(ValueError, match="Model must be specified"): + MCLMC(model=None) + + +def test_mclmc_blackjax_not_installed(monkeypatch): + """Test that ImportError is raised with informative message when blackjax is not installed.""" + import numpyro.infer.mclmc as mclmc_module + + # Temporarily set _BLACKJAX_AVAILABLE to False + monkeypatch.setattr(mclmc_module, "_BLACKJAX_AVAILABLE", False) + + def dummy_model(): + numpyro.sample("x", dist.Normal(0, 1)) + + with pytest.raises(ImportError, match="MCLMC requires the 'blackjax' package"): + MCLMC(model=dummy_model) + + +def test_mclmc_normal(): + """Test MCLMC with a 2D normal distribution. + + Note: MCLMC requires at least 2 dimensions (blackjax limitation). + """ + true_mean = jnp.array([1.0, 2.0]) + true_std = jnp.array([0.5, 1.0]) + num_warmup, num_samples = 1000, 2000 + + def model(): + numpyro.sample("x", dist.Normal(true_mean, true_std).to_event(1)) + + kernel = MCLMC(model=model) + mcmc = MCMC( + kernel, + num_warmup=num_warmup, + num_samples=num_samples, + num_chains=1, + progress_bar=False, + ) + mcmc.run(random.PRNGKey(0)) + samples = mcmc.get_samples() + + assert "x" in samples + assert samples["x"].shape == (num_samples, 2) + assert_allclose(jnp.mean(samples["x"], axis=0), true_mean, atol=0.1) + assert_allclose(jnp.std(samples["x"], axis=0), true_std, atol=0.2) + + +def test_mclmc_gaussian_2d(): + """Test MCLMC with a 2D Gaussian model with observation.""" + num_warmup, num_samples = 1000, 1000 + + def model(): + x = numpyro.sample("x", dist.Normal(0.0, 1.0)) + y = numpyro.sample("y", dist.Normal(0.0, 1.0)) + numpyro.sample("obs", dist.Normal(x + y, 0.5), obs=jnp.array(0.0)) + + kernel = MCLMC( + model=model, + diagonal_preconditioning=True, + desired_energy_var=5e-4, + ) + mcmc = MCMC( + kernel, + num_warmup=num_warmup, + num_samples=num_samples, + num_chains=1, + progress_bar=False, + ) + mcmc.run(random.PRNGKey(0)) + samples = mcmc.get_samples() + + assert "x" in samples + assert "y" in samples + assert samples["x"].shape == (num_samples,) + assert samples["y"].shape == (num_samples,) + # With obs=0, x+y should be close to 0, so means should be near 0 + assert_allclose(jnp.mean(samples["x"]) + jnp.mean(samples["y"]), 0.0, atol=0.2) + + +def test_mclmc_logistic_regression(): + """Test MCLMC with a logistic regression model. + + Note: MCLMC currently doesn't pass model_args, so we use a closure pattern. + """ + N, dim = 1000, 3 + num_warmup, num_samples = 1000, 2000 + + key1, key2, key3 = random.split(random.PRNGKey(0), 3) + data = random.normal(key1, (N, dim)) + true_coefs = jnp.arange(1.0, dim + 1.0) + logits = jnp.sum(true_coefs * data, axis=-1) + labels = dist.Bernoulli(logits=logits).sample(key2) + + # Use closure pattern since MCLMC doesn't pass model_args + def model(): + coefs = numpyro.sample("coefs", dist.Normal(jnp.zeros(dim), jnp.ones(dim))) + logits = jnp.sum(coefs * data, axis=-1) + numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels) + + kernel = MCLMC(model=model) + mcmc = MCMC( + kernel, + num_warmup=num_warmup, + num_samples=num_samples, + num_chains=1, + progress_bar=False, + ) + mcmc.run(key3) + samples = mcmc.get_samples() + + assert "coefs" in samples + assert samples["coefs"].shape == (num_samples, dim) + assert_allclose(jnp.mean(samples["coefs"], 0), true_coefs, atol=0.5) + + +def test_mclmc_sample_shape(): + """Test that MCLMC produces samples with expected shapes.""" + num_warmup, num_samples = 500, 500 + + def model(): + numpyro.sample("a", dist.Normal(0, 1)) + numpyro.sample("b", dist.Normal(0, 1).expand([3])) + numpyro.sample("c", dist.Normal(0, 1).expand([2, 4])) + + kernel = MCLMC(model=model) + mcmc = MCMC( + kernel, + num_warmup=num_warmup, + num_samples=num_samples, + num_chains=1, + progress_bar=False, + ) + mcmc.run(random.PRNGKey(0)) + samples = mcmc.get_samples() + + assert samples["a"].shape == (num_samples,) + assert samples["b"].shape == (num_samples, 3) + assert samples["c"].shape == (num_samples, 2, 4)