diff --git a/dimod/ess.py b/dimod/ess.py new file mode 100755 index 000000000..92ab3d3ee --- /dev/null +++ b/dimod/ess.py @@ -0,0 +1,108 @@ +# Copyright 2026 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from math import floor + +import numpy as np + +__all__ = ['estimate_effective_sample_size'] + + +def estimate_effective_sample_size(x: np.ndarray, b: int | None = None) -> float: + """Estimates the effective sample size of ``x``. + + The effective sample size (ESS) is the number of effectively independent samples drawn from + Markov chains' stationary distribution. The univariate estimator implemented here is the + (multivariate) estimator defined in the first equation at the top of page 14 in + ``_. + + Args: + x: An (m, n) matrix where rows index independent Markov chains and columns index + time steps. + b: Batch size of the estimator. If ``None``, then ``b`` is set to the floor of the + square root of ``n``. Defaults to None. + + Returns: + float: An estimate of the effective sample size of ``x``. + """ + if x.ndim != 2: + raise ValueError("The input matrix ``x`` should have shape (m, n) where m indexes " + f"independent Markov chains and n indexes time. ``x`` has shape {x.shape}") + m, n = x.shape + if b is None: + b = int(floor(n**0.5)) + if b > n or b < 3: + raise ValueError( + f"Batch size should be at least three but no more than the chain length of the Markov " + f"chain. Batch size is {b} and chain length is {n}. If size was not given, it defaults" + f"to the floor of square-root of the chain length." + ) + + s_squared = x.var(1, ddof=1).mean() + # = second equation at the top of page 7 + # = average of "sample variance within series" + + # This estimator $\hat{\tau}^2_L$ is defined in equation (5) of + # Revisiting the Gelman-Rubin Diagnostic (https://arxiv.org/abs/1812.09384) + tau_squared = (2 * _estimate_replicated_batch_means(x, b) + - _estimate_replicated_batch_means(x, b // 3)) + # = equation (5) + # = nVar(xbar_i.) = total variance of the mean-within-series + + sigma_squared = ((n - 1) / n) * s_squared + tau_squared / n + # = first equation at the top of page 10 + # = estimate of the distribution's variance + + ess = m * n * sigma_squared / tau_squared + # = estimate of the effective sample size + # = first equation at the top of page 14 + return ess.item() + + +def _estimate_replicated_batch_means(x: np.ndarray, b: int) -> float: + """Computes the replicated batch means estimate. + + This estimator (:math:`\\hat\\tau^2_b`) is defined in the equation above equation (5) of + ``_. + + The estimator batches each Markov chain into batches of size ``b``, estimates the mean of each + batch, and computes the sample variance of these batched means. + + The first few columns of ``x`` may be dropped in the estimation process in order to satisfy the + requirement that the length of the Markov chain is divisible by the batch size. + + Args: + x: An (m, n) matrix where rows index independent Markov chains and columns index + time steps. + b: Batch size of the estimator. + + Returns: + float: Replicated batch means estimate. + """ + n = x.shape[1] + + n_batches = n // b + trimmed_length = b * n_batches + + x = x[:, (n - trimmed_length):] + ybar = np.mean(np.split(x, n_batches, axis=1), axis=2).mT + + res = b*np.var(ybar, ddof=1) + # NOTE: this is equivalent to + # res = ((ybar - muhat) ** 2).sum() * b / (n_batches * m - 1) + # where + # muhat = x.mean() + return res diff --git a/releasenotes/notes/ess-f32b7e8d5a4ad67a.yaml b/releasenotes/notes/ess-f32b7e8d5a4ad67a.yaml new file mode 100755 index 000000000..a9d23344c --- /dev/null +++ b/releasenotes/notes/ess-f32b7e8d5a4ad67a.yaml @@ -0,0 +1,5 @@ +--- +features: + - | + Add estimator for effective sample size based on + ``_. \ No newline at end of file diff --git a/tests/test_ess.py b/tests/test_ess.py new file mode 100755 index 000000000..032754263 --- /dev/null +++ b/tests/test_ess.py @@ -0,0 +1,66 @@ +# Copyright 2026 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test effective sample size estimation.""" +import unittest +from math import isnan + +import numpy as np + +from dimod.ess import _estimate_replicated_batch_means +from dimod.ess import estimate_effective_sample_size as estimate_ess + + +class TestEffectiveSampleSize(unittest.TestCase): + + def test_estimate_ess(self): + with self.subTest("ESS estimate should be undefined for constant input"): + self.assertTrue(isnan(estimate_ess(np.ones((100, 1000))))) + + with self.subTest("Null batch size should raise an error."): + self.assertRaisesRegex(ValueError, "Batch size should be at least three", + estimate_ess, np.ones((100, 8))) + + with self.subTest("Null batch size should raise an error."): + self.assertRaisesRegex(ValueError, "Batch size should be at least three", + estimate_ess, np.ones((100, 1000)), 0) + + with self.subTest("Batch size larger than chain length should raise an error."): + self.assertRaisesRegex(ValueError, "Batch size should be at least three", + estimate_ess, np.ones((100, 1000)), 1001) + + with self.subTest("Inputs that are not 2D should raise an error."): + self.assertRaisesRegex(ValueError, "The input matrix ``x`` should have shape", + estimate_ess, np.ones((123, 100, 1000)), 234) + + with self.subTest("Single-batch estimates are incorrect."): + x = np.array([[0, 1, 2], + [0, 2, 4]]) + s_squared = (np.var([0, 1, 2], ddof=1) + np.var([0, 2, 4], ddof=1))/2 + tau_squared = (2 * _estimate_replicated_batch_means(x, 3) + - _estimate_replicated_batch_means(x, 1)) + sigma_squared = 2/3*s_squared + tau_squared/3 + answer = 2*3*sigma_squared/tau_squared + self.assertAlmostEqual(answer, estimate_ess(x, 3)) + + with self.subTest("Two-batch estimatse are incorrect."): + x = np.array([[999, 0, 1, 2, 3, 6, 7], + [999, 0, 2, 4, 4, 6, 7]]) + s_squared = (np.var([999, 0, 1, 2, 3, 6, 7], ddof=1) + + np.var([999, 0, 2, 4, 4, 6, 7], ddof=1))/2 + tau_squared = (2 * _estimate_replicated_batch_means(x, 3) + - _estimate_replicated_batch_means(x, 1)) + sigma_squared = 6/7*s_squared + tau_squared/7 + answer = 2*7*sigma_squared/tau_squared + self.assertAlmostEqual(answer, estimate_ess(x, 3))