From 60fbd73cb5502abe3d57653439ac35ba80b58ecb Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 23 May 2025 11:36:26 -0500 Subject: [PATCH 01/14] add MPIEagerJaxArrayContext --- grudge/array_context.py | 34 +++++++++++++++++++++++++++++++++- test/test_dt_utils.py | 4 +++- test/test_metrics.py | 4 +++- 3 files changed, 39 insertions(+), 3 deletions(-) diff --git a/grudge/array_context.py b/grudge/array_context.py index 908c9cb00..fadc95634 100644 --- a/grudge/array_context.py +++ b/grudge/array_context.py @@ -76,10 +76,11 @@ _HAVE_FUSION_ACTX = False -from arraycontext import ArrayContext, NumpyArrayContext +from arraycontext import ArrayContext, EagerJAXArrayContext, NumpyArrayContext from arraycontext.container import ArrayContainer from arraycontext.impl.pytato.compile import LazilyPyOpenCLCompilingFunctionCaller from arraycontext.pytest import ( + _PytestEagerJaxArrayContextFactory, _PytestNumpyArrayContextFactory, _PytestPyOpenCLArrayContextFactoryWithClass, _PytestPytatoPyOpenCLArrayContextFactory, @@ -428,6 +429,26 @@ def clone(self) -> Self: # }}} +# {{{ distributed + eager jax + +class MPIEagerJaxArrayContext(EagerJAXArrayContext, MPIBasedArrayContext): + """An array context for using distributed computation with :mod:`jax` + eager evaluation. + + .. autofunction:: __init__ + """ + + def __init__(self, mpi_communicator) -> None: + super().__init__() + + self.mpi_communicator = mpi_communicator + + def clone(self) -> Self: + return type(self)(self.mpi_communicator) + +# }}} + + # {{{ distributed + pytato array context subclasses class MPIBasePytatoPyOpenCLArrayContext( @@ -521,12 +542,23 @@ def __call__(self): return self.actx_class() +class PytestEagerJAXArrayContextFactory(_PytestEagerJaxArrayContextFactory): + actx_class = EagerJAXArrayContext + + def __call__(self): + import jax + jax.config.update("jax_enable_x64", True) + return self.actx_class() + + register_pytest_array_context_factory("grudge.pyopencl", PytestPyOpenCLArrayContextFactory) register_pytest_array_context_factory("grudge.pytato-pyopencl", PytestPytatoPyOpenCLArrayContextFactory) register_pytest_array_context_factory("grudge.numpy", PytestNumpyArrayContextFactory) +register_pytest_array_context_factory("grudge.eager-jax", + PytestEagerJAXArrayContextFactory) # }}} diff --git a/test/test_dt_utils.py b/test/test_dt_utils.py index cf3ac2021..81b1d09da 100644 --- a/test/test_dt_utils.py +++ b/test/test_dt_utils.py @@ -27,6 +27,7 @@ from arraycontext import pytest_generate_tests_for_array_contexts from grudge.array_context import ( + PytestEagerJAXArrayContextFactory, PytestNumpyArrayContextFactory, PytestPyOpenCLArrayContextFactory, PytestPytatoPyOpenCLArrayContextFactory, @@ -36,7 +37,8 @@ pytest_generate_tests = pytest_generate_tests_for_array_contexts( [PytestPyOpenCLArrayContextFactory, PytestPytatoPyOpenCLArrayContextFactory, - PytestNumpyArrayContextFactory]) + PytestNumpyArrayContextFactory, + PytestEagerJAXArrayContextFactory]) import logging diff --git a/test/test_metrics.py b/test/test_metrics.py index 1ee043b8a..5b21f15b8 100644 --- a/test/test_metrics.py +++ b/test/test_metrics.py @@ -33,6 +33,7 @@ from meshmode.dof_array import flat_norm from grudge.array_context import ( + PytestEagerJAXArrayContextFactory, PytestNumpyArrayContextFactory, PytestPyOpenCLArrayContextFactory, PytestPytatoPyOpenCLArrayContextFactory, @@ -44,7 +45,8 @@ pytest_generate_tests = pytest_generate_tests_for_array_contexts( [PytestPyOpenCLArrayContextFactory, PytestPytatoPyOpenCLArrayContextFactory, - PytestNumpyArrayContextFactory]) + PytestNumpyArrayContextFactory, + PytestEagerJAXArrayContextFactory]) # {{{ inverse metric From 07584bf435496c44bc530dcb8ea39482e069b26a Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 27 May 2025 15:46:26 -0500 Subject: [PATCH 02/14] test with jax --- .test-conda-env-py3.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.test-conda-env-py3.yml b/.test-conda-env-py3.yml index 3a4f0a8f6..c4981063f 100644 --- a/.test-conda-env-py3.yml +++ b/.test-conda-env-py3.yml @@ -12,6 +12,7 @@ dependencies: - pyopencl - python=3 - gmsh +- jax # test scripts use ompi-specific arguments - openmpi From ec629d895a7300e5c784601a552cab362a4a0027 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 28 May 2025 15:06:07 -0500 Subject: [PATCH 03/14] reenable more tests --- test/test_euler_model.py | 10 ++++++++-- test/test_grudge.py | 10 ++++++++-- test/test_modal_connections.py | 6 ++++-- test/test_op.py | 6 ++++-- test/test_reductions.py | 6 ++++-- test/test_tools.py | 13 ------------- test/test_trace_pair.py | 6 ++++-- 7 files changed, 32 insertions(+), 25 deletions(-) diff --git a/test/test_euler_model.py b/test/test_euler_model.py index 6e9e9c1d2..787f5b035 100644 --- a/test/test_euler_model.py +++ b/test/test_euler_model.py @@ -31,12 +31,18 @@ ) from grudge import op -from grudge.array_context import PytestPyOpenCLArrayContextFactory +from grudge.array_context import ( + PytestEagerJAXArrayContextFactory, + PytestNumpyArrayContextFactory, + PytestPyOpenCLArrayContextFactory, +) logger = logging.getLogger(__name__) pytest_generate_tests = pytest_generate_tests_for_array_contexts( - [PytestPyOpenCLArrayContextFactory]) + [PytestPyOpenCLArrayContextFactory, + PytestNumpyArrayContextFactory, + PytestEagerJAXArrayContextFactory]) @pytest.mark.parametrize("order", [1, 2, 3]) diff --git a/test/test_grudge.py b/test/test_grudge.py index b0849310b..de3904a47 100644 --- a/test/test_grudge.py +++ b/test/test_grudge.py @@ -42,13 +42,19 @@ from pytools.obj_array import flat_obj_array from grudge import dof_desc, geometry, op -from grudge.array_context import PytestPyOpenCLArrayContextFactory +from grudge.array_context import ( + PytestEagerJAXArrayContextFactory, + PytestNumpyArrayContextFactory, + PytestPyOpenCLArrayContextFactory, +) from grudge.discretization import make_discretization_collection logger = logging.getLogger(__name__) pytest_generate_tests = pytest_generate_tests_for_array_contexts( - [PytestPyOpenCLArrayContextFactory]) + [PytestPyOpenCLArrayContextFactory, + PytestNumpyArrayContextFactory, + PytestEagerJAXArrayContextFactory]) # {{{ mass operator trig integration diff --git a/test/test_modal_connections.py b/test/test_modal_connections.py index d6bebdbd5..f5c99fa6e 100644 --- a/test/test_modal_connections.py +++ b/test/test_modal_connections.py @@ -23,12 +23,14 @@ from arraycontext import pytest_generate_tests_for_array_contexts -from grudge.array_context import PytestPyOpenCLArrayContextFactory +from grudge.array_context import PytestPyOpenCLArrayContextFactory, PytestNumpyArrayContextFactory, PytestEagerJAXArrayContextFactory from grudge.discretization import make_discretization_collection pytest_generate_tests = pytest_generate_tests_for_array_contexts( - [PytestPyOpenCLArrayContextFactory]) + [PytestPyOpenCLArrayContextFactory, + PytestNumpyArrayContextFactory, + PytestEagerJAXArrayContextFactory]) import pytest diff --git a/test/test_op.py b/test/test_op.py index 17f49a074..011e9cbb1 100644 --- a/test/test_op.py +++ b/test/test_op.py @@ -36,7 +36,7 @@ from pytools.obj_array import make_obj_array from grudge import geometry, op -from grudge.array_context import PytestPyOpenCLArrayContextFactory +from grudge.array_context import PytestPyOpenCLArrayContextFactory, PytestNumpyArrayContextFactory, PytestEagerJAXArrayContextFactory from grudge.discretization import make_discretization_collection from grudge.dof_desc import ( DISCR_TAG_BASE, @@ -52,7 +52,9 @@ logger = logging.getLogger(__name__) pytest_generate_tests = pytest_generate_tests_for_array_contexts( - [PytestPyOpenCLArrayContextFactory]) + [PytestPyOpenCLArrayContextFactory, + PytestNumpyArrayContextFactory, + PytestEagerJAXArrayContextFactory]) # {{{ gradient diff --git a/test/test_reductions.py b/test/test_reductions.py index b8aeec14b..37df46181 100644 --- a/test/test_reductions.py +++ b/test/test_reductions.py @@ -39,13 +39,15 @@ from pytools.obj_array import make_obj_array from grudge import op -from grudge.array_context import PytestPyOpenCLArrayContextFactory +from grudge.array_context import PytestPyOpenCLArrayContextFactory, PytestNumpyArrayContextFactory, PytestEagerJAXArrayContextFactory from grudge.discretization import make_discretization_collection logger = logging.getLogger(__name__) pytest_generate_tests = pytest_generate_tests_for_array_contexts( - [PytestPyOpenCLArrayContextFactory]) + [PytestPyOpenCLArrayContextFactory, + PytestNumpyArrayContextFactory, + PytestEagerJAXArrayContextFactory]) @pytest.mark.parametrize(("mesh_size", "with_initial"), [ diff --git a/test/test_tools.py b/test/test_tools.py index efd5747d5..fe15b977e 100644 --- a/test/test_tools.py +++ b/test/test_tools.py @@ -27,24 +27,11 @@ import numpy as np import numpy.linalg as la # noqa -from arraycontext import pytest_generate_tests_for_array_contexts - -from grudge.array_context import PytestPyOpenCLArrayContextFactory - - -pytest_generate_tests = pytest_generate_tests_for_array_contexts( - [PytestPyOpenCLArrayContextFactory]) - -import logging - import pytest from pytools.obj_array import make_obj_array -logger = logging.getLogger(__name__) - - # {{{ map_subarrays and rec_map_subarrays @dataclass(frozen=True, eq=True) diff --git a/test/test_trace_pair.py b/test/test_trace_pair.py index 5cb060c7e..f4f2cf8ba 100644 --- a/test/test_trace_pair.py +++ b/test/test_trace_pair.py @@ -29,14 +29,16 @@ from arraycontext import pytest_generate_tests_for_array_contexts from meshmode.dof_array import DOFArray -from grudge.array_context import PytestPyOpenCLArrayContextFactory +from grudge.array_context import PytestPyOpenCLArrayContextFactory, PytestNumpyArrayContextFactory, PytestEagerJAXArrayContextFactory from grudge.discretization import make_discretization_collection from grudge.trace_pair import TracePair logger = logging.getLogger(__name__) pytest_generate_tests = pytest_generate_tests_for_array_contexts( - [PytestPyOpenCLArrayContextFactory]) + [PytestPyOpenCLArrayContextFactory, + PytestNumpyArrayContextFactory, + PytestEagerJAXArrayContextFactory]) def test_trace_pair(actx_factory): From e1f9379ae18c8b49b1a919aeeb6e74a2d218c2a1 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 28 May 2025 15:10:43 -0500 Subject: [PATCH 04/14] enable MPI test --- grudge/array_context.py | 2 +- test/test_mpi_communication.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/grudge/array_context.py b/grudge/array_context.py index fadc95634..36b4c121b 100644 --- a/grudge/array_context.py +++ b/grudge/array_context.py @@ -431,7 +431,7 @@ def clone(self) -> Self: # {{{ distributed + eager jax -class MPIEagerJaxArrayContext(EagerJAXArrayContext, MPIBasedArrayContext): +class MPIEagerJAXArrayContext(EagerJAXArrayContext, MPIBasedArrayContext): """An array context for using distributed computation with :mod:`jax` eager evaluation. diff --git a/test/test_mpi_communication.py b/test/test_mpi_communication.py index 5b66be068..496dfd956 100644 --- a/test/test_mpi_communication.py +++ b/test/test_mpi_communication.py @@ -35,7 +35,7 @@ from pytools.obj_array import flat_obj_array from grudge import dof_desc, op -from grudge.array_context import MPIPyOpenCLArrayContext, MPIPytatoArrayContext +from grudge.array_context import MPIPyOpenCLArrayContext, MPIPytatoArrayContext, MPIEagerJAXArrayContext from grudge.discretization import make_discretization_collection from grudge.shortcuts import compiled_lsrk45_step @@ -49,7 +49,7 @@ class SimpleTag: # {{{ mpi test infrastructure -DISTRIBUTED_ACTXS = [MPIPyOpenCLArrayContext, MPIPytatoArrayContext] +DISTRIBUTED_ACTXS = [MPIPyOpenCLArrayContext, MPIPytatoArrayContext, MPIEagerJAXArrayContext] def run_test_with_mpi(num_ranks, f, *args): @@ -87,6 +87,8 @@ def run_test_with_mpi_inner(): actx = actx_class(comm, queue, mpi_base_tag=15000) elif actx_class is MPIPyOpenCLArrayContext: actx = actx_class(comm, queue) + elif actx_class is MPIEagerJAXArrayContext: + actx = actx_class(comm) else: raise ValueError("unknown actx_class") From 419e534ba6b9cfdaa8f868d1b1cf46b0116603eb Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 28 May 2025 15:47:24 -0500 Subject: [PATCH 05/14] ruff --- test/test_modal_connections.py | 6 +++++- test/test_mpi_communication.py | 9 +++++++-- test/test_op.py | 6 +++++- test/test_reductions.py | 6 +++++- test/test_tools.py | 1 - test/test_trace_pair.py | 6 +++++- 6 files changed, 27 insertions(+), 7 deletions(-) diff --git a/test/test_modal_connections.py b/test/test_modal_connections.py index f5c99fa6e..695a7d403 100644 --- a/test/test_modal_connections.py +++ b/test/test_modal_connections.py @@ -23,7 +23,11 @@ from arraycontext import pytest_generate_tests_for_array_contexts -from grudge.array_context import PytestPyOpenCLArrayContextFactory, PytestNumpyArrayContextFactory, PytestEagerJAXArrayContextFactory +from grudge.array_context import ( + PytestEagerJAXArrayContextFactory, + PytestNumpyArrayContextFactory, + PytestPyOpenCLArrayContextFactory, +) from grudge.discretization import make_discretization_collection diff --git a/test/test_mpi_communication.py b/test/test_mpi_communication.py index 496dfd956..ceda62e19 100644 --- a/test/test_mpi_communication.py +++ b/test/test_mpi_communication.py @@ -35,7 +35,11 @@ from pytools.obj_array import flat_obj_array from grudge import dof_desc, op -from grudge.array_context import MPIPyOpenCLArrayContext, MPIPytatoArrayContext, MPIEagerJAXArrayContext +from grudge.array_context import ( + MPIEagerJAXArrayContext, + MPIPyOpenCLArrayContext, + MPIPytatoArrayContext, +) from grudge.discretization import make_discretization_collection from grudge.shortcuts import compiled_lsrk45_step @@ -49,7 +53,8 @@ class SimpleTag: # {{{ mpi test infrastructure -DISTRIBUTED_ACTXS = [MPIPyOpenCLArrayContext, MPIPytatoArrayContext, MPIEagerJAXArrayContext] +DISTRIBUTED_ACTXS = [MPIPyOpenCLArrayContext, MPIPytatoArrayContext, + MPIEagerJAXArrayContext] def run_test_with_mpi(num_ranks, f, *args): diff --git a/test/test_op.py b/test/test_op.py index 011e9cbb1..ff25a4f81 100644 --- a/test/test_op.py +++ b/test/test_op.py @@ -36,7 +36,11 @@ from pytools.obj_array import make_obj_array from grudge import geometry, op -from grudge.array_context import PytestPyOpenCLArrayContextFactory, PytestNumpyArrayContextFactory, PytestEagerJAXArrayContextFactory +from grudge.array_context import ( + PytestEagerJAXArrayContextFactory, + PytestNumpyArrayContextFactory, + PytestPyOpenCLArrayContextFactory, +) from grudge.discretization import make_discretization_collection from grudge.dof_desc import ( DISCR_TAG_BASE, diff --git a/test/test_reductions.py b/test/test_reductions.py index 37df46181..7e735c610 100644 --- a/test/test_reductions.py +++ b/test/test_reductions.py @@ -39,7 +39,11 @@ from pytools.obj_array import make_obj_array from grudge import op -from grudge.array_context import PytestPyOpenCLArrayContextFactory, PytestNumpyArrayContextFactory, PytestEagerJAXArrayContextFactory +from grudge.array_context import ( + PytestEagerJAXArrayContextFactory, + PytestNumpyArrayContextFactory, + PytestPyOpenCLArrayContextFactory, +) from grudge.discretization import make_discretization_collection diff --git a/test/test_tools.py b/test/test_tools.py index fe15b977e..eaf341894 100644 --- a/test/test_tools.py +++ b/test/test_tools.py @@ -26,7 +26,6 @@ import numpy as np import numpy.linalg as la # noqa - import pytest from pytools.obj_array import make_obj_array diff --git a/test/test_trace_pair.py b/test/test_trace_pair.py index f4f2cf8ba..d493f3c3e 100644 --- a/test/test_trace_pair.py +++ b/test/test_trace_pair.py @@ -29,7 +29,11 @@ from arraycontext import pytest_generate_tests_for_array_contexts from meshmode.dof_array import DOFArray -from grudge.array_context import PytestPyOpenCLArrayContextFactory, PytestNumpyArrayContextFactory, PytestEagerJAXArrayContextFactory +from grudge.array_context import ( + PytestEagerJAXArrayContextFactory, + PytestNumpyArrayContextFactory, + PytestPyOpenCLArrayContextFactory, +) from grudge.discretization import make_discretization_collection from grudge.trace_pair import TracePair From 14422b6ff287a255021fe4da6adb29b6d4d0bad7 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 28 May 2025 15:48:00 -0500 Subject: [PATCH 06/14] rewrite attempt _signed_face_ones --- grudge/geometry/metrics.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/grudge/geometry/metrics.py b/grudge/geometry/metrics.py index 952a03a1f..e978ac452 100644 --- a/grudge/geometry/metrics.py +++ b/grudge/geometry/metrics.py @@ -566,15 +566,21 @@ def _signed_face_ones( signed_face_ones_numpy = actx.to_numpy(signed_ones) + new_group_arrays = [] for igrp, grp in enumerate(all_faces_conn.groups): + grp_field = signed_face_ones_numpy[igrp] + sign_mask = np.ones_like(grp_field) + for batch in grp.batches: assert batch.to_element_face is not None i = actx.to_numpy(actx.thaw(batch.to_element_indices)) - grp_field = signed_face_ones_numpy[igrp].reshape(-1) - grp_field[i] = \ - (2.0 * (batch.to_element_face % 2) - 1.0) * grp_field[i] + sign = (2.0 * (batch.to_element_face % 2) - 1.0) + sign_mask[i, :] = sign + + new_group_arrays.append(grp_field * sign_mask) - return actx.from_numpy(signed_face_ones_numpy) + from meshmode.dof_array import DOFArray + return actx.from_numpy(DOFArray(actx, tuple(new_group_arrays))) def parametrization_derivative( From 4bdffa1882816c2d5612093d602343809d867288 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 28 May 2025 16:07:50 -0500 Subject: [PATCH 07/14] also test MPINumpyArrayContext --- test/test_mpi_communication.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/test_mpi_communication.py b/test/test_mpi_communication.py index ceda62e19..4a284c36f 100644 --- a/test/test_mpi_communication.py +++ b/test/test_mpi_communication.py @@ -37,6 +37,7 @@ from grudge import dof_desc, op from grudge.array_context import ( MPIEagerJAXArrayContext, + MPINumpyArrayContext, MPIPyOpenCLArrayContext, MPIPytatoArrayContext, ) @@ -54,7 +55,7 @@ class SimpleTag: # {{{ mpi test infrastructure DISTRIBUTED_ACTXS = [MPIPyOpenCLArrayContext, MPIPytatoArrayContext, - MPIEagerJAXArrayContext] + MPIEagerJAXArrayContext, MPINumpyArrayContext] def run_test_with_mpi(num_ranks, f, *args): @@ -94,6 +95,8 @@ def run_test_with_mpi_inner(): actx = actx_class(comm, queue) elif actx_class is MPIEagerJAXArrayContext: actx = actx_class(comm) + elif actx_class is MPINumpyArrayContext: + actx = actx_class(comm) else: raise ValueError("unknown actx_class") From 120e80211dee873f74a78ac3416ca4e68c78fdc0 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 28 May 2025 16:41:42 -0500 Subject: [PATCH 08/14] debug --- .github/workflows/ci.yml | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ed8346974..334bc5261 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -68,8 +68,15 @@ jobs: - uses: actions/checkout@v4 - name: "Main Script" run: | - curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-and-test-py-project-within-miniconda.sh - . ./build-and-test-py-project-within-miniconda.sh + set -x + curl -L -O https://tiker.net/ci-support-v0 + + sed -i 's/export PYTEST_ADDOPTS=.*/export PYTEST_ADDOPTS="-v"/' ci-support-v0 + + source ci-support-v0 + + build_py_project_in_conda_env + test_py_project pyexamples3: name: Examples on Py3 From ef1b0b155fddd8e3d97412f492f38199aedb61bf Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 28 May 2025 17:14:01 -0500 Subject: [PATCH 09/14] Revert "debug" This reverts commit 120e80211dee873f74a78ac3416ca4e68c78fdc0. --- .github/workflows/ci.yml | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 334bc5261..ed8346974 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -68,15 +68,8 @@ jobs: - uses: actions/checkout@v4 - name: "Main Script" run: | - set -x - curl -L -O https://tiker.net/ci-support-v0 - - sed -i 's/export PYTEST_ADDOPTS=.*/export PYTEST_ADDOPTS="-v"/' ci-support-v0 - - source ci-support-v0 - - build_py_project_in_conda_env - test_py_project + curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-and-test-py-project-within-miniconda.sh + . ./build-and-test-py-project-within-miniconda.sh pyexamples3: name: Examples on Py3 From 850d763030d80bbdeee4a552d9c90383c7f717fb Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 28 May 2025 17:57:00 -0500 Subject: [PATCH 10/14] add to docs --- grudge/array_context.py | 1 + 1 file changed, 1 insertion(+) diff --git a/grudge/array_context.py b/grudge/array_context.py index 36b4c121b..647b01188 100644 --- a/grudge/array_context.py +++ b/grudge/array_context.py @@ -5,6 +5,7 @@ .. autoclass:: MPIPyOpenCLArrayContext .. autoclass:: MPINumpyArrayContext .. class:: MPIPytatoArrayContext +.. autoclass:: MPIEagerJAXArrayContext .. autofunction:: get_reasonable_array_context_class """ From 3107b16419ab13f2291805a34b165f6675f362c2 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 28 May 2025 18:07:41 -0500 Subject: [PATCH 11/14] "add" intersphinx source --- doc/conf.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/conf.py b/doc/conf.py index 9f475cc58..6bb2dc47b 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -19,6 +19,7 @@ intersphinx_mapping = { "arraycontext": ("https://documen.tician.de/arraycontext/", None), "loopy": ("https://documen.tician.de/loopy/", None), + "jax": ("https://docs.jax.dev/en/latest/", None), "meshmode": ("https://documen.tician.de/meshmode/", None), "modepy": ("https://documen.tician.de/modepy/", None), "mpi4py": ("https://mpi4py.readthedocs.io/en/stable", None), @@ -32,6 +33,7 @@ os.environ["PYOPENCL_TEST"] = "port:cpu" nitpick_ignore_regex = [ + ["py:mod", r"jax"], # FIXME: not sure why this does not work ["py:class", r"np\.ndarray"], ["py:data|py:class", r"arraycontext.*ContainerTc"], ] From 85a0d5494dc283165d7c6c9b1f6f46c7057603d3 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 29 May 2025 15:24:57 -0500 Subject: [PATCH 12/14] rewrite _signed_face_ones again Co-authored-by: Alex Fikl --- grudge/geometry/metrics.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/grudge/geometry/metrics.py b/grudge/geometry/metrics.py index e978ac452..85c05836f 100644 --- a/grudge/geometry/metrics.py +++ b/grudge/geometry/metrics.py @@ -560,24 +560,21 @@ def _signed_face_ones( dd_base.untrace(), dd_base ) assert isinstance(all_faces_conn, DirectDiscretizationConnection) - signed_ones = dcoll.discr_from_dd(dd.with_discr_tag(DISCR_TAG_BASE)).zeros( - actx, dtype=dcoll.real_dtype - ) + 1 - signed_face_ones_numpy = actx.to_numpy(signed_ones) + discr = dcoll.discr_from_dd(dd.with_discr_tag(DISCR_TAG_BASE)) new_group_arrays = [] - for igrp, grp in enumerate(all_faces_conn.groups): - grp_field = signed_face_ones_numpy[igrp] - sign_mask = np.ones_like(grp_field) + + for dgrp, grp in zip(discr.groups, all_faces_conn.groups): + sign = np.ones((dgrp.nelements, dgrp.nunit_dofs), + dtype=discr.real_dtype) for batch in grp.batches: assert batch.to_element_face is not None i = actx.to_numpy(actx.thaw(batch.to_element_indices)) - sign = (2.0 * (batch.to_element_face % 2) - 1.0) - sign_mask[i, :] = sign + sign[i, :] = 2.0 * (batch.to_element_face % 2) - 1.0 - new_group_arrays.append(grp_field * sign_mask) + new_group_arrays.append(sign) from meshmode.dof_array import DOFArray return actx.from_numpy(DOFArray(actx, tuple(new_group_arrays))) From e06ee242b3e1c33193cd08ba390083e332b0ae4f Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 29 May 2025 15:28:31 -0500 Subject: [PATCH 13/14] add strict parameter --- grudge/geometry/metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grudge/geometry/metrics.py b/grudge/geometry/metrics.py index 85c05836f..ba154b0a7 100644 --- a/grudge/geometry/metrics.py +++ b/grudge/geometry/metrics.py @@ -565,7 +565,7 @@ def _signed_face_ones( new_group_arrays = [] - for dgrp, grp in zip(discr.groups, all_faces_conn.groups): + for dgrp, grp in zip(discr.groups, all_faces_conn.groups, strict=True): sign = np.ones((dgrp.nelements, dgrp.nunit_dofs), dtype=discr.real_dtype) From 7be0ea2d31370e7c07efb1249af630ef70a9dd2f Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 26 Aug 2025 09:27:11 -0500 Subject: [PATCH 14/14] ruff --- test/test_tools.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_tools.py b/test/test_tools.py index c224193f4..9fcfc8a31 100644 --- a/test/test_tools.py +++ b/test/test_tools.py @@ -37,7 +37,6 @@ pytest_generate_tests = pytest_generate_tests_for_array_contexts( [PytestPyOpenCLArrayContextFactory]) -import logging import pytest