diff --git a/setup.py b/setup.py index adf15ab2e3..dbc4e659e5 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ "plotly>=2.2.1", "scipy>=0.16", "statsmodels>=0.12.0", - "torch>=1.9.0, <2.0", + "torch>=1.9.0", "tqdm>=4.46.0", "typing-extensions>=3.10", "xarray>=0.16.0", diff --git a/src/beanmachine/ppl/inference/compositional_infer.py b/src/beanmachine/ppl/inference/compositional_infer.py index 7eb573ada9..68680504af 100644 --- a/src/beanmachine/ppl/inference/compositional_infer.py +++ b/src/beanmachine/ppl/inference/compositional_infer.py @@ -48,7 +48,7 @@ class _DefaultInference(BaseInference): Mixed inference class that handles both discrete and continuous RVs """ - def __init__(self, nnc_compile: bool = True): + def __init__(self, nnc_compile: bool = False): self._disc_proposers = {} self._cont_proposer = None self._continuous_rvs = set() @@ -155,7 +155,7 @@ def __init__( Union[BaseInference, Tuple[BaseInference, ...], EllipsisClass], ] ] = None, - nnc_compile: bool = True, + nnc_compile: bool = False, ): self.config: Dict[Union[Callable, Tuple[Callable, ...]], BaseInference] = {} # create a set for the RV families that are being covered in the config; this is diff --git a/src/beanmachine/ppl/inference/hmc_inference.py b/src/beanmachine/ppl/inference/hmc_inference.py index 8332d473f2..95ebc7a019 100644 --- a/src/beanmachine/ppl/inference/hmc_inference.py +++ b/src/beanmachine/ppl/inference/hmc_inference.py @@ -108,7 +108,7 @@ def __init__( adapt_mass_matrix: bool = True, full_mass_matrix: bool = False, target_accept_prob: float = 0.8, - nnc_compile: bool = True, + nnc_compile: bool = False, experimental_inductor_compile: bool = False, ): self.trajectory_length = trajectory_length diff --git a/src/beanmachine/ppl/inference/nuts_inference.py b/src/beanmachine/ppl/inference/nuts_inference.py index 3754a1a2a8..e668c9f0c9 100644 --- a/src/beanmachine/ppl/inference/nuts_inference.py +++ b/src/beanmachine/ppl/inference/nuts_inference.py @@ -52,7 +52,7 @@ def __init__( full_mass_matrix: bool = False, multinomial_sampling: bool = True, target_accept_prob: float = 0.8, - nnc_compile: bool = True, + nnc_compile: bool = False, experimental_inductor_compile: bool = False, ): self.max_tree_depth = max_tree_depth diff --git a/src/beanmachine/ppl/inference/proposer/hmc_proposer.py b/src/beanmachine/ppl/inference/proposer/hmc_proposer.py index 6b24c7496e..b190c1c4c7 100644 --- a/src/beanmachine/ppl/inference/proposer/hmc_proposer.py +++ b/src/beanmachine/ppl/inference/proposer/hmc_proposer.py @@ -61,7 +61,7 @@ def __init__( adapt_mass_matrix: bool = True, full_mass_matrix: bool = False, target_accept_prob: float = 0.8, - jit_backend: TorchJITBackend = TorchJITBackend.NNC, + jit_backend: TorchJITBackend = TorchJITBackend.NONE, ): self.world = initial_world self._target_rvs = target_rvs diff --git a/src/beanmachine/ppl/inference/proposer/nnc/__init__.py b/src/beanmachine/ppl/inference/proposer/nnc/__init__.py index f174563b3f..18816fca7a 100644 --- a/src/beanmachine/ppl/inference/proposer/nnc/__init__.py +++ b/src/beanmachine/ppl/inference/proposer/nnc/__init__.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Callable, Optional, Tuple, TypeVar +from typing import Callable, TypeVar from typing_extensions import ParamSpec @@ -14,9 +14,7 @@ R = TypeVar("R") -def nnc_jit( - f: Callable[P, R], static_argnums: Optional[Tuple[int]] = None -) -> Callable[P, R]: +def nnc_jit(f: Callable[P, R]) -> Callable[P, R]: """ A helper function that lazily imports the NNC utils, which initialize the compiler and displaying a experimental warning, then invoke the underlying nnc_jit on @@ -33,7 +31,7 @@ def nnc_jit( # return original function without change return f - return raw_nnc_jit(f, static_argnums) + return raw_nnc_jit(f) __all__ = ["nnc_jit"] diff --git a/src/beanmachine/ppl/inference/proposer/nuts_proposer.py b/src/beanmachine/ppl/inference/proposer/nuts_proposer.py index a02b543416..da8239071a 100644 --- a/src/beanmachine/ppl/inference/proposer/nuts_proposer.py +++ b/src/beanmachine/ppl/inference/proposer/nuts_proposer.py @@ -83,7 +83,7 @@ def __init__( full_mass_matrix: bool = False, multinomial_sampling: bool = True, target_accept_prob: float = 0.8, - jit_backend: TorchJITBackend = TorchJITBackend.NNC, + jit_backend: TorchJITBackend = TorchJITBackend.NONE, ): # note that trajectory_length is not used in NUTS super().__init__( diff --git a/tests/ppl/inference/nnc_test.py b/tests/ppl/inference/nnc_test.py index e379dbe985..d3a92a9d9c 100644 --- a/tests/ppl/inference/nnc_test.py +++ b/tests/ppl/inference/nnc_test.py @@ -20,7 +20,7 @@ def foo(self): def bar(self): return dist.Normal(self.foo(), 1.0) - +@pytest.mark.skip(reason="disable NNC test until we fix the compatibility issue with PyTorch 2.0") @pytest.mark.parametrize( "algorithm", [