Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 6 additions & 142 deletions pytensor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,168 +1,32 @@
"""
PyTensor is an optimizing compiler in Python, built to evaluate
complicated expressions (especially matrix-valued ones) as quickly as
possible. PyTensor compiles expression graphs (see :doc:`graph` ) that
are built by Python code. The expressions in these graphs are called
`Apply` nodes and the variables in these graphs are called `Variable`
nodes.

You compile a graph by calling `function`, which takes a graph, and
returns a callable object. One of pytensor's most important features is
that `function` can transform your graph before compiling it. It can
replace simple expressions with faster or more numerically stable
implementations.

To learn more, check out:

- Op List (:doc:`oplist`)

"""

__docformat__ = "restructuredtext en"

# Set a default logger. It is important to do this before importing some other
# pytensor code, since this code may want to log some messages.
import logging
import sys
import warnings
from functools import singledispatch
from pathlib import Path
from typing import Any, NoReturn, Optional

from pytensor import _version


__version__: str = _version.get_versions()["version"]

del _version

pytensor_logger = logging.getLogger("pytensor")
logging_default_handler = logging.StreamHandler()
logging_default_formatter = logging.Formatter(
fmt="%(levelname)s (%(name)s): %(message)s"
)
logging_default_handler.setFormatter(logging_default_formatter)
pytensor_logger.setLevel(logging.WARNING)

if not pytensor_logger.hasHandlers():
pytensor_logger.addHandler(logging_default_handler)


# Disable default log handler added to pytensor_logger when the module
# is imported.
def disable_log_handler(logger=pytensor_logger, handler=logging_default_handler):
if logger.hasHandlers():
logger.removeHandler(handler)


# Raise a meaningful warning/error if the pytensor directory is in the Python
# path.
rpath = Path(__file__).parent.resolve()
if any(rpath == Path(p).resolve() for p in sys.path):
raise RuntimeError("You have the pytensor directory in your Python path.")

from pytensor.configdefaults import config


# This is the api version for ops that generate C code. External ops
# might need manual changes if this number goes up. An undefined
# __api_version__ can be understood to mean api version 0.
#
# This number is not tied to the release version and should change
# very rarely.
__api_version__ = 1

# isort: off
from pytensor.graph.basic import Variable
from pytensor.graph.replace import clone_replace, graph_replace

# isort: on


def as_symbolic(x: Any, name: str | None = None, **kwargs) -> Variable:
"""Convert `x` into an equivalent PyTensor `Variable`.

Parameters
----------
x
The object to be converted into a ``Variable`` type. A
``numpy.ndarray`` argument will not be copied, but a list of numbers
will be copied to make an ``numpy.ndarray``.
name
If a new ``Variable`` instance is created, it will be named with this
string.
kwargs
Options passed to the appropriate sub-dispatch functions. For example,
`ndim` and `dtype` can be passed when `x` is an `numpy.ndarray` or
`Number` type.

Raises
------
TypeError
If `x` cannot be converted to a `Variable`.

"""
if isinstance(x, Variable):
return x

res = _as_symbolic(x, **kwargs)
res.name = name
return res


@singledispatch
def _as_symbolic(x: Any, **kwargs) -> Variable:
from pytensor.tensor import as_tensor_variable

return as_tensor_variable(x, **kwargs)


# isort: off
from pytensor import scalar, tensor
from pytensor import tensor
from pytensor import sparse
from pytensor.compile import (
In,
Mode,
Out,
ProfileStats,
predefined_linkers,
predefined_modes,
predefined_optimizers,
shared,
wrap_py,
function,
)
from pytensor.compile.function import function, function_dump
from pytensor.compile.function.types import FunctionMaker
from pytensor.gradient import Lop, Rop, grad, subgraph_grad
from pytensor.gradient import Lop, Rop, grad
from pytensor.printing import debugprint as dprint
from pytensor.printing import pp, pprint
from pytensor.updates import OrderedUpdates

# isort: on


def get_underlying_scalar_constant(v):
"""Return the constant scalar (i.e. 0-D) value underlying variable `v`.

If `v` is the output of dim-shuffles, fills, allocs, cast, etc.
this function digs through them.

If ``pytensor.sparse`` is also there, we will look over CSM `Op`.

If `v` is not some view of constant data, then raise a
`NotScalarConstantError`.
"""
warnings.warn(
"get_underlying_scalar_constant is deprecated. Use tensor.get_underlying_scalar_constant_value instead.",
FutureWarning,
)
from pytensor.tensor.basic import get_underlying_scalar_constant_value

return get_underlying_scalar_constant_value(v)


# isort: off
import pytensor.tensor.random.var
import pytensor.sparse
from pytensor.ifelse import ifelse
from pytensor.scan import checkpoints
from pytensor.scan.basic import scan
from pytensor.scan.views import foldl, foldr, map, reduce
from pytensor.compile.builders import OpFromGraph
Expand Down
42 changes: 42 additions & 0 deletions pytensor/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from functools import singledispatch
from typing import Any

from pytensor.graph import Variable


def as_symbolic(x: Any, name: str | None = None, **kwargs) -> Variable:
"""Convert `x` into an equivalent PyTensor `Variable`.

Parameters
----------
x
The object to be converted into a ``Variable`` type. A
``numpy.ndarray`` argument will not be copied, but a list of numbers
will be copied to make an ``numpy.ndarray``.
name
If a new ``Variable`` instance is created, it will be named with this
string.
kwargs
Options passed to the appropriate sub-dispatch functions. For example,
`ndim` and `dtype` can be passed when `x` is an `numpy.ndarray` or
`Number` type.

Raises
------
TypeError
If `x` cannot be converted to a `Variable`.

"""
if isinstance(x, Variable):
return x

res = _as_symbolic(x, **kwargs)
res.name = name
return res


@singledispatch
def _as_symbolic(x: Any, **kwargs) -> Variable:
from pytensor.tensor import as_tensor_variable

return as_tensor_variable(x, **kwargs)
27 changes: 2 additions & 25 deletions pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import logging
import warnings
from typing import Any, Literal
from typing import Any

from pytensor.compile.function.types import Supervisor
from pytensor.configdefaults import config
Expand Down Expand Up @@ -512,7 +512,7 @@ def get_mode(orig_string):
if upper_string == "FAST_RUN":
linker = config.linker
if linker == "auto":
return CVM if config.cxx else VM
return NUMBA
return fast_run_linkers_to_mode[linker]

global _CACHED_RUNTIME_MODES
Expand Down Expand Up @@ -565,26 +565,3 @@ def register_mode(name, mode):
if name in predefined_modes:
raise ValueError(f"Mode name already taken: {name}")
predefined_modes[name] = mode


def get_target_language(mode=None) -> tuple[Literal["py", "c", "numba", "jax"], ...]:
"""Get the compilation target language."""

if mode is None:
mode = get_default_mode()

linker = mode.linker

if isinstance(linker, NumbaLinker):
return ("numba",)
if isinstance(linker, JAXLinker):
return ("jax",)
if isinstance(linker, PerformLinker):
return ("py",)
if isinstance(linker, CLinker):
return ("c",)

if isinstance(linker, VMLinker | OpWiseCLinker):
return ("c", "py") if config.cxx else ("py",)

raise Exception(f"Unsupported Linker: {linker}")
Loading
Loading