diff --git a/.test-conda-env-py3.yml b/.test-conda-env-py3.yml index 3a4f0a8f..c4981063 100644 --- a/.test-conda-env-py3.yml +++ b/.test-conda-env-py3.yml @@ -12,6 +12,7 @@ dependencies: - pyopencl - python=3 - gmsh +- jax # test scripts use ompi-specific arguments - openmpi diff --git a/doc/conf.py b/doc/conf.py index 1597d1c6..7f181517 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -18,6 +18,7 @@ intersphinx_mapping = { "arraycontext": ("https://documen.tician.de/arraycontext/", None), "loopy": ("https://documen.tician.de/loopy/", None), + "jax": ("https://docs.jax.dev/en/latest/", None), "meshmode": ("https://documen.tician.de/meshmode/", None), "modepy": ("https://documen.tician.de/modepy/", None), "mpi4py": ("https://mpi4py.readthedocs.io/en/stable", None), @@ -33,6 +34,7 @@ os.environ["PYOPENCL_TEST"] = "port:cpu" nitpick_ignore_regex = [ + ["py:mod", r"jax"], # FIXME: not sure why this does not work ["py:class", r"np\.ndarray"], ["py:data|py:class", r"arraycontext.*ContainerTc"], ] diff --git a/grudge/array_context.py b/grudge/array_context.py index 9f1074bc..0202333d 100644 --- a/grudge/array_context.py +++ b/grudge/array_context.py @@ -5,6 +5,7 @@ .. autoclass:: MPIPyOpenCLArrayContext .. autoclass:: MPINumpyArrayContext .. class:: MPIPytatoArrayContext +.. autoclass:: MPIEagerJAXArrayContext .. autofunction:: get_reasonable_array_context_class """ @@ -75,9 +76,10 @@ _HAVE_FUSION_ACTX = False -from arraycontext import ArrayContext, NumpyArrayContext +from arraycontext import ArrayContext, EagerJAXArrayContext, NumpyArrayContext from arraycontext.impl.pytato.compile import LazilyPyOpenCLCompilingFunctionCaller from arraycontext.pytest import ( + _PytestEagerJaxArrayContextFactory, _PytestNumpyArrayContextFactory, _PytestPyOpenCLArrayContextFactoryWithClass, _PytestPytatoPyOpenCLArrayContextFactory, @@ -443,6 +445,26 @@ def clone(self) -> Self: # }}} +# {{{ distributed + eager jax + +class MPIEagerJAXArrayContext(EagerJAXArrayContext, MPIBasedArrayContext): + """An array context for using distributed computation with :mod:`jax` + eager evaluation. + + .. autofunction:: __init__ + """ + + def __init__(self, mpi_communicator) -> None: + super().__init__() + + self.mpi_communicator = mpi_communicator + + def clone(self) -> Self: + return type(self)(self.mpi_communicator) + +# }}} + + # {{{ distributed + pytato array context subclasses class MPIBasePytatoPyOpenCLArrayContext( @@ -542,12 +564,23 @@ def __call__(self): return self.actx_class() +class PytestEagerJAXArrayContextFactory(_PytestEagerJaxArrayContextFactory): + actx_class = EagerJAXArrayContext + + def __call__(self): + import jax + jax.config.update("jax_enable_x64", True) + return self.actx_class() + + register_pytest_array_context_factory("grudge.pyopencl", PytestPyOpenCLArrayContextFactory) register_pytest_array_context_factory("grudge.pytato-pyopencl", PytestPytatoPyOpenCLArrayContextFactory) register_pytest_array_context_factory("grudge.numpy", PytestNumpyArrayContextFactory) +register_pytest_array_context_factory("grudge.eager-jax", + PytestEagerJAXArrayContextFactory) # }}} diff --git a/grudge/geometry/metrics.py b/grudge/geometry/metrics.py index 04aa11b9..15e73e6b 100644 --- a/grudge/geometry/metrics.py +++ b/grudge/geometry/metrics.py @@ -597,22 +597,24 @@ def _signed_face_ones( dd_base.untrace(), dd_base ) assert isinstance(all_faces_conn, DirectDiscretizationConnection) - signed_ones = dcoll.discr_from_dd(dd.with_discr_tag(DISCR_TAG_BASE)).zeros( - actx, dtype=dcoll.real_dtype - ) + 1 - signed_face_ones_numpy = actx.to_numpy(signed_ones) + discr = dcoll.discr_from_dd(dd.with_discr_tag(DISCR_TAG_BASE)) + + new_group_arrays = [] + + for dgrp, grp in zip(discr.groups, all_faces_conn.groups, strict=True): + sign = np.ones((dgrp.nelements, dgrp.nunit_dofs), + dtype=discr.real_dtype) - for igrp, grp in enumerate(all_faces_conn.groups): for batch in grp.batches: assert batch.to_element_face is not None i = actx.to_numpy(actx.thaw(batch.to_element_indices)) - grp_field = signed_face_ones_numpy[igrp].reshape(-1) - grp_field[i] = ( # pyright: ignore[reportIndexIssue] - (2.0 * (batch.to_element_face % 2) - 1.0) * grp_field[i] - ) + sign[i, :] = 2.0 * (batch.to_element_face % 2) - 1.0 + + new_group_arrays.append(sign) - return actx.from_numpy(signed_face_ones_numpy) + from meshmode.dof_array import DOFArray + return actx.from_numpy(DOFArray(actx, tuple(new_group_arrays))) def parametrization_derivative( diff --git a/test/test_dt_utils.py b/test/test_dt_utils.py index 59e3d4db..41cc6806 100644 --- a/test/test_dt_utils.py +++ b/test/test_dt_utils.py @@ -31,6 +31,7 @@ from arraycontext import ArrayContextFactory, pytest_generate_tests_for_array_contexts from grudge.array_context import ( + PytestEagerJAXArrayContextFactory, PytestNumpyArrayContextFactory, PytestPyOpenCLArrayContextFactory, PytestPytatoPyOpenCLArrayContextFactory, @@ -40,7 +41,8 @@ pytest_generate_tests = pytest_generate_tests_for_array_contexts( [PytestPyOpenCLArrayContextFactory, PytestPytatoPyOpenCLArrayContextFactory, - PytestNumpyArrayContextFactory]) + PytestNumpyArrayContextFactory, + PytestEagerJAXArrayContextFactory]) import logging diff --git a/test/test_euler_model.py b/test/test_euler_model.py index 4d8504f5..126576db 100644 --- a/test/test_euler_model.py +++ b/test/test_euler_model.py @@ -35,12 +35,18 @@ ) from grudge import op -from grudge.array_context import PytestPyOpenCLArrayContextFactory +from grudge.array_context import ( + PytestEagerJAXArrayContextFactory, + PytestNumpyArrayContextFactory, + PytestPyOpenCLArrayContextFactory, +) logger = logging.getLogger(__name__) pytest_generate_tests = pytest_generate_tests_for_array_contexts( - [PytestPyOpenCLArrayContextFactory]) + [PytestPyOpenCLArrayContextFactory, + PytestNumpyArrayContextFactory, + PytestEagerJAXArrayContextFactory]) @pytest.mark.parametrize("order", [1, 2, 3]) diff --git a/test/test_grudge.py b/test/test_grudge.py index dbf80dab..e5ccf372 100644 --- a/test/test_grudge.py +++ b/test/test_grudge.py @@ -45,13 +45,19 @@ from meshmode.mesh import TensorProductElementGroup from grudge import dof_desc, geometry, op -from grudge.array_context import PytestPyOpenCLArrayContextFactory +from grudge.array_context import ( + PytestEagerJAXArrayContextFactory, + PytestNumpyArrayContextFactory, + PytestPyOpenCLArrayContextFactory, +) from grudge.discretization import make_discretization_collection logger = logging.getLogger(__name__) pytest_generate_tests = pytest_generate_tests_for_array_contexts( - [PytestPyOpenCLArrayContextFactory]) + [PytestPyOpenCLArrayContextFactory, + PytestNumpyArrayContextFactory, + PytestEagerJAXArrayContextFactory]) # {{{ mass operator trig integration diff --git a/test/test_metrics.py b/test/test_metrics.py index 21a7934f..cf3c3035 100644 --- a/test/test_metrics.py +++ b/test/test_metrics.py @@ -36,6 +36,7 @@ from meshmode.dof_array import flat_norm from grudge.array_context import ( + PytestEagerJAXArrayContextFactory, PytestNumpyArrayContextFactory, PytestPyOpenCLArrayContextFactory, PytestPytatoPyOpenCLArrayContextFactory, @@ -47,7 +48,8 @@ pytest_generate_tests = pytest_generate_tests_for_array_contexts( [PytestPyOpenCLArrayContextFactory, PytestPytatoPyOpenCLArrayContextFactory, - PytestNumpyArrayContextFactory]) + PytestNumpyArrayContextFactory, + PytestEagerJAXArrayContextFactory]) # {{{ inverse metric diff --git a/test/test_modal_connections.py b/test/test_modal_connections.py index a7ae866b..38847c59 100644 --- a/test/test_modal_connections.py +++ b/test/test_modal_connections.py @@ -26,12 +26,18 @@ from arraycontext import ArrayContextFactory, pytest_generate_tests_for_array_contexts -from grudge.array_context import PytestPyOpenCLArrayContextFactory +from grudge.array_context import ( + PytestEagerJAXArrayContextFactory, + PytestNumpyArrayContextFactory, + PytestPyOpenCLArrayContextFactory, +) from grudge.discretization import make_discretization_collection pytest_generate_tests = pytest_generate_tests_for_array_contexts( - [PytestPyOpenCLArrayContextFactory]) + [PytestPyOpenCLArrayContextFactory, + PytestNumpyArrayContextFactory, + PytestEagerJAXArrayContextFactory]) import pytest diff --git a/test/test_mpi_communication.py b/test/test_mpi_communication.py index 8a71c725..3fa4fcae 100644 --- a/test/test_mpi_communication.py +++ b/test/test_mpi_communication.py @@ -38,7 +38,12 @@ from meshmode.dof_array import flat_norm from grudge import dof_desc, op -from grudge.array_context import MPIPyOpenCLArrayContext, MPIPytatoArrayContext +from grudge.array_context import ( + MPIEagerJAXArrayContext, + MPINumpyArrayContext, + MPIPyOpenCLArrayContext, + MPIPytatoArrayContext, +) from grudge.discretization import make_discretization_collection from grudge.shortcuts import compiled_lsrk45_step @@ -52,7 +57,8 @@ class SimpleTag: # {{{ mpi test infrastructure -DISTRIBUTED_ACTXS = [MPIPyOpenCLArrayContext, MPIPytatoArrayContext] +DISTRIBUTED_ACTXS = [MPIPyOpenCLArrayContext, MPIPytatoArrayContext, + MPIEagerJAXArrayContext, MPINumpyArrayContext] def run_test_with_mpi(num_ranks, f, *args): @@ -90,6 +96,10 @@ def run_test_with_mpi_inner(): actx = actx_class(comm, queue, mpi_base_tag=15000) elif actx_class is MPIPyOpenCLArrayContext: actx = actx_class(comm, queue) + elif actx_class is MPIEagerJAXArrayContext: + actx = actx_class(comm) + elif actx_class is MPINumpyArrayContext: + actx = actx_class(comm) else: raise ValueError("unknown actx_class") diff --git a/test/test_op.py b/test/test_op.py index 18c2f2a2..588797ec 100644 --- a/test/test_op.py +++ b/test/test_op.py @@ -39,7 +39,11 @@ from meshmode.mesh import BTAG_ALL from grudge import geometry, op -from grudge.array_context import PytestPyOpenCLArrayContextFactory +from grudge.array_context import ( + PytestEagerJAXArrayContextFactory, + PytestNumpyArrayContextFactory, + PytestPyOpenCLArrayContextFactory, +) from grudge.discretization import make_discretization_collection from grudge.dof_desc import ( DISCR_TAG_BASE, @@ -55,7 +59,9 @@ logger = logging.getLogger(__name__) pytest_generate_tests = pytest_generate_tests_for_array_contexts( - [PytestPyOpenCLArrayContextFactory]) + [PytestPyOpenCLArrayContextFactory, + PytestNumpyArrayContextFactory, + PytestEagerJAXArrayContextFactory]) # {{{ gradient diff --git a/test/test_reductions.py b/test/test_reductions.py index 03417c20..f1b03e0d 100644 --- a/test/test_reductions.py +++ b/test/test_reductions.py @@ -44,13 +44,19 @@ from meshmode.dof_array import DOFArray from grudge import op -from grudge.array_context import PytestPyOpenCLArrayContextFactory +from grudge.array_context import ( + PytestEagerJAXArrayContextFactory, + PytestNumpyArrayContextFactory, + PytestPyOpenCLArrayContextFactory, +) from grudge.discretization import make_discretization_collection logger = logging.getLogger(__name__) pytest_generate_tests = pytest_generate_tests_for_array_contexts( - [PytestPyOpenCLArrayContextFactory]) + [PytestPyOpenCLArrayContextFactory, + PytestNumpyArrayContextFactory, + PytestEagerJAXArrayContextFactory]) @pytest.mark.parametrize(("mesh_size", "with_initial"), [ diff --git a/test/test_tools.py b/test/test_tools.py index f7823b39..9fcfc8a3 100644 --- a/test/test_tools.py +++ b/test/test_tools.py @@ -37,16 +37,12 @@ pytest_generate_tests = pytest_generate_tests_for_array_contexts( [PytestPyOpenCLArrayContextFactory]) -import logging import pytest import pytools.obj_array as obj_array -logger = logging.getLogger(__name__) - - # {{{ map_subarrays and rec_map_subarrays @dataclass(frozen=True, eq=True) diff --git a/test/test_trace_pair.py b/test/test_trace_pair.py index a49385a3..b3c86c8b 100644 --- a/test/test_trace_pair.py +++ b/test/test_trace_pair.py @@ -32,14 +32,20 @@ from arraycontext import ArrayContextFactory, pytest_generate_tests_for_array_contexts from meshmode.dof_array import DOFArray -from grudge.array_context import PytestPyOpenCLArrayContextFactory +from grudge.array_context import ( + PytestEagerJAXArrayContextFactory, + PytestNumpyArrayContextFactory, + PytestPyOpenCLArrayContextFactory, +) from grudge.discretization import make_discretization_collection from grudge.trace_pair import TracePair logger = logging.getLogger(__name__) pytest_generate_tests = pytest_generate_tests_for_array_contexts( - [PytestPyOpenCLArrayContextFactory]) + [PytestPyOpenCLArrayContextFactory, + PytestNumpyArrayContextFactory, + PytestEagerJAXArrayContextFactory]) def test_trace_pair(actx_factory: ArrayContextFactory):