Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ repos:
^(.gitlab|.github|config|doc|tests|test_utils)/|
^src/(example_simulator_functions|queens_interfaces)/|
^src/queens/(data_processors|drivers|iterators|models)/|
^src/queens/(schedulers|stochastic_optimizers|variational_distributions|visualization)/|
^src/queens/(schedulers|stochastic_optimizers|visualization)/|
^src/queens/(main.py|global_settings.py)
).*$
- repo: https://github.com/kynan/nbstripout
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ exclude = '''(?x)(
^(.gitlab|.github|config|doc|tests|test_utils)/|
^src/(example_simulator_functions|queens_interfaces)/|
^src/queens/(data_processors|drivers|iterators|models)/|
^src/queens/(schedulers|stochastic_optimizers|variational_distributions|visualization)/|
^src/queens/(schedulers|stochastic_optimizers|visualization)/|
^src/queens/(main.py|global_settings.py)
).*$'''
[[tool.mypy.overrides]]
Expand Down
2 changes: 1 addition & 1 deletion src/queens/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,5 @@
class_module_map = extract_type_checking_imports(__file__)


def __getattr__(name: str) -> Distribution:
def __getattr__(name: str) -> type[Distribution]:
return import_class_from_class_module_map(name, class_module_map, __name__)
2 changes: 1 addition & 1 deletion src/queens/parameters/random_fields/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,5 @@
class_module_map = extract_type_checking_imports(__file__)


def __getattr__(name: str) -> RandomField:
def __getattr__(name: str) -> type[RandomField]:
return import_class_from_class_module_map(name, class_module_map, __name__)
5 changes: 4 additions & 1 deletion src/queens/variational_distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@

Modules containing probability distributions for variational inference.
"""
from __future__ import annotations

from typing import TYPE_CHECKING

from queens.utils.imports import extract_type_checking_imports, import_class_from_class_module_map

if TYPE_CHECKING:
from queens.variational_distributions._variational_distribution import Variational
from queens.variational_distributions.full_rank_normal import FullRankNormal
from queens.variational_distributions.joint import Joint
from queens.variational_distributions.mean_field_normal import MeanFieldNormal
Expand All @@ -30,5 +33,5 @@
class_module_map = extract_type_checking_imports(__file__)


def __getattr__(name):
def __getattr__(name: str) -> type[Variational]:
return import_class_from_class_module_map(name, class_module_map, __name__)
134 changes: 104 additions & 30 deletions src/queens/variational_distributions/_variational_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,95 +15,169 @@
"""Variational Distribution."""

import abc
from typing import Any, Literal, TypeAlias, TypeVar

import numpy as np

# pylint: disable=invalid-name

NDims = TypeVar("NDims", bound=int)
NSamples = TypeVar("NSamples", bound=int)
NParams = TypeVar("NParams", bound=int)
NParamsComponent = TypeVar("NParamsComponent", bound=int)

# Vectors
ArrayNDims: TypeAlias = np.ndarray[tuple[NDims], np.dtype[np.floating]]
ArrayNParams: TypeAlias = np.ndarray[tuple[NParams], np.dtype[np.floating]]
ArrayNParamsComponent: TypeAlias = np.ndarray[tuple[NParamsComponent], np.dtype[np.floating]]
ArrayNSamples: TypeAlias = np.ndarray[tuple[NSamples], np.dtype[np.floating]]

# Matrices
Array1XNParams: TypeAlias = np.ndarray[tuple[Literal[1], NParams], np.dtype[np.floating]]
ArrayNDimsX1: TypeAlias = np.ndarray[tuple[NDims, Literal[1]], np.dtype[np.floating]]
ArrayNDimsXNDims: TypeAlias = np.ndarray[tuple[NDims, NDims], np.dtype[np.floating]]
ArrayNParamsXNParams: TypeAlias = np.ndarray[tuple[NParams, NParams], np.dtype[np.floating]]
ArrayNParamsXNSamples: TypeAlias = np.ndarray[tuple[NParams, NSamples], np.dtype[np.floating]]
ArrayNSamplesXNDims: TypeAlias = np.ndarray[tuple[NSamples, NDims], np.dtype[np.floating]]
ArrayNSamplesXNParams: TypeAlias = np.ndarray[tuple[NSamples, NParams], np.dtype[np.floating]]
Comment on lines +24 to +42
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I love this!

I was always hoping we would be able to do something like this. It will help us a lot! 🙌



class Variational:
"""Base class for probability distributions for variational inference.

Attributes:
dimension (int): dimension of the distribution
dimension: Dimension of the distribution
n_parameters: Number of variational parameters
"""

def __init__(self, dimension):
"""Initialize variational distribution."""
def __init__(self, dimension: NDims, n_parameters: NParams) -> None:
"""Initialize variational distribution.

Args:
dimension: Dimension of the variational distribution
n_parameters: Number of variational parameters
"""
self.dimension = dimension
self.n_parameters = n_parameters

@abc.abstractmethod
def reconstruct_distribution_parameters(self, variational_parameters):
def construct_variational_parameters(self, *args: Any) -> ArrayNParams:
"""Construct variational parameters from distribution parameters.

Args:
args: Distribution parameters

Returns:
Variational parameters
"""

@abc.abstractmethod
def reconstruct_distribution_parameters(self, variational_parameters: ArrayNParams) -> Any:
"""Reconstruct distribution parameters from variational parameters.

Args:
variational_parameters (np.ndarray): Variational parameters
variational_parameters: Variational parameters

Returns:
Distribution parameters
"""

@abc.abstractmethod
def draw(self, variational_parameters, n_draws=1):
def draw(self, variational_parameters: ArrayNParams, n_draws: NSamples) -> ArrayNSamplesXNDims:
"""Draw *n_draws* samples from distribution.

Args:
variational_parameters (np.ndarray): variational parameters (1 x n_params)
n_draws (int): Number of samples
variational_parameters: Variational parameters
n_draws: Number of samples

Returns:
Drawn samples
"""

@abc.abstractmethod
def logpdf(self, variational_parameters, x):
"""Evaluate the natural logarithm of the logpdf at sample.
def logpdf(
self,
variational_parameters: ArrayNParams,
x: ArrayNSamplesXNDims,
) -> ArrayNSamples:
"""Evaluate the natural logarithm of the PDF.

Args:
variational_parameters (np.ndarray): variational parameters (1 x n_params)
x (np.ndarray): Locations to evaluate (n_samples x n_dim)
variational_parameters: Variational parameters
x: Locations to evaluate

Returns:
Log-PDF values
"""

@abc.abstractmethod
def pdf(self, variational_parameters, x):
"""Evaluate the probability density function (pdf) at sample.
def pdf(
self,
variational_parameters: ArrayNParams,
x: ArrayNSamplesXNDims,
) -> ArrayNSamples:
"""Evaluate the probability density function (PDF).

Args:
variational_parameters (np.ndarray): variational parameters (1 x n_params)
x (np.ndarray): Locations to evaluate (n_samples x n_dim)
variational_parameters: Variational parameters
x: Locations to evaluate

Returns:
PDF values
"""

@abc.abstractmethod
def grad_params_logpdf(self, variational_parameters, x):
"""Logpdf gradient w.r.t. the variational parameters.
def grad_params_logpdf(
self,
variational_parameters: ArrayNParams,
x: ArrayNSamplesXNDims,
) -> ArrayNParamsXNSamples:
"""Log-PDF gradient w.r.t. the variational parameters.

Evaluated at samples *x*. Also known as the score function.

Args:
variational_parameters (np.ndarray): variational parameters (1 x n_params)
x (np.ndarray): Locations to evaluate (n_samples x n_dim)
variational_parameters: Variational parameters
x: Locations to evaluate

Returns:
Gradient of the log-PDF w.r.t. the variational parameters
"""

@abc.abstractmethod
def fisher_information_matrix(self, variational_parameters):
"""Compute the fisher information matrix.
def fisher_information_matrix(
self, variational_parameters: ArrayNParams
) -> ArrayNParamsXNParams:
"""Compute the Fisher information matrix.

Depends on the variational distribution for the given
parameterization.

Args:
variational_parameters (np.ndarray): variational parameters (1 x n_params)
variational_parameters: Variational parameters

Returns:
Fisher information matrix
"""

@abc.abstractmethod
def initialize_variational_parameters(self, random=False):
def initialize_variational_parameters(self, random: bool = False) -> ArrayNParams:
"""Initialize variational parameters.

Args:
random (bool, optional): If True, a random initialization is used. Otherwise the
default is selected
random: If True, a random initialization is used. Otherwise the default is selected.

Returns:
variational_parameters (np.ndarray): variational parameters (1 x n_params)
Variational parameters
"""

@abc.abstractmethod
def export_dict(self, variational_parameters):
def export_dict(self, variational_parameters: ArrayNParams) -> dict:
"""Create a dict of the distribution based on the given parameters.

Args:
variational_parameters (np.ndarray): Variational parameters
variational_parameters: Variational parameters

Returns:
export_dict (dictionary): Dict containing distribution information
Dictionary containing distribution information
"""
Loading
Loading