Skip to content
This repository was archived by the owner on Dec 18, 2023. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"plotly>=2.2.1",
"scipy>=0.16",
"statsmodels>=0.12.0",
"torch>=1.9.0, <2.0",
"torch>=1.9.0",
"tqdm>=4.46.0",
"typing-extensions>=3.10",
"xarray>=0.16.0",
Expand Down
4 changes: 2 additions & 2 deletions src/beanmachine/ppl/inference/compositional_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class _DefaultInference(BaseInference):
Mixed inference class that handles both discrete and continuous RVs
"""

def __init__(self, nnc_compile: bool = True):
def __init__(self, nnc_compile: bool = False):
self._disc_proposers = {}
self._cont_proposer = None
self._continuous_rvs = set()
Expand Down Expand Up @@ -155,7 +155,7 @@ def __init__(
Union[BaseInference, Tuple[BaseInference, ...], EllipsisClass],
]
] = None,
nnc_compile: bool = True,
nnc_compile: bool = False,
):
self.config: Dict[Union[Callable, Tuple[Callable, ...]], BaseInference] = {}
# create a set for the RV families that are being covered in the config; this is
Expand Down
2 changes: 1 addition & 1 deletion src/beanmachine/ppl/inference/hmc_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(
adapt_mass_matrix: bool = True,
full_mass_matrix: bool = False,
target_accept_prob: float = 0.8,
nnc_compile: bool = True,
nnc_compile: bool = False,
experimental_inductor_compile: bool = False,
):
self.trajectory_length = trajectory_length
Expand Down
2 changes: 1 addition & 1 deletion src/beanmachine/ppl/inference/nuts_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
full_mass_matrix: bool = False,
multinomial_sampling: bool = True,
target_accept_prob: float = 0.8,
nnc_compile: bool = True,
nnc_compile: bool = False,
experimental_inductor_compile: bool = False,
):
self.max_tree_depth = max_tree_depth
Expand Down
2 changes: 1 addition & 1 deletion src/beanmachine/ppl/inference/proposer/hmc_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
adapt_mass_matrix: bool = True,
full_mass_matrix: bool = False,
target_accept_prob: float = 0.8,
jit_backend: TorchJITBackend = TorchJITBackend.NNC,
jit_backend: TorchJITBackend = TorchJITBackend.NONE,
):
self.world = initial_world
self._target_rvs = target_rvs
Expand Down
8 changes: 3 additions & 5 deletions src/beanmachine/ppl/inference/proposer/nnc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.

import logging
from typing import Callable, Optional, Tuple, TypeVar
from typing import Callable, TypeVar

from typing_extensions import ParamSpec

Expand All @@ -14,9 +14,7 @@
R = TypeVar("R")


def nnc_jit(
f: Callable[P, R], static_argnums: Optional[Tuple[int]] = None
) -> Callable[P, R]:
def nnc_jit(f: Callable[P, R]) -> Callable[P, R]:
"""
A helper function that lazily imports the NNC utils, which initialize the compiler
and displaying a experimental warning, then invoke the underlying nnc_jit on
Expand All @@ -33,7 +31,7 @@ def nnc_jit(
# return original function without change
return f

return raw_nnc_jit(f, static_argnums)
return raw_nnc_jit(f)


__all__ = ["nnc_jit"]
2 changes: 1 addition & 1 deletion src/beanmachine/ppl/inference/proposer/nuts_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(
full_mass_matrix: bool = False,
multinomial_sampling: bool = True,
target_accept_prob: float = 0.8,
jit_backend: TorchJITBackend = TorchJITBackend.NNC,
jit_backend: TorchJITBackend = TorchJITBackend.NONE,
):
# note that trajectory_length is not used in NUTS
super().__init__(
Expand Down
2 changes: 1 addition & 1 deletion tests/ppl/inference/nnc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def foo(self):
def bar(self):
return dist.Normal(self.foo(), 1.0)


@pytest.mark.skip(reason="disable NNC test until we fix the compatibility issue with PyTorch 2.0")
@pytest.mark.parametrize(
"algorithm",
[
Expand Down