From e956adb76c631e541064aa9100e9498810cb316a Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Tue, 9 Sep 2025 12:35:48 +0100 Subject: [PATCH 001/125] Matfree adjoint interpolation --- firedrake/assemble.py | 12 +-- firedrake/interpolation.py | 88 ++++++++++++++----- .../firedrake/regression/test_interpolate.py | 3 + tsfc/driver.py | 32 +++++-- 4 files changed, 103 insertions(+), 32 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 9ca4d17237..19ec557f34 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -591,17 +591,19 @@ def base_form_assembly_visitor(self, expr, tensor, *args): _, v1 = expr.arguments() operand = ufl.replace(operand, {v1: v1.reconstruct(number=0)}) # Get the interpolator - interp_data = expr.interp_data + interp_data = expr.interp_data.copy() default_missing_val = interp_data.pop('default_missing_val', None) - interpolator = firedrake.Interpolator(operand, V, **interp_data) + if (is_adjoint and rank == 1) or rank == 0: + interp_data["access"] = op2.INC + interpolator = firedrake.Interpolator(operand, v, **interp_data) # Assembly if rank == 0: - Iu = interpolator._interpolate(default_missing_val=default_missing_val) - return assemble(ufl.Action(v, Iu), tensor=tensor) + result = interpolator._interpolate(output=tensor, default_missing_val=default_missing_val) + return result.dat.data.item() if tensor is None else result elif rank == 1: # Assembling the action of the Jacobian adjoint. if is_adjoint: - return interpolator._interpolate(v, output=tensor, adjoint=True, default_missing_val=default_missing_val) + return interpolator._interpolate(v, output=tensor, default_missing_val=default_missing_val) # Assembling the Jacobian action. elif interpolator.nargs: return interpolator._interpolate(operand, output=tensor, default_missing_val=default_missing_val) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index d57758e624..a2dc1a35dc 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -552,6 +552,8 @@ def _interpolate( else: if isinstance(self.V, (firedrake.Function, firedrake.Cofunction)): V_dest = self.V.function_space() + elif isinstance(self.V, firedrake.Coargument): + V_dest = self.V.function_space().dual() else: V_dest = self.V if output: @@ -677,9 +679,10 @@ class SameMeshInterpolator(Interpolator): def __init__(self, expr, V, subset=None, freeze_expr=False, access=op2.WRITE, bcs=None, matfree=True, allow_missing_dofs=False, **kwargs): if subset is None: - target = V.function_space().mesh().topology if isinstance(V, firedrake.Function) else V.mesh().topology - temp = extract_unique_domain(expr) - source = target if temp is None else temp.topology + target_mesh = as_domain(V) + source_mesh = extract_unique_domain(expr) + target = target_mesh.topology + source = target if source_mesh is None else source_mesh.topology if all(isinstance(m, firedrake.mesh.MeshTopology) for m in [target, source]) and target is not source: composed_map, result_integral_type = source.trans_mesh_entity_map(target, "cell", "everywhere", None) if result_integral_type != "cell": @@ -732,7 +735,7 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, ** # Interpolation action self.frozen_assembled_interpolator = assembled_interpolator.copy() - if self.nargs: + if self.nargs == 2: function, = function if not hasattr(function, "dat"): raise ValueError("The expression had arguments: we therefore need to be given a Function (not an expression) to interpolate!") @@ -770,20 +773,37 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, ** @PETSc.Log.EventDecorator() def make_interpolator(expr, V, subset, access, bcs=None, matfree=True): assert isinstance(expr, ufl.classes.Expr) + + if isinstance(V, (ufl.Coargument, ufl.Cofunction)): + dual_arg = V + V = dual_arg.function_space().dual() + elif isinstance(V, (ufl.FunctionSpace, ufl.Coefficient)): + fs = V if isinstance(V, ufl.FunctionSpace) else V.function_space() + dual_arg = Coargument(fs.dual(), number=0) + arguments = extract_arguments(expr) + if isinstance(dual_arg, ufl.Coargument): + arguments.append(dual_arg) + rank = len(arguments) + target_mesh = as_domain(V) - if len(arguments) == 0: + if rank <= 1: source_mesh = extract_unique_domain(expr) or target_mesh vom_onto_other_vom = ( isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology) and isinstance(source_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology) and target_mesh is not source_mesh ) - if isinstance(V, firedrake.Function): + + if rank == 0: + R = firedrake.FunctionSpace(target_mesh, "Real", 0) + f = firedrake.Function(R) + elif isinstance(V, firedrake.Function): f = V V = f.function_space() else: - f = firedrake.Function(V) + V_dest = arguments[-1].function_space().dual() + f = firedrake.Function(V_dest) if access in {firedrake.MIN, firedrake.MAX}: finfo = numpy.finfo(f.dat.dtype) if access == firedrake.MIN: @@ -792,11 +812,12 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True): val = firedrake.Constant(finfo.min) f.assign(val) tensor = f.dat - elif len(arguments) == 1: + elif rank == 2: if isinstance(V, firedrake.Function): raise ValueError("Cannot interpolate an expression with an argument into a Function") if len(V) > 1: raise NotImplementedError("Interpolation of mixed expressions with arguments is not supported") + argfs = arguments[0].function_space() source_mesh = argfs.mesh() argfs_map = argfs.cell_node_map() @@ -840,7 +861,7 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True): tensor = op2.Mat(sparsity) f = tensor else: - raise ValueError("Cannot interpolate an expression with %d arguments" % len(arguments)) + raise ValueError("Cannot interpolate an expression with %d arguments" % rank) if vom_onto_other_vom: wrapper = VomOntoVomWrapper(V, source_mesh, target_mesh, expr, arguments, matfree) @@ -859,7 +880,7 @@ def callable(): wrapper.forward_operation(f.dat) return f else: - assert len(arguments) == 1 + assert rank == 2 assert tensor is None # we know we will be outputting either a function or a cofunction, # both of which will use a dat as a data carrier. At present, the @@ -888,7 +909,7 @@ def callable(): % (V.value_size, numpy.prod(expr.ufl_shape, dtype=int))) if len(V) == 1: - loops.extend(_interpolator(V, tensor, expr, subset, arguments, access, bcs=bcs)) + loops.extend(_interpolator(V, tensor, expr, dual_arg, subset, arguments, access, bcs=bcs)) else: if (hasattr(expr, "subfunctions") and len(expr.subfunctions) == len(V) and all(sub_expr.ufl_shape == Vsub.value_shape for Vsub, sub_expr in zip(V, expr.subfunctions))): @@ -905,11 +926,18 @@ def callable(): components = [expr[offset + j] for j in range(Vsub.value_size)] expressions.append(ufl.as_tensor(numpy.reshape(components, Vsub.value_shape))) offset += Vsub.value_size + + if isinstance(dual_arg, Cofunction): + duals = dual_arg.subfunctions + elif isinstance(dual_arg, Coargument): + duals = [Coargument(Vsub.dual(), number=dual_arg.number()) for Vsub in V] + else: + raise ValueError("dual_arg must be a Cofunction or Coargument") # Interpolate each sub expression into each function space - for Vsub, sub_tensor, sub_expr in zip(V, tensor, expressions): - loops.extend(_interpolator(Vsub, sub_tensor, sub_expr, subset, arguments, access, bcs=bcs)) + for Vsub, sub_tensor, sub_expr, sub_dual in zip(V, tensor, expressions, duals): + loops.extend(_interpolator(Vsub, sub_tensor, sub_expr, sub_dual, subset, arguments, access, bcs=bcs)) - if bcs and len(arguments) == 0: + if bcs and rank == 1: loops.extend(partial(bc.apply, f) for bc in bcs) def callable(loops, f): @@ -921,7 +949,7 @@ def callable(loops, f): @utils.known_pyop2_safe -def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): +def _interpolator(V, tensor, expr, dual_arg, subset, arguments, access, bcs=None): try: expr = ufl.as_ufl(expr) except ValueError: @@ -977,13 +1005,31 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): parameters = {} parameters['scalar_type'] = utils.ScalarType + needs_weight = isinstance(dual_arg, ufl.Cofunction) and not to_element.is_dg() + if needs_weight: + W = dual_arg.function_space() + shapes = (W.finat_element.space_dimension(), W.block_size) + domain = "{[i,j]: 0 <= i < %d and 0 <= j < %d}" % shapes + instructions = """ + for i, j + w[i,j] = w[i,j] + 1 + end + """ + weight = firedrake.Function(W) + firedrake.par_loop((domain, instructions), ufl.dx, {"w": (weight, op2.INC)}) + + tmp = firedrake.Function(W) + with weight.dat.vec as w, dual_arg.dat.vec as x, tmp.dat.vec as y: + y.pointwiseDivide(x, w) + dual_arg = tmp + # We need to pass both the ufl element and the finat element # because the finat elements might not have the right mapping # (e.g. L2 Piola, or tensor element with symmetries) # FIXME: for the runtime unknown point set (for cross-mesh # interpolation) we have to pass the finat element we construct # here. Ideally we would only pass the UFL element through. - kernel = compile_expression(cell_set.comm, expr, to_element, V.ufl_element(), + kernel = compile_expression(cell_set.comm, expr, dual_arg, to_element, V.ufl_element(), domain=source_mesh, parameters=parameters) ast = kernel.ast oriented = kernel.oriented @@ -996,7 +1042,8 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): parloop_args = [kernel, cell_set] - coefficients = tsfc_interface.extract_numbered_coefficients(expr, coefficient_numbers) + interp_expr = ufl.Interpolate(expr, dual_arg) + coefficients = tsfc_interface.extract_numbered_coefficients(interp_expr, coefficient_numbers) if needs_external_coords: coefficients = [source_mesh.coordinates] + coefficients @@ -1014,7 +1061,8 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): if isinstance(tensor, op2.Global): parloop_args.append(tensor(access)) elif isinstance(tensor, op2.Dat): - parloop_args.append(tensor(access, V.cell_node_map())) + V_dest = arguments[0].function_space() if isinstance(dual_arg, ufl.Cofunction) else V + parloop_args.append(tensor(access, V_dest.cell_node_map())) else: assert access == op2.WRITE # Other access descriptors not done for Matrices. rows_map = V.cell_node_map() @@ -1117,9 +1165,9 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): f"firedrake-tsfc-expression-kernel-cache-uid{os.getuid()}") -def _compile_expression_key(comm, expr, to_element, ufl_element, domain, parameters) -> tuple[Hashable, ...]: +def _compile_expression_key(comm, expr, dual_arg, to_element, ufl_element, domain, parameters) -> tuple[Hashable, ...]: """Generate a cache key suitable for :func:`tsfc.compile_expression_dual_evaluation`.""" - return (hash_expr(expr), hash(ufl_element), utils.tuplify(parameters)) + return (hash_expr(expr), type(dual_arg), hash(ufl_element), utils.tuplify(parameters)) @memory_and_disk_cache( diff --git a/tests/firedrake/regression/test_interpolate.py b/tests/firedrake/regression/test_interpolate.py index 3f3072a881..1bb6711f26 100644 --- a/tests/firedrake/regression/test_interpolate.py +++ b/tests/firedrake/regression/test_interpolate.py @@ -341,6 +341,7 @@ def test_adjoint_Pk(degree): v_adj_form = assemble(interpolate(TestFunction(Pk), v * dx)) + assert v_adj.function_space() == v_adj_form.function_space() assert np.allclose(v_adj_form.dat.data, v_adj.dat.data) @@ -353,6 +354,7 @@ def test_adjoint_quads(): u_P1 = assemble(conj(TestFunction(P1)) * dx) v_adj = assemble(interpolate(TestFunction(P1), assemble(v * dx))) + assert v_adj.function_space() == u_P1.function_space() assert np.allclose(u_P1.dat.data, v_adj.dat.data) @@ -365,6 +367,7 @@ def test_adjoint_dg(): u_cg = assemble(conj(TestFunction(cg1)) * dx) v_adj = assemble(interpolate(TestFunction(cg1), assemble(v * dx))) + assert v_adj.function_space() == u_cg.function_space() assert np.allclose(u_cg.dat.data, v_adj.dat.data) diff --git a/tsfc/driver.py b/tsfc/driver.py index 1cbbec56c3..55f4ab80fc 100644 --- a/tsfc/driver.py +++ b/tsfc/driver.py @@ -182,14 +182,16 @@ def preprocess_parameters(parameters): return parameters -def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, +def compile_expression_dual_evaluation(expression, dual_arg, + to_element, ufl_element, *, domain=None, interface=None, parameters=None): """Compile a UFL expression to be evaluated against a compile-time known reference element's dual basis. Useful for interpolating UFL expressions into e.g. N1curl spaces. - :arg expression: UFL expression + :arg expression: UFL expression to interpolate + :arg dual_arg: A Cofunction or Coargument to act on the interpolated expression :arg to_element: A FInAT element for the target space :arg ufl_element: The UFL element of the target space. :arg domain: optional UFL domain the expression is defined on (required when expression contains no domain). @@ -210,7 +212,11 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, if isinstance(to_element, (PhysicallyMappedElement, DirectlyDefinedElement)): raise NotImplementedError("Don't know how to interpolate onto zany spaces, sorry") - orig_expression = expression + if not isinstance(dual_arg, (ufl.Coargument, ufl.Cofunction)): + raise ValueError(f"Expecting a Coargument or Cofunction, not {type(dual_arg).__name__}") + + interp_expression = ufl.Interpolate(expression, dual_arg) + orig_expression = interp_expression # Map into reference space expression = apply_mapping(expression, ufl_element, domain) @@ -235,7 +241,7 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, assert domain is not None # Collect required coefficients and determine numbering - coefficients = extract_coefficients(expression) + coefficients = extract_coefficients(interp_expression) orig_coefficients = extract_coefficients(orig_expression) coefficient_numbers = tuple(map(orig_coefficients.index, coefficients)) builder.set_coefficient_numbers(coefficient_numbers) @@ -252,7 +258,7 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, needs_external_coords = True builder.set_coefficients(coefficients) - constants = extract_firedrake_constants(expression) + constants = extract_firedrake_constants(interp_expression) builder.set_constants(constants) # Split mixed coefficients @@ -281,11 +287,23 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, # indices needed for compilation of the expression evaluation, basis_indices = to_element.dual_evaluation(fn) + # Compute the adjoint by contracting against the dual argument + if dual_arg in coefficients: + beta = basis_indices + shape = tuple(i.extent for i in beta) + gem_dual = gem.Variable(f"w_{coefficients.index(dual_arg)}", shape) + evaluation = gem.IndexSum(evaluation * gem_dual[beta], beta) + basis_indices = () + # Build kernel body return_indices = basis_indices + tuple(chain(*argument_multiindices)) return_shape = tuple(i.extent for i in return_indices) - return_var = gem.Variable('A', return_shape) - return_expr = gem.Indexed(return_var, return_indices) + if return_shape: + return_var = gem.Variable('A', return_shape) + return_expr = gem.Indexed(return_var, return_indices) + else: + return_var = gem.Variable('A', (1,)) + return_expr = gem.Indexed(return_var, (0,)) # TODO: one should apply some GEM optimisations as in assembly, # but we don't for now. From 29a410dd46ab56c39b90590ee614c4ae4e05ee76 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 10 Sep 2025 11:32:06 +0100 Subject: [PATCH 002/125] Update interpolation.py --- firedrake/assemble.py | 20 ++- firedrake/interpolation.py | 163 +++++++++--------- .../firedrake/regression/test_interpolate.py | 3 - tsfc/driver.py | 47 +++-- 4 files changed, 129 insertions(+), 104 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 19ec557f34..2a5f43b079 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -590,12 +590,22 @@ def base_form_assembly_visitor(self, expr, tensor, *args): if not is_adjoint and rank == 2: _, v1 = expr.arguments() operand = ufl.replace(operand, {v1: v1.reconstruct(number=0)}) + + target_mesh = V.mesh() + source_mesh = extract_unique_domain(operand) or target_mesh + same_mesh = source_mesh.topology is target_mesh.topology + # Get the interpolator - interp_data = expr.interp_data.copy() + interp_data = expr.interp_data default_missing_val = interp_data.pop('default_missing_val', None) - if (is_adjoint and rank == 1) or rank == 0: + if same_mesh and ((is_adjoint and rank == 1) or rank == 0): + interp_data = interp_data.copy() interp_data["access"] = op2.INC - interpolator = firedrake.Interpolator(operand, v, **interp_data) + + dual_arg = v if same_mesh else V + interp_expr = firedrake.Interpolate(operand, v=dual_arg, **interp_data) + interpolator = firedrake.Interpolator(interp_expr, V, **interp_data) + # Assembly if rank == 0: result = interpolator._interpolate(output=tensor, default_missing_val=default_missing_val) @@ -603,7 +613,7 @@ def base_form_assembly_visitor(self, expr, tensor, *args): elif rank == 1: # Assembling the action of the Jacobian adjoint. if is_adjoint: - return interpolator._interpolate(v, output=tensor, default_missing_val=default_missing_val) + return interpolator._interpolate(v, output=tensor, adjoint=True, default_missing_val=default_missing_val) # Assembling the Jacobian action. elif interpolator.nargs: return interpolator._interpolate(operand, output=tensor, default_missing_val=default_missing_val) @@ -611,7 +621,7 @@ def base_form_assembly_visitor(self, expr, tensor, *args): elif tensor is None: return interpolator._interpolate(default_missing_val=default_missing_val) else: - return firedrake.Interpolator(operand, tensor, **interp_data)._interpolate(default_missing_val=default_missing_val) + return firedrake.Interpolator(interp_expr, tensor, **interp_data)._interpolate(default_missing_val=default_missing_val) elif rank == 2: res = tensor.petscmat if tensor else PETSc.Mat() # Get the interpolation matrix diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index a2dc1a35dc..5ce11ed457 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -241,6 +241,8 @@ class Interpolator(abc.ABC): """ def __new__(cls, expr, V, **kwargs): + if isinstance(expr, ufl.Interpolate): + expr, = expr.ufl_operands target_mesh = as_domain(V) source_mesh = extract_unique_domain(expr) or target_mesh submesh_interp_implemented = \ @@ -266,6 +268,8 @@ def __init__( allow_missing_dofs=False, matfree=True ): + if isinstance(expr, ufl.Interpolate): + expr, = expr.ufl_operands self.expr = expr self.V = V self.subset = subset @@ -374,6 +378,8 @@ def __init__( "Can only interpolate into spaces with point evaluation nodes." ) + if isinstance(expr, ufl.Interpolate): + expr, = expr.ufl_operands super().__init__(expr, V, subset, freeze_expr, access, bcs, allow_missing_dofs, matfree) self.arguments = extract_arguments(expr) @@ -540,7 +546,7 @@ def _interpolate( V_dest = self.expr.function_space().dual() except AttributeError: if self.nargs: - V_dest = self.arguments[0].function_space().dual() + V_dest = self.arguments[-1].function_space().dual() else: coeffs = extract_coefficients(self.expr) if len(coeffs): @@ -552,8 +558,6 @@ def _interpolate( else: if isinstance(self.V, (firedrake.Function, firedrake.Cofunction)): V_dest = self.V.function_space() - elif isinstance(self.V, firedrake.Coargument): - V_dest = self.V.function_space().dual() else: V_dest = self.V if output: @@ -679,10 +683,14 @@ class SameMeshInterpolator(Interpolator): def __init__(self, expr, V, subset=None, freeze_expr=False, access=op2.WRITE, bcs=None, matfree=True, allow_missing_dofs=False, **kwargs): if subset is None: + if isinstance(expr, ufl.Interpolate): + operand, = expr.ufl_operands + else: + operand = expr target_mesh = as_domain(V) - source_mesh = extract_unique_domain(expr) + source_mesh = extract_unique_domain(operand) or target_mesh target = target_mesh.topology - source = target if source_mesh is None else source_mesh.topology + source = source_mesh.topology if all(isinstance(m, firedrake.mesh.MeshTopology) for m in [target, source]) and target is not source: composed_map, result_integral_type = source.trans_mesh_entity_map(target, "cell", "everywhere", None) if result_integral_type != "cell": @@ -703,7 +711,7 @@ def __init__(self, expr, V, subset=None, freeze_expr=False, access=op2.WRITE, self.callable, arguments = make_interpolator(expr, V, subset, access, bcs=bcs, matfree=matfree) except FIAT.hdiv_trace.TraceError: raise NotImplementedError("Can't interpolate onto traces sorry") - self.arguments = arguments + self.arguments = expr.arguments() self.nargs = len(arguments) @PETSc.Log.EventDecorator() @@ -735,16 +743,19 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, ** # Interpolation action self.frozen_assembled_interpolator = assembled_interpolator.copy() - if self.nargs == 2: + if hasattr(assembled_interpolator, "handle") and len(function): function, = function if not hasattr(function, "dat"): raise ValueError("The expression had arguments: we therefore need to be given a Function (not an expression) to interpolate!") if adjoint: mul = assembled_interpolator.handle.multHermitian V = self.arguments[0].function_space().dual() + assert function.function_space() == self.arguments[1].function_space() else: mul = assembled_interpolator.handle.mult - V = self.V + V = self.arguments[1].function_space().dual() + assert function.function_space() == self.arguments[0].function_space() + result = output or firedrake.Function(V) with function.dat.vec_ro as x, result.dat.vec_wo as out: if x is not out: @@ -772,29 +783,23 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, ** @PETSc.Log.EventDecorator() def make_interpolator(expr, V, subset, access, bcs=None, matfree=True): - assert isinstance(expr, ufl.classes.Expr) + assert isinstance(expr, ufl.Interpolate) + dual_arg, operand = expr.argument_slots() + assert isinstance(dual_arg, (ufl.Coargument, ufl.Cofunction)) + + target_mesh = as_domain(dual_arg) + source_mesh = extract_unique_domain(operand) or target_mesh + same_mesh = target_mesh is source_mesh + + vom_onto_other_vom = ( + isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology) + and isinstance(source_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology) + and target_mesh is not source_mesh + ) - if isinstance(V, (ufl.Coargument, ufl.Cofunction)): - dual_arg = V - V = dual_arg.function_space().dual() - elif isinstance(V, (ufl.FunctionSpace, ufl.Coefficient)): - fs = V if isinstance(V, ufl.FunctionSpace) else V.function_space() - dual_arg = Coargument(fs.dual(), number=0) - - arguments = extract_arguments(expr) - if isinstance(dual_arg, ufl.Coargument): - arguments.append(dual_arg) + arguments = expr.arguments() rank = len(arguments) - - target_mesh = as_domain(V) if rank <= 1: - source_mesh = extract_unique_domain(expr) or target_mesh - vom_onto_other_vom = ( - isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology) - and isinstance(source_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology) - and target_mesh is not source_mesh - ) - if rank == 0: R = firedrake.FunctionSpace(target_mesh, "Real", 0) f = firedrake.Function(R) @@ -817,15 +822,8 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True): raise ValueError("Cannot interpolate an expression with an argument into a Function") if len(V) > 1: raise NotImplementedError("Interpolation of mixed expressions with arguments is not supported") - argfs = arguments[0].function_space() - source_mesh = argfs.mesh() argfs_map = argfs.cell_node_map() - vom_onto_other_vom = ( - isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology) - and isinstance(source_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology) - and target_mesh is not source_mesh - ) if isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology) and target_mesh is not source_mesh and not vom_onto_other_vom: if not isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology): raise NotImplementedError("Can only interpolate onto a Vertex Only Mesh") @@ -863,8 +861,10 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True): else: raise ValueError("Cannot interpolate an expression with %d arguments" % rank) + if not same_mesh: + arguments = extract_arguments(operand) if vom_onto_other_vom: - wrapper = VomOntoVomWrapper(V, source_mesh, target_mesh, expr, arguments, matfree) + wrapper = VomOntoVomWrapper(V, source_mesh, target_mesh, operand, arguments, matfree) # NOTE: get_dat_mpi_type ensures we get the correct MPI type for the # data, including the correct data size and dimensional information # (so for vector function spaces in 2 dimensions we might need a @@ -874,13 +874,13 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True): # when it is called. assert f.dat is tensor wrapper.mpi_type, _ = get_dat_mpi_type(f.dat) - assert not len(arguments) + assert len(arguments) == 0 def callable(): wrapper.forward_operation(f.dat) return f else: - assert rank == 2 + assert len(arguments) == 1 assert tensor is None # we know we will be outputting either a function or a cofunction, # both of which will use a dat as a data carrier. At present, the @@ -904,38 +904,40 @@ def callable(): # Make sure we have an expression of the right length i.e. a value for # each component in the value shape of each function space loops = [] - if numpy.prod(expr.ufl_shape, dtype=int) != V.value_size: + if numpy.prod(operand.ufl_shape, dtype=int) != V.value_size: raise RuntimeError('Expression of length %d required, got length %d' - % (V.value_size, numpy.prod(expr.ufl_shape, dtype=int))) + % (V.value_size, numpy.prod(operand.ufl_shape, dtype=int))) if len(V) == 1: - loops.extend(_interpolator(V, tensor, expr, dual_arg, subset, arguments, access, bcs=bcs)) + loops.extend(_interpolator(V, tensor, expr, subset, arguments, access, bcs=bcs)) else: - if (hasattr(expr, "subfunctions") and len(expr.subfunctions) == len(V) - and all(sub_expr.ufl_shape == Vsub.value_shape for Vsub, sub_expr in zip(V, expr.subfunctions))): + if (hasattr(operand, "subfunctions") and len(operand.subfunctions) == len(V) + and all(sub_op.ufl_shape == Vsub.value_shape for Vsub, sub_op in zip(V, operand.subfunctions))): # Use subfunctions if they match the target shapes - expressions = expr.subfunctions + operands = operand.subfunctions else: # Unflatten the expression into the shapes of the mixed components offset = 0 - expressions = [] + operands = [] for Vsub in V: if len(Vsub.value_shape) == 0: - expressions.append(expr[offset]) + operands.append(operand[offset]) else: - components = [expr[offset + j] for j in range(Vsub.value_size)] - expressions.append(ufl.as_tensor(numpy.reshape(components, Vsub.value_shape))) + components = [operand[offset + j] for j in range(Vsub.value_size)] + operands.append(ufl.as_tensor(numpy.reshape(components, Vsub.value_shape))) offset += Vsub.value_size if isinstance(dual_arg, Cofunction): duals = dual_arg.subfunctions elif isinstance(dual_arg, Coargument): - duals = [Coargument(Vsub.dual(), number=dual_arg.number()) for Vsub in V] + duals = [Coargument(Vsub, number=dual_arg.number()) for Vsub in dual_arg.function_space()] else: raise ValueError("dual_arg must be a Cofunction or Coargument") + # Interpolate each sub expression into each function space - for Vsub, sub_tensor, sub_expr, sub_dual in zip(V, tensor, expressions, duals): - loops.extend(_interpolator(Vsub, sub_tensor, sub_expr, sub_dual, subset, arguments, access, bcs=bcs)) + for Vsub, sub_tensor, sub_op, sub_dual in zip(V, tensor, operands, duals): + sub_expr = ufl.Interpolate(sub_op, sub_dual) + loops.extend(_interpolator(Vsub, sub_tensor, sub_expr, subset, arguments, access, bcs=bcs)) if bcs and rank == 1: loops.extend(partial(bc.apply, f) for bc in bcs) @@ -949,11 +951,11 @@ def callable(loops, f): @utils.known_pyop2_safe -def _interpolator(V, tensor, expr, dual_arg, subset, arguments, access, bcs=None): - try: - expr = ufl.as_ufl(expr) - except ValueError: - raise ValueError("Expecting to interpolate a UFL expression") +def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): + if not isinstance(expr, ufl.Interpolate): + raise ValueError("Expecting to interpolate a ufl.Interpolate") + dual_arg, operand = expr.argument_slots() + try: to_element = create_element(V.ufl_element()) except KeyError: @@ -963,17 +965,17 @@ def _interpolator(V, tensor, expr, dual_arg, subset, arguments, access, bcs=None if access is op2.READ: raise ValueError("Can't have READ access for output function") - if len(expr.ufl_shape) != len(V.value_shape): + if len(operand.ufl_shape) != len(V.value_shape): raise RuntimeError('Rank mismatch: Expression rank %d, FunctionSpace rank %d' - % (len(expr.ufl_shape), len(V.value_shape))) + % (len(operand.ufl_shape), len(V.value_shape))) - if expr.ufl_shape != V.value_shape: + if operand.ufl_shape != V.value_shape: raise RuntimeError('Shape mismatch: Expression shape %r, FunctionSpace shape %r' - % (expr.ufl_shape, V.value_shape)) + % (operand.ufl_shape, V.value_shape)) # NOTE: The par_loop is always over the target mesh cells. target_mesh = as_domain(V) - source_mesh = extract_unique_domain(expr) or target_mesh + source_mesh = extract_unique_domain(operand) or target_mesh if isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology): if target_mesh is not source_mesh: if not isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology): @@ -995,7 +997,13 @@ def _interpolator(V, tensor, expr, dual_arg, subset, arguments, access, bcs=None # have their pointset(s) directly replaced with run-time tabulated # equivalent(s) (i.e. finat.point_set.UnknownPointSet(s)) rt_var_name = 'rt_X' - to_element = rebuild(to_element, expr, rt_var_name) + try: + cell = operand.ufl_element().ufl_cell() + except AttributeError: + # expression must be pure function of spatial coordinates so + # domain has correct ufl cell + cell = source_mesh.ufl_cell() + to_element = rebuild(to_element, cell, rt_var_name) cell_set = target_mesh.cell_set if subset is not None: @@ -1008,6 +1016,7 @@ def _interpolator(V, tensor, expr, dual_arg, subset, arguments, access, bcs=None needs_weight = isinstance(dual_arg, ufl.Cofunction) and not to_element.is_dg() if needs_weight: W = dual_arg.function_space() + # TODO cache DOF multiplicity shapes = (W.finat_element.space_dimension(), W.block_size) domain = "{[i,j]: 0 <= i < %d and 0 <= j < %d}" % shapes instructions = """ @@ -1029,7 +1038,7 @@ def _interpolator(V, tensor, expr, dual_arg, subset, arguments, access, bcs=None # FIXME: for the runtime unknown point set (for cross-mesh # interpolation) we have to pass the finat element we construct # here. Ideally we would only pass the UFL element through. - kernel = compile_expression(cell_set.comm, expr, dual_arg, to_element, V.ufl_element(), + kernel = compile_expression(cell_set.comm, expr, to_element, V.ufl_element(), domain=source_mesh, parameters=parameters) ast = kernel.ast oriented = kernel.oriented @@ -1042,8 +1051,8 @@ def _interpolator(V, tensor, expr, dual_arg, subset, arguments, access, bcs=None parloop_args = [kernel, cell_set] - interp_expr = ufl.Interpolate(expr, dual_arg) - coefficients = tsfc_interface.extract_numbered_coefficients(interp_expr, coefficient_numbers) + expr = ufl.Interpolate(operand, v=dual_arg) + coefficients = tsfc_interface.extract_numbered_coefficients(expr, coefficient_numbers) if needs_external_coords: coefficients = [source_mesh.coordinates] + coefficients @@ -1061,12 +1070,13 @@ def _interpolator(V, tensor, expr, dual_arg, subset, arguments, access, bcs=None if isinstance(tensor, op2.Global): parloop_args.append(tensor(access)) elif isinstance(tensor, op2.Dat): - V_dest = arguments[0].function_space() if isinstance(dual_arg, ufl.Cofunction) else V + V_dest = arguments[-1].function_space() if isinstance(dual_arg, ufl.Cofunction) else V parloop_args.append(tensor(access, V_dest.cell_node_map())) else: assert access == op2.WRITE # Other access descriptors not done for Matrices. rows_map = V.cell_node_map() Vcol = arguments[0].function_space() + assert tensor.handle.getSize() == (V.dim(), Vcol.dim()) if isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology): columns_map = Vcol.cell_node_map() if target_mesh is not source_mesh: @@ -1165,9 +1175,10 @@ def _interpolator(V, tensor, expr, dual_arg, subset, arguments, access, bcs=None f"firedrake-tsfc-expression-kernel-cache-uid{os.getuid()}") -def _compile_expression_key(comm, expr, dual_arg, to_element, ufl_element, domain, parameters) -> tuple[Hashable, ...]: +def _compile_expression_key(comm, expr, to_element, ufl_element, domain, parameters) -> tuple[Hashable, ...]: """Generate a cache key suitable for :func:`tsfc.compile_expression_dual_evaluation`.""" - return (hash_expr(expr), type(dual_arg), hash(ufl_element), utils.tuplify(parameters)) + dual_arg, operand = expr.argument_slots() + return (hash_expr(operand), type(dual_arg), hash(ufl_element), utils.tuplify(parameters)) @memory_and_disk_cache( @@ -1185,14 +1196,14 @@ def rebuild(element, expr, rt_var_name): @rebuild.register(finat.fiat_elements.ScalarFiatElement) -def rebuild_dg(element, expr, rt_var_name): +def rebuild_dg(element, expr_cell, rt_var_name): # To tabulate on the given element (which is on a different mesh to the # expression) we must do so at runtime. We therefore create a quadrature # element with runtime points to evaluate for each point in the element's # dual basis. This exists on the same reference cell as the input element # and we can interpolate onto it before mapping the result back onto the # target space. - expr_tdim = extract_unique_domain(expr).topological_dimension() + expr_tdim = expr_cell.topological_dimension() # Need point evaluations and matching weights from dual basis. # This could use FIAT's dual basis as below: # num_points = sum(len(dual.get_point_dict()) for dual in element.fiat_equivalent.dual_basis()) @@ -1212,20 +1223,14 @@ def rebuild_dg(element, expr, rt_var_name): assert rt_var_name.startswith("rt_") runtime_points_expr = gem.Variable(rt_var_name, (num_points, expr_tdim)) rule_pointset = finat.point_set.UnknownPointSet(runtime_points_expr) - try: - expr_fiat_cell = as_fiat_cell(expr.ufl_element().cell) - except AttributeError: - # expression must be pure function of spatial coordinates so - # domain has correct ufl cell - expr_fiat_cell = as_fiat_cell(extract_unique_domain(expr).ufl_cell()) rule = finat.quadrature.QuadratureRule(rule_pointset, weights=weights) - return finat.QuadratureElement(expr_fiat_cell, rule) + return finat.QuadratureElement(as_fiat_cell(expr_cell), rule) @rebuild.register(finat.TensorFiniteElement) -def rebuild_te(element, expr, rt_var_name): +def rebuild_te(element, expr_cell, rt_var_name): return finat.TensorFiniteElement(rebuild(element.base_element, - expr, rt_var_name), + expr_cell, rt_var_name), element._shape, transpose=element._transpose) diff --git a/tests/firedrake/regression/test_interpolate.py b/tests/firedrake/regression/test_interpolate.py index 1bb6711f26..3f3072a881 100644 --- a/tests/firedrake/regression/test_interpolate.py +++ b/tests/firedrake/regression/test_interpolate.py @@ -341,7 +341,6 @@ def test_adjoint_Pk(degree): v_adj_form = assemble(interpolate(TestFunction(Pk), v * dx)) - assert v_adj.function_space() == v_adj_form.function_space() assert np.allclose(v_adj_form.dat.data, v_adj.dat.data) @@ -354,7 +353,6 @@ def test_adjoint_quads(): u_P1 = assemble(conj(TestFunction(P1)) * dx) v_adj = assemble(interpolate(TestFunction(P1), assemble(v * dx))) - assert v_adj.function_space() == u_P1.function_space() assert np.allclose(u_P1.dat.data, v_adj.dat.data) @@ -367,7 +365,6 @@ def test_adjoint_dg(): u_cg = assemble(conj(TestFunction(cg1)) * dx) v_adj = assemble(interpolate(TestFunction(cg1), assemble(v * dx))) - assert v_adj.function_space() == u_cg.function_space() assert np.allclose(u_cg.dat.data, v_adj.dat.data) diff --git a/tsfc/driver.py b/tsfc/driver.py index 55f4ab80fc..29af9067ee 100644 --- a/tsfc/driver.py +++ b/tsfc/driver.py @@ -182,16 +182,14 @@ def preprocess_parameters(parameters): return parameters -def compile_expression_dual_evaluation(expression, dual_arg, - to_element, ufl_element, *, +def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, domain=None, interface=None, parameters=None): """Compile a UFL expression to be evaluated against a compile-time known reference element's dual basis. Useful for interpolating UFL expressions into e.g. N1curl spaces. - :arg expression: UFL expression to interpolate - :arg dual_arg: A Cofunction or Coargument to act on the interpolated expression + :arg expression: UFL expression :arg to_element: A FInAT element for the target space :arg ufl_element: The UFL element of the target space. :arg domain: optional UFL domain the expression is defined on (required when expression contains no domain). @@ -212,18 +210,23 @@ def compile_expression_dual_evaluation(expression, dual_arg, if isinstance(to_element, (PhysicallyMappedElement, DirectlyDefinedElement)): raise NotImplementedError("Don't know how to interpolate onto zany spaces, sorry") - if not isinstance(dual_arg, (ufl.Coargument, ufl.Cofunction)): - raise ValueError(f"Expecting a Coargument or Cofunction, not {type(dual_arg).__name__}") - - interp_expression = ufl.Interpolate(expression, dual_arg) - orig_expression = interp_expression + orig_expression = expression + if isinstance(expression, ufl.Interpolate): + operand, = expression.ufl_operands + else: + operand = expression # Map into reference space - expression = apply_mapping(expression, ufl_element, domain) + operand = apply_mapping(operand, ufl_element, domain) # Apply UFL preprocessing - expression = ufl_utils.preprocess_expression(expression, - complex_mode=complex_mode) + operand = ufl_utils.preprocess_expression(operand, complex_mode=complex_mode) + + if isinstance(expression, ufl.Interpolate): + dual_arg, _ = expression.argument_slots() + expression = ufl.Interpolate(operand, dual_arg) + else: + expression = operand # Initialise kernel builder if interface is None: @@ -231,7 +234,7 @@ def compile_expression_dual_evaluation(expression, dual_arg, from tsfc.kernel_interface.firedrake_loopy import ExpressionKernelBuilder as interface builder = interface(parameters["scalar_type"]) - arguments = extract_arguments(expression) + arguments = extract_arguments(operand) argument_multiindices = tuple(builder.create_element(arg.ufl_element()).get_indices() for arg in arguments) @@ -241,7 +244,7 @@ def compile_expression_dual_evaluation(expression, dual_arg, assert domain is not None # Collect required coefficients and determine numbering - coefficients = extract_coefficients(interp_expression) + coefficients = extract_coefficients(expression) orig_coefficients = extract_coefficients(orig_expression) coefficient_numbers = tuple(map(orig_coefficients.index, coefficients)) builder.set_coefficient_numbers(coefficient_numbers) @@ -258,7 +261,7 @@ def compile_expression_dual_evaluation(expression, dual_arg, needs_external_coords = True builder.set_coefficients(coefficients) - constants = extract_firedrake_constants(interp_expression) + constants = extract_firedrake_constants(expression) builder.set_constants(constants) # Split mixed coefficients @@ -280,8 +283,18 @@ def compile_expression_dual_evaluation(expression, dual_arg, if isinstance(to_element, finat.QuadratureElement): kernel_cfg["quadrature_rule"] = to_element._rule + # TODO register ufl.Interpolate in fem.compile_ufl + if isinstance(expression, ufl.Interpolate): + operand, = expression.ufl_operands + dual_arg = expression.argument_slots()[0] + if not isinstance(dual_arg, (ufl.Coargument, ufl.Cofunction)): + raise ValueError(f"Expecting a Coargument or Cofunction, not {type(dual_arg).__name__}") + else: + operand = expression + dual_arg = None + # Create callable for translation of UFL expression to gem - fn = DualEvaluationCallable(expression, kernel_cfg) + fn = DualEvaluationCallable(operand, kernel_cfg) # Get the gem expression for dual evaluation and corresponding basis # indices needed for compilation of the expression @@ -296,7 +309,7 @@ def compile_expression_dual_evaluation(expression, dual_arg, basis_indices = () # Build kernel body - return_indices = basis_indices + tuple(chain(*argument_multiindices)) + return_indices = tuple(chain(basis_indices, *argument_multiindices)) return_shape = tuple(i.extent for i in return_indices) if return_shape: return_var = gem.Variable('A', return_shape) From a33cfb3d205617216a6de1714850cfd64d3aebce Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 11 Sep 2025 15:56:09 +0100 Subject: [PATCH 003/125] Fixup, cleanup --- firedrake/assemble.py | 13 ++++++++----- firedrake/interpolation.py | 8 +++++--- tsfc/driver.py | 14 ++++---------- 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 2a5f43b079..59ea9b1a26 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -608,8 +608,13 @@ def base_form_assembly_visitor(self, expr, tensor, *args): # Assembly if rank == 0: - result = interpolator._interpolate(output=tensor, default_missing_val=default_missing_val) - return result.dat.data.item() if tensor is None else result + # Assembling the double action. + if same_mesh: + result = interpolator._interpolate(output=tensor, default_missing_val=default_missing_val) + return result.dat.data.item() if tensor is None else result + else: + Iu = interpolator._interpolate(default_missing_val=default_missing_val) + return assemble(ufl.action(v, Iu), tensor=tensor) elif rank == 1: # Assembling the action of the Jacobian adjoint. if is_adjoint: @@ -618,10 +623,8 @@ def base_form_assembly_visitor(self, expr, tensor, *args): elif interpolator.nargs: return interpolator._interpolate(operand, output=tensor, default_missing_val=default_missing_val) # Assembling the operator - elif tensor is None: - return interpolator._interpolate(default_missing_val=default_missing_val) else: - return firedrake.Interpolator(interp_expr, tensor, **interp_data)._interpolate(default_missing_val=default_missing_val) + return interpolator._interpolate(output=tensor, default_missing_val=default_missing_val) elif rank == 2: res = tensor.petscmat if tensor else PETSc.Mat() # Get the interpolation matrix diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 5ce11ed457..8d1acdc9cf 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -783,7 +783,9 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, ** @PETSc.Log.EventDecorator() def make_interpolator(expr, V, subset, access, bcs=None, matfree=True): - assert isinstance(expr, ufl.Interpolate) + if not isinstance(expr, ufl.Interpolate): + fs = V if isinstance(V, ufl.FunctionSpace) else V.function_space() + expr = Interpolate(expr, fs) dual_arg, operand = expr.argument_slots() assert isinstance(dual_arg, (ufl.Coargument, ufl.Cofunction)) @@ -818,7 +820,7 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True): f.assign(val) tensor = f.dat elif rank == 2: - if isinstance(V, firedrake.Function): + if isinstance(V, (firedrake.Function, firedrake.Cofunction)): raise ValueError("Cannot interpolate an expression with an argument into a Function") if len(V) > 1: raise NotImplementedError("Interpolation of mixed expressions with arguments is not supported") @@ -1191,7 +1193,7 @@ def compile_expression(comm, *args, **kwargs): @singledispatch -def rebuild(element, expr, rt_var_name): +def rebuild(element, expr_cell, rt_var_name): raise NotImplementedError(f"Cross mesh interpolation not implemented for a {element} element.") diff --git a/tsfc/driver.py b/tsfc/driver.py index 29af9067ee..1a5dfb7ddd 100644 --- a/tsfc/driver.py +++ b/tsfc/driver.py @@ -223,8 +223,7 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, operand = ufl_utils.preprocess_expression(operand, complex_mode=complex_mode) if isinstance(expression, ufl.Interpolate): - dual_arg, _ = expression.argument_slots() - expression = ufl.Interpolate(operand, dual_arg) + expression = expression._ufl_expr_reconstruct_(operand) else: expression = operand @@ -285,8 +284,7 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, # TODO register ufl.Interpolate in fem.compile_ufl if isinstance(expression, ufl.Interpolate): - operand, = expression.ufl_operands - dual_arg = expression.argument_slots()[0] + dual_arg, operand = expression.argument_slots() if not isinstance(dual_arg, (ufl.Coargument, ufl.Cofunction)): raise ValueError(f"Expecting a Coargument or Cofunction, not {type(dual_arg).__name__}") else: @@ -311,12 +309,8 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, # Build kernel body return_indices = tuple(chain(basis_indices, *argument_multiindices)) return_shape = tuple(i.extent for i in return_indices) - if return_shape: - return_var = gem.Variable('A', return_shape) - return_expr = gem.Indexed(return_var, return_indices) - else: - return_var = gem.Variable('A', (1,)) - return_expr = gem.Indexed(return_var, (0,)) + return_var = gem.Variable('A', return_shape or (1,)) + return_expr = gem.Indexed(return_var, return_indices or (0,)) # TODO: one should apply some GEM optimisations as in assembly, # but we don't for now. From d51db2506dd95f5b87f9bacf8891ee9582935610 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 11 Sep 2025 21:59:35 +0100 Subject: [PATCH 004/125] Reverse indices for dual_arg --- firedrake/assemble.py | 5 ++--- firedrake/interpolation.py | 2 +- tsfc/driver.py | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 59ea9b1a26..dc02b933c4 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -596,14 +596,13 @@ def base_form_assembly_visitor(self, expr, tensor, *args): same_mesh = source_mesh.topology is target_mesh.topology # Get the interpolator - interp_data = expr.interp_data + interp_data = expr.interp_data.copy() default_missing_val = interp_data.pop('default_missing_val', None) if same_mesh and ((is_adjoint and rank == 1) or rank == 0): - interp_data = interp_data.copy() interp_data["access"] = op2.INC dual_arg = v if same_mesh else V - interp_expr = firedrake.Interpolate(operand, v=dual_arg, **interp_data) + interp_expr = reconstruct_interp(operand, v=dual_arg) interpolator = firedrake.Interpolator(interp_expr, V, **interp_data) # Assembly diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 8d1acdc9cf..5dc0701453 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -938,7 +938,7 @@ def callable(): # Interpolate each sub expression into each function space for Vsub, sub_tensor, sub_op, sub_dual in zip(V, tensor, operands, duals): - sub_expr = ufl.Interpolate(sub_op, sub_dual) + sub_expr = expr._ufl_expr_reconstruct_(sub_op, sub_dual) loops.extend(_interpolator(Vsub, sub_tensor, sub_expr, subset, arguments, access, bcs=bcs)) if bcs and rank == 1: diff --git a/tsfc/driver.py b/tsfc/driver.py index 1a5dfb7ddd..0182cd3842 100644 --- a/tsfc/driver.py +++ b/tsfc/driver.py @@ -300,7 +300,7 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, # Compute the adjoint by contracting against the dual argument if dual_arg in coefficients: - beta = basis_indices + beta = basis_indices[::-1] shape = tuple(i.extent for i in beta) gem_dual = gem.Variable(f"w_{coefficients.index(dual_arg)}", shape) evaluation = gem.IndexSum(evaluation * gem_dual[beta], beta) From afe0256a6d63acd9dbf0476367df40e95bf0fac0 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 12 Sep 2025 00:11:07 +0100 Subject: [PATCH 005/125] gem cofunction --- tsfc/driver.py | 12 +++++------- tsfc/fem.py | 5 +++-- tsfc/kernel_interface/common.py | 2 ++ tsfc/ufl_utils.py | 4 +++- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/tsfc/driver.py b/tsfc/driver.py index 0182cd3842..4199f72e48 100644 --- a/tsfc/driver.py +++ b/tsfc/driver.py @@ -282,11 +282,8 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, if isinstance(to_element, finat.QuadratureElement): kernel_cfg["quadrature_rule"] = to_element._rule - # TODO register ufl.Interpolate in fem.compile_ufl if isinstance(expression, ufl.Interpolate): dual_arg, operand = expression.argument_slots() - if not isinstance(dual_arg, (ufl.Coargument, ufl.Cofunction)): - raise ValueError(f"Expecting a Coargument or Cofunction, not {type(dual_arg).__name__}") else: operand = expression dual_arg = None @@ -299,11 +296,12 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, evaluation, basis_indices = to_element.dual_evaluation(fn) # Compute the adjoint by contracting against the dual argument - if dual_arg in coefficients: - beta = basis_indices[::-1] + if dual_arg and not isinstance(dual_arg, ufl.Coargument): + k = len(basis_indices)-len(operand.ufl_shape) + beta = basis_indices[k:] + basis_indices[:k] shape = tuple(i.extent for i in beta) - gem_dual = gem.Variable(f"w_{coefficients.index(dual_arg)}", shape) - evaluation = gem.IndexSum(evaluation * gem_dual[beta], beta) + gem_dual = gem.Variable(f"w_{coefficients.index(dual_arg)}", shape=shape) + evaluation = gem.IndexSum(evaluation * gem_dual[beta], basis_indices) basis_indices = () # Build kernel body diff --git a/tsfc/fem.py b/tsfc/fem.py index 9166b4b8f0..2b3e0ad05a 100644 --- a/tsfc/fem.py +++ b/tsfc/fem.py @@ -21,8 +21,8 @@ from gem.unconcatenate import unconcatenate from ufl.classes import (Argument, CellCoordinate, CellEdgeVectors, CellFacetJacobian, CellOrientation, CellOrigin, - CellVertices, CellVolume, Coefficient, FacetArea, - FacetCoordinate, GeometricQuantity, Jacobian, + CellVertices, CellVolume, Coefficient, Cofunction, + FacetArea, FacetCoordinate, GeometricQuantity, Jacobian, JacobianDeterminant, NegativeRestricted, PositiveRestricted, QuadratureWeight, ReferenceCellEdgeVectors, ReferenceCellVolume, @@ -665,6 +665,7 @@ def translate_constant_value(terminal, mt, ctx): @translate.register(Coefficient) +@translate.register(Cofunction) def translate_coefficient(terminal, mt, ctx): vec = ctx.coefficient(terminal, mt.restriction) diff --git a/tsfc/kernel_interface/common.py b/tsfc/kernel_interface/common.py index f2369cec0d..56c7ce4880 100644 --- a/tsfc/kernel_interface/common.py +++ b/tsfc/kernel_interface/common.py @@ -64,6 +64,8 @@ def coefficient(self, ufl_coefficient, restriction): else: return kernel_arg[{'+': 0, '-': 1}[restriction]] + cofunction = coefficient + def constant(self, const): return self.constant_map[const] diff --git a/tsfc/ufl_utils.py b/tsfc/ufl_utils.py index 18173a9660..6f33ca1a77 100644 --- a/tsfc/ufl_utils.py +++ b/tsfc/ufl_utils.py @@ -20,7 +20,7 @@ from ufl.corealg.multifunction import MultiFunction from ufl.geometry import QuadratureWeight from ufl.geometry import Jacobian, JacobianDeterminant, JacobianInverse -from ufl.classes import (Abs, Argument, CellOrientation, +from ufl.classes import (Abs, Argument, CellOrientation, Cofunction, Expr, FloatValue, Division, Product, ScalarValue, Sqrt, Zero, CellVolume, FacetArea) @@ -171,6 +171,7 @@ def _modified_terminal(self, o): reference_value = _modified_terminal terminal = _modified_terminal + cofunction = terminal class PickRestriction(MultiFunction, ModifiedTerminalMixin): @@ -225,6 +226,7 @@ def _simplify_abs(o, self, in_abs): raise AssertionError("UFL node expected, not %s" % type(o)) +@_simplify_abs.register(Cofunction) @_simplify_abs.register(Expr) def _simplify_abs_expr(o, self, in_abs): # General case, only wrap the outer expression (if necessary) From 7aa048eedaed4142e2def530b8d71f12dcdfa193 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 12 Sep 2025 12:13:48 +0100 Subject: [PATCH 006/125] More fixups --- firedrake/assemble.py | 22 ++++++++++--------- firedrake/interpolation.py | 13 ++++++----- .../firedrake/regression/test_interp_dual.py | 10 ++++++++- tsfc/fem.py | 5 ++--- tsfc/kernel_interface/common.py | 2 -- tsfc/ufl_utils.py | 4 +--- 6 files changed, 32 insertions(+), 24 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index dc02b933c4..119a734ffc 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -538,6 +538,7 @@ def base_form_assembly_visitor(self, expr, tensor, *args): result = expr.assemble(assembly_opts=opts) return tensor.assign(result) if tensor else result elif isinstance(expr, ufl.Interpolate): + orig_expr = expr # Replace assembled children _, operand = expr.argument_slots() v, *assembled_operand = args @@ -588,31 +589,32 @@ def base_form_assembly_visitor(self, expr, tensor, *args): # Workaround: Renumber argument when needed since Interpolator assumes it takes a zero-numbered argument. if not is_adjoint and rank == 2: - _, v1 = expr.arguments() - operand = ufl.replace(operand, {v1: v1.reconstruct(number=0)}) + v0, v1 = expr.arguments() + expr = ufl.replace(expr, {v0: v0.reconstruct(number=v1.number()), + v1: v1.reconstruct(number=v0.number())}) + v, operand = expr.argument_slots() + # Assemble the interpolator matrix if the meshes are different target_mesh = V.mesh() source_mesh = extract_unique_domain(operand) or target_mesh same_mesh = source_mesh.topology is target_mesh.topology + if not same_mesh: + expr = reconstruct_interp(operand, v=V) # Get the interpolator interp_data = expr.interp_data.copy() default_missing_val = interp_data.pop('default_missing_val', None) if same_mesh and ((is_adjoint and rank == 1) or rank == 0): interp_data["access"] = op2.INC - - dual_arg = v if same_mesh else V - interp_expr = reconstruct_interp(operand, v=dual_arg) - interpolator = firedrake.Interpolator(interp_expr, V, **interp_data) + interpolator = firedrake.Interpolator(expr, V, **interp_data) # Assembly if rank == 0: # Assembling the double action. if same_mesh: - result = interpolator._interpolate(output=tensor, default_missing_val=default_missing_val) - return result.dat.data.item() if tensor is None else result + return interpolator._interpolate(output=tensor, default_missing_val=default_missing_val) else: - Iu = interpolator._interpolate(default_missing_val=default_missing_val) + Iu = interpolator._interpolate(operand, default_missing_val=default_missing_val) return assemble(ufl.action(v, Iu), tensor=tensor) elif rank == 1: # Assembling the action of the Jacobian adjoint. @@ -636,7 +638,7 @@ def base_form_assembly_visitor(self, expr, tensor, *args): # Copy the interpolation matrix into the output tensor petsc_mat.copy(result=res) if tensor is None: - tensor = self.assembled_matrix(expr, res) + tensor = self.assembled_matrix(orig_expr, res) return tensor else: # The case rank == 0 is handled via the DAG restructuring diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 5dc0701453..14d23ac294 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -682,11 +682,12 @@ class SameMeshInterpolator(Interpolator): @no_annotations def __init__(self, expr, V, subset=None, freeze_expr=False, access=op2.WRITE, bcs=None, matfree=True, allow_missing_dofs=False, **kwargs): + if isinstance(expr, ufl.Interpolate): + operand, = expr.ufl_operands + else: + operand = expr + expr = Interpolate(operand, V) if subset is None: - if isinstance(expr, ufl.Interpolate): - operand, = expr.ufl_operands - else: - operand = expr target_mesh = as_domain(V) source_mesh = extract_unique_domain(operand) or target_mesh target = target_mesh.topology @@ -770,13 +771,15 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, ** if output: output.assign(assembled_interpolator) return output - if isinstance(self.V, firedrake.Function): + if isinstance(self.V, (firedrake.Function, firedrake.Cofunction)): if copy_required: self.V.assign(assembled_interpolator) return self.V else: if copy_required: return assembled_interpolator.copy() + elif isinstance(assembled_interpolator.dat, op2.Global): + return assembled_interpolator.dat.data.item() else: return assembled_interpolator diff --git a/tests/firedrake/regression/test_interp_dual.py b/tests/firedrake/regression/test_interp_dual.py index 22d9415975..444f352453 100644 --- a/tests/firedrake/regression/test_interp_dual.py +++ b/tests/firedrake/regression/test_interp_dual.py @@ -66,15 +66,23 @@ def test_assemble_interp_operator(V2, f1): def test_assemble_interp_matrix(V1, V2, f1): # -- I(v1, V2) -- # v1 = TrialFunction(V1) - Iv1 = Interpolate(v1, V2) + Iv1 = interpolate(v1, V2) + assert Iv1.arguments()[0].function_space() == V2.dual() + assert Iv1.arguments()[1].function_space() == V1 b = assemble(interpolate(f1, V2)) + assert b.function_space() == V2 # Get the interpolation matrix a = assemble(Iv1) + assert a.arguments()[0].function_space() == V2.dual() + assert a.arguments()[1].function_space() == V1 + assert a.petscmat.getSize() == (V2.dim(), V1.dim()) + # Check that `I * f1 == b` with I the interpolation matrix # and b the interpolation of f1 into V2. res = assemble(action(a, f1)) + assert res.function_space() == V2 assert np.allclose(res.dat.data, b.dat.data) diff --git a/tsfc/fem.py b/tsfc/fem.py index 2b3e0ad05a..9166b4b8f0 100644 --- a/tsfc/fem.py +++ b/tsfc/fem.py @@ -21,8 +21,8 @@ from gem.unconcatenate import unconcatenate from ufl.classes import (Argument, CellCoordinate, CellEdgeVectors, CellFacetJacobian, CellOrientation, CellOrigin, - CellVertices, CellVolume, Coefficient, Cofunction, - FacetArea, FacetCoordinate, GeometricQuantity, Jacobian, + CellVertices, CellVolume, Coefficient, FacetArea, + FacetCoordinate, GeometricQuantity, Jacobian, JacobianDeterminant, NegativeRestricted, PositiveRestricted, QuadratureWeight, ReferenceCellEdgeVectors, ReferenceCellVolume, @@ -665,7 +665,6 @@ def translate_constant_value(terminal, mt, ctx): @translate.register(Coefficient) -@translate.register(Cofunction) def translate_coefficient(terminal, mt, ctx): vec = ctx.coefficient(terminal, mt.restriction) diff --git a/tsfc/kernel_interface/common.py b/tsfc/kernel_interface/common.py index 56c7ce4880..f2369cec0d 100644 --- a/tsfc/kernel_interface/common.py +++ b/tsfc/kernel_interface/common.py @@ -64,8 +64,6 @@ def coefficient(self, ufl_coefficient, restriction): else: return kernel_arg[{'+': 0, '-': 1}[restriction]] - cofunction = coefficient - def constant(self, const): return self.constant_map[const] diff --git a/tsfc/ufl_utils.py b/tsfc/ufl_utils.py index 6f33ca1a77..18173a9660 100644 --- a/tsfc/ufl_utils.py +++ b/tsfc/ufl_utils.py @@ -20,7 +20,7 @@ from ufl.corealg.multifunction import MultiFunction from ufl.geometry import QuadratureWeight from ufl.geometry import Jacobian, JacobianDeterminant, JacobianInverse -from ufl.classes import (Abs, Argument, CellOrientation, Cofunction, +from ufl.classes import (Abs, Argument, CellOrientation, Expr, FloatValue, Division, Product, ScalarValue, Sqrt, Zero, CellVolume, FacetArea) @@ -171,7 +171,6 @@ def _modified_terminal(self, o): reference_value = _modified_terminal terminal = _modified_terminal - cofunction = terminal class PickRestriction(MultiFunction, ModifiedTerminalMixin): @@ -226,7 +225,6 @@ def _simplify_abs(o, self, in_abs): raise AssertionError("UFL node expected, not %s" % type(o)) -@_simplify_abs.register(Cofunction) @_simplify_abs.register(Expr) def _simplify_abs_expr(o, self, in_abs): # General case, only wrap the outer expression (if necessary) From 9826f0dc8fc7c2da659b883f5337a3db82f52a7f Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 12 Sep 2025 14:53:38 +0100 Subject: [PATCH 007/125] More fixup --- firedrake/assemble.py | 2 +- firedrake/interpolation.py | 18 +++++++----------- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 119a734ffc..cafda5f010 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -614,7 +614,7 @@ def base_form_assembly_visitor(self, expr, tensor, *args): if same_mesh: return interpolator._interpolate(output=tensor, default_missing_val=default_missing_val) else: - Iu = interpolator._interpolate(operand, default_missing_val=default_missing_val) + Iu = interpolator._interpolate(default_missing_val=default_missing_val) return assemble(ufl.action(v, Iu), tensor=tensor) elif rank == 1: # Assembling the action of the Jacobian adjoint. diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 14d23ac294..6094e65542 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -685,8 +685,9 @@ def __init__(self, expr, V, subset=None, freeze_expr=False, access=op2.WRITE, if isinstance(expr, ufl.Interpolate): operand, = expr.ufl_operands else: + fs = V if isinstance(V, ufl.FunctionSpace) else V.function_space() operand = expr - expr = Interpolate(operand, V) + expr = Interpolate(operand, fs) if subset is None: target_mesh = as_domain(V) source_mesh = extract_unique_domain(operand) or target_mesh @@ -713,7 +714,7 @@ def __init__(self, expr, V, subset=None, freeze_expr=False, access=op2.WRITE, except FIAT.hdiv_trace.TraceError: raise NotImplementedError("Can't interpolate onto traces sorry") self.arguments = expr.arguments() - self.nargs = len(arguments) + self.nargs = len(extract_arguments(operand)) @PETSc.Log.EventDecorator() def _interpolate(self, *function, output=None, transpose=None, adjoint=False, **kwargs): @@ -725,11 +726,6 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, ** if transpose is not None: warnings.warn("'transpose' argument is deprecated, use 'adjoint' instead", FutureWarning) adjoint = transpose or adjoint - if adjoint and not self.nargs: - raise ValueError("Can currently only apply adjoint interpolation with arguments.") - if self.nargs != len(function): - raise ValueError("Passed %d Functions to interpolate, expected %d" - % (len(function), self.nargs)) try: assembled_interpolator = self.frozen_assembled_interpolator copy_required = True @@ -750,12 +746,12 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, ** raise ValueError("The expression had arguments: we therefore need to be given a Function (not an expression) to interpolate!") if adjoint: mul = assembled_interpolator.handle.multHermitian - V = self.arguments[0].function_space().dual() - assert function.function_space() == self.arguments[1].function_space() + col, row = self.arguments else: mul = assembled_interpolator.handle.mult - V = self.arguments[1].function_space().dual() - assert function.function_space() == self.arguments[0].function_space() + row, col = self.arguments + V = col.function_space().dual() + assert function.function_space() == row.function_space() result = output or firedrake.Function(V) with function.dat.vec_ro as x, result.dat.vec_wo as out: From 4ce9f8e4ff55d71aac480a57d1c3e234ae49ca56 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 12 Sep 2025 16:12:41 +0100 Subject: [PATCH 008/125] Fixup --- firedrake/assemble.py | 5 ++- firedrake/interpolation.py | 64 ++++++++++++++++---------------------- tsfc/driver.py | 3 +- 3 files changed, 32 insertions(+), 40 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index cafda5f010..9617e8db51 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -602,10 +602,13 @@ def base_form_assembly_visitor(self, expr, tensor, *args): expr = reconstruct_interp(operand, v=V) # Get the interpolator - interp_data = expr.interp_data.copy() + interp_data = expr.interp_data default_missing_val = interp_data.pop('default_missing_val', None) if same_mesh and ((is_adjoint and rank == 1) or rank == 0): interp_data["access"] = op2.INC + + if rank == 1 and ((same_mesh and tensor) or isinstance(tensor, firedrake.Function)): + V = tensor interpolator = firedrake.Interpolator(expr, V, **interp_data) # Assembly diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 6094e65542..acc7fee3d4 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -93,10 +93,20 @@ def __init__(self, expr, v, and reduce operations. """ # Check function space + expr = ufl.as_ufl(expr) if isinstance(v, functionspaceimpl.WithGeometry): - expr_args = extract_arguments(ufl.as_ufl(expr)) + expr_args = extract_arguments(expr) is_adjoint = len(expr_args) and expr_args[0].number() == 0 v = Argument(v.dual(), 1 if is_adjoint else 0) + + V = v.arguments()[0].function_space() + if len(expr.ufl_shape) != len(V.value_shape): + raise RuntimeError('Rank mismatch: Expression rank %d, FunctionSpace rank %d' + % (len(expr.ufl_shape), len(V.value_shape))) + + if expr.ufl_shape != V.value_shape: + raise RuntimeError('Shape mismatch: Expression shape %r, FunctionSpace shape %r' + % (expr.ufl_shape, V.value_shape)) super().__init__(expr, v) # -- Interpolate data (e.g. `subset` or `access`) -- # @@ -173,7 +183,7 @@ def interpolate(expr, V, subset=None, access=op2.WRITE, allow_missing_dofs=False raise TypeError(f"Expected a one-form, provided form had {rank} arguments") elif isinstance(V, functionspaceimpl.WithGeometry): dual_arg = Coargument(V.dual(), 0) - expr_args = extract_arguments(expr) + expr_args = extract_arguments(ufl.as_ufl(expr)) if expr_args and expr_args[0].number() == 0: # In this case we are doing adjoint interpolation # When V is a FunctionSpace and expr contains Argument(0), @@ -483,7 +493,7 @@ def __init__( if len(shape) == 0: fs_type = firedrake.FunctionSpace elif len(shape) == 1: - fs_type = firedrake.VectorFunctionSpace + fs_type = partial(firedrake.VectorFunctionSpace, dim=shape[0]) else: fs_type = partial(firedrake.TensorFunctionSpace, shape=shape) P0DG_vom = fs_type(self.vom_dest_node_coords_in_src_mesh, "DG", 0) @@ -710,7 +720,7 @@ def __init__(self, expr, V, subset=None, freeze_expr=False, access=op2.WRITE, super().__init__(expr, V, subset=subset, freeze_expr=freeze_expr, access=access, bcs=bcs, matfree=matfree, allow_missing_dofs=allow_missing_dofs) try: - self.callable, arguments = make_interpolator(expr, V, subset, access, bcs=bcs, matfree=matfree) + self.callable = make_interpolator(expr, V, subset, access, bcs=bcs, matfree=matfree) except FIAT.hdiv_trace.TraceError: raise NotImplementedError("Can't interpolate onto traces sorry") self.arguments = expr.arguments() @@ -726,6 +736,7 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, ** if transpose is not None: warnings.warn("'transpose' argument is deprecated, use 'adjoint' instead", FutureWarning) adjoint = transpose or adjoint + try: assembled_interpolator = self.frozen_assembled_interpolator copy_required = True @@ -740,7 +751,7 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, ** # Interpolation action self.frozen_assembled_interpolator = assembled_interpolator.copy() - if hasattr(assembled_interpolator, "handle") and len(function): + if len(self.arguments) == 2 and len(function): function, = function if not hasattr(function, "dat"): raise ValueError("The expression had arguments: we therefore need to be given a Function (not an expression) to interpolate!") @@ -783,15 +794,11 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, ** @PETSc.Log.EventDecorator() def make_interpolator(expr, V, subset, access, bcs=None, matfree=True): if not isinstance(expr, ufl.Interpolate): - fs = V if isinstance(V, ufl.FunctionSpace) else V.function_space() - expr = Interpolate(expr, fs) + raise ValueError(f"Expecting to interpolate a ufl.Interpolate, got {type(expr).__name__}.") dual_arg, operand = expr.argument_slots() - assert isinstance(dual_arg, (ufl.Coargument, ufl.Cofunction)) target_mesh = as_domain(dual_arg) source_mesh = extract_unique_domain(operand) or target_mesh - same_mesh = target_mesh is source_mesh - vom_onto_other_vom = ( isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology) and isinstance(source_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology) @@ -803,7 +810,7 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True): if rank <= 1: if rank == 0: R = firedrake.FunctionSpace(target_mesh, "Real", 0) - f = firedrake.Function(R) + f = firedrake.Function(R, dtype=utils.ScalarType) elif isinstance(V, firedrake.Function): f = V V = f.function_space() @@ -862,10 +869,8 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True): else: raise ValueError("Cannot interpolate an expression with %d arguments" % rank) - if not same_mesh: - arguments = extract_arguments(operand) if vom_onto_other_vom: - wrapper = VomOntoVomWrapper(V, source_mesh, target_mesh, operand, arguments, matfree) + wrapper = VomOntoVomWrapper(V, source_mesh, target_mesh, operand, matfree) # NOTE: get_dat_mpi_type ensures we get the correct MPI type for the # data, including the correct data size and dimensional information # (so for vector function spaces in 2 dimensions we might need a @@ -875,13 +880,13 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True): # when it is called. assert f.dat is tensor wrapper.mpi_type, _ = get_dat_mpi_type(f.dat) - assert len(arguments) == 0 + assert len(arguments) == 1 def callable(): wrapper.forward_operation(f.dat) return f else: - assert len(arguments) == 1 + assert len(arguments) == 2 assert tensor is None # we know we will be outputting either a function or a cofunction, # both of which will use a dat as a data carrier. At present, the @@ -900,15 +905,9 @@ def callable(): def callable(): return wrapper - return callable, arguments + return callable else: - # Make sure we have an expression of the right length i.e. a value for - # each component in the value shape of each function space loops = [] - if numpy.prod(operand.ufl_shape, dtype=int) != V.value_size: - raise RuntimeError('Expression of length %d required, got length %d' - % (V.value_size, numpy.prod(operand.ufl_shape, dtype=int))) - if len(V) == 1: loops.extend(_interpolator(V, tensor, expr, subset, arguments, access, bcs=bcs)) else: @@ -933,8 +932,7 @@ def callable(): elif isinstance(dual_arg, Coargument): duals = [Coargument(Vsub, number=dual_arg.number()) for Vsub in dual_arg.function_space()] else: - raise ValueError("dual_arg must be a Cofunction or Coargument") - + duals = [v for _, v in sorted(firedrake.formmanipulation.split_form(dual_arg))] # Interpolate each sub expression into each function space for Vsub, sub_tensor, sub_op, sub_dual in zip(V, tensor, operands, duals): sub_expr = expr._ufl_expr_reconstruct_(sub_op, sub_dual) @@ -948,7 +946,7 @@ def callable(loops, f): l() return f - return partial(callable, loops, f), arguments + return partial(callable, loops, f) @utils.known_pyop2_safe @@ -966,14 +964,6 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): if access is op2.READ: raise ValueError("Can't have READ access for output function") - if len(operand.ufl_shape) != len(V.value_shape): - raise RuntimeError('Rank mismatch: Expression rank %d, FunctionSpace rank %d' - % (len(operand.ufl_shape), len(V.value_shape))) - - if operand.ufl_shape != V.value_shape: - raise RuntimeError('Shape mismatch: Expression shape %r, FunctionSpace shape %r' - % (operand.ufl_shape, V.value_shape)) - # NOTE: The par_loop is always over the target mesh cells. target_mesh = as_domain(V) source_mesh = extract_unique_domain(operand) or target_mesh @@ -1401,9 +1391,6 @@ class VomOntoVomWrapper(object): expr : `ufl.Expr` The expression to interpolate. If ``arguments`` is not empty, those arguments must be present within it. - arguments : list of `ufl.Argument` - The arguments in the expression. These are not extracted from expr here - since, where we use this, we already have them. matfree : bool If ``False``, the matrix representating the permutation of the points is constructed and used to perform the interpolation. If ``True``, then the @@ -1411,7 +1398,8 @@ class VomOntoVomWrapper(object): PETSc Star Forest. """ - def __init__(self, V, source_vom, target_vom, expr, arguments, matfree): + def __init__(self, V, source_vom, target_vom, expr, matfree): + arguments = extract_arguments(expr) reduce = False if source_vom.input_ordering is target_vom: reduce = True diff --git a/tsfc/driver.py b/tsfc/driver.py index 4199f72e48..b27264c8fe 100644 --- a/tsfc/driver.py +++ b/tsfc/driver.py @@ -223,7 +223,8 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, operand = ufl_utils.preprocess_expression(operand, complex_mode=complex_mode) if isinstance(expression, ufl.Interpolate): - expression = expression._ufl_expr_reconstruct_(operand) + v, _ = expression.argument_slots() + expression = ufl.Interpolate(operand, v) else: expression = operand From ae80c0ff3ab3dd695698c99bbcb1734c5d5227db Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 12 Sep 2025 21:42:09 +0100 Subject: [PATCH 009/125] cleanup --- firedrake/assemble.py | 28 ++++++++++++---------------- firedrake/interpolation.py | 5 ++--- 2 files changed, 14 insertions(+), 19 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 9617e8db51..3a4545f125 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -597,38 +597,35 @@ def base_form_assembly_visitor(self, expr, tensor, *args): # Assemble the interpolator matrix if the meshes are different target_mesh = V.mesh() source_mesh = extract_unique_domain(operand) or target_mesh - same_mesh = source_mesh.topology is target_mesh.topology - if not same_mesh: + if is_adjoint and source_mesh is not target_mesh: expr = reconstruct_interp(operand, v=V) + matfree = (rank == len(expr.arguments())) and (rank < 2) # Get the interpolator - interp_data = expr.interp_data + interp_data = expr.interp_data.copy() default_missing_val = interp_data.pop('default_missing_val', None) - if same_mesh and ((is_adjoint and rank == 1) or rank == 0): + if matfree and ((is_adjoint and rank == 1) or rank == 0): interp_data["access"] = op2.INC - if rank == 1 and ((same_mesh and tensor) or isinstance(tensor, firedrake.Function)): + if rank == 1 and matfree and isinstance(tensor, firedrake.Function): V = tensor interpolator = firedrake.Interpolator(expr, V, **interp_data) # Assembly - if rank == 0: + if matfree: + # Assembling the operator + return interpolator._interpolate(output=tensor, default_missing_val=default_missing_val) + elif rank == 0: # Assembling the double action. - if same_mesh: - return interpolator._interpolate(output=tensor, default_missing_val=default_missing_val) - else: - Iu = interpolator._interpolate(default_missing_val=default_missing_val) - return assemble(ufl.action(v, Iu), tensor=tensor) + Iu = interpolator._interpolate(default_missing_val=default_missing_val) + return assemble(ufl.Action(v, Iu), tensor=tensor) elif rank == 1: # Assembling the action of the Jacobian adjoint. if is_adjoint: return interpolator._interpolate(v, output=tensor, adjoint=True, default_missing_val=default_missing_val) # Assembling the Jacobian action. - elif interpolator.nargs: - return interpolator._interpolate(operand, output=tensor, default_missing_val=default_missing_val) - # Assembling the operator else: - return interpolator._interpolate(output=tensor, default_missing_val=default_missing_val) + return interpolator._interpolate(operand, output=tensor, default_missing_val=default_missing_val) elif rank == 2: res = tensor.petscmat if tensor else PETSc.Mat() # Get the interpolation matrix @@ -644,7 +641,6 @@ def base_form_assembly_visitor(self, expr, tensor, *args): tensor = self.assembled_matrix(orig_expr, res) return tensor else: - # The case rank == 0 is handled via the DAG restructuring raise ValueError("Incompatible number of arguments.") elif tensor and isinstance(expr, (firedrake.Function, firedrake.Cofunction, firedrake.MatrixBase)): return tensor.assign(expr) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index acc7fee3d4..12aff36f09 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -736,7 +736,6 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, ** if transpose is not None: warnings.warn("'transpose' argument is deprecated, use 'adjoint' instead", FutureWarning) adjoint = transpose or adjoint - try: assembled_interpolator = self.frozen_assembled_interpolator copy_required = True @@ -778,7 +777,7 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, ** if output: output.assign(assembled_interpolator) return output - if isinstance(self.V, (firedrake.Function, firedrake.Cofunction)): + if isinstance(self.V, firedrake.Function): if copy_required: self.V.assign(assembled_interpolator) return self.V @@ -826,7 +825,7 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True): f.assign(val) tensor = f.dat elif rank == 2: - if isinstance(V, (firedrake.Function, firedrake.Cofunction)): + if isinstance(V, firedrake.Function): raise ValueError("Cannot interpolate an expression with an argument into a Function") if len(V) > 1: raise NotImplementedError("Interpolation of mixed expressions with arguments is not supported") From 7fa72d6fd62cbcfda3d914aebec7c579e4d59230 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 12 Sep 2025 23:30:51 +0100 Subject: [PATCH 010/125] Fix Real adjoint --- firedrake/interpolation.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 12aff36f09..4f111f694c 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -784,7 +784,7 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, ** else: if copy_required: return assembled_interpolator.copy() - elif isinstance(assembled_interpolator.dat, op2.Global): + elif len(self.arguments) == 0: return assembled_interpolator.dat.data.item() else: return assembled_interpolator @@ -1016,11 +1016,11 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): """ weight = firedrake.Function(W) firedrake.par_loop((domain, instructions), ufl.dx, {"w": (weight, op2.INC)}) - - tmp = firedrake.Function(W) - with weight.dat.vec as w, dual_arg.dat.vec as x, tmp.dat.vec as y: - y.pointwiseDivide(x, w) - dual_arg = tmp + with weight.dat.vec as w: + w.reciprocal() + petscmat = PETSc.Mat().createDiagonal(w) + weight_mat = firedrake.AssembledMatrix((firedrake.TestFunction(W.dual()), firedrake.TrialFunction(W)), None, petscmat) + dual_arg = firedrake.assemble(ufl.action(weight_mat, dual_arg)) # We need to pass both the ufl element and the finat element # because the finat elements might not have the right mapping From 3f5ca713c37c07834aef6babce1bb344d51d0696 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Sat, 13 Sep 2025 21:29:19 +0100 Subject: [PATCH 011/125] Fixup --- firedrake/assemble.py | 6 ++++-- firedrake/interpolation.py | 21 +++++++++++---------- tsfc/driver.py | 6 +++--- 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 3a4545f125..9394c12aea 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -597,7 +597,7 @@ def base_form_assembly_visitor(self, expr, tensor, *args): # Assemble the interpolator matrix if the meshes are different target_mesh = V.mesh() source_mesh = extract_unique_domain(operand) or target_mesh - if is_adjoint and source_mesh is not target_mesh: + if is_adjoint and rank < 2 and source_mesh is not target_mesh: expr = reconstruct_interp(operand, v=V) matfree = (rank == len(expr.arguments())) and (rank < 2) @@ -634,9 +634,11 @@ def base_form_assembly_visitor(self, expr, tensor, *args): if is_adjoint: # Out-of-place Hermitian transpose petsc_mat.hermitianTranspose(out=res) - else: + elif tensor: # Copy the interpolation matrix into the output tensor petsc_mat.copy(result=res) + else: + res = petsc_mat if tensor is None: tensor = self.assembled_matrix(orig_expr, res) return tensor diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 4f111f694c..24a6fb8bc5 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -724,7 +724,6 @@ def __init__(self, expr, V, subset=None, freeze_expr=False, access=op2.WRITE, except FIAT.hdiv_trace.TraceError: raise NotImplementedError("Can't interpolate onto traces sorry") self.arguments = expr.arguments() - self.nargs = len(extract_arguments(operand)) @PETSc.Log.EventDecorator() def _interpolate(self, *function, output=None, transpose=None, adjoint=False, **kwargs): @@ -743,7 +742,7 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, ** assembled_interpolator = self.callable() copy_required = False # Return the original if self.freeze_expr: - if self.nargs: + if len(self.arguments) == 2: # Interpolation operator self.frozen_assembled_interpolator = assembled_interpolator else: @@ -782,10 +781,10 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, ** self.V.assign(assembled_interpolator) return self.V else: - if copy_required: - return assembled_interpolator.copy() - elif len(self.arguments) == 0: + if len(self.arguments) == 0: return assembled_interpolator.dat.data.item() + elif copy_required: + return assembled_interpolator.copy() else: return assembled_interpolator @@ -908,7 +907,7 @@ def callable(): else: loops = [] if len(V) == 1: - loops.extend(_interpolator(V, tensor, expr, subset, arguments, access, bcs=bcs)) + expressions = (expr,) else: if (hasattr(operand, "subfunctions") and len(operand.subfunctions) == len(V) and all(sub_op.ufl_shape == Vsub.value_shape for Vsub, sub_op in zip(V, operand.subfunctions))): @@ -926,16 +925,18 @@ def callable(): operands.append(ufl.as_tensor(numpy.reshape(components, Vsub.value_shape))) offset += Vsub.value_size + # Split the dual argument if isinstance(dual_arg, Cofunction): duals = dual_arg.subfunctions elif isinstance(dual_arg, Coargument): duals = [Coargument(Vsub, number=dual_arg.number()) for Vsub in dual_arg.function_space()] else: duals = [v for _, v in sorted(firedrake.formmanipulation.split_form(dual_arg))] - # Interpolate each sub expression into each function space - for Vsub, sub_tensor, sub_op, sub_dual in zip(V, tensor, operands, duals): - sub_expr = expr._ufl_expr_reconstruct_(sub_op, sub_dual) - loops.extend(_interpolator(Vsub, sub_tensor, sub_expr, subset, arguments, access, bcs=bcs)) + expressions = map(expr._ufl_expr_reconstruct_, operands, duals) + + # Interpolate each sub expression into each function space + for Vsub, sub_tensor, sub_expr in zip(V, tensor, expressions): + loops.extend(_interpolator(Vsub, sub_tensor, sub_expr, subset, arguments, access, bcs=bcs)) if bcs and rank == 1: loops.extend(partial(bc.apply, f) for bc in bcs) diff --git a/tsfc/driver.py b/tsfc/driver.py index b27264c8fe..56b809014d 100644 --- a/tsfc/driver.py +++ b/tsfc/driver.py @@ -296,9 +296,9 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, # indices needed for compilation of the expression evaluation, basis_indices = to_element.dual_evaluation(fn) - # Compute the adjoint by contracting against the dual argument - if dual_arg and not isinstance(dual_arg, ufl.Coargument): - k = len(basis_indices)-len(operand.ufl_shape) + # Compute the action against the dual argument + if dual_arg in coefficients: + k = len(basis_indices) - len(operand.ufl_shape) beta = basis_indices[k:] + basis_indices[:k] shape = tuple(i.extent for i in beta) gem_dual = gem.Variable(f"w_{coefficients.index(dual_arg)}", shape=shape) From 42f848c97b78978219094637137a2a669ff3056c Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 17 Sep 2025 14:09:05 +0100 Subject: [PATCH 012/125] Fix multiindices --- firedrake/interpolation.py | 16 ++--- .../firedrake/regression/test_interpolate.py | 65 +++++++++++-------- tsfc/driver.py | 14 ++-- 3 files changed, 55 insertions(+), 40 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 24a6fb8bc5..9bbf2dc7d5 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -1007,7 +1007,6 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): needs_weight = isinstance(dual_arg, ufl.Cofunction) and not to_element.is_dg() if needs_weight: W = dual_arg.function_space() - # TODO cache DOF multiplicity shapes = (W.finat_element.space_dimension(), W.block_size) domain = "{[i,j]: 0 <= i < %d and 0 <= j < %d}" % shapes instructions = """ @@ -1017,11 +1016,14 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): """ weight = firedrake.Function(W) firedrake.par_loop((domain, instructions), ufl.dx, {"w": (weight, op2.INC)}) - with weight.dat.vec as w: - w.reciprocal() - petscmat = PETSc.Mat().createDiagonal(w) - weight_mat = firedrake.AssembledMatrix((firedrake.TestFunction(W.dual()), firedrake.TrialFunction(W)), None, petscmat) - dual_arg = firedrake.assemble(ufl.action(weight_mat, dual_arg)) + + # Create a copy and apply the weight + # TODO include this in the callables + v = firedrake.Function(dual_arg) + with v.dat.vec as x, weight.dat.vec as w: + x.pointwiseDivide(x, w) + + expr = expr._ufl_expr_reconstruct_(operand, v=v) # We need to pass both the ufl element and the finat element # because the finat elements might not have the right mapping @@ -1039,10 +1041,8 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): name = kernel.name kernel = op2.Kernel(ast, name, requires_zeroed_output_arguments=True, flop_count=kernel.flop_count, events=(kernel.event,)) - parloop_args = [kernel, cell_set] - expr = ufl.Interpolate(operand, v=dual_arg) coefficients = tsfc_interface.extract_numbered_coefficients(expr, coefficient_numbers) if needs_external_coords: coefficients = [source_mesh.coordinates] + coefficients diff --git a/tests/firedrake/regression/test_interpolate.py b/tests/firedrake/regression/test_interpolate.py index 3f3072a881..70f65dc9e4 100644 --- a/tests/firedrake/regression/test_interpolate.py +++ b/tests/firedrake/regression/test_interpolate.py @@ -327,33 +327,44 @@ def test_trace(): assert np.allclose(x_tr_cg.dat.data, x_tr_dir.dat.data) +@pytest.mark.parametrize("rank", (0, 1)) +@pytest.mark.parametrize("mat_type", ("matfree", "aij")) @pytest.mark.parametrize("degree", range(1, 4)) -def test_adjoint_Pk(degree): - mesh = UnitSquareMesh(10, 10) - Pkp1 = FunctionSpace(mesh, "CG", degree+1) - Pk = FunctionSpace(mesh, "CG", degree) - - v = conj(TestFunction(Pkp1)) - u_Pk = assemble(conj(TestFunction(Pk)) * dx) - v_adj = assemble(interpolate(TestFunction(Pk), assemble(v * dx))) - - assert np.allclose(u_Pk.dat.data, v_adj.dat.data) - - v_adj_form = assemble(interpolate(TestFunction(Pk), v * dx)) +@pytest.mark.parametrize("cell", ["triangle", "quadrilateral"]) +@pytest.mark.parametrize("shape", ("scalar", "vector", "tensor")) +def test_adjoint_Pk(rank, mat_type, degree, cell, shape): + quad = (cell == "quadrilateral") + mesh = UnitSquareMesh(5, 5, quadrilateral=quad) - assert np.allclose(v_adj_form.dat.data, v_adj.dat.data) - - -def test_adjoint_quads(): - mesh = UnitSquareMesh(10, 10) - P1 = FunctionSpace(mesh, "CG", 1) - P2 = FunctionSpace(mesh, "CG", 2) - - v = conj(TestFunction(P2)) - u_P1 = assemble(conj(TestFunction(P1)) * dx) - v_adj = assemble(interpolate(TestFunction(P1), assemble(v * dx))) - - assert np.allclose(u_P1.dat.data, v_adj.dat.data) + x = SpatialCoordinate(mesh) + expr = {"scalar": x[0], "vector": x, "tensor": outer(x, x)}[shape] + fs = {"scalar": FunctionSpace, "vector": VectorFunctionSpace, "tensor": TensorFunctionSpace}[shape] + Pk = fs(mesh, "CG", degree) + Pkp1 = fs(mesh, "CG", degree+1) + + v = assemble(inner(expr, TestFunction(Pkp1)) * dx) + + if rank == 0: + operand = Function(Pk).interpolate(expr) + else: + operand = TestFunction(Pk) + + if mat_type == "matfree": + result = assemble(interpolate(operand, v)) + else: + adj_interp = assemble(interpolate(operand, TrialFunction(Pkp1.dual()))) + if rank == 0: + result = assemble(action(v, adj_interp)) + else: + result = assemble(action(adj_interp, v)) + + expect = assemble(inner(expr, operand) * dx) + if rank == 0: + assert np.allclose(result, expect) + else: + assert expect.function_space() == result.function_space() + for x, y in zip(result.subfunctions, expect.subfunctions): + assert np.allclose(x.dat.data, y.dat.data) def test_adjoint_dg(): @@ -361,9 +372,9 @@ def test_adjoint_dg(): cg1 = FunctionSpace(mesh, "CG", 1) dg1 = FunctionSpace(mesh, "DG", 1) - v = conj(TestFunction(dg1)) + L = conj(TestFunction(dg1)) * dx u_cg = assemble(conj(TestFunction(cg1)) * dx) - v_adj = assemble(interpolate(TestFunction(cg1), assemble(v * dx))) + v_adj = assemble(interpolate(TestFunction(cg1), L)) assert np.allclose(u_cg.dat.data, v_adj.dat.data) diff --git a/tsfc/driver.py b/tsfc/driver.py index 56b809014d..cadca7835e 100644 --- a/tsfc/driver.py +++ b/tsfc/driver.py @@ -1,6 +1,7 @@ import collections import time import sys +import numpy from itertools import chain from finat.physically_mapped import DirectlyDefinedElement, PhysicallyMappedElement @@ -298,11 +299,14 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, # Compute the action against the dual argument if dual_arg in coefficients: - k = len(basis_indices) - len(operand.ufl_shape) - beta = basis_indices[k:] + basis_indices[:k] - shape = tuple(i.extent for i in beta) - gem_dual = gem.Variable(f"w_{coefficients.index(dual_arg)}", shape=shape) - evaluation = gem.IndexSum(evaluation * gem_dual[beta], basis_indices) + name = f"w_{coefficients.index(dual_arg)}" + shape = tuple(i.extent for i in basis_indices) + size = numpy.prod(shape, dtype=int) + gem_dual = gem.Variable(name, shape=(size,)) + gem_dual = gem.reshape(gem_dual, shape) + + evaluation = gem.IndexSum(evaluation * gem_dual[basis_indices], basis_indices) + evaluation = gem.optimise.sum_factorise(*gem.optimise.delta_elimination(*gem.optimise.traverse_product(evaluation))) basis_indices = () # Build kernel body From 08dba92f49e8848ecbf65d966888e8bdddcb413d Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 17 Sep 2025 15:48:51 +0100 Subject: [PATCH 013/125] Remove interpolate(Function(DG0), CG2) from test --- tests/firedrake/regression/test_adjoint_operators.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/firedrake/regression/test_adjoint_operators.py b/tests/firedrake/regression/test_adjoint_operators.py index 04ef29e709..03557bf435 100644 --- a/tests/firedrake/regression/test_adjoint_operators.py +++ b/tests/firedrake/regression/test_adjoint_operators.py @@ -112,8 +112,8 @@ def test_interpolate_scalar_valued(rg): def test_interpolate_vector_valued(): mesh = UnitSquareMesh(10, 10) V1 = VectorFunctionSpace(mesh, "CG", 1) - V2 = VectorFunctionSpace(mesh, "DG", 0) - V3 = VectorFunctionSpace(mesh, "CG", 2) + V2 = VectorFunctionSpace(mesh, "CG", 2) + V3 = VectorFunctionSpace(mesh, "CG", 3) x = SpatialCoordinate(mesh) f = assemble(interpolate(as_vector((x[0]*x[1], x[0]+x[1])), V1)) From d1a0710e1dd29884e448fa7b352a26a4da2568f6 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 17 Sep 2025 16:23:15 +0100 Subject: [PATCH 014/125] Do not sum_factorise --- tsfc/driver.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tsfc/driver.py b/tsfc/driver.py index cadca7835e..5794deecf1 100644 --- a/tsfc/driver.py +++ b/tsfc/driver.py @@ -306,7 +306,6 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, gem_dual = gem.reshape(gem_dual, shape) evaluation = gem.IndexSum(evaluation * gem_dual[basis_indices], basis_indices) - evaluation = gem.optimise.sum_factorise(*gem.optimise.delta_elimination(*gem.optimise.traverse_product(evaluation))) basis_indices = () # Build kernel body From 9329d1a732eba032235e2f4878e23d3e2f6d019b Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 17 Sep 2025 18:35:27 +0100 Subject: [PATCH 015/125] Fix complex conjugate --- tsfc/driver.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tsfc/driver.py b/tsfc/driver.py index 5794deecf1..87863dd1d6 100644 --- a/tsfc/driver.py +++ b/tsfc/driver.py @@ -302,9 +302,9 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, name = f"w_{coefficients.index(dual_arg)}" shape = tuple(i.extent for i in basis_indices) size = numpy.prod(shape, dtype=int) - gem_dual = gem.Variable(name, shape=(size,)) - gem_dual = gem.reshape(gem_dual, shape) - + gem_dual = gem.reshape(gem.Variable(name, shape=(size,)), shape) + if complex_mode: + evaluation = gem.MathFunction('conj', evaluation) evaluation = gem.IndexSum(evaluation * gem_dual[basis_indices], basis_indices) basis_indices = () From 44a98edf28d996875626cb1c487e4955160096ec Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 18 Sep 2025 10:12:06 +0100 Subject: [PATCH 016/125] Suggestions from review --- firedrake/assemble.py | 6 +++++- firedrake/interpolation.py | 8 +++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 9394c12aea..fe63b34fdc 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -594,7 +594,8 @@ def base_form_assembly_visitor(self, expr, tensor, *args): v1: v1.reconstruct(number=v0.number())}) v, operand = expr.argument_slots() - # Assemble the interpolator matrix if the meshes are different + # Matrix-free adjoint interpolation is only implemented by SameMeshInterpolator + # so we need assemble the interpolator matrix if the meshes are different target_mesh = V.mesh() source_mesh = extract_unique_domain(operand) or target_mesh if is_adjoint and rank < 2 and source_mesh is not target_mesh: @@ -605,6 +606,9 @@ def base_form_assembly_visitor(self, expr, tensor, *args): interp_data = expr.interp_data.copy() default_missing_val = interp_data.pop('default_missing_val', None) if matfree and ((is_adjoint and rank == 1) or rank == 0): + # Adjoint interpolation of a Cofunction or the action of a + # Cofunction on an interpolated Function require INC access + # on the output tensor interp_data["access"] = op2.INC if rank == 1 and matfree and isinstance(tensor, firedrake.Function): diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 9bbf2dc7d5..c941edb0b6 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -389,7 +389,9 @@ def __init__( ) if isinstance(expr, ufl.Interpolate): - expr, = expr.ufl_operands + dual_arg, expr = expr.argument_slots() + if not isinstance(dual_arg, Coargument): + raise NotImplementedError(f"{type(self).__name__} does not support matrix-free adjoint interpolation.") super().__init__(expr, V, subset, freeze_expr, access, bcs, allow_missing_dofs, matfree) self.arguments = extract_arguments(expr) @@ -749,7 +751,7 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, ** # Interpolation action self.frozen_assembled_interpolator = assembled_interpolator.copy() - if len(self.arguments) == 2 and len(function): + if len(self.arguments) == 2 and len(function) > 0: function, = function if not hasattr(function, "dat"): raise ValueError("The expression had arguments: we therefore need to be given a Function (not an expression) to interpolate!") @@ -865,7 +867,7 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True): tensor = op2.Mat(sparsity) f = tensor else: - raise ValueError("Cannot interpolate an expression with %d arguments" % rank) + raise ValueError(f"Cannot interpolate an expression with {rank} arguments") if vom_onto_other_vom: wrapper = VomOntoVomWrapper(V, source_mesh, target_mesh, operand, matfree) From 120a2a34badff5457285a7e55b5d543a5b939ad6 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 18 Sep 2025 13:32:12 +0100 Subject: [PATCH 017/125] Reusable Interpolator --- firedrake/assemble.py | 2 ++ firedrake/interpolation.py | 20 ++++++++++++------- firedrake/variational_solver.py | 4 ++-- .../firedrake/regression/test_interpolate.py | 9 +++++---- 4 files changed, 22 insertions(+), 13 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index fe63b34fdc..bfff825e42 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -538,6 +538,8 @@ def base_form_assembly_visitor(self, expr, tensor, *args): result = expr.assemble(assembly_opts=opts) return tensor.assign(result) if tensor else result elif isinstance(expr, ufl.Interpolate): + if not isinstance(expr, firedrake.Interpolate): + expr = firedrake.Interpolate(*reversed(expr.dual_args())) orig_expr = expr # Replace assembled children _, operand = expr.argument_slots() diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index c941edb0b6..6f64f60fb0 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -1006,8 +1006,13 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): parameters = {} parameters['scalar_type'] = utils.ScalarType + callables = () + if access == op2.INC: + callables += (tensor.zero,) + needs_weight = isinstance(dual_arg, ufl.Cofunction) and not to_element.is_dg() if needs_weight: + # Compute the reciprocal of the DOF multiplicity W = dual_arg.function_space() shapes = (W.finat_element.space_dimension(), W.block_size) domain = "{[i,j]: 0 <= i < %d and 0 <= j < %d}" % shapes @@ -1018,14 +1023,14 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): """ weight = firedrake.Function(W) firedrake.par_loop((domain, instructions), ufl.dx, {"w": (weight, op2.INC)}) + with weight.dat.vec as w: + w.reciprocal() - # Create a copy and apply the weight - # TODO include this in the callables - v = firedrake.Function(dual_arg) - with v.dat.vec as x, weight.dat.vec as w: - x.pointwiseDivide(x, w) - + # Create a buffer for the weighted Cofunction and a callable to apply the weight + v = firedrake.Function(W) expr = expr._ufl_expr_reconstruct_(operand, v=v) + with weight.dat.vec_ro as w, dual_arg.dat.vec_ro as x, v.dat.vec_wo as y: + callables += (partial(y.pointwiseMult, x, w),) # We need to pass both the ufl element and the finat element # because the finat elements might not have the right mapping @@ -1043,6 +1048,7 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): name = kernel.name kernel = op2.Kernel(ast, name, requires_zeroed_output_arguments=True, flop_count=kernel.flop_count, events=(kernel.event,)) + parloop_args = [kernel, cell_set] coefficients = tsfc_interface.extract_numbered_coefficients(expr, coefficient_numbers) @@ -1158,7 +1164,7 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): if isinstance(tensor, op2.Mat): return parloop_compute_callable, tensor.assemble else: - return copyin + (parloop_compute_callable, ) + copyout + return copyin + callables + (parloop_compute_callable, ) + copyout try: diff --git a/firedrake/variational_solver.py b/firedrake/variational_solver.py index 724c4257d3..b23424aac6 100644 --- a/firedrake/variational_solver.py +++ b/firedrake/variational_solver.py @@ -7,7 +7,7 @@ from firedrake import dmhooks, slate, solving, solving_utils, ufl_expr, utils from firedrake.petsc import PETSc, DEFAULT_KSP_PARAMETERS, DEFAULT_SNES_PARAMETERS from firedrake.function import Function -from firedrake.interpolation import Interpolate +from firedrake.interpolation import interpolate from firedrake.matrix import MatrixBase from firedrake.ufl_expr import TrialFunction, TestFunction from firedrake.bcs import DirichletBC, EquationBC, extract_subdomain_ids, restricted_function_space @@ -98,7 +98,7 @@ def __init__(self, F, u, bcs=None, J=None, F_arg, = F.arguments() self.F = replace(F, {F_arg: v_res, self.u: self.u_restrict}) else: - self.F = Interpolate(v_res, replace(F, {self.u: self.u_restrict})) + self.F = interpolate(v_res, replace(F, {self.u: self.u_restrict})) v_arg, u_arg = self.J.arguments() self.J = replace(self.J, {v_arg: v_res, u_arg: u_res, self.u: self.u_restrict}) diff --git a/tests/firedrake/regression/test_interpolate.py b/tests/firedrake/regression/test_interpolate.py index 70f65dc9e4..69f139ecbb 100644 --- a/tests/firedrake/regression/test_interpolate.py +++ b/tests/firedrake/regression/test_interpolate.py @@ -329,7 +329,7 @@ def test_trace(): @pytest.mark.parametrize("rank", (0, 1)) @pytest.mark.parametrize("mat_type", ("matfree", "aij")) -@pytest.mark.parametrize("degree", range(1, 4)) +@pytest.mark.parametrize("degree", (1, 3)) @pytest.mark.parametrize("cell", ["triangle", "quadrilateral"]) @pytest.mark.parametrize("shape", ("scalar", "vector", "tensor")) def test_adjoint_Pk(rank, mat_type, degree, cell, shape): @@ -350,14 +350,15 @@ def test_adjoint_Pk(rank, mat_type, degree, cell, shape): operand = TestFunction(Pk) if mat_type == "matfree": - result = assemble(interpolate(operand, v)) + interp = interpolate(operand, v) else: adj_interp = assemble(interpolate(operand, TrialFunction(Pkp1.dual()))) if rank == 0: - result = assemble(action(v, adj_interp)) + interp = action(v, adj_interp) else: - result = assemble(action(adj_interp, v)) + interp = action(adj_interp, v) + result = assemble(interp) expect = assemble(inner(expr, operand) * dx) if rank == 0: assert np.allclose(result, expect) From af77aac88bd61a90a0569f308669e819fba2164b Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 18 Sep 2025 19:22:07 +0100 Subject: [PATCH 018/125] Allow interpolate(..., BaseForm) --- firedrake/interpolation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 6f64f60fb0..3883e3d623 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -175,7 +175,7 @@ def interpolate(expr, V, subset=None, access=op2.WRITE, allow_missing_dofs=False """ if isinstance(V, (Cofunction, Coargument)): dual_arg = V - elif isinstance(V, ufl.Form): + elif isinstance(V, (ufl.Form, ufl.BaseForm)): rank = len(V.arguments()) if rank == 1: dual_arg = V From ea5c07b1e1c3e6a2baffb9470bfdd044e8918f93 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 18 Sep 2025 19:52:14 +0100 Subject: [PATCH 019/125] Update firedrake/assemble.py --- firedrake/assemble.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index bfff825e42..74d35b0900 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -640,7 +640,7 @@ def base_form_assembly_visitor(self, expr, tensor, *args): if is_adjoint: # Out-of-place Hermitian transpose petsc_mat.hermitianTranspose(out=res) - elif tensor: + elif res: # Copy the interpolation matrix into the output tensor petsc_mat.copy(result=res) else: From 62fa31e1972f4cd8edf0db16758e101d1ab0e563 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 19 Sep 2025 16:57:57 +0100 Subject: [PATCH 020/125] Explicitly assemble the interpolate adjoint matrix --- firedrake/assemble.py | 9 +- firedrake/interpolation.py | 89 +++++++++++-------- .../firedrake/regression/test_interp_dual.py | 6 +- tsfc/driver.py | 27 +++--- 4 files changed, 71 insertions(+), 60 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 74d35b0900..45cc024d22 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -589,13 +589,6 @@ def base_form_assembly_visitor(self, expr, tensor, *args): assemble(sub_interp, tensor=tensor.subfunctions[i]) return tensor - # Workaround: Renumber argument when needed since Interpolator assumes it takes a zero-numbered argument. - if not is_adjoint and rank == 2: - v0, v1 = expr.arguments() - expr = ufl.replace(expr, {v0: v0.reconstruct(number=v1.number()), - v1: v1.reconstruct(number=v0.number())}) - v, operand = expr.argument_slots() - # Matrix-free adjoint interpolation is only implemented by SameMeshInterpolator # so we need assemble the interpolator matrix if the meshes are different target_mesh = V.mesh() @@ -637,7 +630,7 @@ def base_form_assembly_visitor(self, expr, tensor, *args): # Get the interpolation matrix op2_mat = interpolator.callable() petsc_mat = op2_mat.handle - if is_adjoint: + if is_adjoint and (source_mesh is not target_mesh): # Out-of-place Hermitian transpose petsc_mat.hermitianTranspose(out=res) elif res: diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 3883e3d623..65eed6e7bc 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -278,9 +278,11 @@ def __init__( allow_missing_dofs=False, matfree=True ): - if isinstance(expr, ufl.Interpolate): - expr, = expr.ufl_operands - self.expr = expr + if not isinstance(expr, ufl.Interpolate): + fs = V if isinstance(V, ufl.FunctionSpace) else V.function_space() + expr = firedrake.Interpolate(expr, fs) + operand, = expr.ufl_operands + self.expr = operand self.V = V self.subset = subset self.freeze_expr = freeze_expr @@ -292,11 +294,16 @@ def __init__( # Cope with the different convention of `Interpolate` and `Interpolator`: # -> Interpolate(Argument(V1, 1), Argument(V2.dual(), 0)) # -> Interpolator(Argument(V1, 0), V2) - expr_args = extract_arguments(expr) - if expr_args and expr_args[0].number() == 0: - v, = expr_args - expr = replace(expr, {v: v.reconstruct(number=1)}) - self.expr_renumbered = expr + target_mesh = as_domain(V) + source_mesh = extract_unique_domain(operand) or target_mesh + if len(expr.arguments()) == 2 and target_mesh is not source_mesh: + expr_args = extract_arguments(operand) + if expr_args and expr_args[0].number() == 0: + v0, v1 = expr.arguments() + expr = ufl.replace(expr, {v0: v0.reconstruct(number=v1.number()), + v1: v1.reconstruct(number=v0.number())}) + self.expr_renumbered, = expr.ufl_operands + self.interpolate = expr def interpolate(self, *function, transpose=None, adjoint=False, default_missing_val=None): """ @@ -387,14 +394,16 @@ def __init__( raise NotImplementedError( "Can only interpolate into spaces with point evaluation nodes." ) + if not isinstance(expr, ufl.Interpolate): + fs = V if isinstance(V, ufl.FunctionSpace) else V.function_space() + expr = Interpolate(expr, fs) + dual_arg, expr = expr.argument_slots() + if not isinstance(dual_arg, Coargument): + raise NotImplementedError(f"{type(self).__name__} does not support matrix-free adjoint interpolation.") - if isinstance(expr, ufl.Interpolate): - dual_arg, expr = expr.argument_slots() - if not isinstance(dual_arg, Coargument): - raise NotImplementedError(f"{type(self).__name__} does not support matrix-free adjoint interpolation.") super().__init__(expr, V, subset, freeze_expr, access, bcs, allow_missing_dofs, matfree) - self.arguments = extract_arguments(expr) + self.arguments = extract_arguments(self.expr_renumbered) self.nargs = len(self.arguments) if self._allow_missing_dofs: @@ -694,13 +703,11 @@ class SameMeshInterpolator(Interpolator): @no_annotations def __init__(self, expr, V, subset=None, freeze_expr=False, access=op2.WRITE, bcs=None, matfree=True, allow_missing_dofs=False, **kwargs): - if isinstance(expr, ufl.Interpolate): - operand, = expr.ufl_operands - else: + if not isinstance(expr, ufl.Interpolate): fs = V if isinstance(V, ufl.FunctionSpace) else V.function_space() - operand = expr - expr = Interpolate(operand, fs) + expr = Interpolate(expr, fs) if subset is None: + operand, = expr.ufl_operands target_mesh = as_domain(V) source_mesh = extract_unique_domain(operand) or target_mesh target = target_mesh.topology @@ -721,6 +728,7 @@ def __init__(self, expr, V, subset=None, freeze_expr=False, access=op2.WRITE, pass super().__init__(expr, V, subset=subset, freeze_expr=freeze_expr, access=access, bcs=bcs, matfree=matfree, allow_missing_dofs=allow_missing_dofs) + expr = self.interpolate try: self.callable = make_interpolator(expr, V, subset, access, bcs=bcs, matfree=matfree) except FIAT.hdiv_trace.TraceError: @@ -761,8 +769,8 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, ** else: mul = assembled_interpolator.handle.mult row, col = self.arguments - V = col.function_space().dual() - assert function.function_space() == row.function_space() + V = row.function_space().dual() + assert function.function_space() == col.function_space() result = output or firedrake.Function(V) with function.dat.vec_ro as x, result.dat.vec_wo as out: @@ -815,7 +823,7 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True): f = V V = f.function_space() else: - V_dest = arguments[-1].function_space().dual() + V_dest = arguments[0].function_space().dual() f = firedrake.Function(V_dest) if access in {firedrake.MIN, firedrake.MAX}: finfo = numpy.finfo(f.dat.dtype) @@ -828,10 +836,11 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True): elif rank == 2: if isinstance(V, firedrake.Function): raise ValueError("Cannot interpolate an expression with an argument into a Function") - if len(V) > 1: + Vrow = arguments[0].function_space() + Vcol = arguments[1].function_space() + if len(Vrow) > 1 or len(Vcol) > 1: raise NotImplementedError("Interpolation of mixed expressions with arguments is not supported") - argfs = arguments[0].function_space() - argfs_map = argfs.cell_node_map() + Vcol_map = Vcol.cell_node_map() if isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology) and target_mesh is not source_mesh and not vom_onto_other_vom: if not isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology): raise NotImplementedError("Can only interpolate onto a Vertex Only Mesh") @@ -839,29 +848,30 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True): raise ValueError("Cannot interpolate onto a mesh of a different geometric dimension") if not hasattr(target_mesh, "_parent_mesh") or target_mesh._parent_mesh is not source_mesh: raise ValueError("Can only interpolate across meshes where the source mesh is the parent of the target") - if argfs_map: + if Vcol_map: # Since the par_loop is over the target mesh cells we need to # compose a map that takes us from target mesh cells to the - # function space nodes on the source mesh. NOTE: argfs_map is + # function space nodes on the source mesh. NOTE: Vcol_map is # allowed to be None when interpolating from a Real space, even # in the trans-mesh case. if source_mesh.extruded: # ExtrudedSet cannot be a map target so we need to build # this ourselves - argfs_map = vom_cell_parent_node_map_extruded(target_mesh, argfs_map) + Vcol_map = vom_cell_parent_node_map_extruded(target_mesh, Vcol_map) else: - argfs_map = compose_map_and_cache(target_mesh.cell_parent_cell_map, argfs_map) + Vcol_map = compose_map_and_cache(target_mesh.cell_parent_cell_map, Vcol_map) elif vom_onto_other_vom: - argfs_map = argfs.cell_node_map() + Vcol_map = Vcol.cell_node_map() else: - argfs_map = argfs.entity_node_map(target_mesh.topology, "cell", None, None) + Vcol_map = Vcol.entity_node_map(target_mesh.topology, "cell", None, None) if vom_onto_other_vom: # We make our own linear operator for this case using PETSc SFs tensor = None else: - sparsity = op2.Sparsity((V.dof_dset, argfs.dof_dset), - [(V.cell_node_map(), argfs_map, None)], # non-mixed - name="%s_%s_sparsity" % (V.name, argfs.name), + Vrow_map = Vrow.cell_node_map() + sparsity = op2.Sparsity((Vrow.dof_dset, Vcol.dof_dset), + [(Vrow_map, Vcol_map, None)], # non-mixed + name="%s_%s_sparsity" % (Vrow.name, Vcol.name), nest=False, block_sparse=True) tensor = op2.Mat(sparsity) @@ -894,7 +904,7 @@ def callable(): # safely use the argument function space. NOTE: If this changes # after cofunctions are fully implemented, this will need to be # reconsidered. - temp_source_func = firedrake.Function(argfs) + temp_source_func = firedrake.Function(Vcol) wrapper.mpi_type, _ = get_dat_mpi_type(temp_source_func.dat) # Leave wrapper inside a callable so we can access the handle @@ -1073,9 +1083,10 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): parloop_args.append(tensor(access, V_dest.cell_node_map())) else: assert access == op2.WRITE # Other access descriptors not done for Matrices. - rows_map = V.cell_node_map() - Vcol = arguments[0].function_space() - assert tensor.handle.getSize() == (V.dim(), Vcol.dim()) + Vrow = arguments[0].function_space() + Vcol = arguments[1].function_space() + rows_map = Vrow.cell_node_map() + assert tensor.handle.getSize() == (Vrow.dim(), Vcol.dim()) if isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology): columns_map = Vcol.cell_node_map() if target_mesh is not source_mesh: @@ -1093,9 +1104,9 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): columns_map = Vcol.entity_node_map(target_mesh.topology, "cell", None, None) lgmaps = None if bcs: - bc_rows = [bc for bc in bcs if bc.function_space() == V] + bc_rows = [bc for bc in bcs if bc.function_space() == Vrow] bc_cols = [bc for bc in bcs if bc.function_space() == Vcol] - lgmaps = [(V.local_to_global_map(bc_rows), Vcol.local_to_global_map(bc_cols))] + lgmaps = [(Vrow.local_to_global_map(bc_rows), Vcol.local_to_global_map(bc_cols))] parloop_args.append(tensor(op2.WRITE, (rows_map, columns_map), lgmaps=lgmaps)) if oriented: co = target_mesh.cell_orientations() diff --git a/tests/firedrake/regression/test_interp_dual.py b/tests/firedrake/regression/test_interp_dual.py index 444f352453..50e29b05cb 100644 --- a/tests/firedrake/regression/test_interp_dual.py +++ b/tests/firedrake/regression/test_interp_dual.py @@ -106,7 +106,11 @@ def test_assemble_interp_adjoint_matrix(V1, V2): # Interpolation from V2* to V1* c1 = Cofunction(V1.dual()).interpolate(c2) # Interpolation matrix (V2* -> V1*) - a = assemble(adjoint(Iv1)) + adj_Iv1 = adjoint(Iv1) + a = assemble(adj_Iv1) + assert a.arguments() == adj_Iv1.arguments() + assert a.petscmat.getSize() == (V1.dim(), V2.dim()) + res = Cofunction(V1.dual()) with c2.dat.vec_ro as x, res.dat.vec_ro as y: a.petscmat.mult(x, y) diff --git a/tsfc/driver.py b/tsfc/driver.py index 87863dd1d6..532eb93902 100644 --- a/tsfc/driver.py +++ b/tsfc/driver.py @@ -225,9 +225,11 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, if isinstance(expression, ufl.Interpolate): v, _ = expression.argument_slots() - expression = ufl.Interpolate(operand, v) else: - expression = operand + arguments = extract_arguments(operand) + number = 1-arguments[0].number() if len(arguments) else 0 + v = ufl.Coargument(ufl.FunctionSpace(domain, ufl_element), number=number) + expression = ufl.Interpolate(operand, v) # Initialise kernel builder if interface is None: @@ -235,9 +237,10 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, from tsfc.kernel_interface.firedrake_loopy import ExpressionKernelBuilder as interface builder = interface(parameters["scalar_type"]) - arguments = extract_arguments(operand) - argument_multiindices = tuple(builder.create_element(arg.ufl_element()).get_indices() - for arg in arguments) + arguments = expression.arguments() + argument_multiindices = {arg.number(): builder.create_element(arg.ufl_element()).get_indices() + for arg in arguments} + assert len(argument_multiindices) == len(arguments) # Replace coordinates (if any) unless otherwise specified by kwarg if domain is None: @@ -284,11 +287,7 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, if isinstance(to_element, finat.QuadratureElement): kernel_cfg["quadrature_rule"] = to_element._rule - if isinstance(expression, ufl.Interpolate): - dual_arg, operand = expression.argument_slots() - else: - operand = expression - dual_arg = None + dual_arg, operand = expression.argument_slots() # Create callable for translation of UFL expression to gem fn = DualEvaluationCallable(operand, kernel_cfg) @@ -307,9 +306,13 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, evaluation = gem.MathFunction('conj', evaluation) evaluation = gem.IndexSum(evaluation * gem_dual[basis_indices], basis_indices) basis_indices = () + else: + argument_multiindices[dual_arg.number()] = basis_indices + + argument_multiindices = dict(sorted(argument_multiindices.items())) # Build kernel body - return_indices = tuple(chain(basis_indices, *argument_multiindices)) + return_indices = tuple(chain.from_iterable(argument_multiindices.values())) return_shape = tuple(i.extent for i in return_indices) return_var = gem.Variable('A', return_shape or (1,)) return_expr = gem.Indexed(return_var, return_indices or (0,)) @@ -381,7 +384,7 @@ def __call__(self, ps): gem_expr, = fem.compile_ufl(self.expression, translation_context, point_sum=False) # In some cases ps.indices may be dropped from expr, but nothing # new should now appear - argument_multiindices = kernel_cfg["argument_multiindices"] + argument_multiindices = kernel_cfg["argument_multiindices"].values() assert set(gem_expr.free_indices) <= set(chain(ps.indices, *argument_multiindices)) return gem_expr From ef0b10cdab600db5b56f5cf4270ecbf5bd6949ac Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 19 Sep 2025 18:30:35 +0100 Subject: [PATCH 021/125] Fix up --- firedrake/assemble.py | 2 +- firedrake/interpolation.py | 33 +++++++++++++++----------------- firedrake/preconditioners/pmg.py | 6 +++--- tsfc/driver.py | 16 +++++----------- 4 files changed, 24 insertions(+), 33 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 45cc024d22..1a9d4813cd 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -593,7 +593,7 @@ def base_form_assembly_visitor(self, expr, tensor, *args): # so we need assemble the interpolator matrix if the meshes are different target_mesh = V.mesh() source_mesh = extract_unique_domain(operand) or target_mesh - if is_adjoint and rank < 2 and source_mesh is not target_mesh: + if is_adjoint and rank < 2 and (source_mesh is not target_mesh): expr = reconstruct_interp(operand, v=V) matfree = (rank == len(expr.arguments())) and (rank < 2) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 65eed6e7bc..81d7fce7cd 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -291,19 +291,18 @@ def __init__( self._allow_missing_dofs = allow_missing_dofs self.matfree = matfree self.callable = None - # Cope with the different convention of `Interpolate` and `Interpolator`: - # -> Interpolate(Argument(V1, 1), Argument(V2.dual(), 0)) - # -> Interpolator(Argument(V1, 0), V2) + # Workaround for matrix-explicit adjoint of cross-mesh interpolation + # Return instead the forward operator target_mesh = as_domain(V) source_mesh = extract_unique_domain(operand) or target_mesh - if len(expr.arguments()) == 2 and target_mesh is not source_mesh: + if len(expr.arguments()) == 2 and (target_mesh is not source_mesh): expr_args = extract_arguments(operand) if expr_args and expr_args[0].number() == 0: v0, v1 = expr.arguments() expr = ufl.replace(expr, {v0: v0.reconstruct(number=v1.number()), v1: v1.reconstruct(number=v0.number())}) self.expr_renumbered, = expr.ufl_operands - self.interpolate = expr + self.ufl_interpolate = expr def interpolate(self, *function, transpose=None, adjoint=False, default_missing_val=None): """ @@ -394,16 +393,14 @@ def __init__( raise NotImplementedError( "Can only interpolate into spaces with point evaluation nodes." ) - if not isinstance(expr, ufl.Interpolate): - fs = V if isinstance(V, ufl.FunctionSpace) else V.function_space() - expr = Interpolate(expr, fs) - dual_arg, expr = expr.argument_slots() - if not isinstance(dual_arg, Coargument): - raise NotImplementedError(f"{type(self).__name__} does not support matrix-free adjoint interpolation.") - + if isinstance(expr, ufl.Interpolate): + dual_arg, operand = expr.argument_slots() + if not isinstance(dual_arg, ufl.Coargument): + raise NotImplementedError(f"{type(self).__name__} does not support matrix-free adjoint action.") super().__init__(expr, V, subset, freeze_expr, access, bcs, allow_missing_dofs, matfree) - self.arguments = extract_arguments(self.expr_renumbered) + expr = self.expr_renumbered + self.arguments = extract_arguments(expr) self.nargs = len(self.arguments) if self._allow_missing_dofs: @@ -703,11 +700,11 @@ class SameMeshInterpolator(Interpolator): @no_annotations def __init__(self, expr, V, subset=None, freeze_expr=False, access=op2.WRITE, bcs=None, matfree=True, allow_missing_dofs=False, **kwargs): - if not isinstance(expr, ufl.Interpolate): - fs = V if isinstance(V, ufl.FunctionSpace) else V.function_space() - expr = Interpolate(expr, fs) if subset is None: - operand, = expr.ufl_operands + if isinstance(expr, ufl.Interpolate): + operand, = expr.ufl_operands + else: + operand = expr target_mesh = as_domain(V) source_mesh = extract_unique_domain(operand) or target_mesh target = target_mesh.topology @@ -728,7 +725,7 @@ def __init__(self, expr, V, subset=None, freeze_expr=False, access=op2.WRITE, pass super().__init__(expr, V, subset=subset, freeze_expr=freeze_expr, access=access, bcs=bcs, matfree=matfree, allow_missing_dofs=allow_missing_dofs) - expr = self.interpolate + expr = self.ufl_interpolate try: self.callable = make_interpolator(expr, V, subset, access, bcs=bcs, matfree=matfree) except FIAT.hdiv_trace.TraceError: diff --git a/firedrake/preconditioners/pmg.py b/firedrake/preconditioners/pmg.py index 1cdea965a6..f4b45a67a5 100644 --- a/firedrake/preconditioners/pmg.py +++ b/firedrake/preconditioners/pmg.py @@ -1457,7 +1457,7 @@ def make_kernels(self, Vf, Vc): except KeyError: pass prolong_kernel, _ = prolongation_transfer_kernel_action(Vf, self.uc) - matrix_kernel, coefficients = prolongation_transfer_kernel_action(Vf, firedrake.TestFunction(Vc)) + matrix_kernel, coefficients = prolongation_transfer_kernel_action(Vf, firedrake.TrialFunction(Vc)) # The way we transpose the prolongation kernel is suboptimal. # A local matrix is generated each time the kernel is executed. @@ -1593,7 +1593,7 @@ def prolongation_matrix_aij(P1, Pk, P1_bcs=[], Pk_bcs=[]): for bc in chain(Pk_bcs_i, P1_bcs_i) if bc is not None) matarg = mat[i, i](op2.WRITE, (Pk.sub(i).cell_node_map(), P1.sub(i).cell_node_map()), lgmaps=((rlgmap, clgmap), ), unroll_map=unroll) - expr = firedrake.TestFunction(P1.sub(i)) + expr = firedrake.TrialFunction(P1.sub(i)) kernel, coefficients = prolongation_transfer_kernel_action(Pk.sub(i), expr) parloop_args = [kernel, mesh.cell_set, matarg] for coefficient in coefficients: @@ -1610,7 +1610,7 @@ def prolongation_matrix_aij(P1, Pk, P1_bcs=[], Pk_bcs=[]): for bc in chain(Pk_bcs, P1_bcs) if bc is not None) matarg = mat(op2.WRITE, (Pk.cell_node_map(), P1.cell_node_map()), lgmaps=((rlgmap, clgmap), ), unroll_map=unroll) - expr = firedrake.TestFunction(P1) + expr = firedrake.TrialFunction(P1) kernel, coefficients = prolongation_transfer_kernel_action(Pk, expr) parloop_args = [kernel, mesh.cell_set, matarg] for coefficient in coefficients: diff --git a/tsfc/driver.py b/tsfc/driver.py index 532eb93902..f420300b17 100644 --- a/tsfc/driver.py +++ b/tsfc/driver.py @@ -6,7 +6,7 @@ from finat.physically_mapped import DirectlyDefinedElement, PhysicallyMappedElement import ufl -from ufl.algorithms import extract_arguments, extract_coefficients +from ufl.algorithms import extract_coefficients from ufl.algorithms.analysis import has_type from ufl.algorithms.apply_coefficient_split import CoefficientSplitter from ufl.classes import Form, GeometricQuantity @@ -211,11 +211,12 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, if isinstance(to_element, (PhysicallyMappedElement, DirectlyDefinedElement)): raise NotImplementedError("Don't know how to interpolate onto zany spaces, sorry") - orig_expression = expression + orig_coefficients = extract_coefficients(expression) if isinstance(expression, ufl.Interpolate): - operand, = expression.ufl_operands + v, operand = expression.argument_slots() else: operand = expression + v = ufl.FunctionSpace(extract_unique_domain(operand), ufl_element) # Map into reference space operand = apply_mapping(operand, ufl_element, domain) @@ -223,12 +224,6 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, # Apply UFL preprocessing operand = ufl_utils.preprocess_expression(operand, complex_mode=complex_mode) - if isinstance(expression, ufl.Interpolate): - v, _ = expression.argument_slots() - else: - arguments = extract_arguments(operand) - number = 1-arguments[0].number() if len(arguments) else 0 - v = ufl.Coargument(ufl.FunctionSpace(domain, ufl_element), number=number) expression = ufl.Interpolate(operand, v) # Initialise kernel builder @@ -249,7 +244,6 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, # Collect required coefficients and determine numbering coefficients = extract_coefficients(expression) - orig_coefficients = extract_coefficients(orig_expression) coefficient_numbers = tuple(map(orig_coefficients.index, coefficients)) builder.set_coefficient_numbers(coefficient_numbers) @@ -321,7 +315,7 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, # but we don't for now. evaluation, = impero_utils.preprocess_gem([evaluation]) impero_c = impero_utils.compile_gem([(return_expr, evaluation)], return_indices) - index_names = dict((idx, "p%d" % i) for (i, idx) in enumerate(basis_indices)) + index_names = {idx: f"p{i}" for (i, idx) in enumerate(basis_indices)} # Handle kernel interface requirements builder.register_requirements([evaluation]) builder.set_output(return_var) From 34ea5f30e24b83b4260eed0d381bdf72940ee530 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Sat, 20 Sep 2025 18:58:13 +0100 Subject: [PATCH 022/125] Move renumbering logic to Interpolator --- firedrake/assemble.py | 52 ++++++--------------------------- firedrake/interpolation.py | 59 ++++++++++++++++++++++++++++++-------- 2 files changed, 55 insertions(+), 56 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 1a9d4813cd..b7534bc412 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -540,7 +540,6 @@ def base_form_assembly_visitor(self, expr, tensor, *args): elif isinstance(expr, ufl.Interpolate): if not isinstance(expr, firedrake.Interpolate): expr = firedrake.Interpolate(*reversed(expr.dual_args())) - orig_expr = expr # Replace assembled children _, operand = expr.argument_slots() v, *assembled_operand = args @@ -559,6 +558,8 @@ def base_form_assembly_visitor(self, expr, tensor, *args): # 4) Interpolate(Argument(V1, 0), Cofunction(...)) -> Action of the Jacobian adjoint # This can be generalized to the case where the first slot is an arbitray expression. rank = len(expr.arguments()) + if rank > 2: + raise ValueError("Cannot assemble an Interpolate with more than two arguments") # If argument numbers have been swapped => Adjoint. arg_operand = ufl.algorithms.extract_arguments(operand) is_adjoint = (arg_operand and arg_operand[0].number() == 0) @@ -589,60 +590,23 @@ def base_form_assembly_visitor(self, expr, tensor, *args): assemble(sub_interp, tensor=tensor.subfunctions[i]) return tensor - # Matrix-free adjoint interpolation is only implemented by SameMeshInterpolator - # so we need assemble the interpolator matrix if the meshes are different - target_mesh = V.mesh() - source_mesh = extract_unique_domain(operand) or target_mesh - if is_adjoint and rank < 2 and (source_mesh is not target_mesh): - expr = reconstruct_interp(operand, v=V) - matfree = (rank == len(expr.arguments())) and (rank < 2) - # Get the interpolator interp_data = expr.interp_data.copy() default_missing_val = interp_data.pop('default_missing_val', None) - if matfree and ((is_adjoint and rank == 1) or rank == 0): + + target_mesh = V.mesh() + source_mesh = extract_unique_domain(operand) or target_mesh + if (source_mesh is target_mesh) and ((is_adjoint and rank == 1) or rank == 0): # Adjoint interpolation of a Cofunction or the action of a # Cofunction on an interpolated Function require INC access # on the output tensor interp_data["access"] = op2.INC - if rank == 1 and matfree and isinstance(tensor, firedrake.Function): + if rank == 1 and isinstance(tensor, firedrake.Function): V = tensor interpolator = firedrake.Interpolator(expr, V, **interp_data) - # Assembly - if matfree: - # Assembling the operator - return interpolator._interpolate(output=tensor, default_missing_val=default_missing_val) - elif rank == 0: - # Assembling the double action. - Iu = interpolator._interpolate(default_missing_val=default_missing_val) - return assemble(ufl.Action(v, Iu), tensor=tensor) - elif rank == 1: - # Assembling the action of the Jacobian adjoint. - if is_adjoint: - return interpolator._interpolate(v, output=tensor, adjoint=True, default_missing_val=default_missing_val) - # Assembling the Jacobian action. - else: - return interpolator._interpolate(operand, output=tensor, default_missing_val=default_missing_val) - elif rank == 2: - res = tensor.petscmat if tensor else PETSc.Mat() - # Get the interpolation matrix - op2_mat = interpolator.callable() - petsc_mat = op2_mat.handle - if is_adjoint and (source_mesh is not target_mesh): - # Out-of-place Hermitian transpose - petsc_mat.hermitianTranspose(out=res) - elif res: - # Copy the interpolation matrix into the output tensor - petsc_mat.copy(result=res) - else: - res = petsc_mat - if tensor is None: - tensor = self.assembled_matrix(orig_expr, res) - return tensor - else: - raise ValueError("Incompatible number of arguments.") + return interpolator.assemble(tensor=tensor, default_missing_val=default_missing_val) elif tensor and isinstance(expr, (firedrake.Function, firedrake.Cofunction, firedrake.MatrixBase)): return tensor.assign(expr) elif tensor and isinstance(expr, ufl.ZeroBaseForm): diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 81d7fce7cd..7598657dfa 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -280,8 +280,9 @@ def __init__( ): if not isinstance(expr, ufl.Interpolate): fs = V if isinstance(V, ufl.FunctionSpace) else V.function_space() - expr = firedrake.Interpolate(expr, fs) - operand, = expr.ufl_operands + expr = Interpolate(expr, fs) + dual_arg, operand = expr.argument_slots() + self.ufl_interpolate = expr self.expr = operand self.V = V self.subset = subset @@ -291,18 +292,20 @@ def __init__( self._allow_missing_dofs = allow_missing_dofs self.matfree = matfree self.callable = None - # Workaround for matrix-explicit adjoint of cross-mesh interpolation - # Return instead the forward operator + # Workaround for adjoint of cross-mesh interpolation + # Assemble the forward operator and then take its adjoint target_mesh = as_domain(V) source_mesh = extract_unique_domain(operand) or target_mesh - if len(expr.arguments()) == 2 and (target_mesh is not source_mesh): + if target_mesh is not source_mesh: + if not isinstance(dual_arg, ufl.Coargument): + expr = expr._ufl_expr_reconstruct_(operand, dual_arg.function_space().dual()) expr_args = extract_arguments(operand) if expr_args and expr_args[0].number() == 0: v0, v1 = expr.arguments() expr = ufl.replace(expr, {v0: v0.reconstruct(number=v1.number()), v1: v1.reconstruct(number=v0.number())}) self.expr_renumbered, = expr.ufl_operands - self.ufl_interpolate = expr + self.ufl_interpolate_renumbered = expr def interpolate(self, *function, transpose=None, adjoint=False, default_missing_val=None): """ @@ -326,6 +329,42 @@ def _interpolate(self, *args, **kwargs): """ pass + def assemble(self, tensor=None, default_missing_val=None): + """Assemble the operator (or its action).""" + from firedrake.assemble import assemble + renumbered = self.ufl_interpolate_renumbered != self.ufl_interpolate + arguments = self.ufl_interpolate.arguments() + if len(arguments) == 2: + # Assembling the operator + res = tensor.petscmat if tensor else PETSc.Mat() + # Get the interpolation matrix + op2mat = self.callable() + petsc_mat = op2mat.handle + if renumbered: + # Out-of-place Hermitian transpose + petsc_mat.hermitianTranspose(out=res) + elif res: + petsc_mat.copy(res) + else: + res = petsc_mat + if tensor is None: + tensor = firedrake.AssembledMatrix(arguments, self.bcs, res) + return tensor + else: + # Assembling the action + missing_args = () + if renumbered: + dual_arg, _ = self.ufl_interpolate.argument_slots() + if not isinstance(dual_arg, ufl.Coargument): + missing_args = (dual_arg,) + + if renumbered and len(arguments) == 0: + Iu = self._interpolate(default_missing_val=default_missing_val) + return assemble(ufl.Action(*missing_args, Iu), tensor=tensor) + else: + return self._interpolate(*missing_args, output=tensor, adjoint=renumbered, + default_missing_val=default_missing_val) + class DofNotDefinedError(Exception): r"""Raised when attempting to interpolate across function spaces where the @@ -393,10 +432,6 @@ def __init__( raise NotImplementedError( "Can only interpolate into spaces with point evaluation nodes." ) - if isinstance(expr, ufl.Interpolate): - dual_arg, operand = expr.argument_slots() - if not isinstance(dual_arg, ufl.Coargument): - raise NotImplementedError(f"{type(self).__name__} does not support matrix-free adjoint action.") super().__init__(expr, V, subset, freeze_expr, access, bcs, allow_missing_dofs, matfree) expr = self.expr_renumbered @@ -544,6 +579,7 @@ def _interpolate( raise ValueError( "Can currently only apply adjoint interpolation with arguments." ) + if self.nargs != len(function): raise ValueError( "Passed %d Functions to interpolate, expected %d" @@ -725,7 +761,7 @@ def __init__(self, expr, V, subset=None, freeze_expr=False, access=op2.WRITE, pass super().__init__(expr, V, subset=subset, freeze_expr=freeze_expr, access=access, bcs=bcs, matfree=matfree, allow_missing_dofs=allow_missing_dofs) - expr = self.ufl_interpolate + expr = self.ufl_interpolate_renumbered try: self.callable = make_interpolator(expr, V, subset, access, bcs=bcs, matfree=matfree) except FIAT.hdiv_trace.TraceError: @@ -801,7 +837,6 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True): if not isinstance(expr, ufl.Interpolate): raise ValueError(f"Expecting to interpolate a ufl.Interpolate, got {type(expr).__name__}.") dual_arg, operand = expr.argument_slots() - target_mesh = as_domain(dual_arg) source_mesh = extract_unique_domain(operand) or target_mesh vom_onto_other_vom = ( From f82b14059ad5005d08204b5bce5e3972e0ddf830 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Sun, 21 Sep 2025 17:30:00 +0100 Subject: [PATCH 023/125] Fix up --- firedrake/interpolation.py | 1 - firedrake/preconditioners/hiptmair.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 7598657dfa..e06ea21f5d 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -579,7 +579,6 @@ def _interpolate( raise ValueError( "Can currently only apply adjoint interpolation with arguments." ) - if self.nargs != len(function): raise ValueError( "Passed %d Functions to interpolate, expected %d" diff --git a/firedrake/preconditioners/hiptmair.py b/firedrake/preconditioners/hiptmair.py index b778bc2ad8..8d760d22d8 100644 --- a/firedrake/preconditioners/hiptmair.py +++ b/firedrake/preconditioners/hiptmair.py @@ -202,7 +202,7 @@ def coarsen(self, pc): coarse_space_bcs = tuple(coarse_space_bcs) if G_callback is None: - interp_petscmat = chop(Interpolator(dminus(test), V, bcs=bcs + coarse_space_bcs).callable().handle) + interp_petscmat = chop(Interpolator(dminus(trial), V, bcs=bcs + coarse_space_bcs).callable().handle) else: interp_petscmat = G_callback(coarse_space, V, coarse_space_bcs, bcs) From be0f779b35ad29c2ea65da91bc1ecb21988d228b Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Mon, 22 Sep 2025 16:33:01 +0100 Subject: [PATCH 024/125] Enhacements for interpolation into VOM --- firedrake/assemble.py | 9 --- firedrake/interpolation.py | 142 +++++++++++++++++-------------------- 2 files changed, 66 insertions(+), 85 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index b7534bc412..be13769378 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -593,15 +593,6 @@ def base_form_assembly_visitor(self, expr, tensor, *args): # Get the interpolator interp_data = expr.interp_data.copy() default_missing_val = interp_data.pop('default_missing_val', None) - - target_mesh = V.mesh() - source_mesh = extract_unique_domain(operand) or target_mesh - if (source_mesh is target_mesh) and ((is_adjoint and rank == 1) or rank == 0): - # Adjoint interpolation of a Cofunction or the action of a - # Cofunction on an interpolated Function require INC access - # on the output tensor - interp_data["access"] = op2.INC - if rank == 1 and isinstance(tensor, firedrake.Function): V = tensor interpolator = firedrake.Interpolator(expr, V, **interp_data) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index e06ea21f5d..ad4aeb2370 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -280,7 +280,7 @@ def __init__( ): if not isinstance(expr, ufl.Interpolate): fs = V if isinstance(V, ufl.FunctionSpace) else V.function_space() - expr = Interpolate(expr, fs) + expr = interpolate(expr, fs) dual_arg, operand = expr.argument_slots() self.ufl_interpolate = expr self.expr = operand @@ -296,7 +296,7 @@ def __init__( # Assemble the forward operator and then take its adjoint target_mesh = as_domain(V) source_mesh = extract_unique_domain(operand) or target_mesh - if target_mesh is not source_mesh: + if not ((target_mesh is source_mesh) or isinstance(target_mesh, VertexOnlyMeshTopology)): if not isinstance(dual_arg, ufl.Coargument): expr = expr._ufl_expr_reconstruct_(operand, dual_arg.function_space().dual()) expr_args = extract_arguments(operand) @@ -304,8 +304,13 @@ def __init__( v0, v1 = expr.arguments() expr = ufl.replace(expr, {v0: v0.reconstruct(number=v1.number()), v1: v1.reconstruct(number=v0.number())}) - self.expr_renumbered, = expr.ufl_operands + + dual_arg, operand = expr.argument_slots() + self.expr_renumbered = operand self.ufl_interpolate_renumbered = expr + if not isinstance(dual_arg, ufl.Coargument): + # Matrix-free assembly of 0-form or 1-form requires INC access + self.access = op2.INC def interpolate(self, *function, transpose=None, adjoint=False, default_missing_val=None): """ @@ -352,17 +357,19 @@ def assemble(self, tensor=None, default_missing_val=None): return tensor else: # Assembling the action - missing_args = () + cofunctions = () if renumbered: + # The renumbered Interpolate has dropped Cofunctions. + # We need to explicitly operate on them. dual_arg, _ = self.ufl_interpolate.argument_slots() if not isinstance(dual_arg, ufl.Coargument): - missing_args = (dual_arg,) + cofunctions = (dual_arg,) if renumbered and len(arguments) == 0: Iu = self._interpolate(default_missing_val=default_missing_val) - return assemble(ufl.Action(*missing_args, Iu), tensor=tensor) + return assemble(ufl.Action(*cofunctions, Iu), tensor=tensor) else: - return self._interpolate(*missing_args, output=tensor, adjoint=renumbered, + return self._interpolate(*cofunctions, output=tensor, adjoint=renumbered, default_missing_val=default_missing_val) @@ -762,7 +769,7 @@ def __init__(self, expr, V, subset=None, freeze_expr=False, access=op2.WRITE, access=access, bcs=bcs, matfree=matfree, allow_missing_dofs=allow_missing_dofs) expr = self.ufl_interpolate_renumbered try: - self.callable = make_interpolator(expr, V, subset, access, bcs=bcs, matfree=matfree) + self.callable = make_interpolator(expr, V, subset, self.access, bcs=bcs, matfree=matfree) except FIAT.hdiv_trace.TraceError: raise NotImplementedError("Can't interpolate onto traces sorry") self.arguments = expr.arguments() @@ -839,8 +846,8 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True): target_mesh = as_domain(dual_arg) source_mesh = extract_unique_domain(operand) or target_mesh vom_onto_other_vom = ( - isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology) - and isinstance(source_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology) + isinstance(target_mesh.topology, VertexOnlyMeshTopology) + and isinstance(source_mesh.topology, VertexOnlyMeshTopology) and target_mesh is not source_mesh ) @@ -871,35 +878,20 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True): Vcol = arguments[1].function_space() if len(Vrow) > 1 or len(Vcol) > 1: raise NotImplementedError("Interpolation of mixed expressions with arguments is not supported") - Vcol_map = Vcol.cell_node_map() - if isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology) and target_mesh is not source_mesh and not vom_onto_other_vom: - if not isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology): + if isinstance(target_mesh.topology, VertexOnlyMeshTopology) and target_mesh is not source_mesh and not vom_onto_other_vom: + if not isinstance(target_mesh.topology, VertexOnlyMeshTopology): raise NotImplementedError("Can only interpolate onto a Vertex Only Mesh") if target_mesh.geometric_dimension() != source_mesh.geometric_dimension(): raise ValueError("Cannot interpolate onto a mesh of a different geometric dimension") if not hasattr(target_mesh, "_parent_mesh") or target_mesh._parent_mesh is not source_mesh: raise ValueError("Can only interpolate across meshes where the source mesh is the parent of the target") - if Vcol_map: - # Since the par_loop is over the target mesh cells we need to - # compose a map that takes us from target mesh cells to the - # function space nodes on the source mesh. NOTE: Vcol_map is - # allowed to be None when interpolating from a Real space, even - # in the trans-mesh case. - if source_mesh.extruded: - # ExtrudedSet cannot be a map target so we need to build - # this ourselves - Vcol_map = vom_cell_parent_node_map_extruded(target_mesh, Vcol_map) - else: - Vcol_map = compose_map_and_cache(target_mesh.cell_parent_cell_map, Vcol_map) - elif vom_onto_other_vom: - Vcol_map = Vcol.cell_node_map() - else: - Vcol_map = Vcol.entity_node_map(target_mesh.topology, "cell", None, None) + if vom_onto_other_vom: # We make our own linear operator for this case using PETSc SFs tensor = None else: - Vrow_map = Vrow.cell_node_map() + Vrow_map = get_coefficient_map(source_mesh, target_mesh, Vrow) + Vcol_map = get_coefficient_map(source_mesh, target_mesh, Vcol) sparsity = op2.Sparsity((Vrow.dof_dset, Vcol.dof_dset), [(Vrow_map, Vcol_map, None)], # non-mixed name="%s_%s_sparsity" % (Vrow.name, Vcol.name), @@ -1010,9 +1002,9 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): # NOTE: The par_loop is always over the target mesh cells. target_mesh = as_domain(V) source_mesh = extract_unique_domain(operand) or target_mesh - if isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology): + if isinstance(target_mesh.topology, VertexOnlyMeshTopology): if target_mesh is not source_mesh: - if not isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology): + if not isinstance(target_mesh.topology, VertexOnlyMeshTopology): raise NotImplementedError("Can only interpolate onto a Vertex Only Mesh") if target_mesh.geometric_dimension() != source_mesh.geometric_dimension(): raise ValueError("Cannot interpolate onto a mesh of a different geometric dimension") @@ -1111,30 +1103,22 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): parloop_args.append(tensor(access)) elif isinstance(tensor, op2.Dat): V_dest = arguments[-1].function_space() if isinstance(dual_arg, ufl.Cofunction) else V - parloop_args.append(tensor(access, V_dest.cell_node_map())) + m_ = get_coefficient_map(source_mesh, target_mesh, V_dest) + parloop_args.append(tensor(access, m_)) else: assert access == op2.WRITE # Other access descriptors not done for Matrices. Vrow = arguments[0].function_space() Vcol = arguments[1].function_space() - rows_map = Vrow.cell_node_map() assert tensor.handle.getSize() == (Vrow.dim(), Vcol.dim()) - if isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology): - columns_map = Vcol.cell_node_map() - if target_mesh is not source_mesh: - # Since the par_loop is over the target mesh cells we need to - # compose a map that takes us from target mesh cells to the - # function space nodes on the source mesh. - if source_mesh.extruded: - # ExtrudedSet cannot be a map target so we need to build - # this ourselves - columns_map = vom_cell_parent_node_map_extruded(target_mesh, columns_map) - else: - columns_map = compose_map_and_cache(target_mesh.cell_parent_cell_map, - columns_map) - else: - columns_map = Vcol.entity_node_map(target_mesh.topology, "cell", None, None) + rows_map = get_coefficient_map(source_mesh, target_mesh, Vrow) + columns_map = get_coefficient_map(source_mesh, target_mesh, Vcol) + lgmaps = None if bcs: + if ufl.duals.is_dual(Vrow): + Vrow = Vrow.dual() + if ufl.duals.is_dual(Vcol): + Vcol = Vcol.dual() bc_rows = [bc for bc in bcs if bc.function_space() == Vrow] bc_cols = [bc for bc in bcs if bc.function_space() == Vcol] lgmaps = [(Vrow.local_to_global_map(bc_rows), Vcol.local_to_global_map(bc_cols))] @@ -1147,38 +1131,14 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): parloop_args.append(cs.dat(op2.READ, cs.cell_node_map())) for coefficient in coefficients: - if isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology): - coeff_mesh = extract_unique_domain(coefficient) - if coeff_mesh is target_mesh or not coeff_mesh: - # NOTE: coeff_mesh is None is allowed e.g. when interpolating from - # a Real space - m_ = coefficient.cell_node_map() - elif coeff_mesh is source_mesh: - if coefficient.cell_node_map(): - # Since the par_loop is over the target mesh cells we need to - # compose a map that takes us from target mesh cells to the - # function space nodes on the source mesh. - if source_mesh.extruded: - # ExtrudedSet cannot be a map target so we need to build - # this ourselves - m_ = vom_cell_parent_node_map_extruded(target_mesh, coefficient.cell_node_map()) - else: - m_ = compose_map_and_cache(target_mesh.cell_parent_cell_map, coefficient.cell_node_map()) - else: - # m_ is allowed to be None when interpolating from a Real space, - # even in the trans-mesh case. - m_ = coefficient.cell_node_map() - else: - raise ValueError("Have coefficient with unexpected mesh") - else: - m_ = coefficient.function_space().entity_node_map(target_mesh.topology, "cell", None, None) + m_ = get_coefficient_map(source_mesh, target_mesh, coefficient.function_space()) parloop_args.append(coefficient.dat(op2.READ, m_)) for const in extract_firedrake_constants(expr): parloop_args.append(const.dat(op2.READ)) # Finally, add the target mesh reference coordinates if they appear in the kernel - if isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology): + if isinstance(target_mesh.topology, VertexOnlyMeshTopology): if target_mesh is not source_mesh: # NOTE: TSFC will sometimes drop run-time arguments in generated # kernels if they are deemed not-necessary. @@ -1209,6 +1169,36 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): return copyin + callables + (parloop_compute_callable, ) + copyout +def get_coefficient_map(source_mesh, target_mesh, fs): + if isinstance(target_mesh.topology, VertexOnlyMeshTopology): + coeff_mesh = fs.mesh() + m_ = fs.cell_node_map() + if coeff_mesh is target_mesh or not coeff_mesh: + # NOTE: coeff_mesh is None is allowed e.g. when interpolating from + # a Real space + pass + elif coeff_mesh is source_mesh: + if m_: + # Since the par_loop is over the target mesh cells we need to + # compose a map that takes us from target mesh cells to the + # function space nodes on the source mesh. + if source_mesh.extruded: + # ExtrudedSet cannot be a map target so we need to build + # this ourselves + m_ = vom_cell_parent_node_map_extruded(target_mesh, m_) + else: + m_ = compose_map_and_cache(target_mesh.cell_parent_cell_map, m_) + else: + # m_ is allowed to be None when interpolating from a Real space, + # even in the trans-mesh case. + pass + else: + raise ValueError("Have coefficient with unexpected mesh") + else: + m_ = fs.entity_node_map(target_mesh.topology, "cell", None, None) + return m_ + + try: _expr_cachedir = os.environ["FIREDRAKE_TSFC_KERNEL_CACHE_DIR"] except KeyError: @@ -1399,7 +1389,7 @@ def vom_cell_parent_node_map_extruded(vertex_only_mesh, extruded_cell_node_map): the parent extruded mesh. """ - if not isinstance(vertex_only_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology): + if not isinstance(vertex_only_mesh.topology, VertexOnlyMeshTopology): raise TypeError("The input mesh must be a VertexOnlyMesh") cnm = extruded_cell_node_map vmx = vertex_only_mesh From 6407f4e64e015f1b2bc1684ec55f3e52858b1180 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 24 Sep 2025 15:52:18 +0100 Subject: [PATCH 025/125] SameMeshInterpolator: support matfree/explcit adjoint on Submesh --- firedrake/interpolation.py | 43 +++++------ .../submesh/test_submesh_interpolate.py | 74 ++++++++++++++++++- 2 files changed, 92 insertions(+), 25 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index ad4aeb2370..a65fc49e31 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -292,11 +292,11 @@ def __init__( self._allow_missing_dofs = allow_missing_dofs self.matfree = matfree self.callable = None - # Workaround for adjoint of cross-mesh interpolation - # Assemble the forward operator and then take its adjoint - target_mesh = as_domain(V) - source_mesh = extract_unique_domain(operand) or target_mesh - if not ((target_mesh is source_mesh) or isinstance(target_mesh, VertexOnlyMeshTopology)): + + # CrossMeshInterpolator is not yet aware of self.ufl_interpolate (which carries the dual arguments). + # Instead, we always construct the forward ufl_interpolate and externally operate on the adjoint and + # supply the cofunctions within assemble(). + if not isinstance(self, SameMeshInterpolator): if not isinstance(dual_arg, ufl.Coargument): expr = expr._ufl_expr_reconstruct_(operand, dual_arg.function_space().dual()) expr_args = extract_arguments(operand) @@ -890,8 +890,8 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True): # We make our own linear operator for this case using PETSc SFs tensor = None else: - Vrow_map = get_coefficient_map(source_mesh, target_mesh, Vrow) - Vcol_map = get_coefficient_map(source_mesh, target_mesh, Vcol) + Vrow_map = get_interp_node_map(source_mesh, target_mesh, Vrow) + Vcol_map = get_interp_node_map(source_mesh, target_mesh, Vcol) sparsity = op2.Sparsity((Vrow.dof_dset, Vcol.dof_dset), [(Vrow_map, Vcol_map, None)], # non-mixed name="%s_%s_sparsity" % (Vrow.name, Vcol.name), @@ -1047,15 +1047,15 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): if needs_weight: # Compute the reciprocal of the DOF multiplicity W = dual_arg.function_space() - shapes = (W.finat_element.space_dimension(), W.block_size) - domain = "{[i,j]: 0 <= i < %d and 0 <= j < %d}" % shapes - instructions = """ - for i, j - w[i,j] = w[i,j] + 1 - end - """ + wsize = W.finat_element.space_dimension() * W.block_size + kernel_code = f""" + void multiplicity(PetscScalar *restrict w) {{ + for (PetscInt i=0; i<{wsize}; i++) w[i] += 1; + }}""" + kernel = op2.Kernel(kernel_code, "multiplicity", requires_zeroed_output_arguments=False) weight = firedrake.Function(W) - firedrake.par_loop((domain, instructions), ufl.dx, {"w": (weight, op2.INC)}) + m_ = get_interp_node_map(source_mesh, target_mesh, W) + op2.par_loop(kernel, cell_set, weight.dat(op2.INC, m_)) with weight.dat.vec as w: w.reciprocal() @@ -1103,15 +1103,15 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): parloop_args.append(tensor(access)) elif isinstance(tensor, op2.Dat): V_dest = arguments[-1].function_space() if isinstance(dual_arg, ufl.Cofunction) else V - m_ = get_coefficient_map(source_mesh, target_mesh, V_dest) + m_ = get_interp_node_map(source_mesh, target_mesh, V_dest) parloop_args.append(tensor(access, m_)) else: assert access == op2.WRITE # Other access descriptors not done for Matrices. Vrow = arguments[0].function_space() Vcol = arguments[1].function_space() assert tensor.handle.getSize() == (Vrow.dim(), Vcol.dim()) - rows_map = get_coefficient_map(source_mesh, target_mesh, Vrow) - columns_map = get_coefficient_map(source_mesh, target_mesh, Vcol) + rows_map = get_interp_node_map(source_mesh, target_mesh, Vrow) + columns_map = get_interp_node_map(source_mesh, target_mesh, Vcol) lgmaps = None if bcs: @@ -1122,7 +1122,7 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): bc_rows = [bc for bc in bcs if bc.function_space() == Vrow] bc_cols = [bc for bc in bcs if bc.function_space() == Vcol] lgmaps = [(Vrow.local_to_global_map(bc_rows), Vcol.local_to_global_map(bc_cols))] - parloop_args.append(tensor(op2.WRITE, (rows_map, columns_map), lgmaps=lgmaps)) + parloop_args.append(tensor(access, (rows_map, columns_map), lgmaps=lgmaps)) if oriented: co = target_mesh.cell_orientations() parloop_args.append(co.dat(op2.READ, co.cell_node_map())) @@ -1131,7 +1131,7 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): parloop_args.append(cs.dat(op2.READ, cs.cell_node_map())) for coefficient in coefficients: - m_ = get_coefficient_map(source_mesh, target_mesh, coefficient.function_space()) + m_ = get_interp_node_map(source_mesh, target_mesh, coefficient.function_space()) parloop_args.append(coefficient.dat(op2.READ, m_)) for const in extract_firedrake_constants(expr): @@ -1169,7 +1169,8 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): return copyin + callables + (parloop_compute_callable, ) + copyout -def get_coefficient_map(source_mesh, target_mesh, fs): +def get_interp_node_map(source_mesh, target_mesh, fs): + """Return the cell-to-node map required by a parloop on the target_mesh.cell_set.""" if isinstance(target_mesh.topology, VertexOnlyMeshTopology): coeff_mesh = fs.mesh() m_ = fs.cell_node_map() diff --git a/tests/firedrake/submesh/test_submesh_interpolate.py b/tests/firedrake/submesh/test_submesh_interpolate.py index 0d3805974f..dd8913b8ad 100644 --- a/tests/firedrake/submesh/test_submesh_interpolate.py +++ b/tests/firedrake/submesh/test_submesh_interpolate.py @@ -27,14 +27,18 @@ def _get_expr(V): return as_vector([cos(x), sin(y)]) -def _test_submesh_interpolate_cell_cell(mesh, subdomain_cond, fe_fesub): +def make_submesh(mesh, subdomain_cond, label_value): dim = mesh.topological_dimension() - (family, degree), (family_sub, degree_sub) = fe_fesub DG0 = FunctionSpace(mesh, "DG", 0) indicator_function = Function(DG0).interpolate(subdomain_cond) - label_value = 999 mesh.mark_entities(indicator_function, label_value) - subm = Submesh(mesh, dim, label_value) + return Submesh(mesh, dim, label_value) + + +def _test_submesh_interpolate_cell_cell(mesh, subdomain_cond, fe_fesub): + (family, degree), (family_sub, degree_sub) = fe_fesub + label_value = 999 + subm = make_submesh(mesh, subdomain_cond, label_value) V = FunctionSpace(mesh, family, degree) V_ = FunctionSpace(mesh, family_sub, degree_sub) Vsub = FunctionSpace(subm, family_sub, degree_sub) @@ -268,3 +272,65 @@ def expr(m): dg2d_ = Function(DG2d).interpolate(dg3d) error = assemble(inner(dg2d_ - expr(subm), dg2d_ - expr(subm)) * dx)**0.5 assert abs(error) < 1.e-14 + + +@pytest.mark.parametrize('fe_fesub', [[("DG", 2), ("DG", 1)], + [("CG", 3), ("CG", 2)]]) +def test_submesh_interpolate_adjoint(fe_fesub): + (family, degree), (family_sub, degree_sub) = fe_fesub + + mesh = UnitSquareMesh(4, 4) + x, y = SpatialCoordinate(mesh) + subdomain_cond = conditional(LT(x, 0.5), 1, 0) + label_value = 999 + subm = make_submesh(mesh, subdomain_cond, label_value) + + V1 = FunctionSpace(subm, family_sub, degree_sub) + V2 = FunctionSpace(mesh, family, degree) + + x, y = SpatialCoordinate(V1.mesh()) + expr = x * y + u1 = Function(V1).interpolate(expr) + ustar2 = assemble(inner(1, TestFunction(V2))*dx(label_value)) + + expected = assemble(inner(1, u1)*dx(label_value)) + + # Test forward 2-form + I = assemble(interpolate(TrialFunction(V1), TestFunction(V2.dual()), allow_missing_dofs=True)) + assert I.arguments()[0].function_space() == V2.dual() + assert I.arguments()[1].function_space() == V1 + + result_forward_2 = assemble(action(ustar2, action(I, u1))) + assert np.isclose(result_forward_2, expected) + + # Test adjoint 2-form + I_adj = assemble(interpolate(TestFunction(V1), TrialFunction(V2.dual()), allow_missing_dofs=True)) + assert I_adj.arguments()[0].function_space() == V1 + assert I_adj.arguments()[1].function_space() == V2.dual() + + result_adjoint_2 = assemble(action(action(I_adj, ustar2), u1)) + assert np.isclose(result_adjoint_2, expected) + + # Test forward 1-form + Iu1 = assemble(interpolate(u1, TestFunction(V2.dual()), allow_missing_dofs=True)) + assert Iu1.function_space() == V2 + + expected_primal = assemble(action(I, u1)) + assert np.allclose(Iu1.dat.data, expected_primal.dat.data) + + result_forward_1 = assemble(action(ustar2, Iu1)) + assert np.isclose(result_forward_1, expected) + + # Test adjoint 1-form + ustar2I = assemble(interpolate(TestFunction(V1), ustar2, allow_missing_dofs=True)) + assert ustar2I.function_space() == V1.dual() + + expected_dual = assemble(action(I_adj, ustar2)) + assert np.allclose(ustar2I.dat.data, expected_dual.dat.data) + + result_adjoint_1 = assemble(action(ustar2I, u1)) + assert np.isclose(result_adjoint_1, expected) + + # Test 0-form + result_0 = assemble(interpolate(u1, ustar2, allow_missing_dofs=True)) + assert np.isclose(result_0, expected) From 449420fa8b530eaa9aeda31160ae65f2a9a99c09 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 24 Sep 2025 16:10:55 +0100 Subject: [PATCH 026/125] Update tsfc/driver.py --- tsfc/driver.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tsfc/driver.py b/tsfc/driver.py index f420300b17..89db890f24 100644 --- a/tsfc/driver.py +++ b/tsfc/driver.py @@ -224,6 +224,7 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, # Apply UFL preprocessing operand = ufl_utils.preprocess_expression(operand, complex_mode=complex_mode) + # Reconstructed Interpolate with mapped operand expression = ufl.Interpolate(operand, v) # Initialise kernel builder From f614764c542b3fb53ac3779ea12951c5f6a5e4a8 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 24 Sep 2025 16:51:55 +0100 Subject: [PATCH 027/125] Test submesh in parallel --- .../submesh/test_submesh_interpolate.py | 26 +++++++++++++------ 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/tests/firedrake/submesh/test_submesh_interpolate.py b/tests/firedrake/submesh/test_submesh_interpolate.py index dd8913b8ad..0a5c1ce2d1 100644 --- a/tests/firedrake/submesh/test_submesh_interpolate.py +++ b/tests/firedrake/submesh/test_submesh_interpolate.py @@ -274,12 +274,13 @@ def expr(m): assert abs(error) < 1.e-14 +@pytest.mark.parallel(nprocs=[1, 3]) @pytest.mark.parametrize('fe_fesub', [[("DG", 2), ("DG", 1)], [("CG", 3), ("CG", 2)]]) def test_submesh_interpolate_adjoint(fe_fesub): (family, degree), (family_sub, degree_sub) = fe_fesub - mesh = UnitSquareMesh(4, 4) + mesh = UnitSquareMesh(8, 8) x, y = SpatialCoordinate(mesh) subdomain_cond = conditional(LT(x, 0.5), 1, 0) label_value = 999 @@ -311,15 +312,24 @@ def test_submesh_interpolate_adjoint(fe_fesub): result_adjoint_2 = assemble(action(action(I_adj, ustar2), u1)) assert np.isclose(result_adjoint_2, expected) - # Test forward 1-form - Iu1 = assemble(interpolate(u1, TestFunction(V2.dual()), allow_missing_dofs=True)) - assert Iu1.function_space() == V2 + # Test forward 1-form (only in serial for now) + if V1.comm.size == 1: + # Matfree forward interpolation with Submesh currently fails in parallel. + # The ghost nodes of the parent mesh may be redistributed + # into different processes as non-ghost dofs of the submesh. + # The submesh kernel will write into ghost nodes of the parent mesh, + # but this will be ignored in the halo exchange if access=op2.WRITE. - expected_primal = assemble(action(I, u1)) - assert np.allclose(Iu1.dat.data, expected_primal.dat.data) + # See https://github.com/firedrakeproject/firedrake/issues/4483 - result_forward_1 = assemble(action(ustar2, Iu1)) - assert np.isclose(result_forward_1, expected) + Iu1 = assemble(interpolate(u1, TestFunction(V2.dual()), allow_missing_dofs=True)) + assert Iu1.function_space() == V2 + + expected_primal = assemble(action(I, u1)) + assert np.allclose(Iu1.dat.data, expected_primal.dat.data) + + result_forward_1 = assemble(action(ustar2, Iu1)) + assert np.isclose(result_forward_1, expected) # Test adjoint 1-form ustar2I = assemble(interpolate(TestFunction(V1), ustar2, allow_missing_dofs=True)) From ea997ea56c30b2ba508da00cf2dc95039afe17c7 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 24 Sep 2025 18:10:07 +0100 Subject: [PATCH 028/125] VOM onto other VOM still needs renumbering --- firedrake/interpolation.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index a65fc49e31..296041a7c0 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -296,7 +296,12 @@ def __init__( # CrossMeshInterpolator is not yet aware of self.ufl_interpolate (which carries the dual arguments). # Instead, we always construct the forward ufl_interpolate and externally operate on the adjoint and # supply the cofunctions within assemble(). - if not isinstance(self, SameMeshInterpolator): + target_mesh = as_domain(V) + source_mesh = extract_unique_domain(operand) or target_mesh + vom_onto_other_vom = ((source_mesh is not target_mesh) + and isinstance(source_mesh.topology, VertexOnlyMeshTopology) + and isinstance(target_mesh.topology, VertexOnlyMeshTopology)) + if not isinstance(self, SameMeshInterpolator) or vom_onto_other_vom: if not isinstance(dual_arg, ufl.Coargument): expr = expr._ufl_expr_reconstruct_(operand, dual_arg.function_space().dual()) expr_args = extract_arguments(operand) From eee77d69dc32b8f798329fe5b1b7735d88e6db09 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 25 Sep 2025 09:56:37 +0100 Subject: [PATCH 029/125] Clarify insane interface --- firedrake/assemble.py | 6 ------ firedrake/interpolation.py | 36 +++++++++++++++++++++++------------- 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index be13769378..bfe671bed6 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -551,12 +551,6 @@ def base_form_assembly_visitor(self, expr, tensor, *args): if (v, operand) != expr.argument_slots(): expr = reconstruct_interp(operand, v=v) - # Different assembly procedures: - # 1) Interpolate(Argument(V1, 1), Argument(V2.dual(), 0)) -> Jacobian (Interpolate matrix) - # 2) Interpolate(Coefficient(...), Argument(V2.dual(), 0)) -> Operator (or Jacobian action) - # 3) Interpolate(Argument(V1, 0), Argument(V2.dual(), 1)) -> Jacobian adjoint - # 4) Interpolate(Argument(V1, 0), Cofunction(...)) -> Action of the Jacobian adjoint - # This can be generalized to the case where the first slot is an arbitray expression. rank = len(expr.arguments()) if rank > 2: raise ValueError("Cannot assemble an Interpolate with more than two arguments") diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 296041a7c0..595126aac7 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -293,19 +293,31 @@ def __init__( self.matfree = matfree self.callable = None - # CrossMeshInterpolator is not yet aware of self.ufl_interpolate (which carries the dual arguments). - # Instead, we always construct the forward ufl_interpolate and externally operate on the adjoint and - # supply the cofunctions within assemble(). + # TODO CrossMeshInterpolator and VomOntoVomXXX are not yet aware of + # self.ufl_interpolate (which carries the dual argument). target_mesh = as_domain(V) source_mesh = extract_unique_domain(operand) or target_mesh vom_onto_other_vom = ((source_mesh is not target_mesh) and isinstance(source_mesh.topology, VertexOnlyMeshTopology) and isinstance(target_mesh.topology, VertexOnlyMeshTopology)) if not isinstance(self, SameMeshInterpolator) or vom_onto_other_vom: + # For bespoke interpolation, we currently rely on different assembly procedures: + # 1) Interpolate(Argument(V1, 1), Argument(V2.dual(), 0)) -> Forward operator (2-form) + # 2) Interpolate(Argument(V1, 0), Argument(V2.dual(), 1)) -> Adjoint operator (2-form) + # 3) Interpolate(Coefficient(V1), Argument(V2.dual(), 0)) -> Forward action (1-form) + # 4) Interpolate(Argument(V1, 0), Cofunction(V2.dual()) -> Adjoint action (1-form) + # 5) Interpolate(Coefficient(V1), Cofunction(V2.dual()) -> Double action (0-form) + + # CrossMeshInterpolator._interpolate only supports forward interpolation (cases 1 and 3). + # For case 2, we first redundantly assemble case 1 and then construct the transpose. + # For cases 4 and 5, we take the forward Interpolate that corresponds to dropping the Cofunction, + # and we separately compute the action againt the dropped Cofunction within assemble(). if not isinstance(dual_arg, ufl.Coargument): + # Drop the Cofunction expr = expr._ufl_expr_reconstruct_(operand, dual_arg.function_space().dual()) expr_args = extract_arguments(operand) if expr_args and expr_args[0].number() == 0: + # Construct the symbolic forward Interpolate v0, v1 = expr.arguments() expr = ufl.replace(expr, {v0: v0.reconstruct(number=v1.number()), v1: v1.reconstruct(number=v0.number())}) @@ -342,7 +354,7 @@ def _interpolate(self, *args, **kwargs): def assemble(self, tensor=None, default_missing_val=None): """Assemble the operator (or its action).""" from firedrake.assemble import assemble - renumbered = self.ufl_interpolate_renumbered != self.ufl_interpolate + needs_adjoint = self.ufl_interpolate_renumbered != self.ufl_interpolate arguments = self.ufl_interpolate.arguments() if len(arguments) == 2: # Assembling the operator @@ -350,7 +362,7 @@ def assemble(self, tensor=None, default_missing_val=None): # Get the interpolation matrix op2mat = self.callable() petsc_mat = op2mat.handle - if renumbered: + if needs_adjoint: # Out-of-place Hermitian transpose petsc_mat.hermitianTranspose(out=res) elif res: @@ -363,18 +375,18 @@ def assemble(self, tensor=None, default_missing_val=None): else: # Assembling the action cofunctions = () - if renumbered: + if needs_adjoint: # The renumbered Interpolate has dropped Cofunctions. # We need to explicitly operate on them. dual_arg, _ = self.ufl_interpolate.argument_slots() if not isinstance(dual_arg, ufl.Coargument): cofunctions = (dual_arg,) - if renumbered and len(arguments) == 0: + if needs_adjoint and len(arguments) == 0: Iu = self._interpolate(default_missing_val=default_missing_val) return assemble(ufl.Action(*cofunctions, Iu), tensor=tensor) else: - return self._interpolate(*cofunctions, output=tensor, adjoint=renumbered, + return self._interpolate(*cofunctions, output=tensor, adjoint=needs_adjoint, default_missing_val=default_missing_val) @@ -850,11 +862,9 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True): dual_arg, operand = expr.argument_slots() target_mesh = as_domain(dual_arg) source_mesh = extract_unique_domain(operand) or target_mesh - vom_onto_other_vom = ( - isinstance(target_mesh.topology, VertexOnlyMeshTopology) - and isinstance(source_mesh.topology, VertexOnlyMeshTopology) - and target_mesh is not source_mesh - ) + vom_onto_other_vom = ((source_mesh is not target_mesh) + and isinstance(source_mesh.topology, VertexOnlyMeshTopology) + and isinstance(target_mesh.topology, VertexOnlyMeshTopology)) arguments = expr.arguments() rank = len(arguments) From 806d60b36407ad9a10837ac1d5a94dec2a846ad0 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 25 Sep 2025 16:36:49 +0100 Subject: [PATCH 030/125] Update firedrake/interpolation.py --- firedrake/interpolation.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 595126aac7..d7147da038 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -1185,7 +1185,11 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): def get_interp_node_map(source_mesh, target_mesh, fs): - """Return the cell-to-node map required by a parloop on the target_mesh.cell_set.""" + """Return the map between cells of the target mesh and nodes of the function space. + + If the function space is defined on the source mesh then the node map is composed + with a map between target and source cells. + """ if isinstance(target_mesh.topology, VertexOnlyMeshTopology): coeff_mesh = fs.mesh() m_ = fs.cell_node_map() From 5c6e9119af36911d38be2f6146e3759a37295cbc Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 25 Sep 2025 16:47:38 +0100 Subject: [PATCH 031/125] Update firedrake/interpolation.py --- firedrake/interpolation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index d7147da038..bca2808b75 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -311,7 +311,7 @@ def __init__( # CrossMeshInterpolator._interpolate only supports forward interpolation (cases 1 and 3). # For case 2, we first redundantly assemble case 1 and then construct the transpose. # For cases 4 and 5, we take the forward Interpolate that corresponds to dropping the Cofunction, - # and we separately compute the action againt the dropped Cofunction within assemble(). + # and we separately compute the action against the dropped Cofunction within assemble(). if not isinstance(dual_arg, ufl.Coargument): # Drop the Cofunction expr = expr._ufl_expr_reconstruct_(operand, dual_arg.function_space().dual()) From e3039402a04928a73cc52ba3b54043a845ff0929 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 25 Sep 2025 17:15:13 +0100 Subject: [PATCH 032/125] Update firedrake/interpolation.py --- firedrake/interpolation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index bca2808b75..c153b1767c 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -295,6 +295,7 @@ def __init__( # TODO CrossMeshInterpolator and VomOntoVomXXX are not yet aware of # self.ufl_interpolate (which carries the dual argument). + # See github issue https://github.com/firedrakeproject/firedrake/issues/4592 target_mesh = as_domain(V) source_mesh = extract_unique_domain(operand) or target_mesh vom_onto_other_vom = ((source_mesh is not target_mesh) From cc1806340baad8516c75292540259ba945b5ee85 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 26 Sep 2025 12:08:27 +0100 Subject: [PATCH 033/125] Apply suggestions from code review Co-authored-by: Connor Ward --- firedrake/interpolation.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 3883e3d623..2cd50b7fe6 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -101,12 +101,10 @@ def __init__(self, expr, v, V = v.arguments()[0].function_space() if len(expr.ufl_shape) != len(V.value_shape): - raise RuntimeError('Rank mismatch: Expression rank %d, FunctionSpace rank %d' - % (len(expr.ufl_shape), len(V.value_shape))) + raise RuntimeError(f'Rank mismatch: Expression rank {len(expr.ufl_shape)}, FunctionSpace rank {len(V.value_shape)}') if expr.ufl_shape != V.value_shape: - raise RuntimeError('Shape mismatch: Expression shape %r, FunctionSpace shape %r' - % (expr.ufl_shape, V.value_shape)) + raise RuntimeError('Shape mismatch: Expression shape {expr.ufl_shape}, FunctionSpace shape {V.value_shape}') super().__init__(expr, v) # -- Interpolate data (e.g. `subset` or `access`) -- # From fd9a3f60b51483343c17b52b245a2e289036d5b5 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 26 Sep 2025 12:30:45 +0100 Subject: [PATCH 034/125] Update firedrake/interpolation.py --- firedrake/interpolation.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 2cd50b7fe6..6983aaa4dd 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -1008,6 +1008,9 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): if access == op2.INC: callables += (tensor.zero,) + # For the matfree adjoint 1-form and the 0-form, the cellwise kernel will add multiple + # contributions from the facet DOFs of the dual argument. + # The incoming Cofunction needs to be weighted by the reciprocal of the DOF multiplicity. needs_weight = isinstance(dual_arg, ufl.Cofunction) and not to_element.is_dg() if needs_weight: # Compute the reciprocal of the DOF multiplicity From 89f97d56fbc5f1ee279a009e0cbf175f90ec6d10 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Thu, 25 Sep 2025 15:19:03 +0100 Subject: [PATCH 035/125] add dataclass add type hint --- firedrake/interpolation.py | 60 +++++++++++++++++++++++++++++++++++++- 1 file changed, 59 insertions(+), 1 deletion(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index f42b534a78..36aa991ea7 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -4,7 +4,8 @@ import abc import warnings from functools import partial, singledispatch -from typing import Hashable +from typing import Hashable, Optional +from dataclasses import dataclass import FIAT import ufl @@ -14,6 +15,7 @@ from pyop2 import op2 from pyop2.caching import memory_and_disk_cache +from pyop2.types import Access from finat.element_factory import create_element, as_fiat_cell from tsfc import compile_expression_dual_evaluation @@ -42,6 +44,62 @@ "SameMeshInterpolator", ) +@dataclass(frozen=True) +class InterpolateOptions: + """Options for interpolation operations. + + Attributes + ---------- + subset : pyop2.types.set.Subset, optional + An optional subset to apply the interpolation over. + Cannot, at present, be used when interpolating across meshes unless + the target mesh is a :func:`.VertexOnlyMesh`. + access : pyop2.types.access.Access, default op2.WRITE + The pyop2 access descriptor for combining updates to shared + DoFs. Possible values include ``WRITE`` and ``INC``. Only ``WRITE`` is + supported at present when interpolating across meshes unless the target + mesh is a :func:`.VertexOnlyMesh`. + + .. note:: + If you use an access descriptor other than ``WRITE``, the + behaviour of interpolation changes if interpolating into a + function space, or an existing function. If the former, then + the newly allocated function will be initialised with + appropriate values (e.g. for MIN access, it will be initialised + with MAX_FLOAT). On the other hand, if you provide a function, + then it is assumed that its values should take part in the + reduction (hence using MIN will compute the MIN between the + existing values and any new values). + + allow_missing_dofs : bool, default False + For interpolation across meshes: allow degrees of freedom (aka DoFs/nodes) + in the target mesh that cannot be defined on the source mesh. + For example, where nodes are point evaluations, points in the target mesh + that are not in the source mesh. When ``False`` this raises a ``ValueError`` + should this occur. When ``True`` the corresponding values are either + (a) unchanged if some ``output`` is given to the :meth:`interpolate` method + or (b) set to zero. + Can be overwritten with the ``default_missing_val`` kwarg of :meth:`interpolate`. + This does not affect adjoint interpolation. Ignored if interpolating within + the same mesh or onto a :func:`.VertexOnlyMesh` (the behaviour of a + :func:`.VertexOnlyMesh` in this scenario is, at present, set when it is created). + default_missing_val : float, optional + For interpolation across meshes: the optional value to assign to DoFs + in the target mesh that are outside the source mesh. If this is not set + then the values are either (a) unchanged if some ``output`` is given to + the :meth:`interpolate` method or (b) set to zero. + Ignored if interpolating within the same mesh or onto a :func:`.VertexOnlyMesh`. + matfree : bool, default True + If ``False``, then construct the permutation matrix for interpolating + between a VOM and its input ordering. Defaults to ``True`` which uses SF broadcast + and reduce operations. + """ + subset: Optional[op2.Subset] = None + access: Access = op2.WRITE + allow_missing_dofs: bool = False + default_missing_val: Optional[float] = None + matfree: bool = True + class Interpolate(ufl.Interpolate): From 613f630b750c274cf1af2bf0f49302c7d5d5b324 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Thu, 25 Sep 2025 15:24:02 +0100 Subject: [PATCH 036/125] use kwargs; dataclass --- firedrake/interpolation.py | 27 ++++++--------------------- 1 file changed, 6 insertions(+), 21 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 36aa991ea7..85387b2308 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -5,7 +5,7 @@ import warnings from functools import partial, singledispatch from typing import Hashable, Optional -from dataclasses import dataclass +from dataclasses import asdict, dataclass import FIAT import ufl @@ -103,12 +103,7 @@ class InterpolateOptions: class Interpolate(ufl.Interpolate): - def __init__(self, expr, v, - subset=None, - access=op2.WRITE, - allow_missing_dofs=False, - default_missing_val=None, - matfree=True): + def __init__(self, expr, v, **kwargs): """Symbolic representation of the interpolation operator. Parameters @@ -167,12 +162,8 @@ def __init__(self, expr, v, % (expr.ufl_shape, V.value_shape)) super().__init__(expr, v) - # -- Interpolate data (e.g. `subset` or `access`) -- # - self.interp_data = {"subset": subset, - "access": access, - "allow_missing_dofs": allow_missing_dofs, - "default_missing_val": default_missing_val, - "matfree": matfree} + self._options = InterpolateOptions(**kwargs) + self.interp_data = asdict(self._options) function_space = ufl.Interpolate.ufl_function_space @@ -182,7 +173,7 @@ def _ufl_expr_reconstruct_(self, expr, v=None, **interp_data): @PETSc.Log.EventDecorator() -def interpolate(expr, V, subset=None, access=op2.WRITE, allow_missing_dofs=False, default_missing_val=None, matfree=True): +def interpolate(expr, V, **kwargs): """Returns a UFL expression for the interpolation operation of ``expr`` into ``V``. :arg expr: a UFL expression. @@ -250,13 +241,7 @@ def interpolate(expr, V, subset=None, access=op2.WRITE, allow_missing_dofs=False else: raise TypeError(f"V must be a FunctionSpace, Cofunction, Coargument or one-form, not a {type(V).__name__}") - interp = Interpolate(expr, dual_arg, - subset=subset, access=access, - allow_missing_dofs=allow_missing_dofs, - default_missing_val=default_missing_val, - matfree=matfree) - - return interp + return Interpolate(expr, dual_arg, **kwargs) class Interpolator(abc.ABC): From 0ee1e2efd18a03aa829357e3b27f9932dc60ab66 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Thu, 25 Sep 2025 15:25:10 +0100 Subject: [PATCH 037/125] `interpolate` docstring; simplify function fix --- firedrake/interpolation.py | 79 ++++++++------------------------------ 1 file changed, 15 insertions(+), 64 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 85387b2308..4b9a537033 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -176,72 +176,23 @@ def _ufl_expr_reconstruct_(self, expr, v=None, **interp_data): def interpolate(expr, V, **kwargs): """Returns a UFL expression for the interpolation operation of ``expr`` into ``V``. - :arg expr: a UFL expression. - :arg V: a :class:`.FunctionSpace` to interpolate into, or a :class:`.Cofunction`, - or :class:`.Coargument`, or a :class:`ufl.form.Form` with one argument (a one-form). - If a :class:`.Cofunction` or a one-form is provided, then we do adjoint interpolation. - :kwarg subset: An optional :class:`pyop2.types.set.Subset` to apply the - interpolation over. Cannot, at present, be used when interpolating - across meshes unless the target mesh is a :func:`.VertexOnlyMesh`. - :kwarg access: The pyop2 access descriptor for combining updates to shared - DoFs. Possible values include ``WRITE`` and ``INC``. Only ``WRITE`` is - supported at present when interpolating across meshes unless the target - mesh is a :func:`.VertexOnlyMesh`. See note below. - :kwarg allow_missing_dofs: For interpolation across meshes: allow - degrees of freedom (aka DoFs/nodes) in the target mesh that cannot be - defined on the source mesh. For example, where nodes are point - evaluations, points in the target mesh that are not in the source mesh. - When ``False`` this raises a ``ValueError`` should this occur. When - ``True`` the corresponding values are either (a) unchanged if - some ``output`` is given to the :meth:`interpolate` method or (b) set - to zero. In either case, if ``default_missing_val`` is specified, that - value is used. This does not affect adjoint interpolation. Ignored if - interpolating within the same mesh or onto a :func:`.VertexOnlyMesh` - (the behaviour of a :func:`.VertexOnlyMesh` in this scenario is, at - present, set when it is created). - :kwarg default_missing_val: For interpolation across meshes: the optional - value to assign to DoFs in the target mesh that are outside the source - mesh. If this is not set then the values are either (a) unchanged if - some ``output`` is given to the :meth:`interpolate` method or (b) set - to zero. Ignored if interpolating within the same mesh or onto a - :func:`.VertexOnlyMesh`. - :kwarg matfree: If ``False``, then construct the permutation matrix for interpolating - between a VOM and its input ordering. Defaults to ``True`` which uses SF broadcast - and reduce operations. - :returns: A symbolic :class:`.Interpolate` object - - .. note:: + Parameters + ---------- + expr : ufl.core.expr.Expr + The UFL expression to interpolate. + V : firedrake.functionspaceimpl.WithGeometry or ufl.BaseForm + The function space to interpolate into or the coargument defined + on the dual of the function space to interpolate into. + **kwargs + Additional interpolation options. See :class:`InterpolateOptions` + for available parameters and their descriptions. - If you use an access descriptor other than ``WRITE``, the - behaviour of interpolation changes if interpolating into a - function space, or an existing function. If the former, then - the newly allocated function will be initialised with - appropriate values (e.g. for MIN access, it will be initialised - with MAX_FLOAT). On the other hand, if you provide a function, - then it is assumed that its values should take part in the - reduction (hence using MIN will compute the MIN between the - existing values and any new values). + Returns + ------- + Interpolate + A symbolic :class:`Interpolate` object representing the interpolation operation. """ - if isinstance(V, (Cofunction, Coargument)): - dual_arg = V - elif isinstance(V, (ufl.Form, ufl.BaseForm)): - rank = len(V.arguments()) - if rank == 1: - dual_arg = V - else: - raise TypeError(f"Expected a one-form, provided form had {rank} arguments") - elif isinstance(V, functionspaceimpl.WithGeometry): - dual_arg = Coargument(V.dual(), 0) - expr_args = extract_arguments(ufl.as_ufl(expr)) - if expr_args and expr_args[0].number() == 0: - warnings.warn("Passing argument numbered 0 in expression for forward interpolation is deprecated. " - "Use a TrialFunction in the expression.") - v, = expr_args - expr = replace(expr, {v: v.reconstruct(number=1)}) - else: - raise TypeError(f"V must be a FunctionSpace, Cofunction, Coargument or one-form, not a {type(V).__name__}") - - return Interpolate(expr, dual_arg, **kwargs) + return Interpolate(expr, V, **kwargs) class Interpolator(abc.ABC): From be3b6b65a8baf0c81262b18939ef9eb83f03ab0f Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Thu, 25 Sep 2025 15:32:36 +0100 Subject: [PATCH 038/125] simplify `Interpolate` --- firedrake/interpolation.py | 66 +++++++++----------------------------- 1 file changed, 15 insertions(+), 51 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 4b9a537033..7824150f6b 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -25,12 +25,13 @@ import finat import firedrake -from firedrake import tsfc_interface, utils, functionspaceimpl +from firedrake import tsfc_interface, utils from firedrake.ufl_expr import Argument, Coargument, action, adjoint as expr_adjoint from firedrake.mesh import MissingPointsBehaviour, VertexOnlyMeshMissingPointsError, VertexOnlyMeshTopology from firedrake.petsc import PETSc from firedrake.halo import _get_mtype as get_dat_mpi_type from firedrake.cofunction import Cofunction +from firedrake.functionspaceimpl import WithGeometry from mpi4py import MPI from pyadjoint import stop_annotating, no_annotations @@ -103,67 +104,30 @@ class InterpolateOptions: class Interpolate(ufl.Interpolate): - def __init__(self, expr, v, **kwargs): + def __init__(self, expr, V, **kwargs): """Symbolic representation of the interpolation operator. Parameters ---------- - expr : ufl.core.expr.Expr or ufl.BaseForm + expr : ufl.core.expr.Expr The UFL expression to interpolate. - v : firedrake.functionspaceimpl.WithGeometryBase or firedrake.ufl_expr.Coargument + V : firedrake.functionspaceimpl.WithGeometry or ufl.BaseForm The function space to interpolate into or the coargument defined on the dual of the function space to interpolate into. - subset : pyop2.types.set.Subset - An optional subset to apply the interpolation over. - Cannot, at present, be used when interpolating across meshes unless - the target mesh is a :func:`.VertexOnlyMesh`. - access : pyop2.types.access.Access - The pyop2 access descriptor for combining updates to shared - DoFs. Possible values include ``WRITE`` and ``INC``. Only ``WRITE`` is - supported at present when interpolating across meshes. See note in - :func:`.interpolate` if changing this from default. - allow_missing_dofs : bool - For interpolation across meshes: allow degrees of freedom (aka DoFs/nodes) - in the target mesh that cannot be defined on the source mesh. - For example, where nodes are point evaluations, points in the target mesh - that are not in the source mesh. When ``False`` this raises a ``ValueError`` - should this occur. When ``True`` the corresponding values are either - (a) unchanged if some ``output`` is given to the :meth:`interpolate` method - or (b) set to zero. - Can be overwritten with the ``default_missing_val`` kwarg of :meth:`interpolate`. - This does not affect adjoint interpolation. Ignored if interpolating within - the same mesh or onto a :func:`.VertexOnlyMesh` (the behaviour of a - :func:`.VertexOnlyMesh` in this scenario is, at present, set when it is created). - default_missing_val : float - For interpolation across meshes: the optional value to assign to DoFs - in the target mesh that are outside the source mesh. If this is not set - then the values are either (a) unchanged if some ``output`` is given to - the :meth:`interpolate` method or (b) set to zero. - Ignored if interpolating within the same mesh or onto a :func:`.VertexOnlyMesh`. - matfree : bool - If ``False``, then construct the permutation matrix for interpolating - between a VOM and its input ordering. Defaults to ``True`` which uses SF broadcast - and reduce operations. + **kwargs + Additional interpolation options. See :class:`InterpolateOptions` + for available parameters and their descriptions. """ - # Check function space - expr = ufl.as_ufl(expr) - if isinstance(v, functionspaceimpl.WithGeometry): - expr_args = extract_arguments(expr) + # TODO: should we allow RHS to be FiredrakeDualSpace? + if isinstance(V, WithGeometry): + # Need to create a Firedrake Coargument so it has a .function_space() method + expr_args = extract_arguments(ufl.as_ufl(expr)) is_adjoint = len(expr_args) and expr_args[0].number() == 0 - v = Argument(v.dual(), 1 if is_adjoint else 0) - - V = v.arguments()[0].function_space() - if len(expr.ufl_shape) != len(V.value_shape): - raise RuntimeError('Rank mismatch: Expression rank %d, FunctionSpace rank %d' - % (len(expr.ufl_shape), len(V.value_shape))) - - if expr.ufl_shape != V.value_shape: - raise RuntimeError('Shape mismatch: Expression shape %r, FunctionSpace shape %r' - % (expr.ufl_shape, V.value_shape)) - super().__init__(expr, v) + V = Argument(V.dual(), 1 if is_adjoint else 0) + super().__init__(expr, V) self._options = InterpolateOptions(**kwargs) - self.interp_data = asdict(self._options) + self.interp_data = asdict(self._options) # TODO: remove this function_space = ufl.Interpolate.ufl_function_space From dd6770ae12a876df1507dcbe394dd17e9db5e8b2 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Thu, 25 Sep 2025 15:42:29 +0100 Subject: [PATCH 039/125] add _get_interpolator function --- firedrake/assemble.py | 3 ++- firedrake/interpolation.py | 40 +++++++++++++++++++++----------------- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index e90ee06445..9f2173ca18 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -24,6 +24,7 @@ from firedrake.bcs import DirichletBC, EquationBC, EquationBCSplit from firedrake.functionspaceimpl import WithGeometry, FunctionSpace, FiredrakeDualSpace from firedrake.functionspacedata import entity_dofs_key, entity_permutations_key +from firedrake.interpolation import _get_interpolator from firedrake.petsc import PETSc from firedrake.slate import slac, slate from firedrake.slate.slac.kernel_builder import CellFacetKernelArg, LayerCountKernelArg @@ -607,7 +608,7 @@ def base_form_assembly_visitor(self, expr, tensor, *args): default_missing_val = interp_data.pop('default_missing_val', None) if rank == 1 and isinstance(tensor, firedrake.Function): V = tensor - interpolator = firedrake.Interpolator(expr, V, **interp_data) + interpolator = _get_interpolator(expr, V, **interp_data) # Assembly return interpolator.assemble(tensor=tensor, default_missing_val=default_missing_val) elif tensor and isinstance(expr, (firedrake.Function, firedrake.Cofunction, firedrake.MatrixBase)): diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 7824150f6b..758889e982 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -134,6 +134,10 @@ def __init__(self, expr, V, **kwargs): def _ufl_expr_reconstruct_(self, expr, v=None, **interp_data): interp_data = interp_data or self.interp_data.copy() return ufl.Interpolate._ufl_expr_reconstruct_(self, expr, v=v, **interp_data) + + @property + def target_space(self): + return self.argument_slots()[0].function_space().dual() @PETSc.Log.EventDecorator() @@ -206,24 +210,6 @@ class Interpolator(abc.ABC): :class:`Interpolator` is also collected). """ - - def __new__(cls, expr, V, **kwargs): - if isinstance(expr, ufl.Interpolate): - expr, = expr.ufl_operands - target_mesh = as_domain(V) - source_mesh = extract_unique_domain(expr) or target_mesh - submesh_interp_implemented = \ - all(isinstance(m.topology, firedrake.mesh.MeshTopology) for m in [target_mesh, source_mesh]) and \ - target_mesh.submesh_ancesters[-1] is source_mesh.submesh_ancesters[-1] and \ - target_mesh.topological_dimension() == source_mesh.topological_dimension() - if target_mesh is source_mesh or submesh_interp_implemented: - return object.__new__(SameMeshInterpolator) - else: - if isinstance(target_mesh.topology, VertexOnlyMeshTopology): - return object.__new__(SameMeshInterpolator) - else: - return object.__new__(CrossMeshInterpolator) - def __init__( self, expr, @@ -348,6 +334,24 @@ def assemble(self, tensor=None, default_missing_val=None): default_missing_val=default_missing_val) +def _get_interpolator(expr, V, **kwargs) -> Interpolator: + if isinstance(expr, ufl.Interpolate): + expr, = expr.ufl_operands + target_mesh = as_domain(V) + source_mesh = extract_unique_domain(expr) or target_mesh + submesh_interp_implemented = \ + all(isinstance(m.topology, firedrake.mesh.MeshTopology) for m in [target_mesh, source_mesh]) and \ + target_mesh.submesh_ancesters[-1] is source_mesh.submesh_ancesters[-1] and \ + target_mesh.topological_dimension() == source_mesh.topological_dimension() + if target_mesh is source_mesh or submesh_interp_implemented: + return SameMeshInterpolator(expr, V, **kwargs) + else: + if isinstance(target_mesh.topology, VertexOnlyMeshTopology): + return SameMeshInterpolator(expr, V, **kwargs) + else: + return CrossMeshInterpolator(expr, V, **kwargs) + + class DofNotDefinedError(Exception): r"""Raised when attempting to interpolate across function spaces where the target function space contains degrees of freedom (i.e. nodes) which cannot From 21dea34d17397c8b405e3a0bd6d3a5e86b825c97 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Thu, 25 Sep 2025 17:27:46 +0100 Subject: [PATCH 040/125] fix `_get_interpolator` --- firedrake/assemble.py | 2 -- firedrake/interpolation.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 9f2173ca18..4acd3aaaa3 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -603,13 +603,11 @@ def base_form_assembly_visitor(self, expr, tensor, *args): assemble(sub_interp, tensor=tensor.subfunctions[i]) return tensor - # Get the interpolator interp_data = expr.interp_data.copy() default_missing_val = interp_data.pop('default_missing_val', None) if rank == 1 and isinstance(tensor, firedrake.Function): V = tensor interpolator = _get_interpolator(expr, V, **interp_data) - # Assembly return interpolator.assemble(tensor=tensor, default_missing_val=default_missing_val) elif tensor and isinstance(expr, (firedrake.Function, firedrake.Cofunction, firedrake.MatrixBase)): return tensor.assign(expr) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 758889e982..ca26bdc9e2 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -335,8 +335,6 @@ def assemble(self, tensor=None, default_missing_val=None): def _get_interpolator(expr, V, **kwargs) -> Interpolator: - if isinstance(expr, ufl.Interpolate): - expr, = expr.ufl_operands target_mesh = as_domain(V) source_mesh = extract_unique_domain(expr) or target_mesh submesh_interp_implemented = \ From 8ed21514546af560040ab8b43f3333b8b6bc5f58 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Thu, 25 Sep 2025 17:39:35 +0100 Subject: [PATCH 041/125] fixes for `test_interpolate_cross_mesh` --- tests/firedrake/regression/test_interpolate_cross_mesh.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/firedrake/regression/test_interpolate_cross_mesh.py b/tests/firedrake/regression/test_interpolate_cross_mesh.py index 52c1746d74..84ec1dc2d8 100644 --- a/tests/firedrake/regression/test_interpolate_cross_mesh.py +++ b/tests/firedrake/regression/test_interpolate_cross_mesh.py @@ -541,7 +541,7 @@ def test_missing_dofs(): V_src = FunctionSpace(m_src, "CG", 2) V_dest = FunctionSpace(m_dest, "CG", 3) with pytest.raises(DofNotDefinedError): - Interpolator(TestFunction(V_src), V_dest) + assemble(interpolate(TrialFunction(V_src), V_dest)) f_src = Function(V_src).interpolate(expr) f_dest = assemble(interpolate(f_src, V_dest, allow_missing_dofs=True)) dest_eval = PointEvaluator(m_dest, coords) From c18037442425a5dc9920405970e580d265455249 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Thu, 25 Sep 2025 18:07:00 +0100 Subject: [PATCH 042/125] remove parameters from Interpolator --- firedrake/assemble.py | 2 +- firedrake/interpolation.py | 141 ++++++++++--------------------------- 2 files changed, 37 insertions(+), 106 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 4acd3aaaa3..f2c718c47d 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -607,7 +607,7 @@ def base_form_assembly_visitor(self, expr, tensor, *args): default_missing_val = interp_data.pop('default_missing_val', None) if rank == 1 and isinstance(tensor, firedrake.Function): V = tensor - interpolator = _get_interpolator(expr, V, **interp_data) + interpolator = _get_interpolator(expr, V) return interpolator.assemble(tensor=tensor, default_missing_val=default_missing_val) elif tensor and isinstance(expr, (firedrake.Function, firedrake.Cofunction, firedrake.MatrixBase)): return tensor.assign(expr) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index ca26bdc9e2..4bdb9ca8f9 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -138,6 +138,10 @@ def _ufl_expr_reconstruct_(self, expr, v=None, **interp_data): @property def target_space(self): return self.argument_slots()[0].function_space().dual() + + @property + def options(self): + return self._options @PETSc.Log.EventDecorator() @@ -164,76 +168,29 @@ def interpolate(expr, V, **kwargs): class Interpolator(abc.ABC): - """A reusable interpolation object. - - :arg expr: The expression to interpolate. - :arg V: The :class:`.FunctionSpace` or :class:`.Function` to - interpolate into. - :kwarg subset: An optional :class:`pyop2.types.set.Subset` to apply the - interpolation over. Cannot, at present, be used when interpolating - across meshes unless the target mesh is a :func:`.VertexOnlyMesh`. - :kwarg freeze_expr: Set to True to prevent the expression being - re-evaluated on each call. Cannot, at present, be used when - interpolating across meshes unless the target mesh is a - :func:`.VertexOnlyMesh`. - :kwarg access: The pyop2 access descriptor for combining updates to shared - DoFs. Possible values include ``WRITE`` and ``INC``. Only ``WRITE`` is - supported at present when interpolating across meshes. See note in - :func:`.interpolate` if changing this from default. - :kwarg bcs: An optional list of boundary conditions to zero-out in the - output function space. Interpolator rows or columns which are - associated with boundary condition nodes are zeroed out when this is - specified. - :kwarg allow_missing_dofs: For interpolation across meshes: allow - degrees of freedom (aka DoFs/nodes) in the target mesh that cannot be - defined on the source mesh. For example, where nodes are point - evaluations, points in the target mesh that are not in the source mesh. - When ``False`` this raises a ``ValueError`` should this occur. When - ``True`` the corresponding values are either (a) unchanged if - some ``output`` is given to the :meth:`interpolate` method or (b) set - to zero. Can be overwritten with the ``default_missing_val`` kwarg - of :meth:`interpolate`. This does not affect adjoint interpolation. - Ignored if interpolating within the same mesh or onto a - :func:`.VertexOnlyMesh` (the behaviour of a :func:`.VertexOnlyMesh` in - this scenario is, at present, set when it is created). - :kwarg matfree: If ``False``, then construct the permutation matrix for interpolating - between a VOM and its input ordering. Defaults to ``True`` which uses SF broadcast - and reduce operations. - - This object can be used to carry out the same interpolation - multiple times (for example in a timestepping loop). - .. note:: + def __init__(self, expr: Interpolate, V, bcs=None): + """Initialise Interpolator. - The :class:`Interpolator` holds a reference to the provided - arguments (such that they won't be collected until the - :class:`Interpolator` is also collected). - - """ - def __init__( - self, - expr, - V, - subset=None, - freeze_expr=False, - access=op2.WRITE, - bcs=None, - allow_missing_dofs=False, - matfree=True - ): - if not isinstance(expr, ufl.Interpolate): - fs = V if isinstance(V, ufl.FunctionSpace) else V.function_space() - expr = interpolate(expr, fs) + Parameters + ---------- + expr : Interpolate + The symbolic interpolation expression. + V : FunctionSpace or Function to interpolate into. + _description_ + bcs : list, optional + List of boundary conditions to zero-out in the output function space. By default None. + """ dual_arg, operand = expr.argument_slots() self.ufl_interpolate = expr self.expr = operand self.V = V - self.subset = subset - self.freeze_expr = freeze_expr - self.access = access + self.subset = expr.options.subset + self.freeze_expr = False + self.access = expr.options.access self.bcs = bcs - self._allow_missing_dofs = allow_missing_dofs - self.matfree = matfree + self._allow_missing_dofs = expr.options.allow_missing_dofs + self.matfree = expr.options.matfree self.callable = None # TODO CrossMeshInterpolator and VomOntoVomXXX are not yet aware of @@ -273,18 +230,6 @@ def __init__( # Matrix-free assembly of 0-form or 1-form requires INC access self.access = op2.INC - def interpolate(self, *function, transpose=None, adjoint=False, default_missing_val=None): - """ - .. warning:: - - This method has been removed. Use the function :func:`interpolate` to return a symbolic - :class:`Interpolate` object. - """ - raise FutureWarning( - "The 'interpolate' method on `Interpolator` objects has been " - "removed. Use the `interpolate` function instead." - ) - @abc.abstractmethod def _interpolate(self, *args, **kwargs): """ @@ -334,7 +279,7 @@ def assemble(self, tensor=None, default_missing_val=None): default_missing_val=default_missing_val) -def _get_interpolator(expr, V, **kwargs) -> Interpolator: +def _get_interpolator(expr: Interpolate, V) -> Interpolator: target_mesh = as_domain(V) source_mesh = extract_unique_domain(expr) or target_mesh submesh_interp_implemented = \ @@ -342,12 +287,12 @@ def _get_interpolator(expr, V, **kwargs) -> Interpolator: target_mesh.submesh_ancesters[-1] is source_mesh.submesh_ancesters[-1] and \ target_mesh.topological_dimension() == source_mesh.topological_dimension() if target_mesh is source_mesh or submesh_interp_implemented: - return SameMeshInterpolator(expr, V, **kwargs) + return SameMeshInterpolator(expr, V) else: if isinstance(target_mesh.topology, VertexOnlyMeshTopology): - return SameMeshInterpolator(expr, V, **kwargs) + return SameMeshInterpolator(expr, V) else: - return CrossMeshInterpolator(expr, V, **kwargs) + return CrossMeshInterpolator(expr, V) class DofNotDefinedError(Exception): @@ -387,26 +332,17 @@ class CrossMeshInterpolator(Interpolator): """ @no_annotations - def __init__( - self, - expr, - V, - subset=None, - freeze_expr=False, - access=op2.WRITE, - bcs=None, - allow_missing_dofs=False, - matfree=True - ): - if subset: + def __init__(self, expr, V, bcs=None): + super().__init__(expr, V, bcs) + if self.subset: raise NotImplementedError("subset not implemented") - if freeze_expr: + if self.freeze_expr: # Probably just need to pass freeze_expr to the various # interpolators for this to work. raise NotImplementedError("freeze_expr not implemented") - if access != op2.WRITE: + if self.access != op2.WRITE: raise NotImplementedError("access other than op2.WRITE not implemented") - if bcs: + if self.bcs: raise NotImplementedError("bcs not implemented") if V.ufl_element().mapping() != "identity": # Identity mapping between reference cell and physical coordinates @@ -416,7 +352,6 @@ def __init__( raise NotImplementedError( "Can only interpolate into spaces with point evaluation nodes." ) - super().__init__(expr, V, subset, freeze_expr, access, bcs, allow_missing_dofs, matfree) expr = self.expr_renumbered self.arguments = extract_arguments(expr) @@ -490,10 +425,7 @@ def __init__( expr_subfunctions, V_dest.subspaces ): self.sub_interpolators.append( - interpolate( - input_sub_func, target_subspace, subset=subset, - access=access, allow_missing_dofs=allow_missing_dofs - ) + interpolate(input_sub_func, target_subspace, **self.interp_data) ) return @@ -717,8 +649,9 @@ class SameMeshInterpolator(Interpolator): """ @no_annotations - def __init__(self, expr, V, subset=None, freeze_expr=False, access=op2.WRITE, - bcs=None, matfree=True, allow_missing_dofs=False, **kwargs): + def __init__(self, expr, V, bcs=None): + super().__init__(expr, V, bcs=bcs) + subset = self.subset if subset is None: if isinstance(expr, ufl.Interpolate): operand, = expr.ufl_operands @@ -736,17 +669,15 @@ def __init__(self, expr, V, subset=None, freeze_expr=False, access=op2.WRITE, make_subset = not indices_active.all() make_subset = target.comm.allreduce(make_subset, op=MPI.LOR) if make_subset: - if not allow_missing_dofs: + if not self._allow_missing_dofs: raise ValueError("iteration (sub)set unclear: run with `allow_missing_dofs=True`") subset = op2.Subset(target.cell_set, numpy.where(indices_active)) else: # Do not need subset as target <= source. pass - super().__init__(expr, V, subset=subset, freeze_expr=freeze_expr, - access=access, bcs=bcs, matfree=matfree, allow_missing_dofs=allow_missing_dofs) expr = self.ufl_interpolate_renumbered try: - self.callable = make_interpolator(expr, V, subset, self.access, bcs=bcs, matfree=matfree) + self.callable = make_interpolator(expr, V, subset, self.access, bcs=bcs, matfree=self.matfree) except FIAT.hdiv_trace.TraceError: raise NotImplementedError("Can't interpolate onto traces sorry") self.arguments = expr.arguments() From a2d1166686b3ccb7cfda4b5ec63caa35f4cab5f3 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Thu, 25 Sep 2025 18:48:08 +0100 Subject: [PATCH 043/125] remove `freeze_expr` and logic --- firedrake/interpolation.py | 33 ++++----------------------------- 1 file changed, 4 insertions(+), 29 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 4bdb9ca8f9..60cbc02c3e 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -115,7 +115,7 @@ def __init__(self, expr, V, **kwargs): The function space to interpolate into or the coargument defined on the dual of the function space to interpolate into. **kwargs - Additional interpolation options. See :class:`InterpolateOptions` + Additional interpolation options. See :class:`InterpolateOptions` for available parameters and their descriptions. """ # TODO: should we allow RHS to be FiredrakeDualSpace? @@ -186,7 +186,6 @@ def __init__(self, expr: Interpolate, V, bcs=None): self.expr = operand self.V = V self.subset = expr.options.subset - self.freeze_expr = False self.access = expr.options.access self.bcs = bcs self._allow_missing_dofs = expr.options.allow_missing_dofs @@ -335,11 +334,7 @@ class CrossMeshInterpolator(Interpolator): def __init__(self, expr, V, bcs=None): super().__init__(expr, V, bcs) if self.subset: - raise NotImplementedError("subset not implemented") - if self.freeze_expr: - # Probably just need to pass freeze_expr to the various - # interpolators for this to work. - raise NotImplementedError("freeze_expr not implemented") + raise NotImplementedError("Subset not implemented.") if self.access != op2.WRITE: raise NotImplementedError("access other than op2.WRITE not implemented") if self.bcs: @@ -683,28 +678,12 @@ def __init__(self, expr, V, bcs=None): self.arguments = expr.arguments() @PETSc.Log.EventDecorator() - def _interpolate(self, *function, output=None, transpose=None, adjoint=False, **kwargs): + def _interpolate(self, *function, output=None, adjoint=False, **kwargs): """Compute the interpolation. For arguments, see :class:`.Interpolator`. """ - - if transpose is not None: - warnings.warn("'transpose' argument is deprecated, use 'adjoint' instead", FutureWarning) - adjoint = transpose or adjoint - try: - assembled_interpolator = self.frozen_assembled_interpolator - copy_required = True - except AttributeError: - assembled_interpolator = self.callable() - copy_required = False # Return the original - if self.freeze_expr: - if len(self.arguments) == 2: - # Interpolation operator - self.frozen_assembled_interpolator = assembled_interpolator - else: - # Interpolation action - self.frozen_assembled_interpolator = assembled_interpolator.copy() + assembled_interpolator = self.callable() if len(self.arguments) == 2 and len(function) > 0: function, = function @@ -734,14 +713,10 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, ** output.assign(assembled_interpolator) return output if isinstance(self.V, firedrake.Function): - if copy_required: - self.V.assign(assembled_interpolator) return self.V else: if len(self.arguments) == 0: return assembled_interpolator.dat.data.item() - elif copy_required: - return assembled_interpolator.copy() else: return assembled_interpolator From 29b9d25a12ddd2428d358d6b50a1ca50cf1898dd Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Thu, 25 Sep 2025 19:41:13 +0100 Subject: [PATCH 044/125] remove parameters --- firedrake/interpolation.py | 35 ++++++++++++++++------------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 60cbc02c3e..23466ab146 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -45,7 +45,7 @@ "SameMeshInterpolator", ) -@dataclass(frozen=True) +@dataclass class InterpolateOptions: """Options for interpolation operations. @@ -185,11 +185,8 @@ def __init__(self, expr: Interpolate, V, bcs=None): self.ufl_interpolate = expr self.expr = operand self.V = V - self.subset = expr.options.subset - self.access = expr.options.access + self.options = expr.options self.bcs = bcs - self._allow_missing_dofs = expr.options.allow_missing_dofs - self.matfree = expr.options.matfree self.callable = None # TODO CrossMeshInterpolator and VomOntoVomXXX are not yet aware of @@ -227,7 +224,7 @@ def __init__(self, expr: Interpolate, V, bcs=None): self.ufl_interpolate_renumbered = expr if not isinstance(dual_arg, ufl.Coargument): # Matrix-free assembly of 0-form or 1-form requires INC access - self.access = op2.INC + self.options.access = op2.INC @abc.abstractmethod def _interpolate(self, *args, **kwargs): @@ -331,14 +328,14 @@ class CrossMeshInterpolator(Interpolator): """ @no_annotations - def __init__(self, expr, V, bcs=None): + def __init__(self, expr: Interpolate, V, bcs=None): super().__init__(expr, V, bcs) - if self.subset: - raise NotImplementedError("Subset not implemented.") - if self.access != op2.WRITE: - raise NotImplementedError("access other than op2.WRITE not implemented") + if self.options.access != op2.WRITE: + raise NotImplementedError( + "Access other than op2.WRITE not implemented for cross-mesh interpolation." + ) if self.bcs: - raise NotImplementedError("bcs not implemented") + raise NotImplementedError("bcs not implemented.") if V.ufl_element().mapping() != "identity": # Identity mapping between reference cell and physical coordinates # implies point evaluation nodes. A more general version would @@ -352,7 +349,7 @@ def __init__(self, expr, V, bcs=None): self.arguments = extract_arguments(expr) self.nargs = len(self.arguments) - if self._allow_missing_dofs: + if self.options.allow_missing_dofs: missing_points_behaviour = MissingPointsBehaviour.IGNORE else: missing_points_behaviour = MissingPointsBehaviour.ERROR @@ -420,7 +417,7 @@ def __init__(self, expr, V, bcs=None): expr_subfunctions, V_dest.subspaces ): self.sub_interpolators.append( - interpolate(input_sub_func, target_subspace, **self.interp_data) + interpolate(input_sub_func, target_subspace, **asdict(self.options)) ) return @@ -568,7 +565,7 @@ def _interpolate( f_src_at_dest_node_coords_dest_mesh_decomp.dat.data_wo[ : ] = default_missing_val - elif self._allow_missing_dofs: + elif self.options.allow_missing_dofs: # If we have allowed missing points we know we might end up # with points in the target mesh that are not in the source # mesh. However, since we haven't specified a default missing @@ -582,7 +579,7 @@ def _interpolate( assemble(interp, tensor=f_src_at_dest_node_coords_dest_mesh_decomp) # we can now confidently assign this to a function on V_dest - if self._allow_missing_dofs and default_missing_val is None: + if self.options.allow_missing_dofs and default_missing_val is None: indices = numpy.where( ~numpy.isnan(f_src_at_dest_node_coords_dest_mesh_decomp.dat.data_ro) )[0] @@ -646,7 +643,7 @@ class SameMeshInterpolator(Interpolator): @no_annotations def __init__(self, expr, V, bcs=None): super().__init__(expr, V, bcs=bcs) - subset = self.subset + subset = self.options.subset if subset is None: if isinstance(expr, ufl.Interpolate): operand, = expr.ufl_operands @@ -664,7 +661,7 @@ def __init__(self, expr, V, bcs=None): make_subset = not indices_active.all() make_subset = target.comm.allreduce(make_subset, op=MPI.LOR) if make_subset: - if not self._allow_missing_dofs: + if not self.options.allow_missing_dofs: raise ValueError("iteration (sub)set unclear: run with `allow_missing_dofs=True`") subset = op2.Subset(target.cell_set, numpy.where(indices_active)) else: @@ -672,7 +669,7 @@ def __init__(self, expr, V, bcs=None): pass expr = self.ufl_interpolate_renumbered try: - self.callable = make_interpolator(expr, V, subset, self.access, bcs=bcs, matfree=self.matfree) + self.callable = make_interpolator(expr, V, subset, self.options.access, bcs=bcs, matfree=self.options.matfree) except FIAT.hdiv_trace.TraceError: raise NotImplementedError("Can't interpolate onto traces sorry") self.arguments = expr.arguments() From 5e0853e782a10d266aae7b04a581274d1df04ba2 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Thu, 25 Sep 2025 22:13:43 +0100 Subject: [PATCH 045/125] remove interp_data dict --- firedrake/assemble.py | 4 +--- firedrake/interpolation.py | 13 ++++++++----- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index f2c718c47d..1875b2fc2d 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -603,12 +603,10 @@ def base_form_assembly_visitor(self, expr, tensor, *args): assemble(sub_interp, tensor=tensor.subfunctions[i]) return tensor - interp_data = expr.interp_data.copy() - default_missing_val = interp_data.pop('default_missing_val', None) if rank == 1 and isinstance(tensor, firedrake.Function): V = tensor interpolator = _get_interpolator(expr, V) - return interpolator.assemble(tensor=tensor, default_missing_val=default_missing_val) + return interpolator.assemble(tensor=tensor) elif tensor and isinstance(expr, (firedrake.Function, firedrake.Cofunction, firedrake.MatrixBase)): return tensor.assign(expr) elif tensor and isinstance(expr, ufl.ZeroBaseForm): diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 23466ab146..a070a5ed69 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -127,17 +127,20 @@ def __init__(self, expr, V, **kwargs): super().__init__(expr, V) self._options = InterpolateOptions(**kwargs) - self.interp_data = asdict(self._options) # TODO: remove this function_space = ufl.Interpolate.ufl_function_space def _ufl_expr_reconstruct_(self, expr, v=None, **interp_data): - interp_data = interp_data or self.interp_data.copy() + interp_data = interp_data or asdict(self.options) return ufl.Interpolate._ufl_expr_reconstruct_(self, expr, v=v, **interp_data) @property def target_space(self): return self.argument_slots()[0].function_space().dual() + + @property + def source_space(self): + return self.argument_slots()[1].function_space() @property def options(self): @@ -236,7 +239,7 @@ def _interpolate(self, *args, **kwargs): """ pass - def assemble(self, tensor=None, default_missing_val=None): + def assemble(self, tensor=None): """Assemble the operator (or its action).""" from firedrake.assemble import assemble needs_adjoint = self.ufl_interpolate_renumbered != self.ufl_interpolate @@ -268,11 +271,11 @@ def assemble(self, tensor=None, default_missing_val=None): cofunctions = (dual_arg,) if needs_adjoint and len(arguments) == 0: - Iu = self._interpolate(default_missing_val=default_missing_val) + Iu = self._interpolate(default_missing_val=self.options.default_missing_val) return assemble(ufl.Action(*cofunctions, Iu), tensor=tensor) else: return self._interpolate(*cofunctions, output=tensor, adjoint=needs_adjoint, - default_missing_val=default_missing_val) + default_missing_val=self.options.default_missing_val) def _get_interpolator(expr: Interpolate, V) -> Interpolator: From e8125e7bf122b01e98a1e90b787d6140c8393255 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Thu, 25 Sep 2025 22:17:04 +0100 Subject: [PATCH 046/125] simplify logic --- firedrake/interpolation.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index a070a5ed69..d9c2014480 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -257,9 +257,7 @@ def assemble(self, tensor=None): petsc_mat.copy(res) else: res = petsc_mat - if tensor is None: - tensor = firedrake.AssembledMatrix(arguments, self.bcs, res) - return tensor + return tensor or firedrake.AssembledMatrix(arguments, self.bcs, res) else: # Assembling the action cofunctions = () From df935f8426b784a604b10e18f9e209518c746661 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Thu, 25 Sep 2025 22:28:59 +0100 Subject: [PATCH 047/125] remove default_missing_val argument --- firedrake/interpolation.py | 26 +++++++------------------- 1 file changed, 7 insertions(+), 19 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index d9c2014480..3759170d1c 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -269,11 +269,10 @@ def assemble(self, tensor=None): cofunctions = (dual_arg,) if needs_adjoint and len(arguments) == 0: - Iu = self._interpolate(default_missing_val=self.options.default_missing_val) + Iu = self._interpolate() return assemble(ufl.Action(*cofunctions, Iu), tensor=tensor) else: - return self._interpolate(*cofunctions, output=tensor, adjoint=needs_adjoint, - default_missing_val=self.options.default_missing_val) + return self._interpolate(*cofunctions, output=tensor, adjoint=needs_adjoint) def _get_interpolator(expr: Interpolate, V) -> Interpolator: @@ -466,24 +465,13 @@ def __init__(self, expr: Interpolate, V, bcs=None): # interpolation method below. @PETSc.Log.EventDecorator() - def _interpolate( - self, - *function, - output=None, - transpose=None, - adjoint=False, - default_missing_val=None, - **kwargs, - ): + def _interpolate(self, *function, output=None, adjoint=False): """Compute the interpolation. For arguments, see :class:`.Interpolator`. """ from firedrake.assemble import assemble - if transpose is not None: - warnings.warn("'transpose' argument is deprecated, use 'adjoint' instead", FutureWarning) - adjoint = transpose or adjoint if adjoint and not self.nargs: raise ValueError( "Can currently only apply adjoint interpolation with arguments." @@ -562,10 +550,10 @@ def _interpolate( ) # We have to create the Function before interpolating so we can # set default missing values (if requested). - if default_missing_val is not None: + if self.options.default_missing_val is not None: f_src_at_dest_node_coords_dest_mesh_decomp.dat.data_wo[ : - ] = default_missing_val + ] = self.options.default_missing_val elif self.options.allow_missing_dofs: # If we have allowed missing points we know we might end up # with points in the target mesh that are not in the source @@ -580,7 +568,7 @@ def _interpolate( assemble(interp, tensor=f_src_at_dest_node_coords_dest_mesh_decomp) # we can now confidently assign this to a function on V_dest - if self.options.allow_missing_dofs and default_missing_val is None: + if self.options.allow_missing_dofs and self.options.default_missing_val is None: indices = numpy.where( ~numpy.isnan(f_src_at_dest_node_coords_dest_mesh_decomp.dat.data_ro) )[0] @@ -676,7 +664,7 @@ def __init__(self, expr, V, bcs=None): self.arguments = expr.arguments() @PETSc.Log.EventDecorator() - def _interpolate(self, *function, output=None, adjoint=False, **kwargs): + def _interpolate(self, *function, output=None, adjoint=False): """Compute the interpolation. For arguments, see :class:`.Interpolator`. From ce2659a46c710ac326bf649bc418de87614e75bc Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Thu, 25 Sep 2025 22:51:18 +0100 Subject: [PATCH 048/125] tidy --- firedrake/interpolation.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 3759170d1c..0457de9195 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -473,15 +473,9 @@ def _interpolate(self, *function, output=None, adjoint=False): from firedrake.assemble import assemble if adjoint and not self.nargs: - raise ValueError( - "Can currently only apply adjoint interpolation with arguments." - ) + raise ValueError("Can currently only apply adjoint interpolation with arguments.") if self.nargs != len(function): - raise ValueError( - "Passed %d Functions to interpolate, expected %d" - % (len(function), self.nargs) - ) - + raise ValueError(f"Passed {len(function)} Functions to interpolate, expected {self.nargs}") if self.nargs: (f_src,) = function if not hasattr(f_src, "dat"): From 9c372774577eae1d9ef4dd183056b1611f53b594 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Fri, 26 Sep 2025 11:59:32 +0100 Subject: [PATCH 049/125] refactor crossmeshinterpolator --- firedrake/interpolation.py | 180 ++++++++++++++++--------------------- 1 file changed, 78 insertions(+), 102 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 0457de9195..6a686659dc 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -25,7 +25,7 @@ import finat import firedrake -from firedrake import tsfc_interface, utils +from firedrake import tsfc_interface, utils, TrialFunction from firedrake.ufl_expr import Argument, Coargument, action, adjoint as expr_adjoint from firedrake.mesh import MissingPointsBehaviour, VertexOnlyMeshMissingPointsError, VertexOnlyMeshTopology from firedrake.petsc import PETSc @@ -241,7 +241,6 @@ def _interpolate(self, *args, **kwargs): def assemble(self, tensor=None): """Assemble the operator (or its action).""" - from firedrake.assemble import assemble needs_adjoint = self.ufl_interpolate_renumbered != self.ufl_interpolate arguments = self.ufl_interpolate.arguments() if len(arguments) == 2: @@ -270,7 +269,7 @@ def assemble(self, tensor=None): if needs_adjoint and len(arguments) == 0: Iu = self._interpolate() - return assemble(ufl.Action(*cofunctions, Iu), tensor=tensor) + return firedrake.assemble(ufl.Action(*cofunctions, Iu), tensor=tensor) else: return self._interpolate(*cofunctions, output=tensor, adjoint=needs_adjoint) @@ -350,119 +349,97 @@ def __init__(self, expr: Interpolate, V, bcs=None): self.nargs = len(self.arguments) if self.options.allow_missing_dofs: - missing_points_behaviour = MissingPointsBehaviour.IGNORE + self.missing_points_behaviour = MissingPointsBehaviour.IGNORE else: - missing_points_behaviour = MissingPointsBehaviour.ERROR + self.missing_points_behaviour = MissingPointsBehaviour.ERROR # setup - V_dest = V.function_space() if isinstance(V, firedrake.Function) else V - src_mesh = extract_unique_domain(expr) - dest_mesh = as_domain(V_dest) - src_mesh_gdim = src_mesh.geometric_dimension() - dest_mesh_gdim = dest_mesh.geometric_dimension() - if src_mesh_gdim != dest_mesh_gdim: - raise ValueError( - "geometric dimensions of source and destination meshes must match" - ) - self.src_mesh = src_mesh - self.dest_mesh = dest_mesh + self.V_dest = V.function_space() if isinstance(V, firedrake.Function) else V + self.src_mesh = extract_unique_domain(expr) + self.dest_mesh = as_domain(self.V_dest) + if self.src_mesh.geometric_dimension() != self.dest_mesh.geometric_dimension(): + raise ValueError("Geometric dimensions of source and destination meshes must match.") self.sub_interpolators = [] - - # Create a VOM at the nodes of V_dest in src_mesh. We don't include halo - # node coordinates because interpolation doesn't usually include halos. - # NOTE: it is very important to set redundant=False, otherwise the - # input ordering VOM will only contain the points on rank 0! - # QUESTION: Should any of the below have annotation turned off? - ufl_scalar_element = V_dest.ufl_element() - if isinstance(ufl_scalar_element, finat.ufl.MixedElement): - if all( - ufl_scalar_element.sub_elements[0] == e - for e in ufl_scalar_element.sub_elements - ): - # For a VectorElement or TensorElement the correct - # VectorFunctionSpace equivalent is built from the scalar - # sub-element. - ufl_scalar_element = ufl_scalar_element.sub_elements[0] - if ufl_scalar_element.reference_value_shape != (): + dest_element = self.V_dest.ufl_element() + if isinstance(dest_element, finat.ufl.MixedElement): + if isinstance(dest_element, (finat.ufl.VectorElement, finat.ufl.TensorElement)): + base_element = dest_element.sub_elements[0] + if base_element.reference_value_shape != (): raise NotImplementedError( "Can't yet cross-mesh interpolate onto function spaces made from VectorElements or TensorElements made from sub elements with value shape other than ()." ) + self.dest_element = base_element + self._symbolic_expression() else: - # Build and save an interpolator for each sub-element - # separately for MixedFunctionSpaces. NOTE: since we can't have - # expressions for MixedFunctionSpaces we know that the input - # argument ``expr`` must be a Function. V_dest can be a Function - # or a FunctionSpace, and subfunctions works for both. - if self.nargs == 1: - # Arguments don't have a subfunctions property so I have to - # make them myself. NOTE: this will not be correct when we - # start allowing interpolators created from an expression - # with arguments, as opposed to just being the argument. - expr_subfunctions = [ - firedrake.TestFunction(V_src_sub_func) - for V_src_sub_func in self.expr.function_space().subspaces - ] - elif self.nargs > 1: - raise NotImplementedError( - "Can't yet create an interpolator from an expression with multiple arguments." - ) - else: - expr_subfunctions = self.expr.subfunctions - if len(expr_subfunctions) != len(V_dest.subspaces): - raise NotImplementedError( - "Can't interpolate from a non-mixed function space into a mixed function space." - ) - for input_sub_func, target_subspace in zip( - expr_subfunctions, V_dest.subspaces - ): - self.sub_interpolators.append( - interpolate(input_sub_func, target_subspace, **asdict(self.options)) - ) - return + self._mixed_function_space() + else: + self.dest_element = dest_element + self._symbolic_expression() + def _symbolic_expression(self): from firedrake.assemble import assemble - V_dest_vec = firedrake.VectorFunctionSpace(dest_mesh, ufl_scalar_element) - f_dest_node_coords = Interpolate(dest_mesh.coordinates, V_dest_vec) - f_dest_node_coords = assemble(f_dest_node_coords) - dest_node_coords = f_dest_node_coords.dat.data_ro.reshape(-1, dest_mesh_gdim) + # Immerse coordinates of V_dest point evaluation dofs in src_mesh + V_dest_vec = firedrake.VectorFunctionSpace(self.dest_mesh, self.dest_element) + f_dest_node_coords = assemble(interpolate(self.dest_mesh.coordinates, V_dest_vec)) + dest_node_coords = f_dest_node_coords.dat.data_ro.reshape(-1, self.dest_mesh.geometric_dimension()) try: - self.vom_dest_node_coords_in_src_mesh = firedrake.VertexOnlyMesh( - src_mesh, + self.vom = firedrake.VertexOnlyMesh( + self.src_mesh, dest_node_coords, redundant=False, - missing_points_behaviour=missing_points_behaviour, + missing_points_behaviour=self.missing_points_behaviour, ) except VertexOnlyMeshMissingPointsError: - raise DofNotDefinedError(src_mesh, dest_mesh) - # vom_dest_node_coords_in_src_mesh uses the parallel decomposition of - # the global node coordinates of V_dest in the SOURCE mesh (src_mesh). - # I first point evaluate my expression at these locations, giving a - # P0DG function on the VOM. As described in the manual, this is an - # interpolation operation. - shape = V_dest.ufl_function_space().value_shape + raise DofNotDefinedError(self.src_mesh, self.dest_mesh) + + # Evaluate expr at the immersed coordinates + shape = self.V_dest.ufl_function_space().value_shape if len(shape) == 0: fs_type = firedrake.FunctionSpace elif len(shape) == 1: fs_type = partial(firedrake.VectorFunctionSpace, dim=shape[0]) else: fs_type = partial(firedrake.TensorFunctionSpace, shape=shape) - P0DG_vom = fs_type(self.vom_dest_node_coords_in_src_mesh, "DG", 0) - self.point_eval_interpolate = Interpolate(self.expr_renumbered, P0DG_vom) - # The parallel decomposition of the nodes of V_dest in the DESTINATION - # mesh (dest_mesh) is retrieved using the input_ordering attribute of the - # VOM. This again is an interpolation operation, which, under the hood - # is a PETSc SF reduce. - P0DG_vom_i_o = fs_type( - self.vom_dest_node_coords_in_src_mesh.input_ordering, "DG", 0 - ) - self.to_input_ordering_interpolate = Interpolate( - firedrake.TrialFunction(P0DG_vom), P0DG_vom_i_o - ) - # The P0DG function outputted by the above interpolation has the - # correct parallel decomposition for the nodes of V_dest in dest_mesh so - # we can safely assign the dat values. This is all done in the actual - # interpolation method below. + P0DG_vom = fs_type(self.vom, "DG", 0) + self.point_eval = interpolate(self.expr_renumbered, P0DG_vom) + + # Interpolate into the input-ordering + P0DG_vom_i_o = fs_type(self.vom.input_ordering, "DG", 0) + self.point_eval_io = interpolate(TrialFunction(P0DG_vom), P0DG_vom_i_o) + + def _mixed_function_space(self): + # Build and save an interpolator for each sub-element + # separately for MixedFunctionSpaces. NOTE: since we can't have + # expressions for MixedFunctionSpaces we know that the input + # argument ``expr`` must be a Function. V_dest can be a Function + # or a FunctionSpace, and subfunctions works for both. + if self.nargs == 1: + # Arguments don't have a subfunctions property so I have to + # make them myself. NOTE: this will not be correct when we + # start allowing interpolators created from an expression + # with arguments, as opposed to just being the argument. + expr_subfunctions = [ + firedrake.TestFunction(V_src_sub_func) + for V_src_sub_func in self.expr.function_space().subspaces + ] + elif self.nargs > 1: + raise NotImplementedError( + "Can't yet create an interpolator from an expression with multiple arguments." + ) + else: + expr_subfunctions = self.expr.subfunctions + + if len(expr_subfunctions) != len(self.V_dest.subspaces): + raise NotImplementedError( + "Can't interpolate from a non-mixed function space into a mixed function space." + ) + for input_sub_func, target_subspace in zip( + expr_subfunctions, self.V_dest.subspaces + ): + self.sub_interpolators.append( + interpolate(input_sub_func, target_subspace, **asdict(self.options)) + ) @PETSc.Log.EventDecorator() def _interpolate(self, *function, output=None, adjoint=False): @@ -471,7 +448,6 @@ def _interpolate(self, *function, output=None, adjoint=False): For arguments, see :class:`.Interpolator`. """ from firedrake.assemble import assemble - if adjoint and not self.nargs: raise ValueError("Can currently only apply adjoint interpolation with arguments.") if self.nargs != len(function): @@ -533,14 +509,14 @@ def _interpolate(self, *function, output=None, adjoint=False): # f_src is already contained in self.point_eval_interpolate assert not self.nargs f_src_at_dest_node_coords_src_mesh_decomp = ( - assemble(self.point_eval_interpolate) + assemble(self.point_eval) ) else: f_src_at_dest_node_coords_src_mesh_decomp = ( - assemble(action(self.point_eval_interpolate, f_src)) + assemble(action(self.point_eval, f_src)) ) f_src_at_dest_node_coords_dest_mesh_decomp = firedrake.Function( - self.to_input_ordering_interpolate.function_space() + self.point_eval_io.function_space() ) # We have to create the Function before interpolating so we can # set default missing values (if requested). @@ -558,7 +534,7 @@ def _interpolate(self, *function, output=None, adjoint=False): # the output function. f_src_at_dest_node_coords_dest_mesh_decomp.dat.data_wo[:] = numpy.nan - interp = action(self.to_input_ordering_interpolate, f_src_at_dest_node_coords_src_mesh_decomp) + interp = action(self.point_eval_io, f_src_at_dest_node_coords_src_mesh_decomp) assemble(interp, tensor=f_src_at_dest_node_coords_dest_mesh_decomp) # we can now confidently assign this to a function on V_dest @@ -585,7 +561,7 @@ def _interpolate(self, *function, output=None, adjoint=False): # cofunction on the input-ordering VOM (which has this parallel # decomposition and ordering) and assign the dat values. f_src_at_dest_node_coords_dest_mesh_decomp = firedrake.Cofunction( - self.to_input_ordering_interpolate.function_space().dual() + self.point_eval_io.function_space().dual() ) f_src_at_dest_node_coords_dest_mesh_decomp.dat.data_wo[ : @@ -596,7 +572,7 @@ def _interpolate(self, *function, output=None, adjoint=False): # don't have to worry about skipping over missing points here # because I'm going from the input ordering VOM to the original VOM # and all points from the input ordering VOM are in the original. - interp = action(expr_adjoint(self.to_input_ordering_interpolate), f_src_at_dest_node_coords_dest_mesh_decomp) + interp = action(expr_adjoint(self.point_eval_io), f_src_at_dest_node_coords_dest_mesh_decomp) f_src_at_src_node_coords = assemble(interp) # NOTE: if I wanted the default missing value to be applied to # adjoint interpolation I would have to do it here. However, @@ -609,7 +585,7 @@ def _interpolate(self, *function, output=None, adjoint=False): # SameMeshInterpolator.interpolate did not effect the result. For # now, I say in the docstring that it only applies to forward # interpolation. - interp = action(expr_adjoint(self.point_eval_interpolate), f_src_at_src_node_coords) + interp = action(expr_adjoint(self.point_eval), f_src_at_src_node_coords) assemble(interp, tensor=output) return output From 3146105182698e6194f3ca6accfa2b0779b8cdcd Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Fri, 26 Sep 2025 12:07:01 +0100 Subject: [PATCH 050/125] fix --- firedrake/interpolation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 6a686659dc..a8ee9474c5 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -25,7 +25,7 @@ import finat import firedrake -from firedrake import tsfc_interface, utils, TrialFunction +from firedrake import tsfc_interface, utils from firedrake.ufl_expr import Argument, Coargument, action, adjoint as expr_adjoint from firedrake.mesh import MissingPointsBehaviour, VertexOnlyMeshMissingPointsError, VertexOnlyMeshTopology from firedrake.petsc import PETSc @@ -406,7 +406,7 @@ def _symbolic_expression(self): # Interpolate into the input-ordering P0DG_vom_i_o = fs_type(self.vom.input_ordering, "DG", 0) - self.point_eval_io = interpolate(TrialFunction(P0DG_vom), P0DG_vom_i_o) + self.point_eval_io = interpolate(firedrake.TrialFunction(P0DG_vom), P0DG_vom_i_o) def _mixed_function_space(self): # Build and save an interpolator for each sub-element From f90f6fcf387ba764fb9dcc1a6e4eb6ad7c9407a1 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Fri, 26 Sep 2025 13:22:08 +0100 Subject: [PATCH 051/125] tidy --- firedrake/interpolation.py | 57 ++++++++++++++++++++------------------ 1 file changed, 30 insertions(+), 27 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index a8ee9474c5..39d903e8a8 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -344,8 +344,7 @@ def __init__(self, expr: Interpolate, V, bcs=None): "Can only interpolate into spaces with point evaluation nodes." ) - expr = self.expr_renumbered - self.arguments = extract_arguments(expr) + self.arguments = extract_arguments(self.expr_renumbered) self.nargs = len(self.arguments) if self.options.allow_missing_dofs: @@ -353,31 +352,39 @@ def __init__(self, expr: Interpolate, V, bcs=None): else: self.missing_points_behaviour = MissingPointsBehaviour.ERROR - # setup self.V_dest = V.function_space() if isinstance(V, firedrake.Function) else V - self.src_mesh = extract_unique_domain(expr) + self.src_mesh = extract_unique_domain(self.expr_renumbered) self.dest_mesh = as_domain(self.V_dest) if self.src_mesh.geometric_dimension() != self.dest_mesh.geometric_dimension(): raise ValueError("Geometric dimensions of source and destination meshes must match.") self.sub_interpolators = [] dest_element = self.V_dest.ufl_element() - if isinstance(dest_element, finat.ufl.MixedElement): - if isinstance(dest_element, (finat.ufl.VectorElement, finat.ufl.TensorElement)): - base_element = dest_element.sub_elements[0] - if base_element.reference_value_shape != (): - raise NotImplementedError( - "Can't yet cross-mesh interpolate onto function spaces made from VectorElements or TensorElements made from sub elements with value shape other than ()." - ) - self.dest_element = base_element - self._symbolic_expression() - else: - self._mixed_function_space() + if isinstance(dest_element, (finat.ufl.VectorElement, finat.ufl.TensorElement)): + # In this case all sub elements are equal + base_element = dest_element.sub_elements[0] + if base_element.reference_value_shape != (): + raise NotImplementedError( + "Can't yet cross-mesh interpolate onto function spaces made from VectorElements or TensorElements made from sub elements with value shape other than ()." + ) + self.dest_element = base_element + self._get_symbolic_expressions() + elif isinstance(dest_element, finat.ufl.MixedElement): + self._mixed_function_space() else: + # scalar fiat/finat element self.dest_element = dest_element - self._symbolic_expression() + self._get_symbolic_expressions() + + def _get_symbolic_expressions(self): + """Constructs the symbolic Interpolate expressions for cross-mesh interpolation. - def _symbolic_expression(self): + Raises + ------ + DofNotDefinedError + If some DoFs in the target function space cannot be defined + in the source function space. + """ from firedrake.assemble import assemble # Immerse coordinates of V_dest point evaluation dofs in src_mesh V_dest_vec = firedrake.VectorFunctionSpace(self.dest_mesh, self.dest_element) @@ -409,16 +416,14 @@ def _symbolic_expression(self): self.point_eval_io = interpolate(firedrake.TrialFunction(P0DG_vom), P0DG_vom_i_o) def _mixed_function_space(self): - # Build and save an interpolator for each sub-element - # separately for MixedFunctionSpaces. NOTE: since we can't have - # expressions for MixedFunctionSpaces we know that the input - # argument ``expr`` must be a Function. V_dest can be a Function - # or a FunctionSpace, and subfunctions works for both. + """Builds symbolic Interpolate expressions for each sub-element of a MixedFunctionSpace. + """ + # NOTE: since we can't have expressions for MixedFunctionSpaces + # we know that the input argument ``expr`` must be a Function. + # V_dest can be a Function or a FunctionSpace, and subfunctions works for both. if self.nargs == 1: # Arguments don't have a subfunctions property so I have to - # make them myself. NOTE: this will not be correct when we - # start allowing interpolators created from an expression - # with arguments, as opposed to just being the argument. + # make them myself. expr_subfunctions = [ firedrake.TestFunction(V_src_sub_func) for V_src_sub_func in self.expr.function_space().subspaces @@ -444,8 +449,6 @@ def _mixed_function_space(self): @PETSc.Log.EventDecorator() def _interpolate(self, *function, output=None, adjoint=False): """Compute the interpolation. - - For arguments, see :class:`.Interpolator`. """ from firedrake.assemble import assemble if adjoint and not self.nargs: From 4097207cf8d1b7d8edcaf1504536bbf97eba6398 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Fri, 26 Sep 2025 14:18:21 +0100 Subject: [PATCH 052/125] io -> input_ordering tidy --- firedrake/interpolation.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 39d903e8a8..e4cba895a1 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -2,7 +2,6 @@ import os import tempfile import abc -import warnings from functools import partial, singledispatch from typing import Hashable, Optional from dataclasses import asdict, dataclass @@ -413,7 +412,7 @@ def _get_symbolic_expressions(self): # Interpolate into the input-ordering P0DG_vom_i_o = fs_type(self.vom.input_ordering, "DG", 0) - self.point_eval_io = interpolate(firedrake.TrialFunction(P0DG_vom), P0DG_vom_i_o) + self.point_eval_input_ordering = interpolate(firedrake.TrialFunction(P0DG_vom), P0DG_vom_i_o) def _mixed_function_space(self): """Builds symbolic Interpolate expressions for each sub-element of a MixedFunctionSpace. @@ -519,7 +518,7 @@ def _interpolate(self, *function, output=None, adjoint=False): assemble(action(self.point_eval, f_src)) ) f_src_at_dest_node_coords_dest_mesh_decomp = firedrake.Function( - self.point_eval_io.function_space() + self.point_eval_input_ordering.function_space() ) # We have to create the Function before interpolating so we can # set default missing values (if requested). @@ -537,7 +536,7 @@ def _interpolate(self, *function, output=None, adjoint=False): # the output function. f_src_at_dest_node_coords_dest_mesh_decomp.dat.data_wo[:] = numpy.nan - interp = action(self.point_eval_io, f_src_at_dest_node_coords_src_mesh_decomp) + interp = action(self.point_eval_input_ordering, f_src_at_dest_node_coords_src_mesh_decomp) assemble(interp, tensor=f_src_at_dest_node_coords_dest_mesh_decomp) # we can now confidently assign this to a function on V_dest @@ -564,7 +563,7 @@ def _interpolate(self, *function, output=None, adjoint=False): # cofunction on the input-ordering VOM (which has this parallel # decomposition and ordering) and assign the dat values. f_src_at_dest_node_coords_dest_mesh_decomp = firedrake.Cofunction( - self.point_eval_io.function_space().dual() + self.point_eval_input_ordering.function_space().dual() ) f_src_at_dest_node_coords_dest_mesh_decomp.dat.data_wo[ : @@ -575,7 +574,7 @@ def _interpolate(self, *function, output=None, adjoint=False): # don't have to worry about skipping over missing points here # because I'm going from the input ordering VOM to the original VOM # and all points from the input ordering VOM are in the original. - interp = action(expr_adjoint(self.point_eval_io), f_src_at_dest_node_coords_dest_mesh_decomp) + interp = action(expr_adjoint(self.point_eval_input_ordering), f_src_at_dest_node_coords_dest_mesh_decomp) f_src_at_src_node_coords = assemble(interp) # NOTE: if I wanted the default missing value to be applied to # adjoint interpolation I would have to do it here. However, @@ -618,13 +617,13 @@ def __init__(self, expr, V, bcs=None): if all(isinstance(m, firedrake.mesh.MeshTopology) for m in [target, source]) and target is not source: composed_map, result_integral_type = source.trans_mesh_entity_map(target, "cell", "everywhere", None) if result_integral_type != "cell": - raise AssertionError("Only cell-cell interpolation supported") + raise AssertionError("Only cell-cell interpolation supported.") indices_active = composed_map.indices_active_with_halo make_subset = not indices_active.all() make_subset = target.comm.allreduce(make_subset, op=MPI.LOR) if make_subset: if not self.options.allow_missing_dofs: - raise ValueError("iteration (sub)set unclear: run with `allow_missing_dofs=True`") + raise ValueError("Iteration (sub)set unclear: run with `allow_missing_dofs=True`.") subset = op2.Subset(target.cell_set, numpy.where(indices_active)) else: # Do not need subset as target <= source. @@ -633,7 +632,7 @@ def __init__(self, expr, V, bcs=None): try: self.callable = make_interpolator(expr, V, subset, self.options.access, bcs=bcs, matfree=self.options.matfree) except FIAT.hdiv_trace.TraceError: - raise NotImplementedError("Can't interpolate onto traces sorry") + raise NotImplementedError("Can't interpolate onto traces.") self.arguments = expr.arguments() @PETSc.Log.EventDecorator() From f895423c44b7b6281a5c93d3d42dd9e21f8b2055 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Fri, 26 Sep 2025 15:05:43 +0100 Subject: [PATCH 053/125] `make_interpolator` -> `_get_callable` --- firedrake/interpolation.py | 294 +++++++++++++++++++------------------ 1 file changed, 150 insertions(+), 144 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index e4cba895a1..78aeb7b605 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -284,6 +284,8 @@ def _get_interpolator(expr: Interpolate, V) -> Interpolator: return SameMeshInterpolator(expr, V) else: if isinstance(target_mesh.topology, VertexOnlyMeshTopology): + if isinstance(source_mesh.topology, VertexOnlyMeshTopology): + return VomOntoVomInterpolator(expr, V) return SameMeshInterpolator(expr, V) else: return CrossMeshInterpolator(expr, V) @@ -628,12 +630,154 @@ def __init__(self, expr, V, bcs=None): else: # Do not need subset as target <= source. pass - expr = self.ufl_interpolate_renumbered + self.subset = subset try: - self.callable = make_interpolator(expr, V, subset, self.options.access, bcs=bcs, matfree=self.options.matfree) + self.callable = self._get_callable(V) except FIAT.hdiv_trace.TraceError: raise NotImplementedError("Can't interpolate onto traces.") - self.arguments = expr.arguments() + self.arguments = self.ufl_interpolate_renumbered.arguments() + + def _get_callable(self, V): + expr = self.ufl_interpolate_renumbered + dual_arg, operand = expr.argument_slots() + target_mesh = as_domain(dual_arg) + source_mesh = extract_unique_domain(operand) or target_mesh + vom_onto_other_vom = ((source_mesh is not target_mesh) + and isinstance(source_mesh.topology, VertexOnlyMeshTopology) + and isinstance(target_mesh.topology, VertexOnlyMeshTopology)) + + arguments = expr.arguments() + rank = len(arguments) + if rank <= 1: + if rank == 0: + R = firedrake.FunctionSpace(target_mesh, "Real", 0) + f = firedrake.Function(R, dtype=utils.ScalarType) + elif isinstance(V, firedrake.Function): + f = V + V = f.function_space() + else: + V_dest = arguments[0].function_space().dual() + f = firedrake.Function(V_dest) + if self.options.access in {firedrake.MIN, firedrake.MAX}: + finfo = numpy.finfo(f.dat.dtype) + if self.options.access == firedrake.MIN: + val = firedrake.Constant(finfo.max) + else: + val = firedrake.Constant(finfo.min) + f.assign(val) + tensor = f.dat + elif rank == 2: + if isinstance(V, firedrake.Function): + raise ValueError("Cannot interpolate an expression with an argument into a Function") + Vrow = arguments[0].function_space() + Vcol = arguments[1].function_space() + if len(Vrow) > 1 or len(Vcol) > 1: + raise NotImplementedError("Interpolation of mixed expressions with arguments is not supported") + if isinstance(target_mesh.topology, VertexOnlyMeshTopology) and target_mesh is not source_mesh and not vom_onto_other_vom: + if not isinstance(target_mesh.topology, VertexOnlyMeshTopology): + raise NotImplementedError("Can only interpolate onto a Vertex Only Mesh") + if target_mesh.geometric_dimension() != source_mesh.geometric_dimension(): + raise ValueError("Cannot interpolate onto a mesh of a different geometric dimension") + if not hasattr(target_mesh, "_parent_mesh") or target_mesh._parent_mesh is not source_mesh: + raise ValueError("Can only interpolate across meshes where the source mesh is the parent of the target") + + if vom_onto_other_vom: + # We make our own linear operator for this case using PETSc SFs + tensor = None + else: + Vrow_map = get_interp_node_map(source_mesh, target_mesh, Vrow) + Vcol_map = get_interp_node_map(source_mesh, target_mesh, Vcol) + sparsity = op2.Sparsity((Vrow.dof_dset, Vcol.dof_dset), + [(Vrow_map, Vcol_map, None)], # non-mixed + name="%s_%s_sparsity" % (Vrow.name, Vcol.name), + nest=False, + block_sparse=True) + tensor = op2.Mat(sparsity) + f = tensor + else: + raise ValueError(f"Cannot interpolate an expression with {rank} arguments") + + if vom_onto_other_vom: + wrapper = VomOntoVomWrapper(V, source_mesh, target_mesh, operand, self.options.matfree) + # NOTE: get_dat_mpi_type ensures we get the correct MPI type for the + # data, including the correct data size and dimensional information + # (so for vector function spaces in 2 dimensions we might need a + # concatenation of 2 MPI.DOUBLE types when we are in real mode) + if tensor is not None: + # Callable will do interpolation into our pre-supplied function f + # when it is called. + assert f.dat is tensor + wrapper.mpi_type, _ = get_dat_mpi_type(f.dat) + assert len(arguments) == 1 + + def callable(): + wrapper.forward_operation(f.dat) + return f + else: + assert len(arguments) == 2 + assert tensor is None + # we know we will be outputting either a function or a cofunction, + # both of which will use a dat as a data carrier. At present, the + # data type does not depend on function space dimension, so we can + # safely use the argument function space. NOTE: If this changes + # after cofunctions are fully implemented, this will need to be + # reconsidered. + temp_source_func = firedrake.Function(Vcol) + wrapper.mpi_type, _ = get_dat_mpi_type(temp_source_func.dat) + + # Leave wrapper inside a callable so we can access the handle + # property. If matfree is True, then the handle is a PETSc SF + # pretending to be a PETSc Mat. If matfree is False, then this + # will be a PETSc Mat representing the equivalent permutation + # matrix + def callable(): + return wrapper + + return callable + else: + loops = [] + if len(V) == 1: + expressions = (expr,) + else: + if (hasattr(operand, "subfunctions") and len(operand.subfunctions) == len(V) + and all(sub_op.ufl_shape == Vsub.value_shape for Vsub, sub_op in zip(V, operand.subfunctions))): + # Use subfunctions if they match the target shapes + operands = operand.subfunctions + else: + # Unflatten the expression into the shapes of the mixed components + offset = 0 + operands = [] + for Vsub in V: + if len(Vsub.value_shape) == 0: + operands.append(operand[offset]) + else: + components = [operand[offset + j] for j in range(Vsub.value_size)] + operands.append(ufl.as_tensor(numpy.reshape(components, Vsub.value_shape))) + offset += Vsub.value_size + + # Split the dual argument + if isinstance(dual_arg, Cofunction): + duals = dual_arg.subfunctions + elif isinstance(dual_arg, Coargument): + duals = [Coargument(Vsub, number=dual_arg.number()) for Vsub in dual_arg.function_space()] + else: + duals = [v for _, v in sorted(firedrake.formmanipulation.split_form(dual_arg))] + expressions = map(expr._ufl_expr_reconstruct_, operands, duals) + + # Interpolate each sub expression into each function space + for Vsub, sub_tensor, sub_expr in zip(V, tensor, expressions): + loops.extend(_interpolator(Vsub, sub_tensor, sub_expr, self.subset, arguments, self.options.access, bcs=self.bcs)) + + if self.bcs and rank == 1: + loops.extend(partial(bc.apply, f) for bc in self.bcs) + + def callable(loops, f): + for l in loops: + l() + return f + + return partial(callable, loops, f) + @PETSc.Log.EventDecorator() def _interpolate(self, *function, output=None, adjoint=False): @@ -679,148 +823,10 @@ def _interpolate(self, *function, output=None, adjoint=False): return assembled_interpolator -@PETSc.Log.EventDecorator() -def make_interpolator(expr, V, subset, access, bcs=None, matfree=True): - if not isinstance(expr, ufl.Interpolate): - raise ValueError(f"Expecting to interpolate a ufl.Interpolate, got {type(expr).__name__}.") - dual_arg, operand = expr.argument_slots() - target_mesh = as_domain(dual_arg) - source_mesh = extract_unique_domain(operand) or target_mesh - vom_onto_other_vom = ((source_mesh is not target_mesh) - and isinstance(source_mesh.topology, VertexOnlyMeshTopology) - and isinstance(target_mesh.topology, VertexOnlyMeshTopology)) - - arguments = expr.arguments() - rank = len(arguments) - if rank <= 1: - if rank == 0: - R = firedrake.FunctionSpace(target_mesh, "Real", 0) - f = firedrake.Function(R, dtype=utils.ScalarType) - elif isinstance(V, firedrake.Function): - f = V - V = f.function_space() - else: - V_dest = arguments[0].function_space().dual() - f = firedrake.Function(V_dest) - if access in {firedrake.MIN, firedrake.MAX}: - finfo = numpy.finfo(f.dat.dtype) - if access == firedrake.MIN: - val = firedrake.Constant(finfo.max) - else: - val = firedrake.Constant(finfo.min) - f.assign(val) - tensor = f.dat - elif rank == 2: - if isinstance(V, firedrake.Function): - raise ValueError("Cannot interpolate an expression with an argument into a Function") - Vrow = arguments[0].function_space() - Vcol = arguments[1].function_space() - if len(Vrow) > 1 or len(Vcol) > 1: - raise NotImplementedError("Interpolation of mixed expressions with arguments is not supported") - if isinstance(target_mesh.topology, VertexOnlyMeshTopology) and target_mesh is not source_mesh and not vom_onto_other_vom: - if not isinstance(target_mesh.topology, VertexOnlyMeshTopology): - raise NotImplementedError("Can only interpolate onto a Vertex Only Mesh") - if target_mesh.geometric_dimension() != source_mesh.geometric_dimension(): - raise ValueError("Cannot interpolate onto a mesh of a different geometric dimension") - if not hasattr(target_mesh, "_parent_mesh") or target_mesh._parent_mesh is not source_mesh: - raise ValueError("Can only interpolate across meshes where the source mesh is the parent of the target") - - if vom_onto_other_vom: - # We make our own linear operator for this case using PETSc SFs - tensor = None - else: - Vrow_map = get_interp_node_map(source_mesh, target_mesh, Vrow) - Vcol_map = get_interp_node_map(source_mesh, target_mesh, Vcol) - sparsity = op2.Sparsity((Vrow.dof_dset, Vcol.dof_dset), - [(Vrow_map, Vcol_map, None)], # non-mixed - name="%s_%s_sparsity" % (Vrow.name, Vcol.name), - nest=False, - block_sparse=True) - tensor = op2.Mat(sparsity) - f = tensor - else: - raise ValueError(f"Cannot interpolate an expression with {rank} arguments") - - if vom_onto_other_vom: - wrapper = VomOntoVomWrapper(V, source_mesh, target_mesh, operand, matfree) - # NOTE: get_dat_mpi_type ensures we get the correct MPI type for the - # data, including the correct data size and dimensional information - # (so for vector function spaces in 2 dimensions we might need a - # concatenation of 2 MPI.DOUBLE types when we are in real mode) - if tensor is not None: - # Callable will do interpolation into our pre-supplied function f - # when it is called. - assert f.dat is tensor - wrapper.mpi_type, _ = get_dat_mpi_type(f.dat) - assert len(arguments) == 1 - - def callable(): - wrapper.forward_operation(f.dat) - return f - else: - assert len(arguments) == 2 - assert tensor is None - # we know we will be outputting either a function or a cofunction, - # both of which will use a dat as a data carrier. At present, the - # data type does not depend on function space dimension, so we can - # safely use the argument function space. NOTE: If this changes - # after cofunctions are fully implemented, this will need to be - # reconsidered. - temp_source_func = firedrake.Function(Vcol) - wrapper.mpi_type, _ = get_dat_mpi_type(temp_source_func.dat) - - # Leave wrapper inside a callable so we can access the handle - # property. If matfree is True, then the handle is a PETSc SF - # pretending to be a PETSc Mat. If matfree is False, then this - # will be a PETSc Mat representing the equivalent permutation - # matrix - def callable(): - return wrapper - - return callable - else: - loops = [] - if len(V) == 1: - expressions = (expr,) - else: - if (hasattr(operand, "subfunctions") and len(operand.subfunctions) == len(V) - and all(sub_op.ufl_shape == Vsub.value_shape for Vsub, sub_op in zip(V, operand.subfunctions))): - # Use subfunctions if they match the target shapes - operands = operand.subfunctions - else: - # Unflatten the expression into the shapes of the mixed components - offset = 0 - operands = [] - for Vsub in V: - if len(Vsub.value_shape) == 0: - operands.append(operand[offset]) - else: - components = [operand[offset + j] for j in range(Vsub.value_size)] - operands.append(ufl.as_tensor(numpy.reshape(components, Vsub.value_shape))) - offset += Vsub.value_size - - # Split the dual argument - if isinstance(dual_arg, Cofunction): - duals = dual_arg.subfunctions - elif isinstance(dual_arg, Coargument): - duals = [Coargument(Vsub, number=dual_arg.number()) for Vsub in dual_arg.function_space()] - else: - duals = [v for _, v in sorted(firedrake.formmanipulation.split_form(dual_arg))] - expressions = map(expr._ufl_expr_reconstruct_, operands, duals) - - # Interpolate each sub expression into each function space - for Vsub, sub_tensor, sub_expr in zip(V, tensor, expressions): - loops.extend(_interpolator(Vsub, sub_tensor, sub_expr, subset, arguments, access, bcs=bcs)) - - if bcs and rank == 1: - loops.extend(partial(bc.apply, f) for bc in bcs) +class VomOntoVomInterpolator(SameMeshInterpolator): - def callable(loops, f): - for l in loops: - l() - return f - - return partial(callable, loops, f) + def __init__(self, expr: Interpolate, V, bcs=None): + super().__init__(expr, V, bcs=bcs) @utils.known_pyop2_safe From 4a843ae1faab94d11f7ce0b717e90b01bcbb670b Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Fri, 26 Sep 2025 15:06:55 +0100 Subject: [PATCH 054/125] remove comment --- firedrake/interpolation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 78aeb7b605..f628eea176 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -117,7 +117,6 @@ def __init__(self, expr, V, **kwargs): Additional interpolation options. See :class:`InterpolateOptions` for available parameters and their descriptions. """ - # TODO: should we allow RHS to be FiredrakeDualSpace? if isinstance(V, WithGeometry): # Need to create a Firedrake Coargument so it has a .function_space() method expr_args = extract_arguments(ufl.as_ufl(expr)) From f43b2c02fb845720e8d640e8190d50729f853ae1 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Fri, 26 Sep 2025 15:07:38 +0100 Subject: [PATCH 055/125] remove properties --- firedrake/interpolation.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index f628eea176..f0937db286 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -132,14 +132,6 @@ def _ufl_expr_reconstruct_(self, expr, v=None, **interp_data): interp_data = interp_data or asdict(self.options) return ufl.Interpolate._ufl_expr_reconstruct_(self, expr, v=v, **interp_data) - @property - def target_space(self): - return self.argument_slots()[0].function_space().dual() - - @property - def source_space(self): - return self.argument_slots()[1].function_space() - @property def options(self): return self._options From 9e33ce25940155f2eeda8497aee35e4be018fea8 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 26 Sep 2025 15:17:45 +0100 Subject: [PATCH 056/125] suggestions from review --- firedrake/interpolation.py | 78 ++++++++++++------- .../submesh/test_submesh_interpolate.py | 34 ++++---- 2 files changed, 68 insertions(+), 44 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index d38219652f..c318a8c2b7 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -3,6 +3,8 @@ import tempfile import abc import warnings +from collections.abc import Iterable +from typing import Literal from functools import partial, singledispatch from typing import Hashable @@ -23,6 +25,7 @@ import finat import firedrake +import firedrake.bcs from firedrake import tsfc_interface, utils, functionspaceimpl from firedrake.ufl_expr import Argument, Coargument, action, adjoint as expr_adjoint from firedrake.mesh import MissingPointsBehaviour, VertexOnlyMeshMissingPointsError, VertexOnlyMeshTopology @@ -47,7 +50,7 @@ class Interpolate(ufl.Interpolate): def __init__(self, expr, v, subset=None, - access=op2.WRITE, + access=None, allow_missing_dofs=False, default_missing_val=None, matfree=True): @@ -122,7 +125,7 @@ def _ufl_expr_reconstruct_(self, expr, v=None, **interp_data): @PETSc.Log.EventDecorator() -def interpolate(expr, V, subset=None, access=op2.WRITE, allow_missing_dofs=False, default_missing_val=None, matfree=True): +def interpolate(expr, V, subset=None, access=None, allow_missing_dofs=False, default_missing_val=None, matfree=True): """Returns a UFL expression for the interpolation operation of ``expr`` into ``V``. :arg expr: a UFL expression. @@ -203,25 +206,34 @@ def interpolate(expr, V, subset=None, access=op2.WRITE, allow_missing_dofs=False class Interpolator(abc.ABC): """A reusable interpolation object. - :arg expr: The expression to interpolate. - :arg V: The :class:`.FunctionSpace` or :class:`.Function` to + Parameters + ---------- + expr + The underlying ufl.Interpolate or the operand to the ufl.Interpolate. + V + The :class:`.FunctionSpace` or :class:`.Function` to interpolate into. - :kwarg subset: An optional :class:`pyop2.types.set.Subset` to apply the + subset + An optional :class:`pyop2.types.set.Subset` to apply the interpolation over. Cannot, at present, be used when interpolating across meshes unless the target mesh is a :func:`.VertexOnlyMesh`. - :kwarg freeze_expr: Set to True to prevent the expression being + freeze_expr + Set to True to prevent the expression being re-evaluated on each call. Cannot, at present, be used when interpolating across meshes unless the target mesh is a :func:`.VertexOnlyMesh`. - :kwarg access: The pyop2 access descriptor for combining updates to shared - DoFs. Possible values include ``WRITE`` and ``INC``. Only ``WRITE`` is - supported at present when interpolating across meshes. See note in - :func:`.interpolate` if changing this from default. - :kwarg bcs: An optional list of boundary conditions to zero-out in the + access + The pyop2 access descriptor for combining updates to shared DoFs. + Only ``op2.WRITE`` is supported at present when interpolating across meshes. + Only ``op2.INC`` is supported for the matrix-free adjoint interpolation. + See note in :func:`.interpolate` if changing this from default. + bcs + An optional list of boundary conditions to zero-out in the output function space. Interpolator rows or columns which are associated with boundary condition nodes are zeroed out when this is specified. - :kwarg allow_missing_dofs: For interpolation across meshes: allow + allow_missing_dofs + For interpolation across meshes: allow degrees of freedom (aka DoFs/nodes) in the target mesh that cannot be defined on the source mesh. For example, where nodes are point evaluations, points in the target mesh that are not in the source mesh. @@ -233,14 +245,16 @@ class Interpolator(abc.ABC): Ignored if interpolating within the same mesh or onto a :func:`.VertexOnlyMesh` (the behaviour of a :func:`.VertexOnlyMesh` in this scenario is, at present, set when it is created). - :kwarg matfree: If ``False``, then construct the permutation matrix for interpolating + matfree + If ``False``, then construct the permutation matrix for interpolating between a VOM and its input ordering. Defaults to ``True`` which uses SF broadcast and reduce operations. This object can be used to carry out the same interpolation multiple times (for example in a timestepping loop). - .. note:: + Note + ---- The :class:`Interpolator` holds a reference to the provided arguments (such that they won't be collected until the @@ -267,14 +281,14 @@ def __new__(cls, expr, V, **kwargs): def __init__( self, - expr, - V, - subset=None, - freeze_expr=False, - access=op2.WRITE, - bcs=None, - allow_missing_dofs=False, - matfree=True + expr: ufl.Interpolate | ufl.classes.Expr, + V: ufl.FunctionSpace | firedrake.function.Function, + subset: op2.Subset | None = None, + freeze_expr: bool = False, + access: Literal[op2.WRITE, op2.MIN, op2.MAX, op2.INC] | None = None, + bcs: Iterable[firedrake.bcs.BCBase] | None = None, + allow_missing_dofs: bool = False, + matfree: bool = True ): if not isinstance(expr, ufl.Interpolate): fs = V if isinstance(V, ufl.FunctionSpace) else V.function_space() @@ -285,7 +299,6 @@ def __init__( self.V = V self.subset = subset self.freeze_expr = freeze_expr - self.access = access self.bcs = bcs self._allow_missing_dofs = allow_missing_dofs self.matfree = matfree @@ -324,9 +337,16 @@ def __init__( dual_arg, operand = expr.argument_slots() self.expr_renumbered = operand self.ufl_interpolate_renumbered = expr + if not isinstance(dual_arg, ufl.Coargument): # Matrix-free assembly of 0-form or 1-form requires INC access - self.access = op2.INC + if access and access != op2.INC: + raise ValueError("Matfree adjoint interpolation requires INC access") + access = op2.INC + elif access is None: + # Default access for forward 1-form or 2-form (forward and adjoint) + access = op2.WRITE + self.access = access def interpolate(self, *function, transpose=None, adjoint=False, default_missing_val=None): """ @@ -432,7 +452,7 @@ def __init__( V, subset=None, freeze_expr=False, - access=op2.WRITE, + access=None, bcs=None, allow_missing_dofs=False, matfree=True @@ -756,7 +776,7 @@ class SameMeshInterpolator(Interpolator): """ @no_annotations - def __init__(self, expr, V, subset=None, freeze_expr=False, access=op2.WRITE, + def __init__(self, expr, V, subset=None, freeze_expr=False, access=None, bcs=None, matfree=True, allow_missing_dofs=False, **kwargs): if subset is None: if isinstance(expr, ufl.Interpolate): @@ -1187,9 +1207,9 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): def get_interp_node_map(source_mesh, target_mesh, fs): - """Return the map between cells of the target mesh and nodes of the function space. - - If the function space is defined on the source mesh then the node map is composed + """Return the map between cells of the target mesh and nodes of the function space. + + If the function space is defined on the source mesh then the node map is composed with a map between target and source cells. """ if isinstance(target_mesh.topology, VertexOnlyMeshTopology): diff --git a/tests/firedrake/submesh/test_submesh_interpolate.py b/tests/firedrake/submesh/test_submesh_interpolate.py index 0a5c1ce2d1..19fc2cd334 100644 --- a/tests/firedrake/submesh/test_submesh_interpolate.py +++ b/tests/firedrake/submesh/test_submesh_interpolate.py @@ -1,6 +1,7 @@ import pytest from firedrake import * import numpy as np +from mpi4py import MPI from ufl.conditional import GT, LT from os.path import abspath, dirname, join @@ -282,7 +283,7 @@ def test_submesh_interpolate_adjoint(fe_fesub): mesh = UnitSquareMesh(8, 8) x, y = SpatialCoordinate(mesh) - subdomain_cond = conditional(LT(x, 0.5), 1, 0) + subdomain_cond = conditional(And(LT(x, 0.5), LT(y, 0.5)), 1, 0) label_value = 999 subm = make_submesh(mesh, subdomain_cond, label_value) @@ -312,24 +313,27 @@ def test_submesh_interpolate_adjoint(fe_fesub): result_adjoint_2 = assemble(action(action(I_adj, ustar2), u1)) assert np.isclose(result_adjoint_2, expected) - # Test forward 1-form (only in serial for now) - if V1.comm.size == 1: - # Matfree forward interpolation with Submesh currently fails in parallel. - # The ghost nodes of the parent mesh may be redistributed - # into different processes as non-ghost dofs of the submesh. - # The submesh kernel will write into ghost nodes of the parent mesh, - # but this will be ignored in the halo exchange if access=op2.WRITE. + # Test forward 1-form (only works in serial for continuous elements) + # Matfree forward interpolation with Submesh currently fails in parallel. + # The ghost nodes of the parent mesh may be redistributed + # into different processes as non-ghost dofs of the submesh. + # The submesh kernel will write into ghost nodes of the parent mesh, + # but this will be ignored in the halo exchange if access=op2.WRITE. - # See https://github.com/firedrakeproject/firedrake/issues/4483 + # See https://github.com/firedrakeproject/firedrake/issues/4483 + expected_to_pass = (V2.comm.size == 1 or V2.finat_element.is_dg()) - Iu1 = assemble(interpolate(u1, TestFunction(V2.dual()), allow_missing_dofs=True)) - assert Iu1.function_space() == V2 + Iu1 = assemble(interpolate(u1, TestFunction(V2.dual()), allow_missing_dofs=True)) + assert Iu1.function_space() == V2 - expected_primal = assemble(action(I, u1)) - assert np.allclose(Iu1.dat.data, expected_primal.dat.data) + expected_primal = assemble(action(I, u1)) + test1 = np.allclose(Iu1.dat.data, expected_primal.dat.data) + test1 = V2.comm.allreduce(test1, MPI.LAND) + assert test1 == expected_to_pass - result_forward_1 = assemble(action(ustar2, Iu1)) - assert np.isclose(result_forward_1, expected) + result_forward_1 = assemble(action(ustar2, Iu1)) + test0 = np.isclose(result_forward_1, expected) + assert test0 == expected_to_pass # Test adjoint 1-form ustar2I = assemble(interpolate(TestFunction(V1), ustar2, allow_missing_dofs=True)) From cce10bd10a4f17abf97c87a9908f1a6afc11142b Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Fri, 26 Sep 2025 15:53:57 +0100 Subject: [PATCH 057/125] tidy / add comments --- firedrake/interpolation.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 967e926b1c..bef19dc9e6 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -356,7 +356,7 @@ def __init__(self, expr: Interpolate, V, bcs=None): if self.src_mesh.geometric_dimension() != self.dest_mesh.geometric_dimension(): raise ValueError("Geometric dimensions of source and destination meshes must match.") - self.sub_interpolators = [] + self.sub_interpolates = [] dest_element = self.V_dest.ufl_element() if isinstance(dest_element, (finat.ufl.VectorElement, finat.ufl.TensorElement)): # In this case all sub elements are equal @@ -398,7 +398,7 @@ def _get_symbolic_expressions(self): except VertexOnlyMeshMissingPointsError: raise DofNotDefinedError(self.src_mesh, self.dest_mesh) - # Evaluate expr at the immersed coordinates + # Get the correct type of function space shape = self.V_dest.ufl_function_space().value_shape if len(shape) == 0: fs_type = firedrake.FunctionSpace @@ -406,10 +406,12 @@ def _get_symbolic_expressions(self): fs_type = partial(firedrake.VectorFunctionSpace, dim=shape[0]) else: fs_type = partial(firedrake.TensorFunctionSpace, shape=shape) + + # Get expression for point evaluation at the dest_node_coords P0DG_vom = fs_type(self.vom, "DG", 0) self.point_eval = interpolate(self.expr_renumbered, P0DG_vom) - # Interpolate into the input-ordering + # Interpolate into the input-ordering VOM P0DG_vom_i_o = fs_type(self.vom.input_ordering, "DG", 0) self.point_eval_input_ordering = interpolate(firedrake.TrialFunction(P0DG_vom), P0DG_vom_i_o) @@ -434,15 +436,10 @@ def _mixed_function_space(self): expr_subfunctions = self.expr.subfunctions if len(expr_subfunctions) != len(self.V_dest.subspaces): - raise NotImplementedError( - "Can't interpolate from a non-mixed function space into a mixed function space." - ) - for input_sub_func, target_subspace in zip( - expr_subfunctions, self.V_dest.subspaces - ): - self.sub_interpolators.append( - interpolate(input_sub_func, target_subspace, **asdict(self.options)) - ) + raise NotImplementedError("Can't interpolate from a non-mixed function space into a mixed function space.") + + for sub_func, subspace in zip(expr_subfunctions, self.V_dest.subspaces): + self.sub_interpolates.append(interpolate(sub_func, subspace, **asdict(self.options))) @PETSc.Log.EventDecorator() def _interpolate(self, *function, output=None, adjoint=False): @@ -490,14 +487,14 @@ def _interpolate(self, *function, output=None, adjoint=False): else: output = firedrake.Function(V_dest) - if len(self.sub_interpolators): + if len(self.sub_interpolates): # MixedFunctionSpace case for sub_interpolate, f_src_sub_func, output_sub_func in zip( - self.sub_interpolators, f_src.subfunctions, output.subfunctions + self.sub_interpolates, f_src.subfunctions, output.subfunctions ): if f_src is self.expr: # f_src is already contained in self.point_eval_interpolate, - # so the sub_interpolators are already prepared to interpolate + # so the sub_interpolates are already prepared to interpolate # without needing to be given a Function assert not self.nargs assemble(sub_interpolate, tensor=output_sub_func) From 6b3906297b6217dc43f411972b95a9fc08870cf5 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Fri, 26 Sep 2025 15:54:26 +0100 Subject: [PATCH 058/125] Test -> Trial inside `CrossMeshInterpolator._mixed_function_space` --- firedrake/interpolation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index bef19dc9e6..4539a57c7e 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -425,7 +425,7 @@ def _mixed_function_space(self): # Arguments don't have a subfunctions property so I have to # make them myself. expr_subfunctions = [ - firedrake.TestFunction(V_src_sub_func) + firedrake.TrialFunction(V_src_sub_func) for V_src_sub_func in self.expr.function_space().subspaces ] elif self.nargs > 1: From 899a0ecc745374e27472db0e5109bee4642a6227 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Fri, 26 Sep 2025 15:57:01 +0100 Subject: [PATCH 059/125] lint --- firedrake/interpolation.py | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 4539a57c7e..4c9f141504 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -2,9 +2,7 @@ import os import tempfile import abc -import warnings -from collections.abc import Iterable -from typing import Literal + from functools import partial, singledispatch from typing import Hashable, Optional from dataclasses import asdict, dataclass @@ -12,7 +10,7 @@ import FIAT import ufl import finat.ufl -from ufl.algorithms import extract_arguments, extract_coefficients, replace +from ufl.algorithms import extract_arguments, extract_coefficients from ufl.domain import as_domain, extract_unique_domain from pyop2 import op2 @@ -27,8 +25,6 @@ import finat import firedrake -import firedrake.bcs -from firedrake import tsfc_interface, utils, functionspaceimpl from firedrake import tsfc_interface, utils from firedrake.ufl_expr import Argument, Coargument, action, adjoint as expr_adjoint from firedrake.mesh import MissingPointsBehaviour, VertexOnlyMeshMissingPointsError, VertexOnlyMeshTopology @@ -49,10 +45,11 @@ "SameMeshInterpolator", ) + @dataclass class InterpolateOptions: """Options for interpolation operations. - + Attributes ---------- subset : pyop2.types.set.Subset, optional @@ -63,8 +60,8 @@ class InterpolateOptions: The pyop2 access descriptor for combining updates to shared DoFs. Possible values include ``WRITE`` and ``INC``. Only ``WRITE`` is supported at present when interpolating across meshes unless the target - mesh is a :func:`.VertexOnlyMesh`. - + mesh is a :func:`.VertexOnlyMesh`. + .. note:: If you use an access descriptor other than ``WRITE``, the behaviour of interpolation changes if interpolating into a @@ -75,7 +72,7 @@ class InterpolateOptions: then it is assumed that its values should take part in the reduction (hence using MIN will compute the MIN between the existing values and any new values). - + allow_missing_dofs : bool, default False For interpolation across meshes: allow degrees of freedom (aka DoFs/nodes) in the target mesh that cannot be defined on the source mesh. @@ -136,7 +133,7 @@ def __init__(self, expr, V, **kwargs): def _ufl_expr_reconstruct_(self, expr, v=None, **interp_data): interp_data = interp_data or asdict(self.options) return ufl.Interpolate._ufl_expr_reconstruct_(self, expr, v=v, **interp_data) - + @property def options(self): return self._options @@ -154,7 +151,7 @@ def interpolate(expr, V, **kwargs): The function space to interpolate into or the coargument defined on the dual of the function space to interpolate into. **kwargs - Additional interpolation options. See :class:`InterpolateOptions` + Additional interpolation options. See :class:`InterpolateOptions` for available parameters and their descriptions. Returns @@ -418,8 +415,8 @@ def _get_symbolic_expressions(self): def _mixed_function_space(self): """Builds symbolic Interpolate expressions for each sub-element of a MixedFunctionSpace. """ - # NOTE: since we can't have expressions for MixedFunctionSpaces - # we know that the input argument ``expr`` must be a Function. + # NOTE: since we can't have expressions for MixedFunctionSpaces + # we know that the input argument ``expr`` must be a Function. # V_dest can be a Function or a FunctionSpace, and subfunctions works for both. if self.nargs == 1: # Arguments don't have a subfunctions property so I have to @@ -630,15 +627,15 @@ def __init__(self, expr, V, bcs=None): except FIAT.hdiv_trace.TraceError: raise NotImplementedError("Can't interpolate onto traces.") self.arguments = self.ufl_interpolate_renumbered.arguments() - + def _get_callable(self, V): expr = self.ufl_interpolate_renumbered dual_arg, operand = expr.argument_slots() target_mesh = as_domain(dual_arg) source_mesh = extract_unique_domain(operand) or target_mesh vom_onto_other_vom = ((source_mesh is not target_mesh) - and isinstance(source_mesh.topology, VertexOnlyMeshTopology) - and isinstance(target_mesh.topology, VertexOnlyMeshTopology)) + and isinstance(source_mesh.topology, VertexOnlyMeshTopology) + and isinstance(target_mesh.topology, VertexOnlyMeshTopology)) arguments = expr.arguments() rank = len(arguments) @@ -772,7 +769,6 @@ def callable(loops, f): return partial(callable, loops, f) - @PETSc.Log.EventDecorator() def _interpolate(self, *function, output=None, adjoint=False): """Compute the interpolation. From 733da2d5b8c0386a2b82592883ec3e8adf163f22 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Fri, 26 Sep 2025 16:53:08 +0100 Subject: [PATCH 060/125] create `_get_tensor` method --- firedrake/interpolation.py | 176 ++++++++++++++++++++----------------- 1 file changed, 95 insertions(+), 81 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 4c9f141504..d2bb769b4b 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -603,7 +603,7 @@ def __init__(self, expr, V, bcs=None): operand, = expr.ufl_operands else: operand = expr - target_mesh = as_domain(V) + target_mesh = as_domain(self.V) source_mesh = extract_unique_domain(operand) or target_mesh target = target_mesh.topology source = source_mesh.topology @@ -623,12 +623,12 @@ def __init__(self, expr, V, bcs=None): pass self.subset = subset try: - self.callable = self._get_callable(V) + self.callable = self._get_callable() except FIAT.hdiv_trace.TraceError: raise NotImplementedError("Can't interpolate onto traces.") self.arguments = self.ufl_interpolate_renumbered.arguments() - def _get_callable(self, V): + def _get_tensor(self): expr = self.ufl_interpolate_renumbered dual_arg, operand = expr.argument_slots() target_mesh = as_domain(dual_arg) @@ -643,9 +643,9 @@ def _get_callable(self, V): if rank == 0: R = firedrake.FunctionSpace(target_mesh, "Real", 0) f = firedrake.Function(R, dtype=utils.ScalarType) - elif isinstance(V, firedrake.Function): - f = V - V = f.function_space() + elif isinstance(self.V, firedrake.Function): + f = self.V + self.V = f.function_space() else: V_dest = arguments[0].function_space().dual() f = firedrake.Function(V_dest) @@ -658,7 +658,7 @@ def _get_callable(self, V): f.assign(val) tensor = f.dat elif rank == 2: - if isinstance(V, firedrake.Function): + if isinstance(self.V, firedrake.Function): raise ValueError("Cannot interpolate an expression with an argument into a Function") Vrow = arguments[0].function_space() Vcol = arguments[1].function_space() @@ -687,87 +687,57 @@ def _get_callable(self, V): f = tensor else: raise ValueError(f"Cannot interpolate an expression with {rank} arguments") + return f, tensor - if vom_onto_other_vom: - wrapper = VomOntoVomWrapper(V, source_mesh, target_mesh, operand, self.options.matfree) - # NOTE: get_dat_mpi_type ensures we get the correct MPI type for the - # data, including the correct data size and dimensional information - # (so for vector function spaces in 2 dimensions we might need a - # concatenation of 2 MPI.DOUBLE types when we are in real mode) - if tensor is not None: - # Callable will do interpolation into our pre-supplied function f - # when it is called. - assert f.dat is tensor - wrapper.mpi_type, _ = get_dat_mpi_type(f.dat) - assert len(arguments) == 1 - - def callable(): - wrapper.forward_operation(f.dat) - return f - else: - assert len(arguments) == 2 - assert tensor is None - # we know we will be outputting either a function or a cofunction, - # both of which will use a dat as a data carrier. At present, the - # data type does not depend on function space dimension, so we can - # safely use the argument function space. NOTE: If this changes - # after cofunctions are fully implemented, this will need to be - # reconsidered. - temp_source_func = firedrake.Function(Vcol) - wrapper.mpi_type, _ = get_dat_mpi_type(temp_source_func.dat) - - # Leave wrapper inside a callable so we can access the handle - # property. If matfree is True, then the handle is a PETSc SF - # pretending to be a PETSc Mat. If matfree is False, then this - # will be a PETSc Mat representing the equivalent permutation - # matrix - def callable(): - return wrapper - - return callable + def _get_callable(self): + expr = self.ufl_interpolate_renumbered + dual_arg, operand = expr.argument_slots() + arguments = expr.arguments() + rank = len(arguments) + f, tensor = self._get_tensor() + + loops = [] + if len(self.V) == 1: + expressions = (expr,) else: - loops = [] - if len(V) == 1: - expressions = (expr,) + if (hasattr(operand, "subfunctions") and len(operand.subfunctions) == len(self.V) + and all(sub_op.ufl_shape == Vsub.value_shape for Vsub, sub_op in zip(self.V, operand.subfunctions))): + # Use subfunctions if they match the target shapes + operands = operand.subfunctions else: - if (hasattr(operand, "subfunctions") and len(operand.subfunctions) == len(V) - and all(sub_op.ufl_shape == Vsub.value_shape for Vsub, sub_op in zip(V, operand.subfunctions))): - # Use subfunctions if they match the target shapes - operands = operand.subfunctions - else: - # Unflatten the expression into the shapes of the mixed components - offset = 0 - operands = [] - for Vsub in V: - if len(Vsub.value_shape) == 0: - operands.append(operand[offset]) - else: - components = [operand[offset + j] for j in range(Vsub.value_size)] - operands.append(ufl.as_tensor(numpy.reshape(components, Vsub.value_shape))) - offset += Vsub.value_size - - # Split the dual argument - if isinstance(dual_arg, Cofunction): - duals = dual_arg.subfunctions - elif isinstance(dual_arg, Coargument): - duals = [Coargument(Vsub, number=dual_arg.number()) for Vsub in dual_arg.function_space()] - else: - duals = [v for _, v in sorted(firedrake.formmanipulation.split_form(dual_arg))] - expressions = map(expr._ufl_expr_reconstruct_, operands, duals) + # Unflatten the expression into the shapes of the mixed components + offset = 0 + operands = [] + for Vsub in self.V: + if len(Vsub.value_shape) == 0: + operands.append(operand[offset]) + else: + components = [operand[offset + j] for j in range(Vsub.value_size)] + operands.append(ufl.as_tensor(numpy.reshape(components, Vsub.value_shape))) + offset += Vsub.value_size + + # Split the dual argument + if isinstance(dual_arg, Cofunction): + duals = dual_arg.subfunctions + elif isinstance(dual_arg, Coargument): + duals = [Coargument(Vsub, number=dual_arg.number()) for Vsub in dual_arg.function_space()] + else: + duals = [v for _, v in sorted(firedrake.formmanipulation.split_form(dual_arg))] + expressions = map(expr._ufl_expr_reconstruct_, operands, duals) - # Interpolate each sub expression into each function space - for Vsub, sub_tensor, sub_expr in zip(V, tensor, expressions): - loops.extend(_interpolator(Vsub, sub_tensor, sub_expr, self.subset, arguments, self.options.access, bcs=self.bcs)) + # Interpolate each sub expression into each function space + for Vsub, sub_tensor, sub_expr in zip(self.V, tensor, expressions): + loops.extend(_interpolator(Vsub, sub_tensor, sub_expr, self.subset, arguments, self.options.access, bcs=self.bcs)) - if self.bcs and rank == 1: - loops.extend(partial(bc.apply, f) for bc in self.bcs) + if self.bcs and rank == 1: + loops.extend(partial(bc.apply, f) for bc in self.bcs) - def callable(loops, f): - for l in loops: - l() - return f + def callable(loops, f): + for l in loops: + l() + return f - return partial(callable, loops, f) + return partial(callable, loops, f) @PETSc.Log.EventDecorator() def _interpolate(self, *function, output=None, adjoint=False): @@ -818,6 +788,50 @@ class VomOntoVomInterpolator(SameMeshInterpolator): def __init__(self, expr: Interpolate, V, bcs=None): super().__init__(expr, V, bcs=bcs) + def _get_callable(self): + expr = self.ufl_interpolate_renumbered + dual_arg, operand = expr.argument_slots() + target_mesh = as_domain(dual_arg) + source_mesh = extract_unique_domain(operand) or target_mesh + arguments = expr.arguments() + f, tensor = self._get_tensor() + wrapper = VomOntoVomWrapper(self.V, source_mesh, target_mesh, operand, self.options.matfree) + # NOTE: get_dat_mpi_type ensures we get the correct MPI type for the + # data, including the correct data size and dimensional information + # (so for vector function spaces in 2 dimensions we might need a + # concatenation of 2 MPI.DOUBLE types when we are in real mode) + if tensor is not None: + # Callable will do interpolation into our pre-supplied function f + # when it is called. + assert f.dat is tensor + wrapper.mpi_type, _ = get_dat_mpi_type(f.dat) + assert len(arguments) == 1 + + def callable(): + wrapper.forward_operation(f.dat) + return f + else: + assert len(arguments) == 2 + assert tensor is None + # we know we will be outputting either a function or a cofunction, + # both of which will use a dat as a data carrier. At present, the + # data type does not depend on function space dimension, so we can + # safely use the argument function space. NOTE: If this changes + # after cofunctions are fully implemented, this will need to be + # reconsidered. + temp_source_func = firedrake.Function(arguments[1].function_space()) + wrapper.mpi_type, _ = get_dat_mpi_type(temp_source_func.dat) + + # Leave wrapper inside a callable so we can access the handle + # property. If matfree is True, then the handle is a PETSc SF + # pretending to be a PETSc Mat. If matfree is False, then this + # will be a PETSc Mat representing the equivalent permutation + # matrix + def callable(): + return wrapper + + return callable + @utils.known_pyop2_safe def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): From d89348880d2430078ae3e39505bb0a978a130ba5 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Fri, 26 Sep 2025 16:58:52 +0100 Subject: [PATCH 061/125] tidy --- firedrake/interpolation.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index d2bb769b4b..98aa330afd 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -633,10 +633,6 @@ def _get_tensor(self): dual_arg, operand = expr.argument_slots() target_mesh = as_domain(dual_arg) source_mesh = extract_unique_domain(operand) or target_mesh - vom_onto_other_vom = ((source_mesh is not target_mesh) - and isinstance(source_mesh.topology, VertexOnlyMeshTopology) - and isinstance(target_mesh.topology, VertexOnlyMeshTopology)) - arguments = expr.arguments() rank = len(arguments) if rank <= 1: @@ -664,6 +660,7 @@ def _get_tensor(self): Vcol = arguments[1].function_space() if len(Vrow) > 1 or len(Vcol) > 1: raise NotImplementedError("Interpolation of mixed expressions with arguments is not supported") + vom_onto_other_vom = isinstance(self, VomOntoVomInterpolator) if isinstance(target_mesh.topology, VertexOnlyMeshTopology) and target_mesh is not source_mesh and not vom_onto_other_vom: if not isinstance(target_mesh.topology, VertexOnlyMeshTopology): raise NotImplementedError("Can only interpolate onto a Vertex Only Mesh") @@ -680,7 +677,7 @@ def _get_tensor(self): Vcol_map = get_interp_node_map(source_mesh, target_mesh, Vcol) sparsity = op2.Sparsity((Vrow.dof_dset, Vcol.dof_dset), [(Vrow_map, Vcol_map, None)], # non-mixed - name="%s_%s_sparsity" % (Vrow.name, Vcol.name), + name=f"{Vrow.name}_{Vcol.name}_sparsity", nest=False, block_sparse=True) tensor = op2.Mat(sparsity) From 05237c9a29b0e70171652bc2481c50db2fe08ccf Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Mon, 29 Sep 2025 10:22:00 +0100 Subject: [PATCH 062/125] simplify --- firedrake/interpolation.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 98aa330afd..5f1c28abac 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -363,13 +363,12 @@ def __init__(self, expr: Interpolate, V, bcs=None): "Can't yet cross-mesh interpolate onto function spaces made from VectorElements or TensorElements made from sub elements with value shape other than ()." ) self.dest_element = base_element - self._get_symbolic_expressions() elif isinstance(dest_element, finat.ufl.MixedElement): - self._mixed_function_space() + return self._mixed_function_space() else: # scalar fiat/finat element self.dest_element = dest_element - self._get_symbolic_expressions() + self._get_symbolic_expressions() def _get_symbolic_expressions(self): """Constructs the symbolic Interpolate expressions for cross-mesh interpolation. From c7426b20417e740a6e50b6e3182b339fbbce5d69 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Tue, 23 Sep 2025 15:51:07 +0100 Subject: [PATCH 063/125] simplify . --- firedrake/interpolation.py | 30 ++++-------------------------- 1 file changed, 4 insertions(+), 26 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 12ed650f48..96842dac17 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -174,32 +174,10 @@ def interpolate(expr, V, subset=None, access=None, allow_missing_dofs=False, def reduction (hence using MIN will compute the MIN between the existing values and any new values). """ - if isinstance(V, (Cofunction, Coargument)): - dual_arg = V - elif isinstance(V, ufl.BaseForm): - rank = len(V.arguments()) - if rank == 1: - dual_arg = V - else: - raise TypeError(f"Expected a one-form, provided form had {rank} arguments") - elif isinstance(V, functionspaceimpl.WithGeometry): - dual_arg = Coargument(V.dual(), 0) - expr_args = extract_arguments(ufl.as_ufl(expr)) - if expr_args and expr_args[0].number() == 0: - warnings.warn("Passing argument numbered 0 in expression for forward interpolation is deprecated. " - "Use a TrialFunction in the expression.") - v, = expr_args - expr = replace(expr, {v: v.reconstruct(number=1)}) - else: - raise TypeError(f"V must be a FunctionSpace, Cofunction, Coargument or one-form, not a {type(V).__name__}") - - interp = Interpolate(expr, dual_arg, - subset=subset, access=access, - allow_missing_dofs=allow_missing_dofs, - default_missing_val=default_missing_val, - matfree=matfree) - - return interp + return Interpolate( + expr, V, subset=subset, access=access, allow_missing_dofs=allow_missing_dofs, + default_missing_val=default_missing_val, matfree=matfree + ) class Interpolator(abc.ABC): From 106ea87bc63557a74ab311f2e9e3815b16806b55 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Tue, 23 Sep 2025 15:55:12 +0100 Subject: [PATCH 064/125] tidy function interpolate --- firedrake/function.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/firedrake/function.py b/firedrake/function.py index b2cda5bc4e..38dc1d0e41 100644 --- a/firedrake/function.py +++ b/firedrake/function.py @@ -384,7 +384,10 @@ def interpolate(self, """ from firedrake import interpolation, assemble V = self.function_space() - interp = interpolation.Interpolate(expression, V, **kwargs) + interp = interpolate( + expression, V, subset=subset, allow_missing_dofs=allow_missing_dofs, + default_missing_val=default_missing_val + ) return assemble(interp, tensor=self, ad_block_tag=ad_block_tag) def zero(self, subset=None): From 91caf943550c0af1f458303dde5e55f63f15c6c0 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Tue, 23 Sep 2025 16:12:40 +0100 Subject: [PATCH 065/125] create Coargument in Firedrake --- firedrake/interpolation.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 96842dac17..bab0b93d36 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -174,6 +174,11 @@ def interpolate(expr, V, subset=None, access=None, allow_missing_dofs=False, def reduction (hence using MIN will compute the MIN between the existing values and any new values). """ + if isinstance(V, functionspaceimpl.WithGeometry): + # Need to create a Firedrake Argument so that it has a .function_space() method + expr_args = extract_arguments(ufl.as_ufl(expr)) + is_adjoint = len(expr_args) and expr_args[0].number() == 0 + V = Argument(V.dual(), 1 if is_adjoint else 0) return Interpolate( expr, V, subset=subset, access=access, allow_missing_dofs=allow_missing_dofs, default_missing_val=default_missing_val, matfree=matfree From 91e4500c0a1ebeda743497990a6b0db5e9d48545 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Tue, 23 Sep 2025 17:28:25 +0100 Subject: [PATCH 066/125] Change `Interpolate` to `interpolate` --- demos/boussinesq/boussinesq.py.rst | 2 +- demos/multicomponent/multicomponent.py.rst | 2 +- firedrake/external_operators/point_expr_operator.py | 10 +++++----- firedrake/interpolation.py | 6 +++--- firedrake/mesh.py | 2 +- firedrake/mg/utils.py | 2 +- firedrake/preconditioners/gtmg.py | 4 ++-- firedrake/preconditioners/patch.py | 4 ++-- firedrake/pyplot/mpl.py | 12 ++++++------ firedrake/utility_meshes.py | 8 ++++---- tests/firedrake/adjoint/test_reduced_functional.py | 4 ++-- .../external_operators/test_external_operators.py | 2 +- tests/firedrake/multigrid/test_poisson_gtmg.py | 2 +- tests/firedrake/regression/test_adjoint_operators.py | 2 +- tests/firedrake/regression/test_interpolate_zany.py | 6 +++--- tests/firedrake/submesh/test_submesh_interpolate.py | 4 ++-- 16 files changed, 36 insertions(+), 36 deletions(-) diff --git a/demos/boussinesq/boussinesq.py.rst b/demos/boussinesq/boussinesq.py.rst index edfdf3c1a5..5cc6708cf0 100644 --- a/demos/boussinesq/boussinesq.py.rst +++ b/demos/boussinesq/boussinesq.py.rst @@ -184,7 +184,7 @@ implements a boundary condition that fixes a field at a single point. :: # Take the basis function with the largest abs value at bc_point v = TestFunction(V) - F = assemble(Interpolate(inner(v, v), Fvom)) + F = assemble(interpolate(inner(v, v), Fvom)) with F.dat.vec as Fvec: max_index, _ = Fvec.max() nodes = V.dof_dset.lgmap.applyInverse([max_index]) diff --git a/demos/multicomponent/multicomponent.py.rst b/demos/multicomponent/multicomponent.py.rst index bf74e5d2e0..7a29d6ef1c 100644 --- a/demos/multicomponent/multicomponent.py.rst +++ b/demos/multicomponent/multicomponent.py.rst @@ -521,7 +521,7 @@ mathematically valid to do this):: # Take the basis function with the largest abs value at bc_point v = TestFunction(V) - F = assemble(Interpolate(inner(v, v), Fvom)) + F = assemble(interpolate(inner(v, v), Fvom)) with F.dat.vec as Fvec: max_index, _ = Fvec.max() nodes = V.dof_dset.lgmap.applyInverse([max_index]) diff --git a/firedrake/external_operators/point_expr_operator.py b/firedrake/external_operators/point_expr_operator.py index 3aa40e1d5b..4e7183e47f 100644 --- a/firedrake/external_operators/point_expr_operator.py +++ b/firedrake/external_operators/point_expr_operator.py @@ -5,7 +5,7 @@ import firedrake.ufl_expr as ufl_expr from firedrake.assemble import assemble -from firedrake.interpolation import Interpolate +from firedrake.interpolation import interpolate from firedrake.external_operators import AbstractExternalOperator, assemble_method @@ -58,7 +58,7 @@ def assemble_operator(self, *args, **kwargs): V = self.function_space() expr = as_ufl(self.expr(*self.ufl_operands)) if len(V) < 2: - interp = Interpolate(expr, self.function_space()) + interp = interpolate(expr, self.function_space()) return assemble(interp) # Interpolation of UFL expressions for mixed functions is not yet supported # -> `Function.assign` might be enough in some cases. @@ -72,7 +72,7 @@ def assemble_operator(self, *args, **kwargs): def assemble_Jacobian_action(self, *args, **kwargs): V = self.function_space() expr = as_ufl(self.expr(*self.ufl_operands)) - interp = Interpolate(expr, V) + interp = interpolate(expr, V) u, = [e for i, e in enumerate(self.ufl_operands) if self.derivatives[i] == 1] w = self.argument_slots()[-1] @@ -83,7 +83,7 @@ def assemble_Jacobian_action(self, *args, **kwargs): def assemble_Jacobian(self, *args, assembly_opts, **kwargs): V = self.function_space() expr = as_ufl(self.expr(*self.ufl_operands)) - interp = Interpolate(expr, V) + interp = interpolate(expr, V) u, = [e for i, e in enumerate(self.ufl_operands) if self.derivatives[i] == 1] jac = ufl_expr.derivative(interp, u) @@ -99,7 +99,7 @@ def assemble_Jacobian_adjoint(self, *args, assembly_opts, **kwargs): def assemble_Jacobian_adjoint_action(self, *args, **kwargs): V = self.function_space() expr = as_ufl(self.expr(*self.ufl_operands)) - interp = Interpolate(expr, V) + interp = interpolate(expr, V) u, = [e for i, e in enumerate(self.ufl_operands) if self.derivatives[i] == 1] ustar = self.argument_slots()[0] diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index bab0b93d36..b008b748e0 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -511,7 +511,7 @@ def __init__( from firedrake.assemble import assemble V_dest_vec = firedrake.VectorFunctionSpace(dest_mesh, ufl_scalar_element) - f_dest_node_coords = Interpolate(dest_mesh.coordinates, V_dest_vec) + f_dest_node_coords = interpolate(dest_mesh.coordinates, V_dest_vec) f_dest_node_coords = assemble(f_dest_node_coords) dest_node_coords = f_dest_node_coords.dat.data_ro.reshape(-1, dest_mesh_gdim) try: @@ -536,7 +536,7 @@ def __init__( else: fs_type = partial(firedrake.TensorFunctionSpace, shape=shape) P0DG_vom = fs_type(self.vom_dest_node_coords_in_src_mesh, "DG", 0) - self.point_eval_interpolate = Interpolate(self.expr_renumbered, P0DG_vom) + self.point_eval_interpolate = interpolate(self.expr_renumbered, P0DG_vom) # The parallel decomposition of the nodes of V_dest in the DESTINATION # mesh (dest_mesh) is retrieved using the input_ordering attribute of the # VOM. This again is an interpolation operation, which, under the hood @@ -544,7 +544,7 @@ def __init__( P0DG_vom_i_o = fs_type( self.vom_dest_node_coords_in_src_mesh.input_ordering, "DG", 0 ) - self.to_input_ordering_interpolate = Interpolate( + self.to_input_ordering_interpolate = interpolate( firedrake.TrialFunction(P0DG_vom), P0DG_vom_i_o ) # The P0DG function outputted by the above interpolation has the diff --git a/firedrake/mesh.py b/firedrake/mesh.py index 5c9f92fb02..b28b8ea6b1 100644 --- a/firedrake/mesh.py +++ b/firedrake/mesh.py @@ -4153,7 +4153,7 @@ def _parent_mesh_embedding( # nessesary, to other processes. P0DG = functionspace.FunctionSpace(parent_mesh, "DG", 0) with stop_annotating(): - visible_ranks = interpolation.Interpolate( + visible_ranks = interpolation.interpolate( constant.Constant(parent_mesh.comm.rank), P0DG ) visible_ranks = assemble(visible_ranks).dat.data_ro_with_halos.real diff --git a/firedrake/mg/utils.py b/firedrake/mg/utils.py index 37832b64dc..886cc7530c 100644 --- a/firedrake/mg/utils.py +++ b/firedrake/mg/utils.py @@ -143,7 +143,7 @@ def physical_node_locations(V): Vc = V.collapse().reconstruct(element=finat.ufl.VectorElement(element, dim=mesh.geometric_dimension())) # FIXME: This is unsafe for DG coordinates and CG target spaces. - locations = firedrake.assemble(firedrake.Interpolate(firedrake.SpatialCoordinate(mesh), Vc)) + locations = firedrake.assemble(firedrake.interpolate(firedrake.SpatialCoordinate(mesh), Vc)) return cache.setdefault(key, locations) diff --git a/firedrake/preconditioners/gtmg.py b/firedrake/preconditioners/gtmg.py index 6ce73cd6b4..2ac5df9a5d 100644 --- a/firedrake/preconditioners/gtmg.py +++ b/firedrake/preconditioners/gtmg.py @@ -4,7 +4,7 @@ from firedrake.petsc import PETSc from firedrake.preconditioners.base import PCBase from firedrake.parameters import parameters -from firedrake.interpolation import Interpolate +from firedrake.interpolation import interpolate from firedrake.solving_utils import _SNESContext from firedrake.matrix_free.operators import ImplicitMatrixContext import firedrake.dmhooks as dmhooks @@ -155,7 +155,7 @@ def initialize(self, pc): # Create interpolation matrix from coarse space to fine space fine_space = ctx.J.arguments()[0].function_space() coarse_test, coarse_trial = coarse_operator.arguments() - interp = assemble(Interpolate(coarse_trial, fine_space)) + interp = assemble(interpolate(coarse_trial, fine_space)) interp_petscmat = interp.petscmat restr_petscmat = appctx.get("restriction_matrix", None) diff --git a/firedrake/preconditioners/patch.py b/firedrake/preconditioners/patch.py index bb47093a3d..4910ba1452 100644 --- a/firedrake/preconditioners/patch.py +++ b/firedrake/preconditioners/patch.py @@ -4,7 +4,7 @@ from firedrake.solving_utils import _SNESContext from firedrake.utils import cached_property, complex_mode, IntType from firedrake.dmhooks import get_appctx, push_appctx, pop_appctx -from firedrake.interpolation import Interpolate +from firedrake.interpolation import interpolate from firedrake.ufl_expr import extract_domains from collections import namedtuple @@ -668,7 +668,7 @@ def sort_entities(self, dm, axis, dir, ndiv=None, divisions=None): # with access descriptor MAX to define a consistent opinion # about where the vertices are. CGk = V.reconstruct(family="Lagrange") - coordinates = assemble(Interpolate(coordinates, CGk, access=op2.MAX)) + coordinates = assemble(interpolate(coordinates, CGk, access=op2.MAX)) select = partial(select_entity, dm=dm, exclude="pyop2_ghost") entities = [(p, self.coords(dm, p, coordinates)) for p in diff --git a/firedrake/pyplot/mpl.py b/firedrake/pyplot/mpl.py index 3cf010a1c9..d6a7aa5112 100644 --- a/firedrake/pyplot/mpl.py +++ b/firedrake/pyplot/mpl.py @@ -18,7 +18,7 @@ import mpl_toolkits.mplot3d from mpl_toolkits.mplot3d.art3d import Line3DCollection, Poly3DCollection from math import factorial -from firedrake import (Interpolate, sqrt, inner, Function, SpatialCoordinate, +from firedrake import (interpolate, sqrt, inner, Function, SpatialCoordinate, FunctionSpace, VectorFunctionSpace, PointNotInDomainError, Constant, assemble, dx) from firedrake.mesh import MeshGeometry @@ -120,7 +120,7 @@ def triplot(mesh, axes=None, interior_kw={}, boundary_kw={}): if element.degree() != 1: # Interpolate to piecewise linear. V = VectorFunctionSpace(mesh, element.family(), 1) - coordinates = assemble(Interpolate(coordinates, V)) + coordinates = assemble(interpolate(coordinates, V)) coords = toreal(coordinates.dat.data_ro_with_halos, "real") result = [] @@ -215,7 +215,7 @@ def _plot_2d_field(method_name, function, *args, complex_component="real", **kwa if len(function.ufl_shape) == 1: element = function.ufl_element().sub_elements[0] Q = FunctionSpace(mesh, element) - function = assemble(Interpolate(sqrt(inner(function, function)), Q)) + function = assemble(interpolate(sqrt(inner(function, function)), Q)) num_sample_points = kwargs.pop("num_sample_points", 10) function_plotter = FunctionPlotter(mesh, num_sample_points) @@ -326,7 +326,7 @@ def trisurf(function, *args, complex_component="real", **kwargs): if len(function.ufl_shape) == 1: element = function.ufl_element().sub_elements[0] Q = FunctionSpace(mesh, element) - function = assemble(Interpolate(sqrt(inner(function, function)), Q)) + function = assemble(interpolate(sqrt(inner(function, function)), Q)) num_sample_points = kwargs.pop("num_sample_points", 10) function_plotter = FunctionPlotter(mesh, num_sample_points) @@ -355,7 +355,7 @@ def quiver(function, *, complex_component="real", **kwargs): coords = toreal(extract_unique_domain(function).coordinates.dat.data_ro, "real") V = extract_unique_domain(function).coordinates.function_space() - function_interp = assemble(Interpolate(function, V)) + function_interp = assemble(interpolate(function, V)) vals = toreal(function_interp.dat.data_ro, complex_component) C = np.linalg.norm(vals, axis=1) return axes.quiver(*(coords.T), *(vals.T), C, **kwargs) @@ -816,7 +816,7 @@ def _bezier_plot(function, axes, complex_component="real", **kwargs): mesh = function.function_space().mesh() if deg == 0: V = FunctionSpace(mesh, "DG", 1) - interp = assemble(Interpolate(function, V)) + interp = assemble(interpolate(function, V)) return _bezier_plot(interp, axes, complex_component=complex_component, **kwargs) y_vals = _bezier_calculate_points(function) diff --git a/firedrake/utility_meshes.py b/firedrake/utility_meshes.py index 223fa59a2e..7b1818c2de 100644 --- a/firedrake/utility_meshes.py +++ b/firedrake/utility_meshes.py @@ -11,7 +11,7 @@ Function, Constant, assemble, - Interpolate, + interpolate, FiniteElement, interval, tetrahedron, @@ -2351,7 +2351,7 @@ def OctahedralSphereMesh( ) if degree > 1: # use it to build a higher-order mesh - m = assemble(Interpolate(ufl.SpatialCoordinate(m), VectorFunctionSpace(m, "CG", degree))) + m = assemble(interpolate(ufl.SpatialCoordinate(m), VectorFunctionSpace(m, "CG", degree))) m = mesh.Mesh( m, name=name, @@ -2386,11 +2386,11 @@ def OctahedralSphereMesh( # Make a copy of the coordinates so that we can blend two different # mappings near the pole Vc = m.coordinates.function_space() - Xlatitudinal = assemble(Interpolate( + Xlatitudinal = assemble(interpolate( Constant(radius) * ufl.as_vector([x * scale, y * scale, znew]), Vc )) Vlow = VectorFunctionSpace(m, "CG", 1) - Xlow = assemble(Interpolate(Xlatitudinal, Vlow)) + Xlow = assemble(interpolate(Xlatitudinal, Vlow)) r = ufl.sqrt(Xlow[0] ** 2 + Xlow[1] ** 2 + Xlow[2] ** 2) Xradial = Constant(radius) * Xlow / r diff --git a/tests/firedrake/adjoint/test_reduced_functional.py b/tests/firedrake/adjoint/test_reduced_functional.py index e440967fe9..eb20b4bf82 100644 --- a/tests/firedrake/adjoint/test_reduced_functional.py +++ b/tests/firedrake/adjoint/test_reduced_functional.py @@ -214,7 +214,7 @@ def test_interpolate(): f = Function(V) f.dat.data[:] = 2 - J = assemble(Interpolate(f**2, c)) + J = assemble(interpolate(f**2, c)) Jhat = ReducedFunctional(J, Control(f)) h = Function(V) @@ -244,7 +244,7 @@ def test_interpolate_mixed(): f1, f2 = split(f) exprs = [f2 * div(f1)**2, grad(f2) * div(f1)] expr = as_vector([e[i] for e in exprs for i in np.ndindex(e.ufl_shape)]) - J = assemble(Interpolate(expr, c)) + J = assemble(interpolate(expr, c)) Jhat = ReducedFunctional(J, Control(f)) h = Function(V) diff --git a/tests/firedrake/external_operators/test_external_operators.py b/tests/firedrake/external_operators/test_external_operators.py index b6153f3d1f..a47953a566 100644 --- a/tests/firedrake/external_operators/test_external_operators.py +++ b/tests/firedrake/external_operators/test_external_operators.py @@ -104,7 +104,7 @@ def test_assemble(V, f): assert isinstance(jac, MatrixBase) # Assemble the exact Jacobian, i.e. the interpolation matrix: `Interpolate(dexpr(u,v,w)/du, V)` - jac_exact = assemble(Interpolate(derivative(expr(u, v, w), u), V)) + jac_exact = assemble(interpolate(derivative(expr(u, v, w), u), V)) np.allclose(jac.petscmat[:, :], jac_exact.petscmat[:, :], rtol=1e-14) # -- dNdu(u, v, w; δu, v*) (TLM) -- # diff --git a/tests/firedrake/multigrid/test_poisson_gtmg.py b/tests/firedrake/multigrid/test_poisson_gtmg.py index f70e5c6825..a4154dd392 100644 --- a/tests/firedrake/multigrid/test_poisson_gtmg.py +++ b/tests/firedrake/multigrid/test_poisson_gtmg.py @@ -60,7 +60,7 @@ def p1_callback(): if custom_transfer: P1 = get_p1_space() V = FunctionSpace(mesh, "DGT", degree - 1) - I = assemble(Interpolate(TrialFunction(P1), V)).petscmat + I = assemble(interpolate(TrialFunction(P1), V)).petscmat R = PETSc.Mat().createTranspose(I) appctx['interpolation_matrix'] = I appctx['restriction_matrix'] = R diff --git a/tests/firedrake/regression/test_adjoint_operators.py b/tests/firedrake/regression/test_adjoint_operators.py index cc4f1ade43..57faf80477 100644 --- a/tests/firedrake/regression/test_adjoint_operators.py +++ b/tests/firedrake/regression/test_adjoint_operators.py @@ -729,7 +729,7 @@ def test_copy_function(): g = f.copy(deepcopy=True) J = assemble(g*dx) rf = ReducedFunctional(J, Control(f)) - a = assemble(Interpolate(-one, V)) + a = assemble(interpolate(-one, V)) assert np.isclose(rf(a), -J) diff --git a/tests/firedrake/regression/test_interpolate_zany.py b/tests/firedrake/regression/test_interpolate_zany.py index b2054843cd..875d0fb0c1 100644 --- a/tests/firedrake/regression/test_interpolate_zany.py +++ b/tests/firedrake/regression/test_interpolate_zany.py @@ -117,7 +117,7 @@ def test_interpolate_zany_into_vom(V, mesh, which, expr_at_vom): P0 = expr_at_vom.function_space() # Interpolate a Function into P0(vom) - f_at_vom = assemble(Interpolate(fexpr, P0)) + f_at_vom = assemble(interpolate(fexpr, P0)) assert numpy.allclose(f_at_vom.dat.data_ro, expr_at_vom.dat.data_ro) # Construct a Cofunction on P0(vom)* @@ -125,10 +125,10 @@ def test_interpolate_zany_into_vom(V, mesh, which, expr_at_vom): expected_action = assemble(action(Fvom, expr_at_vom)) # Interpolate a Function into Fvom - f_at_vom = assemble(Interpolate(fexpr, Fvom)) + f_at_vom = assemble(interpolate(fexpr, Fvom)) assert numpy.allclose(f_at_vom, expected_action) # Interpolate a TestFunction into Fvom - expr_vom = assemble(Interpolate(vexpr, Fvom)) + expr_vom = assemble(interpolate(vexpr, Fvom)) f_at_vom = assemble(action(expr_vom, f)) assert numpy.allclose(f_at_vom, expected_action) diff --git a/tests/firedrake/submesh/test_submesh_interpolate.py b/tests/firedrake/submesh/test_submesh_interpolate.py index a26c1acb08..92e422cf6a 100644 --- a/tests/firedrake/submesh/test_submesh_interpolate.py +++ b/tests/firedrake/submesh/test_submesh_interpolate.py @@ -51,7 +51,7 @@ def _test_submesh_interpolate_cell_cell(mesh, subdomain_cond, fe_fesub): f = Function(V_).interpolate(f) v0 = Coargument(V.dual(), 0) v1 = TrialFunction(Vsub) - interp = Interpolate(v1, v0, allow_missing_dofs=True) + interp = interpolate(v1, v0, allow_missing_dofs=True) A = assemble(interp) g = assemble(action(A, gsub)) assert assemble(inner(g - f, g - f) * dx(label_value)).real < 1e-14 @@ -165,7 +165,7 @@ def test_submesh_interpolate_subcell_subcell_2_processes(): f_l.dat.data_with_halos[:] = 3.0 v0 = Coargument(V_r.dual(), 0) v1 = TrialFunction(V_l) - interp = Interpolate(v1, v0, allow_missing_dofs=True) + interp = interpolate(v1, v0, allow_missing_dofs=True) A = assemble(interp) f_r = assemble(action(A, f_l)) g_r = Function(V_r).interpolate(conditional(x < 2.001, 3.0, 0.0)) From f661be2d1a979bae6d9b3287bf2846219017b92e Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Tue, 23 Sep 2025 17:28:47 +0100 Subject: [PATCH 067/125] update `test_interp_dual.py` --- .../firedrake/regression/test_interp_dual.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/firedrake/regression/test_interp_dual.py b/tests/firedrake/regression/test_interp_dual.py index 50e29b05cb..ccd4de13f0 100644 --- a/tests/firedrake/regression/test_interp_dual.py +++ b/tests/firedrake/regression/test_interp_dual.py @@ -54,7 +54,7 @@ def test_assemble_interp_adjoint_tensor(mesh, V1, f1): def test_assemble_interp_operator(V2, f1): # Check type - If1 = Interpolate(f1, V2) + If1 = Interpolate(f1, Argument(V2.dual(), 0)) assert isinstance(If1, ufl.Interpolate) # -- I(f1, V2) -- # @@ -89,7 +89,7 @@ def test_assemble_interp_matrix(V1, V2, f1): def test_assemble_interp_tlm(V1, V2, f1): # -- Action(I(v1, V2), f1) -- # v1 = TrialFunction(V1) - Iv1 = Interpolate(v1, V2) + Iv1 = Interpolate(v1, Argument(V2.dual(), 0)) b = assemble(interpolate(f1, V2)) assembled_action_Iv1 = assemble(action(Iv1, f1)) @@ -99,7 +99,7 @@ def test_assemble_interp_tlm(V1, V2, f1): def test_assemble_interp_adjoint_matrix(V1, V2): # -- Adjoint(I(v1, V2)) -- # v1 = TrialFunction(V1) - Iv1 = Interpolate(v1, V2) + Iv1 = Interpolate(v1, Argument(V2.dual(), 0)) v2 = TestFunction(V2) c2 = assemble(conj(v2) * dx) @@ -120,7 +120,7 @@ def test_assemble_interp_adjoint_matrix(V1, V2): def test_assemble_interp_adjoint_model(V1, V2): # -- Action(Adjoint(I(v1, v2)), fstar) -- # v1 = TrialFunction(V1) - Iv1 = Interpolate(v1, V2) + Iv1 = Interpolate(v1, Argument(V2.dual(), 0)) fstar = Cofunction(V2.dual()) v = Argument(V1, 0) @@ -167,9 +167,9 @@ def test_assemble_base_form_operator_expressions(mesh): f2 = Function(V1).interpolate(sin(2*pi*y)) f3 = Function(V1).interpolate(cos(2*pi*x)) - If1 = Interpolate(f1, V2) - If2 = Interpolate(f2, V2) - If3 = Interpolate(f3, V2) + If1 = Interpolate(f1, Argument(V2.dual(), 0)) + If2 = Interpolate(f2, Argument(V2.dual(), 0)) + If3 = Interpolate(f3, Argument(V2.dual(), 0)) # Sum of BaseFormOperators (1-form) res = assemble(If1 + If2 + If3) @@ -234,7 +234,7 @@ def test_solve_interp_f(mesh): # -- Solution where the source term is interpolated via `ufl.Interpolate` u2 = Function(V1) - If = Interpolate(f1, V2) + If = Interpolate(f1, Argument(V2.dual(), 0)) # This requires assembling If F2 = inner(grad(u2), grad(w))*dx + inner(u2, w)*dx - inner(If, w)*dx solve(F2 == 0, u2) @@ -267,7 +267,7 @@ def test_solve_interp_u(mesh): # -- Solution where u2 is interpolated via `ufl.Interpolate` (mat-free) u2 = Function(V1) # Iu is the identity - Iu = Interpolate(u2, V1) + Iu = Interpolate(u2, Argument(V1.dual(), 0)) # This requires assembling the action the Jacobian of Iu F2 = inner(grad(u2), grad(w))*dx + inner(Iu, w)*dx - inner(f, w)*dx solve(F2 == 0, u2, solver_parameters={"mat_type": "matfree", @@ -278,7 +278,7 @@ def test_solve_interp_u(mesh): # Same problem with grad(Iu) instead of grad(Iu) u2 = Function(V1) # Iu is the identity - Iu = Interpolate(u2, V1) + Iu = Interpolate(u2, Argument(V1.dual(), 0)) # This requires assembling the action the Jacobian of Iu F2 = inner(grad(Iu), grad(w))*dx + inner(Iu, w)*dx - inner(f, w)*dx solve(F2 == 0, u2, solver_parameters={"mat_type": "matfree", From bc4e48b46f9a05436eddb6e1cc29dd81fca09401 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Tue, 23 Sep 2025 17:35:47 +0100 Subject: [PATCH 068/125] DROP BEFORE MERGE: use UFL branch --- .github/workflows/core.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/core.yml b/.github/workflows/core.yml index 55df1c8e01..61ea7f9d34 100644 --- a/.github/workflows/core.yml +++ b/.github/workflows/core.yml @@ -210,6 +210,7 @@ jobs: --extra-index-url https://download.pytorch.org/whl/cpu \ "$(echo ./firedrake-repo/dist/firedrake-*.tar.gz)[ci]" + pip install -I "fenics-ufl @ git+https://github.com/firedrakeproject/ufl.git@leo/sanitise-interpolate" firedrake-clean pip list From ac252dc31b78d31e057b5175c919714092804d71 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Tue, 23 Sep 2025 17:36:30 +0100 Subject: [PATCH 069/125] move FunctionSpace check into `Interpolate` fix --- firedrake/interpolation.py | 27 +++++++-------------------- 1 file changed, 7 insertions(+), 20 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index b008b748e0..72ab00b038 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -48,7 +48,7 @@ class Interpolate(ufl.Interpolate): - def __init__(self, expr, v, + def __init__(self, expr, V, subset=None, access=None, allow_missing_dofs=False, @@ -60,7 +60,7 @@ def __init__(self, expr, v, ---------- expr : ufl.core.expr.Expr or ufl.BaseForm The UFL expression to interpolate. - v : firedrake.functionspaceimpl.WithGeometryBase or firedrake.ufl_expr.Coargument + V : firedrake.functionspaceimpl.WithGeometryBase or firedrake.ufl_expr.Coargument The function space to interpolate into or the coargument defined on the dual of the function space to interpolate into. subset : pyop2.types.set.Subset @@ -95,20 +95,12 @@ def __init__(self, expr, v, between a VOM and its input ordering. Defaults to ``True`` which uses SF broadcast and reduce operations. """ - # Check function space - expr = ufl.as_ufl(expr) - if isinstance(v, functionspaceimpl.WithGeometry): - expr_args = extract_arguments(expr) + if isinstance(V, functionspaceimpl.WithGeometry): + # Need to create a Firedrake Argument so that it has a .function_space() method + expr_args = extract_arguments(ufl.as_ufl(expr)) is_adjoint = len(expr_args) and expr_args[0].number() == 0 - v = Argument(v.dual(), 1 if is_adjoint else 0) - - V = v.arguments()[0].function_space() - if len(expr.ufl_shape) != len(V.value_shape): - raise RuntimeError(f'Rank mismatch: Expression rank {len(expr.ufl_shape)}, FunctionSpace rank {len(V.value_shape)}') - - if expr.ufl_shape != V.value_shape: - raise RuntimeError('Shape mismatch: Expression shape {expr.ufl_shape}, FunctionSpace shape {V.value_shape}') - super().__init__(expr, v) + V = Argument(V.dual(), 1 if is_adjoint else 0) + super().__init__(expr, V) # -- Interpolate data (e.g. `subset` or `access`) -- # self.interp_data = {"subset": subset, @@ -174,11 +166,6 @@ def interpolate(expr, V, subset=None, access=None, allow_missing_dofs=False, def reduction (hence using MIN will compute the MIN between the existing values and any new values). """ - if isinstance(V, functionspaceimpl.WithGeometry): - # Need to create a Firedrake Argument so that it has a .function_space() method - expr_args = extract_arguments(ufl.as_ufl(expr)) - is_adjoint = len(expr_args) and expr_args[0].number() == 0 - V = Argument(V.dual(), 1 if is_adjoint else 0) return Interpolate( expr, V, subset=subset, access=access, allow_missing_dofs=allow_missing_dofs, default_missing_val=default_missing_val, matfree=matfree From 79d6c8359a39267be177a30addc5bca3817ff69f Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Tue, 23 Sep 2025 17:45:44 +0100 Subject: [PATCH 070/125] lint --- firedrake/interpolation.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 72ab00b038..46a741244c 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -27,11 +27,10 @@ import firedrake import firedrake.bcs from firedrake import tsfc_interface, utils, functionspaceimpl -from firedrake.ufl_expr import Argument, Coargument, action, adjoint as expr_adjoint +from firedrake.ufl_expr import Argument, action, adjoint as expr_adjoint from firedrake.mesh import MissingPointsBehaviour, VertexOnlyMeshMissingPointsError, VertexOnlyMeshTopology from firedrake.petsc import PETSc from firedrake.halo import _get_mtype as get_dat_mpi_type -from firedrake.cofunction import Cofunction from mpi4py import MPI from pyadjoint import stop_annotating, no_annotations From 4a509f77cf3e1be4dc92c026ce210e80adab110e Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Wed, 24 Sep 2025 11:52:24 +0100 Subject: [PATCH 071/125] test -> trial test -> trial fix hypre-ads --- firedrake/preconditioners/hypre_ads.py | 6 +++--- firedrake/preconditioners/hypre_ams.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/firedrake/preconditioners/hypre_ads.py b/firedrake/preconditioners/hypre_ads.py index 89c10dc438..98443f2c75 100644 --- a/firedrake/preconditioners/hypre_ads.py +++ b/firedrake/preconditioners/hypre_ads.py @@ -1,7 +1,7 @@ from firedrake.preconditioners.base import PCBase from firedrake.petsc import PETSc from firedrake.function import Function -from firedrake.ufl_expr import TestFunction +from firedrake.ufl_expr import TrialFunction from firedrake.dmhooks import get_function_space from firedrake.preconditioners.hypre_ams import chop from firedrake.interpolation import interpolate @@ -31,12 +31,12 @@ def initialize(self, obj): NC1 = V.reconstruct(family="N1curl" if mesh.ufl_cell().is_simplex() else "NCE", degree=1) G_callback = appctx.get("get_gradient", None) if G_callback is None: - G = chop(assemble(interpolate(grad(TestFunction(P1)), NC1)).petscmat) + G = chop(assemble(interpolate(grad(TrialFunction(P1)), NC1)).petscmat) else: G = G_callback(P1, NC1) C_callback = appctx.get("get_curl", None) if C_callback is None: - C = chop(assemble(interpolate(curl(TestFunction(NC1)), V)).petscmat) + C = chop(assemble(interpolate(curl(TrialFunction(NC1)), V)).petscmat) else: C = C_callback(NC1, V) diff --git a/firedrake/preconditioners/hypre_ams.py b/firedrake/preconditioners/hypre_ams.py index 9a59702af4..594fe88590 100644 --- a/firedrake/preconditioners/hypre_ams.py +++ b/firedrake/preconditioners/hypre_ams.py @@ -2,7 +2,7 @@ from firedrake.preconditioners.base import PCBase from firedrake.petsc import PETSc from firedrake.function import Function -from firedrake.ufl_expr import TestFunction +from firedrake.ufl_expr import TrialFunction from firedrake.dmhooks import get_function_space from firedrake.utils import complex_mode from firedrake.interpolation import interpolate @@ -51,7 +51,7 @@ def initialize(self, obj): P1 = V.reconstruct(family="Lagrange", degree=1) G_callback = appctx.get("get_gradient", None) if G_callback is None: - G = chop(assemble(interpolate(grad(TestFunction(P1)), V)).petscmat) + G = chop(assemble(interpolate(grad(TrialFunction(P1)), V)).petscmat) else: G = G_callback(P1, V) From 812f5d8fd567ada72c9edfd6a9be44dc74e0e62f Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Wed, 1 Oct 2025 13:53:33 +0100 Subject: [PATCH 072/125] tidy function.interpolate --- firedrake/function.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/firedrake/function.py b/firedrake/function.py index 38dc1d0e41..06e83b1c56 100644 --- a/firedrake/function.py +++ b/firedrake/function.py @@ -382,12 +382,9 @@ def interpolate(self, firedrake.function.Function Returns `self` """ - from firedrake import interpolation, assemble + from firedrake import interpolate, assemble V = self.function_space() - interp = interpolate( - expression, V, subset=subset, allow_missing_dofs=allow_missing_dofs, - default_missing_val=default_missing_val - ) + interp = interpolate(expression, V, **kwargs) return assemble(interp, tensor=self, ad_block_tag=ad_block_tag) def zero(self, subset=None): From e50cf4ea6ec53807b2ae3fadbea59f0a0c2dc952 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Wed, 1 Oct 2025 13:54:27 +0100 Subject: [PATCH 073/125] tidy cofunction.interpolate cofunction docstring --- firedrake/cofunction.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/firedrake/cofunction.py b/firedrake/cofunction.py index f0fda10f63..9ee867c622 100644 --- a/firedrake/cofunction.py +++ b/firedrake/cofunction.py @@ -318,7 +318,7 @@ def interpolate(self, Parameters ---------- expression - A dual UFL expression to interpolate. + A UFL BaseForm to adjoint interpolate. ad_block_tag An optional string for tagging the resulting assemble block on the Pyadjoint tape. @@ -331,9 +331,9 @@ def interpolate(self, firedrake.cofunction.Cofunction Returns `self` """ - from firedrake import interpolation, assemble + from firedrake import interpolate, assemble v, = self.arguments() - interp = interpolation.Interpolate(v, expression, **kwargs) + interp = interpolate(v, expression, **kwargs) return assemble(interp, tensor=self, ad_block_tag=ad_block_tag) @property From 5d9c81e5dc8bdb3d171c9b7675818e4a28f7288a Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Wed, 1 Oct 2025 14:10:49 +0100 Subject: [PATCH 074/125] tidy type hints in function.py --- firedrake/function.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/firedrake/function.py b/firedrake/function.py index 06e83b1c56..e2ac708fd2 100644 --- a/firedrake/function.py +++ b/firedrake/function.py @@ -14,7 +14,7 @@ from numbers import Number from pathlib import Path from functools import partial -from typing import Tuple +from typing import Tuple, Optional from pyop2 import op2, mpi from pyop2.exceptions import DataTypeError, DataValueError @@ -362,7 +362,7 @@ def function_space(self): @PETSc.Log.EventDecorator() def interpolate(self, expression: ufl.classes.Expr, - ad_block_tag: str | None = None, + ad_block_tag: Optional[str] = None, **kwargs): """Interpolate an expression onto this :class:`Function`. @@ -701,13 +701,13 @@ def __init__(self, domain, point): self.point = point def __str__(self): - return "domain %s does not contain point %s" % (self.domain, self.point) + return f"Domain {self.domain} does not contain point {self.point}" class PointEvaluator: r"""Convenience class for evaluating a :class:`Function` at a set of points.""" - def __init__(self, mesh: MeshGeometry, points: np.ndarray | list, tolerance: float | None = None, + def __init__(self, mesh: MeshGeometry, points: np.ndarray | list, tolerance: Optional[float] = None, missing_points_behaviour: str = "error", redundant: bool = True) -> None: r""" Parameters @@ -716,7 +716,7 @@ def __init__(self, mesh: MeshGeometry, points: np.ndarray | list, tolerance: flo The mesh on which to embed the points. points : numpy.ndarray | list Array or list of points to evaluate at. - tolerance : float | None + tolerance : Optional[float] Tolerance to use when checking if a point is in a cell. If ``None`` (the default), the ``tolerance`` of the ``mesh`` is used. missing_points_behaviour : str From fddcb796e54c6afe903cd9d3a638007860e08558 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Wed, 1 Oct 2025 15:04:55 +0100 Subject: [PATCH 075/125] remove UFL branch --- .github/workflows/core.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/core.yml b/.github/workflows/core.yml index 61ea7f9d34..55df1c8e01 100644 --- a/.github/workflows/core.yml +++ b/.github/workflows/core.yml @@ -210,7 +210,6 @@ jobs: --extra-index-url https://download.pytorch.org/whl/cpu \ "$(echo ./firedrake-repo/dist/firedrake-*.tar.gz)[ci]" - pip install -I "fenics-ufl @ git+https://github.com/firedrakeproject/ufl.git@leo/sanitise-interpolate" firedrake-clean pip list From 1bc83d4577392b19b493aa4ea00ebfc078e1ff36 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Wed, 1 Oct 2025 15:19:01 +0100 Subject: [PATCH 076/125] lint --- firedrake/function.py | 6 +++--- firedrake/interpolation.py | 9 ++++----- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/firedrake/function.py b/firedrake/function.py index e2ac708fd2..bea8e965de 100644 --- a/firedrake/function.py +++ b/firedrake/function.py @@ -14,7 +14,7 @@ from numbers import Number from pathlib import Path from functools import partial -from typing import Tuple, Optional +from typing import Tuple from pyop2 import op2, mpi from pyop2.exceptions import DataTypeError, DataValueError @@ -362,7 +362,7 @@ def function_space(self): @PETSc.Log.EventDecorator() def interpolate(self, expression: ufl.classes.Expr, - ad_block_tag: Optional[str] = None, + ad_block_tag: str | None = None, **kwargs): """Interpolate an expression onto this :class:`Function`. @@ -707,7 +707,7 @@ def __str__(self): class PointEvaluator: r"""Convenience class for evaluating a :class:`Function` at a set of points.""" - def __init__(self, mesh: MeshGeometry, points: np.ndarray | list, tolerance: Optional[float] = None, + def __init__(self, mesh: MeshGeometry, points: np.ndarray | list, tolerance: float | None = None, missing_points_behaviour: str = "error", redundant: bool = True) -> None: r""" Parameters diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 46a741244c..c6040b2a2f 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -4,14 +4,13 @@ import abc import warnings from collections.abc import Iterable -from typing import Literal from functools import partial, singledispatch from typing import Hashable import FIAT import ufl import finat.ufl -from ufl.algorithms import extract_arguments, extract_coefficients, replace +from ufl.algorithms import extract_arguments, extract_coefficients from ufl.domain import as_domain, extract_unique_domain from pyop2 import op2 @@ -25,9 +24,9 @@ import finat import firedrake -import firedrake.bcs from firedrake import tsfc_interface, utils, functionspaceimpl -from firedrake.ufl_expr import Argument, action, adjoint as expr_adjoint +from firedrake.ufl_expr import Argument, Coargument, action, adjoint as expr_adjoint +from firedrake.cofunction import Cofunction from firedrake.mesh import MissingPointsBehaviour, VertexOnlyMeshMissingPointsError, VertexOnlyMeshTopology from firedrake.petsc import PETSc from firedrake.halo import _get_mtype as get_dat_mpi_type @@ -263,7 +262,7 @@ def __init__( V: ufl.FunctionSpace | firedrake.function.Function, subset: op2.Subset | None = None, freeze_expr: bool = False, - access: Literal[op2.WRITE, op2.MIN, op2.MAX, op2.INC] | None = None, + access: op2.Access | None = None, bcs: Iterable[firedrake.bcs.BCBase] | None = None, allow_missing_dofs: bool = False, matfree: bool = True From 8ce32aec3a3509bfaab29dab31423667cd2fa0f6 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Wed, 1 Oct 2025 15:23:06 +0100 Subject: [PATCH 077/125] fix typing fix --- firedrake/interpolation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index c6040b2a2f..b0a49062c1 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -5,7 +5,7 @@ import warnings from collections.abc import Iterable from functools import partial, singledispatch -from typing import Hashable +from typing import Hashable, Literal import FIAT import ufl @@ -262,7 +262,7 @@ def __init__( V: ufl.FunctionSpace | firedrake.function.Function, subset: op2.Subset | None = None, freeze_expr: bool = False, - access: op2.Access | None = None, + access: Literal[op2.WRITE, op2.MIN, op2.MAX, op2.INC] | None = None, bcs: Iterable[firedrake.bcs.BCBase] | None = None, allow_missing_dofs: bool = False, matfree: bool = True From 78dac92684c7dbfcbc951c1e802598a957b2a55f Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Wed, 1 Oct 2025 22:43:50 +0100 Subject: [PATCH 078/125] runtimeerror -> valueerror --- tests/firedrake/regression/test_function.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/firedrake/regression/test_function.py b/tests/firedrake/regression/test_function.py index 85c6cdc0a0..1f9ed7a429 100644 --- a/tests/firedrake/regression/test_function.py +++ b/tests/firedrake/regression/test_function.py @@ -81,22 +81,22 @@ def test_firedrake_tensor_function_nonstandard_shape(W_nonstandard_shape): def test_mismatching_rank_interpolation(V): f = Function(V) - with pytest.raises(RuntimeError): + with pytest.raises(ValueError): f.interpolate(Constant((1, 2))) VV = VectorFunctionSpace(V.mesh(), 'CG', 1) f = Function(VV) - with pytest.raises(RuntimeError): + with pytest.raises(ValueError): f.interpolate(Constant((1, 2))) VVV = TensorFunctionSpace(V.mesh(), 'CG', 1) f = Function(VVV) - with pytest.raises(RuntimeError): + with pytest.raises(ValueError): f.interpolate(Constant((1, 2))) def test_mismatching_shape_interpolation(V): VV = VectorFunctionSpace(V.mesh(), 'CG', 1) f = Function(VV) - with pytest.raises(RuntimeError): + with pytest.raises(ValueError): f.interpolate(Constant([1] * (VV.value_shape[0] + 1))) From 62aff37cb6360b5d45ddec4069109f68293bf060 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Wed, 1 Oct 2025 22:48:12 +0100 Subject: [PATCH 079/125] add check for shape mismatch to `Interpolate` --- firedrake/interpolation.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index b0a49062c1..6f20cc29ef 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -98,6 +98,11 @@ def __init__(self, expr, V, expr_args = extract_arguments(ufl.as_ufl(expr)) is_adjoint = len(expr_args) and expr_args[0].number() == 0 V = Argument(V.dual(), 1 if is_adjoint else 0) + + target_shape = V.arguments()[0].function_space().value_shape + if expr.ufl_shape != target_shape: + raise ValueError(f"Shape mismatch: Expression shape {expr.ufl_shape}, FunctionSpace shape {target_shape}.") + super().__init__(expr, V) # -- Interpolate data (e.g. `subset` or `access`) -- # From 68e552723db79c35152adc6890657af81a98ca15 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Thu, 2 Oct 2025 11:24:03 +0100 Subject: [PATCH 080/125] use ufl.as_expr --- firedrake/interpolation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 6f20cc29ef..bca2fd9e65 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -93,9 +93,10 @@ def __init__(self, expr, V, between a VOM and its input ordering. Defaults to ``True`` which uses SF broadcast and reduce operations. """ + expr = ufl.as_ufl(expr) if isinstance(V, functionspaceimpl.WithGeometry): # Need to create a Firedrake Argument so that it has a .function_space() method - expr_args = extract_arguments(ufl.as_ufl(expr)) + expr_args = extract_arguments(expr) is_adjoint = len(expr_args) and expr_args[0].number() == 0 V = Argument(V.dual(), 1 if is_adjoint else 0) From d7823655ec51c3776b940c322327c318aa94395d Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Mon, 6 Oct 2025 12:34:27 +0100 Subject: [PATCH 081/125] update expr arg check --- firedrake/interpolation.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index bca2fd9e65..620363f20a 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -12,6 +12,7 @@ import finat.ufl from ufl.algorithms import extract_arguments, extract_coefficients from ufl.domain import as_domain, extract_unique_domain +from ufl.duals import is_dual from pyop2 import op2 from pyop2.caching import memory_and_disk_cache @@ -96,8 +97,8 @@ def __init__(self, expr, V, expr = ufl.as_ufl(expr) if isinstance(V, functionspaceimpl.WithGeometry): # Need to create a Firedrake Argument so that it has a .function_space() method - expr_args = extract_arguments(expr) - is_adjoint = len(expr_args) and expr_args[0].number() == 0 + expr_arg_numbers = {arg.number() for arg in extract_arguments(expr) if not is_dual(arg)} + is_adjoint = len(expr_arg_numbers) and expr_arg_numbers[0] == 0 V = Argument(V.dual(), 1 if is_adjoint else 0) target_shape = V.arguments()[0].function_space().value_shape From f6f4a110764e5dfbd278db06dd3b727127c432dc Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Mon, 6 Oct 2025 12:35:15 +0100 Subject: [PATCH 082/125] lint --- firedrake/interpolation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 620363f20a..529dd05fe1 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -27,7 +27,6 @@ import firedrake from firedrake import tsfc_interface, utils, functionspaceimpl from firedrake.ufl_expr import Argument, Coargument, action, adjoint as expr_adjoint -from firedrake.cofunction import Cofunction from firedrake.mesh import MissingPointsBehaviour, VertexOnlyMeshMissingPointsError, VertexOnlyMeshTopology from firedrake.petsc import PETSc from firedrake.halo import _get_mtype as get_dat_mpi_type From 3a59ff18a94e18815e8f1a6dca5c02f52a639623 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Mon, 6 Oct 2025 12:45:19 +0100 Subject: [PATCH 083/125] fix check --- firedrake/interpolation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 529dd05fe1..47a5d1b7d1 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -97,7 +97,7 @@ def __init__(self, expr, V, if isinstance(V, functionspaceimpl.WithGeometry): # Need to create a Firedrake Argument so that it has a .function_space() method expr_arg_numbers = {arg.number() for arg in extract_arguments(expr) if not is_dual(arg)} - is_adjoint = len(expr_arg_numbers) and expr_arg_numbers[0] == 0 + is_adjoint = len(expr_arg_numbers) and expr_arg_numbers == {0} V = Argument(V.dual(), 1 if is_adjoint else 0) target_shape = V.arguments()[0].function_space().value_shape From c8ffb2b6d4e9af6787b160ca24a1590dc927b322 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Wed, 1 Oct 2025 16:15:20 +0100 Subject: [PATCH 084/125] tidy typing --- firedrake/interpolation.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index e98f4736b3..34f99b79e3 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -4,7 +4,7 @@ import abc from functools import partial, singledispatch -from typing import Hashable, Optional +from typing import Hashable, Literal from dataclasses import asdict, dataclass import FIAT @@ -97,10 +97,10 @@ class InterpolateOptions: between a VOM and its input ordering. Defaults to ``True`` which uses SF broadcast and reduce operations. """ - subset: Optional[op2.Subset] = None - access: Access = op2.WRITE + subset: op2.Subset | None = None + access: Literal[op2.WRITE, op2.MIN, op2.MAX, op2.INC] = op2.WRITE allow_missing_dofs: bool = False - default_missing_val: Optional[float] = None + default_missing_val: float | None = None matfree: bool = True @@ -219,6 +219,7 @@ def __init__(self, expr: Interpolate, V, bcs=None): self.expr_renumbered = operand self.ufl_interpolate_renumbered = expr + access = self.options.access if not isinstance(dual_arg, ufl.Coargument): # Matrix-free assembly of 0-form or 1-form requires INC access self.options.access = op2.INC @@ -365,7 +366,8 @@ def __init__(self, expr: Interpolate, V, bcs=None): ) self.dest_element = base_element elif isinstance(dest_element, finat.ufl.MixedElement): - return self._mixed_function_space() + self._mixed_function_space() + return else: # scalar fiat/finat element self.dest_element = dest_element From 8ffb3fc3de617e8e07ae647a2e9c06ba9cec2066 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Wed, 1 Oct 2025 17:32:54 +0100 Subject: [PATCH 085/125] review suggestions --- firedrake/interpolation.py | 56 ++++++++++++++++++-------------------- 1 file changed, 27 insertions(+), 29 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 34f99b79e3..d645d645df 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -12,10 +12,10 @@ import finat.ufl from ufl.algorithms import extract_arguments, extract_coefficients from ufl.domain import as_domain, extract_unique_domain +from ufl.classes import Expr from pyop2 import op2 from pyop2.caching import memory_and_disk_cache -from pyop2.types import Access from finat.element_factory import create_element, as_fiat_cell from tsfc import compile_expression_dual_evaluation @@ -59,21 +59,10 @@ class InterpolateOptions: the target mesh is a :func:`.VertexOnlyMesh`. access : pyop2.types.access.Access, default op2.WRITE The pyop2 access descriptor for combining updates to shared - DoFs. Possible values include ``WRITE`` and ``INC``. Only ``WRITE`` is - supported at present when interpolating across meshes unless the target - mesh is a :func:`.VertexOnlyMesh`. - - .. note:: - If you use an access descriptor other than ``WRITE``, the - behaviour of interpolation changes if interpolating into a - function space, or an existing function. If the former, then - the newly allocated function will be initialised with - appropriate values (e.g. for MIN access, it will be initialised - with MAX_FLOAT). On the other hand, if you provide a function, - then it is assumed that its values should take part in the - reduction (hence using MIN will compute the MIN between the - existing values and any new values). - + DoFs. Possible values include ``WRITE``, ``MIN``, ``MAX``, and ``INC``. + Only ``WRITE`` is supported at present when interpolating across meshes + unless the target mesh is a :func:`.VertexOnlyMesh`. Only ``INC`` is + supported for the matrix-free adjoint interpolation. allow_missing_dofs : bool, default False For interpolation across meshes: allow degrees of freedom (aka DoFs/nodes) in the target mesh that cannot be defined on the source mesh. @@ -98,7 +87,7 @@ class InterpolateOptions: and reduce operations. """ subset: op2.Subset | None = None - access: Literal[op2.WRITE, op2.MIN, op2.MAX, op2.INC] = op2.WRITE + access: Literal[op2.WRITE, op2.MIN, op2.MAX, op2.INC] | None = None allow_missing_dofs: bool = False default_missing_val: float | None = None matfree: bool = True @@ -106,7 +95,7 @@ class InterpolateOptions: class Interpolate(ufl.Interpolate): - def __init__(self, expr, V, **kwargs): + def __init__(self, expr: Expr, V: WithGeometry | ufl.BaseForm, **kwargs): """Symbolic representation of the interpolation operator. Parameters @@ -125,13 +114,20 @@ def __init__(self, expr, V, **kwargs): expr_args = extract_arguments(ufl.as_ufl(expr)) is_adjoint = len(expr_args) and expr_args[0].number() == 0 V = Argument(V.dual(), 1 if is_adjoint else 0) + + target_shape = V.arguments()[0].function_space().value_shape + if expr.ufl_shape != target_shape: + raise ValueError(f"Shape mismatch: Expression shape {expr.ufl_shape}, FunctionSpace shape {target_shape}.") + super().__init__(expr, V) self._options = InterpolateOptions(**kwargs) function_space = ufl.Interpolate.ufl_function_space - def _ufl_expr_reconstruct_(self, expr, v=None, **interp_data): + def _ufl_expr_reconstruct_( + self, expr: Expr, v: WithGeometry | ufl.BaseForm | None = None, **interp_data + ): interp_data = interp_data or asdict(self.options) return ufl.Interpolate._ufl_expr_reconstruct_(self, expr, v=v, **interp_data) @@ -141,7 +137,7 @@ def options(self): @PETSc.Log.EventDecorator() -def interpolate(expr, V, **kwargs): +def interpolate(expr: Expr, V: WithGeometry | ufl.BaseForm, **kwargs): """Returns a UFL expression for the interpolation operation of ``expr`` into ``V``. Parameters @@ -188,12 +184,8 @@ def __init__(self, expr: Interpolate, V, bcs=None): # TODO CrossMeshInterpolator and VomOntoVomXXX are not yet aware of # self.ufl_interpolate (which carries the dual argument). # See github issue https://github.com/firedrakeproject/firedrake/issues/4592 - target_mesh = as_domain(V) - source_mesh = extract_unique_domain(operand) or target_mesh - vom_onto_other_vom = ((source_mesh is not target_mesh) - and isinstance(source_mesh.topology, VertexOnlyMeshTopology) - and isinstance(target_mesh.topology, VertexOnlyMeshTopology)) - if not isinstance(self, SameMeshInterpolator) or vom_onto_other_vom: + + if isinstance(self, CrossMeshInterpolator | VomOntoVomInterpolator): # For bespoke interpolation, we currently rely on different assembly procedures: # 1) Interpolate(Argument(V1, 1), Argument(V2.dual(), 0)) -> Forward operator (2-form) # 2) Interpolate(Argument(V1, 0), Argument(V2.dual(), 1)) -> Adjoint operator (2-form) @@ -222,7 +214,13 @@ def __init__(self, expr: Interpolate, V, bcs=None): access = self.options.access if not isinstance(dual_arg, ufl.Coargument): # Matrix-free assembly of 0-form or 1-form requires INC access - self.options.access = op2.INC + if access and access != op2.INC: + raise ValueError("Matfree adjoint interpolation requires INC access") + access = op2.INC + elif access is None: + # Default access for forward 1-form or 2-form (forward and adjoint) + access = op2.WRITE + self.options.access = access @abc.abstractmethod def _interpolate(self, *args, **kwargs): @@ -269,7 +267,7 @@ def assemble(self, tensor=None): return self._interpolate(*cofunctions, output=tensor, adjoint=needs_adjoint) -def _get_interpolator(expr: Interpolate, V) -> Interpolator: +def _get_interpolator(expr: Interpolate | Expr, V) -> Interpolator: target_mesh = as_domain(V) source_mesh = extract_unique_domain(expr) or target_mesh submesh_interp_implemented = \ @@ -689,7 +687,7 @@ def _get_tensor(self): return f, tensor def _get_callable(self): - expr = self.ufl_interpolate_renumbered + expr = self.ufl_interpolate dual_arg, operand = expr.argument_slots() arguments = expr.arguments() rank = len(arguments) From a73e23f40123ff5a25bb7f646388df15a453c802 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Wed, 1 Oct 2025 17:45:09 +0100 Subject: [PATCH 086/125] Squashed commit of the following: commit 3aeb517cb053decfd1749923d62ba75a5401745e Author: Pablo Brubeck Date: Wed Oct 1 17:17:44 2025 +0100 Test the MixedInterpolator 0-form across different meshes commit fee366abc0ef081c22d1d819fdd73d9c29407f59 Author: Pablo Brubeck Date: Wed Oct 1 16:56:56 2025 +0100 Implement missing functionality in CrossMeshInterpolator commit e400fc99cc263c5695d0c60ac62818a5e03555ba Author: Pablo Brubeck Date: Wed Oct 1 15:16:39 2025 +0100 cleanup commit 997e63886f43be79bb55f6b81582581afee58e85 Author: Pablo Brubeck Date: Wed Oct 1 12:57:59 2025 +0100 cleanup commit f1080ea3f9d3cb0477b7b48f1367a65a4ba0de28 Author: Pablo Brubeck Date: Wed Oct 1 09:11:07 2025 +0100 Interpolate: support fieldsplit commit 76ec36770f649f7418e4eb42315f307f88ed356e Author: Pablo Brubeck Date: Wed Oct 1 07:26:07 2025 +0100 Fixup commit 3bd935e36bbed7e0f5c4e5aebb83ce7eafd71c95 Author: Pablo Brubeck Date: Wed Oct 1 00:23:21 2025 +0100 MixedInterpolator --- firedrake/assemble.py | 31 +-- firedrake/interpolation.py | 177 +++++++++++++----- .../firedrake/regression/test_interpolate.py | 47 +++++ 3 files changed, 180 insertions(+), 75 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 1875b2fc2d..291ed300eb 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -18,7 +18,6 @@ import finat.ufl from firedrake import (extrusion_utils as eutils, matrix, parameters, solving, tsfc_interface, utils) -from firedrake.formmanipulation import split_form from firedrake.adjoint_utils import annotate_assemble from firedrake.ufl_expr import extract_unique_domain from firedrake.bcs import DirichletBC, EquationBC, EquationBCSplit @@ -573,36 +572,12 @@ def base_form_assembly_visitor(self, expr, tensor, *args): rank = len(expr.arguments()) if rank > 2: raise ValueError("Cannot assemble an Interpolate with more than two arguments") - # If argument numbers have been swapped => Adjoint. - arg_operand = ufl.algorithms.extract_arguments(operand) - is_adjoint = (arg_operand and arg_operand[0].number() == 0) - # Get the target space V = v.function_space().dual() - # Dual interpolation from mixed source - if is_adjoint and len(V) > 1: - cur = 0 - sub_operands = [] - components = numpy.reshape(operand, (-1,)) - for Vi in V: - sub_operands.append(ufl.as_tensor(components[cur:cur+Vi.value_size].reshape(Vi.value_shape))) - cur += Vi.value_size - - # Component-split of the primal operands interpolated into the dual argument-split - split_interp = sum(reconstruct_interp(sub_operands[i], v=vi) for (i,), vi in split_form(v)) - return assemble(split_interp, tensor=tensor) - - # Dual interpolation into mixed target - if is_adjoint and len(arg_operand[0].function_space()) > 1 and rank == 1: - V = arg_operand[0].function_space() - tensor = tensor or firedrake.Cofunction(V.dual()) - - # Argument-split of the Interpolate gets assembled into the corresponding sub-tensor - for (i,), sub_interp in split_form(expr): - assemble(sub_interp, tensor=tensor.subfunctions[i]) - return tensor - + # Get the interpolator + interp_data = expr.interp_data.copy() + default_missing_val = interp_data.pop('default_missing_val', None) if rank == 1 and isinstance(tensor, firedrake.Function): V = tensor interpolator = _get_interpolator(expr, V) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index d645d645df..63261044c4 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -245,8 +245,8 @@ def assemble(self, tensor=None): if needs_adjoint: # Out-of-place Hermitian transpose petsc_mat.hermitianTranspose(out=res) - elif res: - petsc_mat.copy(res) + elif tensor: + petsc_mat.copy(tensor.petscmat) else: res = petsc_mat return tensor or firedrake.AssembledMatrix(arguments, self.bcs, res) @@ -268,6 +268,18 @@ def assemble(self, tensor=None): def _get_interpolator(expr: Interpolate | Expr, V) -> Interpolator: + V_target = V if isinstance(V, ufl.FunctionSpace) else V.function_space() + if not isinstance(expr, ufl.Interpolate): + expr = interpolate(expr, V_target) + + spaces = [a.function_space() for a in expr.arguments()] + has_mixed_spaces = any(len(space) > 1 for space in spaces) + if len(spaces) == 2 and has_mixed_spaces: + return object.__new__(MixedInterpolator) + + operand, = expr.ufl_operands + target_mesh = as_domain(V) + source_mesh = extract_unique_domain(operand) or target_mesh target_mesh = as_domain(V) source_mesh = extract_unique_domain(expr) or target_mesh submesh_interp_implemented = \ @@ -484,21 +496,6 @@ def _interpolate(self, *function, output=None, adjoint=False): else: output = firedrake.Function(V_dest) - if len(self.sub_interpolates): - # MixedFunctionSpace case - for sub_interpolate, f_src_sub_func, output_sub_func in zip( - self.sub_interpolates, f_src.subfunctions, output.subfunctions - ): - if f_src is self.expr: - # f_src is already contained in self.point_eval_interpolate, - # so the sub_interpolates are already prepared to interpolate - # without needing to be given a Function - assert not self.nargs - assemble(sub_interpolate, tensor=output_sub_func) - else: - assemble(action(sub_interpolate, f_src_sub_func), tensor=output_sub_func) - return output - if not adjoint: if f_src is self.expr: # f_src is already contained in self.point_eval_interpolate @@ -694,37 +691,37 @@ def _get_callable(self): f, tensor = self._get_tensor() loops = [] - if len(self.V) == 1: - expressions = (expr,) - else: - if (hasattr(operand, "subfunctions") and len(operand.subfunctions) == len(self.V) - and all(sub_op.ufl_shape == Vsub.value_shape for Vsub, sub_op in zip(self.V, operand.subfunctions))): - # Use subfunctions if they match the target shapes - operands = operand.subfunctions - else: - # Unflatten the expression into the shapes of the mixed components - offset = 0 - operands = [] - for Vsub in self.V: - if len(Vsub.value_shape) == 0: - operands.append(operand[offset]) - else: - components = [operand[offset + j] for j in range(Vsub.value_size)] - operands.append(ufl.as_tensor(numpy.reshape(components, Vsub.value_shape))) - offset += Vsub.value_size - # Split the dual argument - if isinstance(dual_arg, Cofunction): - duals = dual_arg.subfunctions - elif isinstance(dual_arg, Coargument): - duals = [Coargument(Vsub, number=dual_arg.number()) for Vsub in dual_arg.function_space()] - else: - duals = [v for _, v in sorted(firedrake.formmanipulation.split_form(dual_arg))] - expressions = map(expr._ufl_expr_reconstruct_, operands, duals) + if access == op2.INC: + loops.append(tensor.zero) + + dual_arg, operand = expr.argument_slots() + # Arguments in the operand are allowed to be from a MixedFunctionSpace + # We need to split the target space V and generate separate kernels + if len(V) == 1: + expressions = {(0,): expr} + elif isinstance(dual_arg, Coargument): + # Split in the coargument + expressions = dict(firedrake.formmanipulation.split_form(expr)) + else: + # Split in the cofunction: split_form can only split in the coargument + # Replace the cofunction with a coargument to construct the Jacobian + interp = expr._ufl_expr_reconstruct_(operand, V) + # Split the Jacobian into blocks + interp_split = dict(firedrake.formmanipulation.split_form(interp)) + # Split the cofunction + dual_split = dict(firedrake.formmanipulation.split_form(dual_arg)) + # Combine the splits by taking their action + expressions = {i: action(interp_split[i], dual_split[i[-1:]]) for i in interp_split} # Interpolate each sub expression into each function space - for Vsub, sub_tensor, sub_expr in zip(self.V, tensor, expressions): - loops.extend(_interpolator(Vsub, sub_tensor, sub_expr, self.subset, arguments, self.options.access, bcs=self.bcs)) + for indices, sub_expr in expressions.items(): + if isinstance(sub_expr, ufl.ZeroBaseForm): + continue + arguments = sub_expr.arguments() + sub_space = sub_expr.argument_slots()[0].function_space().dual() + sub_tensor = tensor[indices[0]] if rank == 1 else tensor + loops.extend(_interpolator(sub_space, sub_tensor, sub_expr, subset, arguments, access, bcs=bcs)) if self.bcs and rank == 1: loops.extend(partial(bc.apply, f) for bc in self.bcs) @@ -886,8 +883,6 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): parameters['scalar_type'] = utils.ScalarType callables = () - if access == op2.INC: - callables += (tensor.zero,) # For the matfree adjoint 1-form and the 0-form, the cellwise kernel will add multiple # contributions from the facet DOFs of the dual argument. @@ -1532,3 +1527,91 @@ def _wrap_dummy_mat(self): def duplicate(self, mat=None, op=None): return self._wrap_dummy_mat() + + +class MixedInterpolator(Interpolator): + """A reusable interpolation object between MixedFunctionSpaces. + + Parameters + ---------- + expr + The underlying ufl.Interpolate or the operand to the ufl.Interpolate. + V + The :class:`.FunctionSpace` or :class:`.Function` to + interpolate into. + bcs + A list of boundary conditions. + **kwargs + Any extra kwargs are passed on to the sub Interpolators. + For details see :class:`firedrake.interpolation.Interpolator`. + """ + def __init__(self, expr, V, bcs=None, **kwargs): + super(MixedInterpolator, self).__init__(expr, V, bcs=bcs, **kwargs) + expr = self.ufl_interpolate + self.arguments = expr.arguments() + rank = len(self.arguments) + + needs_action = len([a for a in self.arguments if isinstance(a, Coargument)]) == 0 + if needs_action: + dual_arg, operand = expr.argument_slots() + # Split the dual argument + dual_split = dict(firedrake.formmanipulation.split_form(dual_arg)) + # Create the Jacobian to be split into blocks + expr = expr._ufl_expr_reconstruct_(operand, V) + + Isub = {} + for indices, form in firedrake.formmanipulation.split_form(expr): + if isinstance(form, ufl.ZeroBaseForm): + # Ensure block sparsity + continue + vi, _ = form.argument_slots() + Vtarget = vi.function_space().dual() + if bcs and rank != 0: + args = form.arguments() + Vsource = args[1-vi.number()].function_space() + sub_bcs = [bc for bc in bcs if bc.function_space() in {Vsource, Vtarget}] + else: + sub_bcs = None + if needs_action: + # Take the action of each sub-cofunction against each block + form = action(form, dual_split[indices[-1:]]) + + Isub[indices] = Interpolator(form, Vtarget, bcs=sub_bcs, **kwargs) + + self._sub_interpolators = Isub + self.callable = self._get_callable + + def __getitem__(self, item): + return self._sub_interpolators[item] + + def __iter__(self): + return iter(self._sub_interpolators) + + def _get_callable(self): + """Assemble the operator.""" + shape = tuple(len(a.function_space()) for a in self.arguments) + blocks = numpy.full(shape, PETSc.Mat(), dtype=object) + for i in self: + blocks[i] = self[i].callable().handle + petscmat = PETSc.Mat().createNest(blocks) + tensor = firedrake.AssembledMatrix(self.arguments, self.bcs, petscmat) + return tensor.M + + def _interpolate(self, *function, output=None, adjoint=False, **kwargs): + """Assemble the action.""" + rank = len(self.arguments) + if rank == 0: + result = sum(self[i].assemble(**kwargs) for i in self) + return output.assign(result) if output else result + + if output is None: + output = firedrake.Function(self.arguments[-1].function_space().dual()) + + if rank == 1: + for k, sub_tensor in enumerate(output.subfunctions): + sub_tensor.assign(sum(self[i].assemble(**kwargs) for i in self if i[0] == k)) + elif rank == 2: + for k, sub_tensor in enumerate(output.subfunctions): + sub_tensor.assign(sum(self[i]._interpolate(*function, adjoint=adjoint, **kwargs) + for i in self if i[0] == k)) + return output diff --git a/tests/firedrake/regression/test_interpolate.py b/tests/firedrake/regression/test_interpolate.py index 69f139ecbb..40a831ad47 100644 --- a/tests/firedrake/regression/test_interpolate.py +++ b/tests/firedrake/regression/test_interpolate.py @@ -517,3 +517,50 @@ def test_interpolate_logical_not(): a = assemble(interpolate(conditional(Not(x < .2), 1, 0), V)) b = assemble(interpolate(conditional(x >= .2, 1, 0), V)) assert np.allclose(a.dat.data, b.dat.data) + + +@pytest.mark.parametrize("mode", ("forward", "adjoint")) +def test_mixed_matrix(mode): + nx = 3 + mesh = UnitSquareMesh(nx, nx) + + V1 = VectorFunctionSpace(mesh, "CG", 2) + V2 = FunctionSpace(mesh, "CG", 1) + V3 = FunctionSpace(mesh, "CG", 1) + V4 = FunctionSpace(mesh, "DG", 1) + + Z = V1 * V2 + W = V3 * V3 * V4 + + if mode == "forward": + I = Interpolate(TrialFunction(Z), TestFunction(W.dual())) + a = assemble(I) + assert a.arguments()[0].function_space() == W.dual() + assert a.arguments()[1].function_space() == Z + assert a.petscmat.getSize() == (W.dim(), Z.dim()) + assert a.petscmat.getType() == "nest" + + u = Function(Z) + u.subfunctions[0].sub(0).assign(1) + u.subfunctions[0].sub(1).assign(2) + u.subfunctions[1].assign(3) + result_matfree = assemble(Interpolate(u, TestFunction(W.dual()))) + elif mode == "adjoint": + I = Interpolate(TestFunction(Z), TrialFunction(W.dual())) + a = assemble(I) + assert a.arguments()[1].function_space() == W.dual() + assert a.arguments()[0].function_space() == Z + assert a.petscmat.getSize() == (Z.dim(), W.dim()) + assert a.petscmat.getType() == "nest" + + u = Function(W.dual()) + u.subfunctions[0].assign(1) + u.subfunctions[1].assign(2) + u.subfunctions[2].assign(3) + result_matfree = assemble(Interpolate(TestFunction(Z), u)) + else: + raise ValueError(f"Unrecognized mode {mode}") + + result_explicit = assemble(action(a, u)) + for x, y in zip(result_explicit.subfunctions, result_matfree.subfunctions): + assert np.allclose(x.dat.data, y.dat.data) From 37b9cac5c686edb0f053c235cf140c38e841a83a Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Wed, 1 Oct 2025 20:26:41 +0100 Subject: [PATCH 087/125] tidy fix options fixes --- firedrake/assemble.py | 3 - firedrake/interpolation.py | 127 +++++++++++++++---------------------- 2 files changed, 51 insertions(+), 79 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 291ed300eb..d424fb77e4 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -575,9 +575,6 @@ def base_form_assembly_visitor(self, expr, tensor, *args): # Get the target space V = v.function_space().dual() - # Get the interpolator - interp_data = expr.interp_data.copy() - default_missing_val = interp_data.pop('default_missing_val', None) if rank == 1 and isinstance(tensor, firedrake.Function): V = tensor interpolator = _get_interpolator(expr, V) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 63261044c4..2383b07b09 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -173,11 +173,16 @@ def __init__(self, expr: Interpolate, V, bcs=None): bcs : list, optional List of boundary conditions to zero-out in the output function space. By default None. """ + if not isinstance(expr, ufl.Interpolate): + expr = interpolate(expr, V if isinstance(V, ufl.FunctionSpace) else V.function_space()) dual_arg, operand = expr.argument_slots() self.ufl_interpolate = expr self.expr = operand self.V = V - self.options = expr.options + self.subset = expr.options.subset + self.allow_missing_dofs = expr.options.allow_missing_dofs + self.default_missing_val = expr.options.default_missing_val + self.matfree = expr.options.matfree self.bcs = bcs self.callable = None @@ -211,7 +216,7 @@ def __init__(self, expr: Interpolate, V, bcs=None): self.expr_renumbered = operand self.ufl_interpolate_renumbered = expr - access = self.options.access + access = expr.options.access if not isinstance(dual_arg, ufl.Coargument): # Matrix-free assembly of 0-form or 1-form requires INC access if access and access != op2.INC: @@ -220,7 +225,7 @@ def __init__(self, expr: Interpolate, V, bcs=None): elif access is None: # Default access for forward 1-form or 2-form (forward and adjoint) access = op2.WRITE - self.options.access = access + self.access = access @abc.abstractmethod def _interpolate(self, *args, **kwargs): @@ -267,34 +272,34 @@ def assemble(self, tensor=None): return self._interpolate(*cofunctions, output=tensor, adjoint=needs_adjoint) -def _get_interpolator(expr: Interpolate | Expr, V) -> Interpolator: +def _get_interpolator(expr: Interpolate | Expr, V, bcs=None) -> Interpolator: V_target = V if isinstance(V, ufl.FunctionSpace) else V.function_space() if not isinstance(expr, ufl.Interpolate): expr = interpolate(expr, V_target) - spaces = [a.function_space() for a in expr.arguments()] - has_mixed_spaces = any(len(space) > 1 for space in spaces) - if len(spaces) == 2 and has_mixed_spaces: - return object.__new__(MixedInterpolator) + arguments = expr.arguments() + has_mixed_arguments = any(len(arg.function_space()) > 1 for arg in arguments) + if len(arguments) == 2 and has_mixed_arguments: + return MixedInterpolator(expr, V, bcs=bcs) operand, = expr.ufl_operands target_mesh = as_domain(V) source_mesh = extract_unique_domain(operand) or target_mesh - target_mesh = as_domain(V) - source_mesh = extract_unique_domain(expr) or target_mesh submesh_interp_implemented = \ all(isinstance(m.topology, firedrake.mesh.MeshTopology) for m in [target_mesh, source_mesh]) and \ target_mesh.submesh_ancesters[-1] is source_mesh.submesh_ancesters[-1] and \ target_mesh.topological_dimension() == source_mesh.topological_dimension() if target_mesh is source_mesh or submesh_interp_implemented: - return SameMeshInterpolator(expr, V) + return SameMeshInterpolator(expr, V, bcs=bcs) else: if isinstance(target_mesh.topology, VertexOnlyMeshTopology): if isinstance(source_mesh.topology, VertexOnlyMeshTopology): - return VomOntoVomInterpolator(expr, V) - return SameMeshInterpolator(expr, V) + return VomOntoVomInterpolator(expr, V, bcs=bcs) + return SameMeshInterpolator(expr, V, bcs=bcs) + elif has_mixed_arguments or len(V_target) > 1: + return MixedInterpolator(expr, V, bcs=bcs) else: - return CrossMeshInterpolator(expr, V) + return CrossMeshInterpolator(expr, V, bcs=bcs) class DofNotDefinedError(Exception): @@ -336,7 +341,7 @@ class CrossMeshInterpolator(Interpolator): @no_annotations def __init__(self, expr: Interpolate, V, bcs=None): super().__init__(expr, V, bcs) - if self.options.access != op2.WRITE: + if self.access != op2.WRITE: raise NotImplementedError( "Access other than op2.WRITE not implemented for cross-mesh interpolation." ) @@ -354,7 +359,7 @@ def __init__(self, expr: Interpolate, V, bcs=None): self.arguments = extract_arguments(self.expr_renumbered) self.nargs = len(self.arguments) - if self.options.allow_missing_dofs: + if self.allow_missing_dofs: self.missing_points_behaviour = MissingPointsBehaviour.IGNORE else: self.missing_points_behaviour = MissingPointsBehaviour.ERROR @@ -365,19 +370,18 @@ def __init__(self, expr: Interpolate, V, bcs=None): if self.src_mesh.geometric_dimension() != self.dest_mesh.geometric_dimension(): raise ValueError("Geometric dimensions of source and destination meshes must match.") - self.sub_interpolates = [] dest_element = self.V_dest.ufl_element() - if isinstance(dest_element, (finat.ufl.VectorElement, finat.ufl.TensorElement)): - # In this case all sub elements are equal - base_element = dest_element.sub_elements[0] - if base_element.reference_value_shape != (): - raise NotImplementedError( - "Can't yet cross-mesh interpolate onto function spaces made from VectorElements or TensorElements made from sub elements with value shape other than ()." - ) - self.dest_element = base_element - elif isinstance(dest_element, finat.ufl.MixedElement): - self._mixed_function_space() - return + if isinstance(dest_element, finat.ufl.MixedElement): + if isinstance(dest_element, (finat.ufl.VectorElement, finat.ufl.TensorElement)): + # In this case all sub elements are equal + base_element = dest_element.sub_elements[0] + if base_element.reference_value_shape != (): + raise NotImplementedError( + "Can't yet cross-mesh interpolate onto function spaces made from VectorElements or TensorElements made from sub elements with value shape other than ()." + ) + self.dest_element = base_element + else: + raise NotImplementedError("Interpolation matrix with MixedFunctionSpace requires MixedInterpolator") else: # scalar fiat/finat element self.dest_element = dest_element @@ -424,32 +428,6 @@ def _get_symbolic_expressions(self): P0DG_vom_i_o = fs_type(self.vom.input_ordering, "DG", 0) self.point_eval_input_ordering = interpolate(firedrake.TrialFunction(P0DG_vom), P0DG_vom_i_o) - def _mixed_function_space(self): - """Builds symbolic Interpolate expressions for each sub-element of a MixedFunctionSpace. - """ - # NOTE: since we can't have expressions for MixedFunctionSpaces - # we know that the input argument ``expr`` must be a Function. - # V_dest can be a Function or a FunctionSpace, and subfunctions works for both. - if self.nargs == 1: - # Arguments don't have a subfunctions property so I have to - # make them myself. - expr_subfunctions = [ - firedrake.TrialFunction(V_src_sub_func) - for V_src_sub_func in self.expr.function_space().subspaces - ] - elif self.nargs > 1: - raise NotImplementedError( - "Can't yet create an interpolator from an expression with multiple arguments." - ) - else: - expr_subfunctions = self.expr.subfunctions - - if len(expr_subfunctions) != len(self.V_dest.subspaces): - raise NotImplementedError("Can't interpolate from a non-mixed function space into a mixed function space.") - - for sub_func, subspace in zip(expr_subfunctions, self.V_dest.subspaces): - self.sub_interpolates.append(interpolate(sub_func, subspace, **asdict(self.options))) - @PETSc.Log.EventDecorator() def _interpolate(self, *function, output=None, adjoint=False): """Compute the interpolation. @@ -512,11 +490,11 @@ def _interpolate(self, *function, output=None, adjoint=False): ) # We have to create the Function before interpolating so we can # set default missing values (if requested). - if self.options.default_missing_val is not None: + if self.default_missing_val is not None: f_src_at_dest_node_coords_dest_mesh_decomp.dat.data_wo[ : - ] = self.options.default_missing_val - elif self.options.allow_missing_dofs: + ] = self.default_missing_val + elif self.allow_missing_dofs: # If we have allowed missing points we know we might end up # with points in the target mesh that are not in the source # mesh. However, since we haven't specified a default missing @@ -530,7 +508,7 @@ def _interpolate(self, *function, output=None, adjoint=False): assemble(interp, tensor=f_src_at_dest_node_coords_dest_mesh_decomp) # we can now confidently assign this to a function on V_dest - if self.options.allow_missing_dofs and self.options.default_missing_val is None: + if self.allow_missing_dofs and self.default_missing_val is None: indices = numpy.where( ~numpy.isnan(f_src_at_dest_node_coords_dest_mesh_decomp.dat.data_ro) )[0] @@ -594,7 +572,7 @@ class SameMeshInterpolator(Interpolator): @no_annotations def __init__(self, expr, V, bcs=None): super().__init__(expr, V, bcs=bcs) - subset = self.options.subset + subset = self.subset if subset is None: if isinstance(expr, ufl.Interpolate): operand, = expr.ufl_operands @@ -612,7 +590,7 @@ def __init__(self, expr, V, bcs=None): make_subset = not indices_active.all() make_subset = target.comm.allreduce(make_subset, op=MPI.LOR) if make_subset: - if not self.options.allow_missing_dofs: + if not self.allow_missing_dofs: raise ValueError("Iteration (sub)set unclear: run with `allow_missing_dofs=True`.") subset = op2.Subset(target.cell_set, numpy.where(indices_active)) else: @@ -642,9 +620,9 @@ def _get_tensor(self): else: V_dest = arguments[0].function_space().dual() f = firedrake.Function(V_dest) - if self.options.access in {firedrake.MIN, firedrake.MAX}: + if self.access in {firedrake.MIN, firedrake.MAX}: finfo = numpy.finfo(f.dat.dtype) - if self.options.access == firedrake.MIN: + if self.access == firedrake.MIN: val = firedrake.Constant(finfo.max) else: val = firedrake.Constant(finfo.min) @@ -656,11 +634,11 @@ def _get_tensor(self): Vrow = arguments[0].function_space() Vcol = arguments[1].function_space() if len(Vrow) > 1 or len(Vcol) > 1: - raise NotImplementedError("Interpolation of mixed expressions with arguments is not supported") + raise NotImplementedError("Interpolation matrix with MixedFunctionSpace requires MixedInterpolator") vom_onto_other_vom = isinstance(self, VomOntoVomInterpolator) if isinstance(target_mesh.topology, VertexOnlyMeshTopology) and target_mesh is not source_mesh and not vom_onto_other_vom: if not isinstance(target_mesh.topology, VertexOnlyMeshTopology): - raise NotImplementedError("Can only interpolate onto a Vertex Only Mesh") + raise NotImplementedError("Can only interpolate onto a VertexOnlyMesh") if target_mesh.geometric_dimension() != source_mesh.geometric_dimension(): raise ValueError("Cannot interpolate onto a mesh of a different geometric dimension") if not hasattr(target_mesh, "_parent_mesh") or target_mesh._parent_mesh is not source_mesh: @@ -684,7 +662,7 @@ def _get_tensor(self): return f, tensor def _get_callable(self): - expr = self.ufl_interpolate + expr = self.ufl_interpolate_renumbered dual_arg, operand = expr.argument_slots() arguments = expr.arguments() rank = len(arguments) @@ -692,13 +670,13 @@ def _get_callable(self): loops = [] - if access == op2.INC: + if self.access == op2.INC: loops.append(tensor.zero) dual_arg, operand = expr.argument_slots() # Arguments in the operand are allowed to be from a MixedFunctionSpace # We need to split the target space V and generate separate kernels - if len(V) == 1: + if len(arguments) == 2: expressions = {(0,): expr} elif isinstance(dual_arg, Coargument): # Split in the coargument @@ -706,7 +684,7 @@ def _get_callable(self): else: # Split in the cofunction: split_form can only split in the coargument # Replace the cofunction with a coargument to construct the Jacobian - interp = expr._ufl_expr_reconstruct_(operand, V) + interp = expr._ufl_expr_reconstruct_(operand, self.V) # Split the Jacobian into blocks interp_split = dict(firedrake.formmanipulation.split_form(interp)) # Split the cofunction @@ -721,7 +699,7 @@ def _get_callable(self): arguments = sub_expr.arguments() sub_space = sub_expr.argument_slots()[0].function_space().dual() sub_tensor = tensor[indices[0]] if rank == 1 else tensor - loops.extend(_interpolator(sub_space, sub_tensor, sub_expr, subset, arguments, access, bcs=bcs)) + loops.extend(_interpolator(sub_space, sub_tensor, sub_expr, self.subset, arguments, self.access, bcs=self.bcs)) if self.bcs and rank == 1: loops.extend(partial(bc.apply, f) for bc in self.bcs) @@ -789,7 +767,7 @@ def _get_callable(self): source_mesh = extract_unique_domain(operand) or target_mesh arguments = expr.arguments() f, tensor = self._get_tensor() - wrapper = VomOntoVomWrapper(self.V, source_mesh, target_mesh, operand, self.options.matfree) + wrapper = VomOntoVomWrapper(self.V, source_mesh, target_mesh, operand, self.matfree) # NOTE: get_dat_mpi_type ensures we get the correct MPI type for the # data, including the correct data size and dimensional information # (so for vector function spaces in 2 dimensions we might need a @@ -1541,12 +1519,9 @@ class MixedInterpolator(Interpolator): interpolate into. bcs A list of boundary conditions. - **kwargs - Any extra kwargs are passed on to the sub Interpolators. - For details see :class:`firedrake.interpolation.Interpolator`. """ - def __init__(self, expr, V, bcs=None, **kwargs): - super(MixedInterpolator, self).__init__(expr, V, bcs=bcs, **kwargs) + def __init__(self, expr, V, bcs=None): + super().__init__(expr, V, bcs=bcs) expr = self.ufl_interpolate self.arguments = expr.arguments() rank = len(self.arguments) @@ -1576,7 +1551,7 @@ def __init__(self, expr, V, bcs=None, **kwargs): # Take the action of each sub-cofunction against each block form = action(form, dual_split[indices[-1:]]) - Isub[indices] = Interpolator(form, Vtarget, bcs=sub_bcs, **kwargs) + Isub[indices] = _get_interpolator(form, Vtarget, bcs=sub_bcs) self._sub_interpolators = Isub self.callable = self._get_callable From 90a0fa51ff9c24d27a46ea5239c8bd234eceee95 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Wed, 8 Oct 2025 16:09:29 +0100 Subject: [PATCH 088/125] remove `V` argument fixes --- firedrake/assemble.py | 6 +--- firedrake/interpolation.py | 59 +++++++++++++++++--------------------- 2 files changed, 27 insertions(+), 38 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index d424fb77e4..149a4f3f5a 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -572,12 +572,8 @@ def base_form_assembly_visitor(self, expr, tensor, *args): rank = len(expr.arguments()) if rank > 2: raise ValueError("Cannot assemble an Interpolate with more than two arguments") - # Get the target space - V = v.function_space().dual() - if rank == 1 and isinstance(tensor, firedrake.Function): - V = tensor - interpolator = _get_interpolator(expr, V) + interpolator = _get_interpolator(expr) return interpolator.assemble(tensor=tensor) elif tensor and isinstance(expr, (firedrake.Function, firedrake.Cofunction, firedrake.MatrixBase)): return tensor.assign(expr) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 2383b07b09..d7a4c0171a 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -161,24 +161,20 @@ def interpolate(expr: Expr, V: WithGeometry | ufl.BaseForm, **kwargs): class Interpolator(abc.ABC): - def __init__(self, expr: Interpolate, V, bcs=None): + def __init__(self, expr: Interpolate, bcs=None): """Initialise Interpolator. Parameters ---------- expr : Interpolate The symbolic interpolation expression. - V : FunctionSpace or Function to interpolate into. - _description_ bcs : list, optional List of boundary conditions to zero-out in the output function space. By default None. """ - if not isinstance(expr, ufl.Interpolate): - expr = interpolate(expr, V if isinstance(V, ufl.FunctionSpace) else V.function_space()) dual_arg, operand = expr.argument_slots() self.ufl_interpolate = expr self.expr = operand - self.V = V + self.V = dual_arg.function_space().dual() self.subset = expr.options.subset self.allow_missing_dofs = expr.options.allow_missing_dofs self.default_missing_val = expr.options.default_missing_val @@ -272,15 +268,13 @@ def assemble(self, tensor=None): return self._interpolate(*cofunctions, output=tensor, adjoint=needs_adjoint) -def _get_interpolator(expr: Interpolate | Expr, V, bcs=None) -> Interpolator: - V_target = V if isinstance(V, ufl.FunctionSpace) else V.function_space() - if not isinstance(expr, ufl.Interpolate): - expr = interpolate(expr, V_target) +def _get_interpolator(expr: Interpolate, bcs=None) -> Interpolator: + V = expr.argument_slots()[0].function_space().dual() # Target function space arguments = expr.arguments() has_mixed_arguments = any(len(arg.function_space()) > 1 for arg in arguments) if len(arguments) == 2 and has_mixed_arguments: - return MixedInterpolator(expr, V, bcs=bcs) + return MixedInterpolator(expr, bcs=bcs) operand, = expr.ufl_operands target_mesh = as_domain(V) @@ -290,16 +284,16 @@ def _get_interpolator(expr: Interpolate | Expr, V, bcs=None) -> Interpolator: target_mesh.submesh_ancesters[-1] is source_mesh.submesh_ancesters[-1] and \ target_mesh.topological_dimension() == source_mesh.topological_dimension() if target_mesh is source_mesh or submesh_interp_implemented: - return SameMeshInterpolator(expr, V, bcs=bcs) + return SameMeshInterpolator(expr, bcs=bcs) else: if isinstance(target_mesh.topology, VertexOnlyMeshTopology): if isinstance(source_mesh.topology, VertexOnlyMeshTopology): - return VomOntoVomInterpolator(expr, V, bcs=bcs) - return SameMeshInterpolator(expr, V, bcs=bcs) - elif has_mixed_arguments or len(V_target) > 1: - return MixedInterpolator(expr, V, bcs=bcs) + return VomOntoVomInterpolator(expr, bcs=bcs) + return SameMeshInterpolator(expr, bcs=bcs) + elif has_mixed_arguments or len(V) > 1: + return MixedInterpolator(expr, bcs=bcs) else: - return CrossMeshInterpolator(expr, V, bcs=bcs) + return CrossMeshInterpolator(expr, bcs=bcs) class DofNotDefinedError(Exception): @@ -339,15 +333,15 @@ class CrossMeshInterpolator(Interpolator): """ @no_annotations - def __init__(self, expr: Interpolate, V, bcs=None): - super().__init__(expr, V, bcs) + def __init__(self, expr: Interpolate, bcs=None): + super().__init__(expr, bcs) if self.access != op2.WRITE: raise NotImplementedError( "Access other than op2.WRITE not implemented for cross-mesh interpolation." ) if self.bcs: raise NotImplementedError("bcs not implemented.") - if V.ufl_element().mapping() != "identity": + if self.V.ufl_element().mapping() != "identity": # Identity mapping between reference cell and physical coordinates # implies point evaluation nodes. A more general version would # require finding the global coordinates of all quadrature points @@ -364,13 +358,12 @@ def __init__(self, expr: Interpolate, V, bcs=None): else: self.missing_points_behaviour = MissingPointsBehaviour.ERROR - self.V_dest = V.function_space() if isinstance(V, firedrake.Function) else V self.src_mesh = extract_unique_domain(self.expr_renumbered) - self.dest_mesh = as_domain(self.V_dest) + self.dest_mesh = as_domain(self.V) if self.src_mesh.geometric_dimension() != self.dest_mesh.geometric_dimension(): raise ValueError("Geometric dimensions of source and destination meshes must match.") - dest_element = self.V_dest.ufl_element() + dest_element = self.V.ufl_element() if isinstance(dest_element, finat.ufl.MixedElement): if isinstance(dest_element, (finat.ufl.VectorElement, finat.ufl.TensorElement)): # In this case all sub elements are equal @@ -412,7 +405,7 @@ def _get_symbolic_expressions(self): raise DofNotDefinedError(self.src_mesh, self.dest_mesh) # Get the correct type of function space - shape = self.V_dest.ufl_function_space().value_shape + shape = self.V.ufl_function_space().value_shape if len(shape) == 0: fs_type = firedrake.FunctionSpace elif len(shape) == 1: @@ -570,8 +563,8 @@ class SameMeshInterpolator(Interpolator): """ @no_annotations - def __init__(self, expr, V, bcs=None): - super().__init__(expr, V, bcs=bcs) + def __init__(self, expr, bcs=None): + super().__init__(expr, bcs=bcs) subset = self.subset if subset is None: if isinstance(expr, ufl.Interpolate): @@ -757,8 +750,8 @@ def _interpolate(self, *function, output=None, adjoint=False): class VomOntoVomInterpolator(SameMeshInterpolator): - def __init__(self, expr: Interpolate, V, bcs=None): - super().__init__(expr, V, bcs=bcs) + def __init__(self, expr: Interpolate, bcs=None): + super().__init__(expr, bcs=bcs) def _get_callable(self): expr = self.ufl_interpolate_renumbered @@ -1520,8 +1513,8 @@ class MixedInterpolator(Interpolator): bcs A list of boundary conditions. """ - def __init__(self, expr, V, bcs=None): - super().__init__(expr, V, bcs=bcs) + def __init__(self, expr, bcs=None): + super().__init__(expr, bcs=bcs) expr = self.ufl_interpolate self.arguments = expr.arguments() rank = len(self.arguments) @@ -1532,7 +1525,7 @@ def __init__(self, expr, V, bcs=None): # Split the dual argument dual_split = dict(firedrake.formmanipulation.split_form(dual_arg)) # Create the Jacobian to be split into blocks - expr = expr._ufl_expr_reconstruct_(operand, V) + expr = expr._ufl_expr_reconstruct_(operand, self.V) Isub = {} for indices, form in firedrake.formmanipulation.split_form(expr): @@ -1543,7 +1536,7 @@ def __init__(self, expr, V, bcs=None): Vtarget = vi.function_space().dual() if bcs and rank != 0: args = form.arguments() - Vsource = args[1-vi.number()].function_space() + Vsource = args[1 - vi.number()].function_space() sub_bcs = [bc for bc in bcs if bc.function_space() in {Vsource, Vtarget}] else: sub_bcs = None @@ -1551,7 +1544,7 @@ def __init__(self, expr, V, bcs=None): # Take the action of each sub-cofunction against each block form = action(form, dual_split[indices[-1:]]) - Isub[indices] = _get_interpolator(form, Vtarget, bcs=sub_bcs) + Isub[indices] = _get_interpolator(form, bcs=sub_bcs) self._sub_interpolators = Isub self.callable = self._get_callable From 72bd29e4e4edf4fc528e0c7d635b16ff75865490 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Wed, 8 Oct 2025 19:00:41 +0100 Subject: [PATCH 089/125] assemble cross-mesh interpolation matrix; add test --- firedrake/interpolation.py | 27 +++++++++++++------ .../regression/test_interpolate_cross_mesh.py | 6 +++++ 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index d7a4c0171a..c991389305 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -350,7 +350,7 @@ def __init__(self, expr: Interpolate, bcs=None): "Can only interpolate into spaces with point evaluation nodes." ) - self.arguments = extract_arguments(self.expr_renumbered) + self.arguments = self.ufl_interpolate_renumbered.arguments() self.nargs = len(self.arguments) if self.allow_missing_dofs: @@ -378,6 +378,8 @@ def __init__(self, expr: Interpolate, bcs=None): else: # scalar fiat/finat element self.dest_element = dest_element + if self.nargs == 2: + self.matfree = False self._get_symbolic_expressions() def _get_symbolic_expressions(self): @@ -419,18 +421,27 @@ def _get_symbolic_expressions(self): # Interpolate into the input-ordering VOM P0DG_vom_i_o = fs_type(self.vom.input_ordering, "DG", 0) - self.point_eval_input_ordering = interpolate(firedrake.TrialFunction(P0DG_vom), P0DG_vom_i_o) + self.point_eval_input_ordering = interpolate(firedrake.TrialFunction(P0DG_vom), P0DG_vom_i_o, matfree=self.matfree) + + if self.nargs == 2: + # The cross-mesh interpolation matrix is the product of the + # `self.point_eval_interpolate` and the permutation + # given by `self.to_input_ordering_interpolate`. + self.handle = assemble(action(self.point_eval_input_ordering, self.point_eval)).petscmat + self.callable = lambda: self @PETSc.Log.EventDecorator() def _interpolate(self, *function, output=None, adjoint=False): """Compute the interpolation. """ from firedrake.assemble import assemble - if adjoint and not self.nargs: + expr_args = extract_arguments(self.expr_renumbered) + nargs = len(expr_args) + if adjoint and not nargs: raise ValueError("Can currently only apply adjoint interpolation with arguments.") - if self.nargs != len(function): + if nargs != len(function): raise ValueError(f"Passed {len(function)} Functions to interpolate, expected {self.nargs}") - if self.nargs: + if nargs: (f_src,) = function if not hasattr(f_src, "dat"): raise ValueError( @@ -443,8 +454,8 @@ def _interpolate(self, *function, output=None, adjoint=False): try: V_dest = self.expr.function_space().dual() except AttributeError: - if self.nargs: - V_dest = self.arguments[-1].function_space().dual() + if nargs: + V_dest = expr_args[-1].function_space().dual() else: coeffs = extract_coefficients(self.expr) if len(coeffs): @@ -470,7 +481,7 @@ def _interpolate(self, *function, output=None, adjoint=False): if not adjoint: if f_src is self.expr: # f_src is already contained in self.point_eval_interpolate - assert not self.nargs + assert not nargs f_src_at_dest_node_coords_src_mesh_decomp = ( assemble(self.point_eval) ) diff --git a/tests/firedrake/regression/test_interpolate_cross_mesh.py b/tests/firedrake/regression/test_interpolate_cross_mesh.py index f9f0719d64..4fafc80f4e 100644 --- a/tests/firedrake/regression/test_interpolate_cross_mesh.py +++ b/tests/firedrake/regression/test_interpolate_cross_mesh.py @@ -680,6 +680,12 @@ def test_interpolate_matrix_cross_mesh(): f_interp2.dat.data_wo[:] = f_at_points_correct_order3.dat.data_ro[:] assert np.allclose(f_interp2.dat.data_ro, g.dat.data_ro) + interp_mat2 = assemble(interpolate(TrialFunction(U), V)) + assert interp_mat2.arguments() == (TestFunction(V.dual()), TrialFunction(U)) + f_interp3 = assemble(interp_mat2 @ f) + assert f_interp3.function_space() == V + assert np.allclose(f_interp3.dat.data_ro, g.dat.data_ro) + @pytest.mark.parallel([2, 3, 4]) def test_voting_algorithm_edgecases(): From a8e7d303a24fc9f890bc06acfe4bf74200936c97 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Wed, 8 Oct 2025 20:24:56 +0100 Subject: [PATCH 090/125] assemble adjoint cross-mesh interpolation matrix First stage of removing the renumbered interpolate --- firedrake/interpolation.py | 19 +++++++++---------- .../regression/test_interpolate_cross_mesh.py | 19 +++++++++++++++++++ 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index c991389305..1d3a204f9d 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -109,11 +109,11 @@ def __init__(self, expr: Expr, V: WithGeometry | ufl.BaseForm, **kwargs): Additional interpolation options. See :class:`InterpolateOptions` for available parameters and their descriptions. """ + expr_args = extract_arguments(ufl.as_ufl(expr)) + self.is_adjoint = len(expr_args) and expr_args[0].number() == 0 if isinstance(V, WithGeometry): # Need to create a Firedrake Coargument so it has a .function_space() method - expr_args = extract_arguments(ufl.as_ufl(expr)) - is_adjoint = len(expr_args) and expr_args[0].number() == 0 - V = Argument(V.dual(), 1 if is_adjoint else 0) + V = Argument(V.dual(), 1 if self.is_adjoint else 0) target_shape = V.arguments()[0].function_space().value_shape if expr.ufl_shape != target_shape: @@ -201,8 +201,7 @@ def __init__(self, expr: Interpolate, bcs=None): if not isinstance(dual_arg, ufl.Coargument): # Drop the Cofunction expr = expr._ufl_expr_reconstruct_(operand, dual_arg.function_space().dual()) - expr_args = extract_arguments(operand) - if expr_args and expr_args[0].number() == 0: + if expr.is_adjoint: # Construct the symbolic forward Interpolate v0, v1 = expr.arguments() expr = ufl.replace(expr, {v0: v0.reconstruct(number=v1.number()), @@ -243,10 +242,7 @@ def assemble(self, tensor=None): # Get the interpolation matrix op2mat = self.callable() petsc_mat = op2mat.handle - if needs_adjoint: - # Out-of-place Hermitian transpose - petsc_mat.hermitianTranspose(out=res) - elif tensor: + if tensor: petsc_mat.copy(tensor.petscmat) else: res = petsc_mat @@ -427,7 +423,10 @@ def _get_symbolic_expressions(self): # The cross-mesh interpolation matrix is the product of the # `self.point_eval_interpolate` and the permutation # given by `self.to_input_ordering_interpolate`. - self.handle = assemble(action(self.point_eval_input_ordering, self.point_eval)).petscmat + symbolic = action(self.point_eval_input_ordering, self.point_eval) + if self.ufl_interpolate.is_adjoint: + symbolic = expr_adjoint(symbolic) + self.handle = assemble(symbolic).petscmat self.callable = lambda: self @PETSc.Log.EventDecorator() diff --git a/tests/firedrake/regression/test_interpolate_cross_mesh.py b/tests/firedrake/regression/test_interpolate_cross_mesh.py index 4fafc80f4e..e35367e08c 100644 --- a/tests/firedrake/regression/test_interpolate_cross_mesh.py +++ b/tests/firedrake/regression/test_interpolate_cross_mesh.py @@ -687,6 +687,25 @@ def test_interpolate_matrix_cross_mesh(): assert np.allclose(f_interp3.dat.data_ro, g.dat.data_ro) +def test_interpolate_matrix_cross_mesh_adjoint(): + mesh_fine = UnitSquareMesh(4, 4) + mesh_coarse = UnitSquareMesh(2, 2) + + V_coarse = FunctionSpace(mesh_coarse, "CG", 1) + V_fine = FunctionSpace(mesh_fine, "CG", 1) + + cofunc_fine = assemble(TestFunction(V_fine) * dx) + + interp = assemble(interpolate(TestFunction(V_coarse), TrialFunction(V_fine.dual()))) + cofunc_coarse = assemble(Action(interp, cofunc_fine)) + assert interp.arguments() == (TestFunction(V_coarse), TrialFunction(V_fine.dual())) + assert cofunc_coarse.function_space() == V_coarse.dual() + + # Compare cofunc_fine with direct interpolation + cofunc_coarse_direct = assemble(TestFunction(V_coarse) * dx) + assert np.allclose(cofunc_coarse.dat.data_ro, cofunc_coarse_direct.dat.data_ro) + + @pytest.mark.parallel([2, 3, 4]) def test_voting_algorithm_edgecases(): # this triggers lots of cases where the VOM voting algorithm has to deal From 40c23da4871be0ffaf934773d5a48e8d146955a1 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Thu, 9 Oct 2025 12:58:54 +0100 Subject: [PATCH 091/125] changes --- firedrake/interpolation.py | 83 ++++++++++++++++---------------------- 1 file changed, 35 insertions(+), 48 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 1d3a204f9d..4cb168d3b5 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -330,6 +330,7 @@ class CrossMeshInterpolator(Interpolator): @no_annotations def __init__(self, expr: Interpolate, bcs=None): + from firedrake.assemble import assemble super().__init__(expr, bcs) if self.access != op2.WRITE: raise NotImplementedError( @@ -346,7 +347,7 @@ def __init__(self, expr: Interpolate, bcs=None): "Can only interpolate into spaces with point evaluation nodes." ) - self.arguments = self.ufl_interpolate_renumbered.arguments() + self.arguments = self.ufl_interpolate.arguments() self.nargs = len(self.arguments) if self.allow_missing_dofs: @@ -374,10 +375,19 @@ def __init__(self, expr: Interpolate, bcs=None): else: # scalar fiat/finat element self.dest_element = dest_element - if self.nargs == 2: - self.matfree = False + self._get_symbolic_expressions() + if self.nargs == 2: + # The cross-mesh interpolation matrix is the product of the + # `self.point_eval_interpolate` and the permutation + # given by `self.to_input_ordering_interpolate`. + symbolic = action(self.point_eval_input_ordering, self.point_eval) + if self.ufl_interpolate.is_adjoint: + symbolic = expr_adjoint(symbolic) + self.handle = assemble(symbolic).petscmat + self.callable = lambda: self + def _get_symbolic_expressions(self): """Constructs the symbolic Interpolate expressions for cross-mesh interpolation. @@ -415,20 +425,14 @@ def _get_symbolic_expressions(self): P0DG_vom = fs_type(self.vom, "DG", 0) self.point_eval = interpolate(self.expr_renumbered, P0DG_vom) + if self.nargs == 2: + # If assembling the operator, we need the concrete permutation matrix + self.matfree = False + # Interpolate into the input-ordering VOM P0DG_vom_i_o = fs_type(self.vom.input_ordering, "DG", 0) self.point_eval_input_ordering = interpolate(firedrake.TrialFunction(P0DG_vom), P0DG_vom_i_o, matfree=self.matfree) - if self.nargs == 2: - # The cross-mesh interpolation matrix is the product of the - # `self.point_eval_interpolate` and the permutation - # given by `self.to_input_ordering_interpolate`. - symbolic = action(self.point_eval_input_ordering, self.point_eval) - if self.ufl_interpolate.is_adjoint: - symbolic = expr_adjoint(symbolic) - self.handle = assemble(symbolic).petscmat - self.callable = lambda: self - @PETSc.Log.EventDecorator() def _interpolate(self, *function, output=None, adjoint=False): """Compute the interpolation. @@ -464,30 +468,21 @@ def _interpolate(self, *function, output=None, adjoint=False): "Can't adjoint interpolate an expression with no coefficients or arguments." ) else: - if isinstance(self.V, (firedrake.Function, firedrake.Cofunction)): - V_dest = self.V.function_space() - else: - V_dest = self.V + V_dest = self.V + if output: if output.function_space() != V_dest: raise ValueError("Given output has the wrong function space!") else: - if isinstance(self.V, (firedrake.Function, firedrake.Cofunction)): - output = self.V - else: - output = firedrake.Function(V_dest) + output = firedrake.Function(V_dest) if not adjoint: if f_src is self.expr: # f_src is already contained in self.point_eval_interpolate assert not nargs - f_src_at_dest_node_coords_src_mesh_decomp = ( - assemble(self.point_eval) - ) + f_src_at_dest_node_coords_src_mesh_decomp = assemble(self.point_eval) else: - f_src_at_dest_node_coords_src_mesh_decomp = ( - assemble(action(self.point_eval, f_src)) - ) + f_src_at_dest_node_coords_src_mesh_decomp = assemble(action(self.point_eval, f_src)) f_src_at_dest_node_coords_dest_mesh_decomp = firedrake.Function( self.point_eval_input_ordering.function_space() ) @@ -617,9 +612,6 @@ def _get_tensor(self): if rank == 0: R = firedrake.FunctionSpace(target_mesh, "Real", 0) f = firedrake.Function(R, dtype=utils.ScalarType) - elif isinstance(self.V, firedrake.Function): - f = self.V - self.V = f.function_space() else: V_dest = arguments[0].function_space().dual() f = firedrake.Function(V_dest) @@ -632,8 +624,6 @@ def _get_tensor(self): f.assign(val) tensor = f.dat elif rank == 2: - if isinstance(self.V, firedrake.Function): - raise ValueError("Cannot interpolate an expression with an argument into a Function") Vrow = arguments[0].function_space() Vcol = arguments[1].function_space() if len(Vrow) > 1 or len(Vcol) > 1: @@ -646,19 +636,14 @@ def _get_tensor(self): raise ValueError("Cannot interpolate onto a mesh of a different geometric dimension") if not hasattr(target_mesh, "_parent_mesh") or target_mesh._parent_mesh is not source_mesh: raise ValueError("Can only interpolate across meshes where the source mesh is the parent of the target") - - if vom_onto_other_vom: - # We make our own linear operator for this case using PETSc SFs - tensor = None - else: - Vrow_map = get_interp_node_map(source_mesh, target_mesh, Vrow) - Vcol_map = get_interp_node_map(source_mesh, target_mesh, Vcol) - sparsity = op2.Sparsity((Vrow.dof_dset, Vcol.dof_dset), - [(Vrow_map, Vcol_map, None)], # non-mixed - name=f"{Vrow.name}_{Vcol.name}_sparsity", - nest=False, - block_sparse=True) - tensor = op2.Mat(sparsity) + Vrow_map = get_interp_node_map(source_mesh, target_mesh, Vrow) + Vcol_map = get_interp_node_map(source_mesh, target_mesh, Vcol) + sparsity = op2.Sparsity((Vrow.dof_dset, Vcol.dof_dset), + [(Vrow_map, Vcol_map, None)], # non-mixed + name=f"{Vrow.name}_{Vcol.name}_sparsity", + nest=False, + block_sparse=True) + tensor = op2.Mat(sparsity) f = tensor else: raise ValueError(f"Cannot interpolate an expression with {rank} arguments") @@ -769,15 +754,17 @@ def _get_callable(self): target_mesh = as_domain(dual_arg) source_mesh = extract_unique_domain(operand) or target_mesh arguments = expr.arguments() - f, tensor = self._get_tensor() + if len(arguments) == 2: + # We make our own linear operator for this case using PETSc SFs + tensor = None + else: + f, tensor = self._get_tensor() wrapper = VomOntoVomWrapper(self.V, source_mesh, target_mesh, operand, self.matfree) # NOTE: get_dat_mpi_type ensures we get the correct MPI type for the # data, including the correct data size and dimensional information # (so for vector function spaces in 2 dimensions we might need a # concatenation of 2 MPI.DOUBLE types when we are in real mode) if tensor is not None: - # Callable will do interpolation into our pre-supplied function f - # when it is called. assert f.dat is tensor wrapper.mpi_type, _ = get_dat_mpi_type(f.dat) assert len(arguments) == 1 From fcdb7c5c37cf67f720f35f4080d5cdf5a75e2a48 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Thu, 9 Oct 2025 14:13:01 +0100 Subject: [PATCH 092/125] remove repeated checks --- firedrake/interpolation.py | 121 ++++++++++++++++++------------------- 1 file changed, 59 insertions(+), 62 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 4cb168d3b5..da1246c142 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -275,21 +275,30 @@ def _get_interpolator(expr: Interpolate, bcs=None) -> Interpolator: operand, = expr.ufl_operands target_mesh = as_domain(V) source_mesh = extract_unique_domain(operand) or target_mesh - submesh_interp_implemented = \ - all(isinstance(m.topology, firedrake.mesh.MeshTopology) for m in [target_mesh, source_mesh]) and \ - target_mesh.submesh_ancesters[-1] is source_mesh.submesh_ancesters[-1] and \ - target_mesh.topological_dimension() == source_mesh.topological_dimension() + submesh_interp_implemented = ( + all(isinstance(m.topology, firedrake.mesh.MeshTopology) for m in [target_mesh, source_mesh]) + and target_mesh.submesh_ancesters[-1] is source_mesh.submesh_ancesters[-1] + and target_mesh.topological_dimension() == source_mesh.topological_dimension() + ) if target_mesh is source_mesh or submesh_interp_implemented: return SameMeshInterpolator(expr, bcs=bcs) - else: - if isinstance(target_mesh.topology, VertexOnlyMeshTopology): - if isinstance(source_mesh.topology, VertexOnlyMeshTopology): - return VomOntoVomInterpolator(expr, bcs=bcs) - return SameMeshInterpolator(expr, bcs=bcs) - elif has_mixed_arguments or len(V) > 1: - return MixedInterpolator(expr, bcs=bcs) - else: - return CrossMeshInterpolator(expr, bcs=bcs) + + target_topology = target_mesh.topology + source_topology = source_mesh.topology + + if isinstance(target_topology, VertexOnlyMeshTopology): + if isinstance(source_topology, VertexOnlyMeshTopology): + return VomOntoVomInterpolator(expr, bcs=bcs) + if target_mesh.geometric_dimension() != source_mesh.geometric_dimension(): + raise ValueError("Cannot interpolate onto a mesh of a different geometric dimension") + if not hasattr(target_mesh, "_parent_mesh") or target_mesh._parent_mesh is not source_mesh: + raise ValueError("Can only interpolate across meshes where the source mesh is the parent of the target") + return SameMeshInterpolator(expr, bcs=bcs) + + if has_mixed_arguments or len(V) > 1: + return MixedInterpolator(expr, bcs=bcs) + + return CrossMeshInterpolator(expr, bcs=bcs) class DofNotDefinedError(Exception): @@ -608,34 +617,26 @@ def _get_tensor(self): source_mesh = extract_unique_domain(operand) or target_mesh arguments = expr.arguments() rank = len(arguments) - if rank <= 1: - if rank == 0: - R = firedrake.FunctionSpace(target_mesh, "Real", 0) - f = firedrake.Function(R, dtype=utils.ScalarType) - else: - V_dest = arguments[0].function_space().dual() - f = firedrake.Function(V_dest) - if self.access in {firedrake.MIN, firedrake.MAX}: - finfo = numpy.finfo(f.dat.dtype) - if self.access == firedrake.MIN: - val = firedrake.Constant(finfo.max) - else: - val = firedrake.Constant(finfo.min) - f.assign(val) + if rank == 0: + R = firedrake.FunctionSpace(target_mesh, "Real", 0) + f = firedrake.Function(R, dtype=utils.ScalarType) + tensor = f.dat + elif rank == 1: + V_dest = arguments[0].function_space().dual() + f = firedrake.Function(V_dest) + if self.access in {firedrake.MIN, firedrake.MAX}: + finfo = numpy.finfo(f.dat.dtype) + if self.access == firedrake.MIN: + val = firedrake.Constant(finfo.max) + else: + val = firedrake.Constant(finfo.min) + f.assign(val) tensor = f.dat elif rank == 2: Vrow = arguments[0].function_space() Vcol = arguments[1].function_space() if len(Vrow) > 1 or len(Vcol) > 1: raise NotImplementedError("Interpolation matrix with MixedFunctionSpace requires MixedInterpolator") - vom_onto_other_vom = isinstance(self, VomOntoVomInterpolator) - if isinstance(target_mesh.topology, VertexOnlyMeshTopology) and target_mesh is not source_mesh and not vom_onto_other_vom: - if not isinstance(target_mesh.topology, VertexOnlyMeshTopology): - raise NotImplementedError("Can only interpolate onto a VertexOnlyMesh") - if target_mesh.geometric_dimension() != source_mesh.geometric_dimension(): - raise ValueError("Cannot interpolate onto a mesh of a different geometric dimension") - if not hasattr(target_mesh, "_parent_mesh") or target_mesh._parent_mesh is not source_mesh: - raise ValueError("Can only interpolate across meshes where the source mesh is the parent of the target") Vrow_map = get_interp_node_map(source_mesh, target_mesh, Vrow) Vcol_map = get_interp_node_map(source_mesh, target_mesh, Vcol) sparsity = op2.Sparsity((Vrow.dof_dset, Vcol.dof_dset), @@ -814,33 +815,26 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): target_mesh = as_domain(V) source_mesh = extract_unique_domain(operand) or target_mesh if isinstance(target_mesh.topology, VertexOnlyMeshTopology): - if target_mesh is not source_mesh: - if not isinstance(target_mesh.topology, VertexOnlyMeshTopology): - raise NotImplementedError("Can only interpolate onto a Vertex Only Mesh") - if target_mesh.geometric_dimension() != source_mesh.geometric_dimension(): - raise ValueError("Cannot interpolate onto a mesh of a different geometric dimension") - if not hasattr(target_mesh, "_parent_mesh") or target_mesh._parent_mesh is not source_mesh: - raise ValueError("Can only interpolate across meshes where the source mesh is the parent of the target") - # For trans-mesh interpolation we use a FInAT QuadratureElement as the - # (base) target element with runtime point set expressions as their - # quadrature rule point set and weights from their dual basis. - # NOTE: This setup is useful for thinking about future design - in the - # future this `rebuild` function can be absorbed into FInAT as a - # transformer that eats an element and gives you an equivalent (which - # may or may not be a QuadratureElement) that lets you do run time - # tabulation. Alternatively (and this all depends on future design - # decision about FInAT how dual evaluation should work) the - # to_element's dual basis (which look rather like quadrature rules) can - # have their pointset(s) directly replaced with run-time tabulated - # equivalent(s) (i.e. finat.point_set.UnknownPointSet(s)) - rt_var_name = 'rt_X' - try: - cell = operand.ufl_element().ufl_cell() - except AttributeError: - # expression must be pure function of spatial coordinates so - # domain has correct ufl cell - cell = source_mesh.ufl_cell() - to_element = rebuild(to_element, cell, rt_var_name) + # For trans-mesh interpolation we use a FInAT QuadratureElement as the + # (base) target element with runtime point set expressions as their + # quadrature rule point set and weights from their dual basis. + # NOTE: This setup is useful for thinking about future design - in the + # future this `rebuild` function can be absorbed into FInAT as a + # transformer that eats an element and gives you an equivalent (which + # may or may not be a QuadratureElement) that lets you do run time + # tabulation. Alternatively (and this all depends on future design + # decision about FInAT how dual evaluation should work) the + # to_element's dual basis (which look rather like quadrature rules) can + # have their pointset(s) directly replaced with run-time tabulated + # equivalent(s) (i.e. finat.point_set.UnknownPointSet(s)) + rt_var_name = 'rt_X' + try: + cell = operand.ufl_element().ufl_cell() + except AttributeError: + # expression must be pure function of spatial coordinates so + # domain has correct ufl cell + cell = source_mesh.ufl_cell() + to_element = rebuild(to_element, cell, rt_var_name) cell_set = target_mesh.cell_set if subset is not None: @@ -911,6 +905,7 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): else: copyin = () copyout = () + if isinstance(tensor, op2.Global): parloop_args.append(tensor(access)) elif isinstance(tensor, op2.Dat): @@ -935,9 +930,11 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): bc_cols = [bc for bc in bcs if bc.function_space() == Vcol] lgmaps = [(Vrow.local_to_global_map(bc_rows), Vcol.local_to_global_map(bc_cols))] parloop_args.append(tensor(access, (rows_map, columns_map), lgmaps=lgmaps)) + if oriented: co = target_mesh.cell_orientations() parloop_args.append(co.dat(op2.READ, co.cell_node_map())) + if needs_cell_sizes: cs = source_mesh.cell_sizes parloop_args.append(cs.dat(op2.READ, cs.cell_node_map())) From e7fe358259ff945626e3eea6b47514e55b58cb43 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Thu, 9 Oct 2025 17:24:13 +0100 Subject: [PATCH 093/125] progress on adjoint cross-mesh / vom-to-vom --- firedrake/interpolation.py | 79 ++++++++++++++++++++++++-------------- 1 file changed, 50 insertions(+), 29 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index da1246c142..f912dbd7ab 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -186,26 +186,26 @@ def __init__(self, expr: Interpolate, bcs=None): # self.ufl_interpolate (which carries the dual argument). # See github issue https://github.com/firedrakeproject/firedrake/issues/4592 - if isinstance(self, CrossMeshInterpolator | VomOntoVomInterpolator): - # For bespoke interpolation, we currently rely on different assembly procedures: - # 1) Interpolate(Argument(V1, 1), Argument(V2.dual(), 0)) -> Forward operator (2-form) - # 2) Interpolate(Argument(V1, 0), Argument(V2.dual(), 1)) -> Adjoint operator (2-form) - # 3) Interpolate(Coefficient(V1), Argument(V2.dual(), 0)) -> Forward action (1-form) - # 4) Interpolate(Argument(V1, 0), Cofunction(V2.dual()) -> Adjoint action (1-form) - # 5) Interpolate(Coefficient(V1), Cofunction(V2.dual()) -> Double action (0-form) - - # CrossMeshInterpolator._interpolate only supports forward interpolation (cases 1 and 3). - # For case 2, we first redundantly assemble case 1 and then construct the transpose. - # For cases 4 and 5, we take the forward Interpolate that corresponds to dropping the Cofunction, - # and we separately compute the action against the dropped Cofunction within assemble(). - if not isinstance(dual_arg, ufl.Coargument): - # Drop the Cofunction - expr = expr._ufl_expr_reconstruct_(operand, dual_arg.function_space().dual()) - if expr.is_adjoint: - # Construct the symbolic forward Interpolate - v0, v1 = expr.arguments() - expr = ufl.replace(expr, {v0: v0.reconstruct(number=v1.number()), - v1: v1.reconstruct(number=v0.number())}) + # if isinstance(self, CrossMeshInterpolator | VomOntoVomInterpolator): + # # For bespoke interpolation, we currently rely on different assembly procedures: + # # 1) Interpolate(Argument(V1, 1), Argument(V2.dual(), 0)) -> Forward operator (2-form) + # # 2) Interpolate(Argument(V1, 0), Argument(V2.dual(), 1)) -> Adjoint operator (2-form) + # # 3) Interpolate(Coefficient(V1), Argument(V2.dual(), 0)) -> Forward action (1-form) + # # 4) Interpolate(Argument(V1, 0), Cofunction(V2.dual()) -> Adjoint action (1-form) + # # 5) Interpolate(Coefficient(V1), Cofunction(V2.dual()) -> Double action (0-form) + + # # CrossMeshInterpolator._interpolate only supports forward interpolation (cases 1 and 3). + # # For case 2, we first redundantly assemble case 1 and then construct the transpose. + # # For cases 4 and 5, we take the forward Interpolate that corresponds to dropping the Cofunction, + # # and we separately compute the action against the dropped Cofunction within assemble(). + # if not isinstance(dual_arg, ufl.Coargument): + # # Drop the Cofunction + # expr = expr._ufl_expr_reconstruct_(operand, dual_arg.function_space().dual()) + # if expr.is_adjoint: + # # Construct the symbolic forward Interpolate + # v0, v1 = expr.arguments() + # expr = ufl.replace(expr, {v0: v0.reconstruct(number=v1.number()), + # v1: v1.reconstruct(number=v0.number())}) dual_arg, operand = expr.argument_slots() self.expr_renumbered = operand @@ -234,7 +234,8 @@ def _interpolate(self, *args, **kwargs): def assemble(self, tensor=None): """Assemble the operator (or its action).""" - needs_adjoint = self.ufl_interpolate_renumbered != self.ufl_interpolate + # needs_adjoint = self.ufl_interpolate_renumbered != self.ufl_interpolate + needs_adjoint = self.ufl_interpolate.is_adjoint arguments = self.ufl_interpolate.arguments() if len(arguments) == 2: # Assembling the operator @@ -342,9 +343,10 @@ def __init__(self, expr: Interpolate, bcs=None): from firedrake.assemble import assemble super().__init__(expr, bcs) if self.access != op2.WRITE: - raise NotImplementedError( - "Access other than op2.WRITE not implemented for cross-mesh interpolation." - ) + # raise NotImplementedError( + # "Access other than op2.WRITE not implemented for cross-mesh interpolation." + # ) + self.access = op2.WRITE if self.bcs: raise NotImplementedError("bcs not implemented.") if self.V.ufl_element().mapping() != "identity": @@ -562,7 +564,7 @@ def _interpolate(self, *function, output=None, adjoint=False): # SameMeshInterpolator.interpolate did not effect the result. For # now, I say in the docstring that it only applies to forward # interpolation. - interp = action(expr_adjoint(self.point_eval), f_src_at_src_node_coords) + interp = action(self.point_eval, f_src_at_src_node_coords) assemble(interp, tensor=output) return output @@ -755,6 +757,7 @@ def _get_callable(self): target_mesh = as_domain(dual_arg) source_mesh = extract_unique_domain(operand) or target_mesh arguments = expr.arguments() + if len(arguments) == 2: # We make our own linear operator for this case using PETSc SFs tensor = None @@ -769,10 +772,15 @@ def _get_callable(self): assert f.dat is tensor wrapper.mpi_type, _ = get_dat_mpi_type(f.dat) assert len(arguments) == 1 - - def callable(): - wrapper.forward_operation(f.dat) - return f + if expr.is_adjoint: + assert isinstance(dual_arg, ufl.Cofunction) + def callable(): + wrapper.adjoint_operation(dual_arg.dat, f.dat) + return f + else: + def callable(): + wrapper.forward_operation(f.dat) + return f else: assert len(arguments) == 2 assert tensor is None @@ -1298,6 +1306,19 @@ def forward_operation(self, target_dat): with coeff.dat.vec_ro as coeff_vec, target_dat.vec_wo as target_vec: self.handle.mult(coeff_vec, target_vec) + def adjoint_operation(self, source_dat, target_dat): + """Apply the adjoint interpolation operation. + + Parameters + ---------- + source_dat : dat + The dat from the cofunction (on the target mesh in forward sense). + target_dat : dat + The dat to write the result to (on the source mesh in forward sense). + """ + with source_dat.vec_ro as source_vec, target_dat.vec_wo as target_vec: + self.handle.multHermitian(source_vec, target_vec) + class VomOntoVomDummyMat(object): """Dummy object to stand in for a PETSc ``Mat`` when we are interpolating From efc45c77cb1a41f66925b6d5186a7e335cacfe01 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Thu, 9 Oct 2025 18:16:24 +0100 Subject: [PATCH 094/125] tidy --- firedrake/interpolation.py | 58 ++++++++------------------------------ 1 file changed, 11 insertions(+), 47 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index f912dbd7ab..5ad457a970 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -182,35 +182,6 @@ def __init__(self, expr: Interpolate, bcs=None): self.bcs = bcs self.callable = None - # TODO CrossMeshInterpolator and VomOntoVomXXX are not yet aware of - # self.ufl_interpolate (which carries the dual argument). - # See github issue https://github.com/firedrakeproject/firedrake/issues/4592 - - # if isinstance(self, CrossMeshInterpolator | VomOntoVomInterpolator): - # # For bespoke interpolation, we currently rely on different assembly procedures: - # # 1) Interpolate(Argument(V1, 1), Argument(V2.dual(), 0)) -> Forward operator (2-form) - # # 2) Interpolate(Argument(V1, 0), Argument(V2.dual(), 1)) -> Adjoint operator (2-form) - # # 3) Interpolate(Coefficient(V1), Argument(V2.dual(), 0)) -> Forward action (1-form) - # # 4) Interpolate(Argument(V1, 0), Cofunction(V2.dual()) -> Adjoint action (1-form) - # # 5) Interpolate(Coefficient(V1), Cofunction(V2.dual()) -> Double action (0-form) - - # # CrossMeshInterpolator._interpolate only supports forward interpolation (cases 1 and 3). - # # For case 2, we first redundantly assemble case 1 and then construct the transpose. - # # For cases 4 and 5, we take the forward Interpolate that corresponds to dropping the Cofunction, - # # and we separately compute the action against the dropped Cofunction within assemble(). - # if not isinstance(dual_arg, ufl.Coargument): - # # Drop the Cofunction - # expr = expr._ufl_expr_reconstruct_(operand, dual_arg.function_space().dual()) - # if expr.is_adjoint: - # # Construct the symbolic forward Interpolate - # v0, v1 = expr.arguments() - # expr = ufl.replace(expr, {v0: v0.reconstruct(number=v1.number()), - # v1: v1.reconstruct(number=v0.number())}) - - dual_arg, operand = expr.argument_slots() - self.expr_renumbered = operand - self.ufl_interpolate_renumbered = expr - access = expr.options.access if not isinstance(dual_arg, ufl.Coargument): # Matrix-free assembly of 0-form or 1-form requires INC access @@ -234,7 +205,6 @@ def _interpolate(self, *args, **kwargs): def assemble(self, tensor=None): """Assemble the operator (or its action).""" - # needs_adjoint = self.ufl_interpolate_renumbered != self.ufl_interpolate needs_adjoint = self.ufl_interpolate.is_adjoint arguments = self.ufl_interpolate.arguments() if len(arguments) == 2: @@ -250,15 +220,9 @@ def assemble(self, tensor=None): return tensor or firedrake.AssembledMatrix(arguments, self.bcs, res) else: # Assembling the action - cofunctions = () - if needs_adjoint: - # The renumbered Interpolate has dropped Cofunctions. - # We need to explicitly operate on them. - dual_arg, _ = self.ufl_interpolate.argument_slots() - if not isinstance(dual_arg, ufl.Coargument): - cofunctions = (dual_arg,) - - if needs_adjoint and len(arguments) == 0: + dual_arg, _ = self.ufl_interpolate.argument_slots() + cofunctions = (dual_arg,) + if len(arguments) == 0: Iu = self._interpolate() return firedrake.assemble(ufl.Action(*cofunctions, Iu), tensor=tensor) else: @@ -366,7 +330,7 @@ def __init__(self, expr: Interpolate, bcs=None): else: self.missing_points_behaviour = MissingPointsBehaviour.ERROR - self.src_mesh = extract_unique_domain(self.expr_renumbered) + self.src_mesh = extract_unique_domain(self.expr) self.dest_mesh = as_domain(self.V) if self.src_mesh.geometric_dimension() != self.dest_mesh.geometric_dimension(): raise ValueError("Geometric dimensions of source and destination meshes must match.") @@ -434,7 +398,7 @@ def _get_symbolic_expressions(self): # Get expression for point evaluation at the dest_node_coords P0DG_vom = fs_type(self.vom, "DG", 0) - self.point_eval = interpolate(self.expr_renumbered, P0DG_vom) + self.point_eval = interpolate(self.expr, P0DG_vom) if self.nargs == 2: # If assembling the operator, we need the concrete permutation matrix @@ -449,11 +413,11 @@ def _interpolate(self, *function, output=None, adjoint=False): """Compute the interpolation. """ from firedrake.assemble import assemble - expr_args = extract_arguments(self.expr_renumbered) + expr_args = extract_arguments(self.expr) nargs = len(expr_args) if adjoint and not nargs: raise ValueError("Can currently only apply adjoint interpolation with arguments.") - if nargs != len(function): + if self.nargs != len(function): raise ValueError(f"Passed {len(function)} Functions to interpolate, expected {self.nargs}") if nargs: (f_src,) = function @@ -610,10 +574,10 @@ def __init__(self, expr, bcs=None): self.callable = self._get_callable() except FIAT.hdiv_trace.TraceError: raise NotImplementedError("Can't interpolate onto traces.") - self.arguments = self.ufl_interpolate_renumbered.arguments() + self.arguments = self.ufl_interpolate.arguments() def _get_tensor(self): - expr = self.ufl_interpolate_renumbered + expr = self.ufl_interpolate dual_arg, operand = expr.argument_slots() target_mesh = as_domain(dual_arg) source_mesh = extract_unique_domain(operand) or target_mesh @@ -653,7 +617,7 @@ def _get_tensor(self): return f, tensor def _get_callable(self): - expr = self.ufl_interpolate_renumbered + expr = self.ufl_interpolate dual_arg, operand = expr.argument_slots() arguments = expr.arguments() rank = len(arguments) @@ -752,7 +716,7 @@ def __init__(self, expr: Interpolate, bcs=None): super().__init__(expr, bcs=bcs) def _get_callable(self): - expr = self.ufl_interpolate_renumbered + expr = self.ufl_interpolate dual_arg, operand = expr.argument_slots() target_mesh = as_domain(dual_arg) source_mesh = extract_unique_domain(operand) or target_mesh From f71dd4745488a05c9e9076bbb3de5028a67f3d06 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Thu, 9 Oct 2025 19:55:03 +0100 Subject: [PATCH 095/125] tidy --- firedrake/interpolation.py | 55 ++++++++++++++------------------------ 1 file changed, 20 insertions(+), 35 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 5ad457a970..e18819962c 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -205,7 +205,6 @@ def _interpolate(self, *args, **kwargs): def assemble(self, tensor=None): """Assemble the operator (or its action).""" - needs_adjoint = self.ufl_interpolate.is_adjoint arguments = self.ufl_interpolate.arguments() if len(arguments) == 2: # Assembling the operator @@ -219,14 +218,7 @@ def assemble(self, tensor=None): res = petsc_mat return tensor or firedrake.AssembledMatrix(arguments, self.bcs, res) else: - # Assembling the action - dual_arg, _ = self.ufl_interpolate.argument_slots() - cofunctions = (dual_arg,) - if len(arguments) == 0: - Iu = self._interpolate() - return firedrake.assemble(ufl.Action(*cofunctions, Iu), tensor=tensor) - else: - return self._interpolate(*cofunctions, output=tensor, adjoint=needs_adjoint) + return self._interpolate(output=tensor) def _get_interpolator(expr: Interpolate, bcs=None) -> Interpolator: @@ -409,26 +401,17 @@ def _get_symbolic_expressions(self): self.point_eval_input_ordering = interpolate(firedrake.TrialFunction(P0DG_vom), P0DG_vom_i_o, matfree=self.matfree) @PETSc.Log.EventDecorator() - def _interpolate(self, *function, output=None, adjoint=False): + def _interpolate(self, output=None): """Compute the interpolation. """ from firedrake.assemble import assemble - expr_args = extract_arguments(self.expr) + expr_args = self.ufl_interpolate.arguments() nargs = len(expr_args) + adjoint = self.ufl_interpolate.is_adjoint if adjoint and not nargs: raise ValueError("Can currently only apply adjoint interpolation with arguments.") - if self.nargs != len(function): - raise ValueError(f"Passed {len(function)} Functions to interpolate, expected {self.nargs}") - if nargs: - (f_src,) = function - if not hasattr(f_src, "dat"): - raise ValueError( - "The expression had arguments: we therefore need to be given a Function (not an expression) to interpolate!" - ) - else: - f_src = self.expr - if adjoint: + f_src = self.ufl_interpolate.argument_slots()[0] try: V_dest = self.expr.function_space().dual() except AttributeError: @@ -443,6 +426,7 @@ def _interpolate(self, *function, output=None, adjoint=False): "Can't adjoint interpolate an expression with no coefficients or arguments." ) else: + f_src = self.expr V_dest = self.V if output: @@ -454,7 +438,6 @@ def _interpolate(self, *function, output=None, adjoint=False): if not adjoint: if f_src is self.expr: # f_src is already contained in self.point_eval_interpolate - assert not nargs f_src_at_dest_node_coords_src_mesh_decomp = assemble(self.point_eval) else: f_src_at_dest_node_coords_src_mesh_decomp = assemble(action(self.point_eval, f_src)) @@ -667,15 +650,19 @@ def callable(loops, f): return partial(callable, loops, f) @PETSc.Log.EventDecorator() - def _interpolate(self, *function, output=None, adjoint=False): + def _interpolate(self, output=None): """Compute the interpolation. For arguments, see :class:`.Interpolator`. """ assembled_interpolator = self.callable() - - if len(self.arguments) == 2 and len(function) > 0: - function, = function + adjoint = self.ufl_interpolate.is_adjoint + dual_arg, operand = self.ufl_interpolate.argument_slots() + if adjoint: + function = dual_arg + else: + function = operand + if len(self.arguments) == 2: if not hasattr(function, "dat"): raise ValueError("The expression had arguments: we therefore need to be given a Function (not an expression) to interpolate!") if adjoint: @@ -701,13 +688,10 @@ def _interpolate(self, *function, output=None, adjoint=False): if output: output.assign(assembled_interpolator) return output - if isinstance(self.V, firedrake.Function): - return self.V + if len(self.arguments) == 0: + return assembled_interpolator.dat.data.item() else: - if len(self.arguments) == 0: - return assembled_interpolator.dat.data.item() - else: - return assembled_interpolator + return assembled_interpolator class VomOntoVomInterpolator(SameMeshInterpolator): @@ -1544,7 +1528,7 @@ def _get_callable(self): tensor = firedrake.AssembledMatrix(self.arguments, self.bcs, petscmat) return tensor.M - def _interpolate(self, *function, output=None, adjoint=False, **kwargs): + def _interpolate(self, output=None, **kwargs): """Assemble the action.""" rank = len(self.arguments) if rank == 0: @@ -1559,6 +1543,7 @@ def _interpolate(self, *function, output=None, adjoint=False, **kwargs): sub_tensor.assign(sum(self[i].assemble(**kwargs) for i in self if i[0] == k)) elif rank == 2: for k, sub_tensor in enumerate(output.subfunctions): - sub_tensor.assign(sum(self[i]._interpolate(*function, adjoint=adjoint, **kwargs) + sub_tensor.assign(sum(self[i]._interpolate(**kwargs) for i in self if i[0] == k)) return output + \ No newline at end of file From 17469c540e53fafc465e3c9ca1ad738025864bdf Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Fri, 10 Oct 2025 12:56:02 +0100 Subject: [PATCH 096/125] fixes --- firedrake/interpolation.py | 272 ++++++++++++++----------------------- 1 file changed, 101 insertions(+), 171 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index e18819962c..cd8289dd04 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -28,6 +28,7 @@ from firedrake import tsfc_interface, utils from firedrake.ufl_expr import Argument, Coargument, action, adjoint as expr_adjoint from firedrake.cofunction import Cofunction +from firedrake.function import Function from firedrake.mesh import MissingPointsBehaviour, VertexOnlyMeshMissingPointsError, VertexOnlyMeshTopology from firedrake.petsc import PETSc from firedrake.halo import _get_mtype as get_dat_mpi_type @@ -114,10 +115,9 @@ def __init__(self, expr: Expr, V: WithGeometry | ufl.BaseForm, **kwargs): if isinstance(V, WithGeometry): # Need to create a Firedrake Coargument so it has a .function_space() method V = Argument(V.dual(), 1 if self.is_adjoint else 0) - - target_shape = V.arguments()[0].function_space().value_shape - if expr.ufl_shape != target_shape: - raise ValueError(f"Shape mismatch: Expression shape {expr.ufl_shape}, FunctionSpace shape {target_shape}.") + self.target_space = V.arguments()[0].function_space() + if expr.ufl_shape != self.target_space.value_shape: + raise ValueError(f"Shape mismatch: Expression shape {expr.ufl_shape}, FunctionSpace shape {self.target_space.value_shape}.") super().__init__(expr, V) @@ -172,9 +172,16 @@ def __init__(self, expr: Interpolate, bcs=None): List of boundary conditions to zero-out in the output function space. By default None. """ dual_arg, operand = expr.argument_slots() - self.ufl_interpolate = expr - self.expr = operand - self.V = dual_arg.function_space().dual() + self.expr = expr + self.expr_args = expr.arguments() + self.rank = len(self.expr_args) + self.operand = operand + self.dual_arg = dual_arg + self.V_dest = self.expr.target_space + self.target_mesh = as_domain(self.V_dest) + self.source_mesh = extract_unique_domain(operand) or self.target_mesh + + # Interpolation options self.subset = expr.options.subset self.allow_missing_dofs = expr.options.allow_missing_dofs self.default_missing_val = expr.options.default_missing_val @@ -205,8 +212,7 @@ def _interpolate(self, *args, **kwargs): def assemble(self, tensor=None): """Assemble the operator (or its action).""" - arguments = self.ufl_interpolate.arguments() - if len(arguments) == 2: + if self.rank == 2: # Assembling the operator res = tensor.petscmat if tensor else PETSc.Mat() # Get the interpolation matrix @@ -216,21 +222,19 @@ def assemble(self, tensor=None): petsc_mat.copy(tensor.petscmat) else: res = petsc_mat - return tensor or firedrake.AssembledMatrix(arguments, self.bcs, res) + return tensor or firedrake.AssembledMatrix(self.expr_args, self.bcs, res) else: return self._interpolate(output=tensor) def _get_interpolator(expr: Interpolate, bcs=None) -> Interpolator: - V = expr.argument_slots()[0].function_space().dual() # Target function space - arguments = expr.arguments() has_mixed_arguments = any(len(arg.function_space()) > 1 for arg in arguments) if len(arguments) == 2 and has_mixed_arguments: return MixedInterpolator(expr, bcs=bcs) operand, = expr.ufl_operands - target_mesh = as_domain(V) + target_mesh = as_domain(expr.target_space) source_mesh = extract_unique_domain(operand) or target_mesh submesh_interp_implemented = ( all(isinstance(m.topology, firedrake.mesh.MeshTopology) for m in [target_mesh, source_mesh]) @@ -252,7 +256,7 @@ def _get_interpolator(expr: Interpolate, bcs=None) -> Interpolator: raise ValueError("Can only interpolate across meshes where the source mesh is the parent of the target") return SameMeshInterpolator(expr, bcs=bcs) - if has_mixed_arguments or len(V) > 1: + if has_mixed_arguments or len(expr.target_space) > 1: return MixedInterpolator(expr, bcs=bcs) return CrossMeshInterpolator(expr, bcs=bcs) @@ -305,7 +309,7 @@ def __init__(self, expr: Interpolate, bcs=None): self.access = op2.WRITE if self.bcs: raise NotImplementedError("bcs not implemented.") - if self.V.ufl_element().mapping() != "identity": + if self.V_dest.ufl_element().mapping() != "identity": # Identity mapping between reference cell and physical coordinates # implies point evaluation nodes. A more general version would # require finding the global coordinates of all quadrature points @@ -314,20 +318,15 @@ def __init__(self, expr: Interpolate, bcs=None): "Can only interpolate into spaces with point evaluation nodes." ) - self.arguments = self.ufl_interpolate.arguments() - self.nargs = len(self.arguments) - if self.allow_missing_dofs: self.missing_points_behaviour = MissingPointsBehaviour.IGNORE else: self.missing_points_behaviour = MissingPointsBehaviour.ERROR - self.src_mesh = extract_unique_domain(self.expr) - self.dest_mesh = as_domain(self.V) - if self.src_mesh.geometric_dimension() != self.dest_mesh.geometric_dimension(): + if self.source_mesh.geometric_dimension() != self.target_mesh.geometric_dimension(): raise ValueError("Geometric dimensions of source and destination meshes must match.") - dest_element = self.V.ufl_element() + dest_element = self.V_dest.ufl_element() if isinstance(dest_element, finat.ufl.MixedElement): if isinstance(dest_element, (finat.ufl.VectorElement, finat.ufl.TensorElement)): # In this case all sub elements are equal @@ -345,12 +344,12 @@ def __init__(self, expr: Interpolate, bcs=None): self._get_symbolic_expressions() - if self.nargs == 2: + if self.rank == 2: # The cross-mesh interpolation matrix is the product of the # `self.point_eval_interpolate` and the permutation # given by `self.to_input_ordering_interpolate`. symbolic = action(self.point_eval_input_ordering, self.point_eval) - if self.ufl_interpolate.is_adjoint: + if self.expr.is_adjoint: symbolic = expr_adjoint(symbolic) self.handle = assemble(symbolic).petscmat self.callable = lambda: self @@ -366,21 +365,21 @@ def _get_symbolic_expressions(self): """ from firedrake.assemble import assemble # Immerse coordinates of V_dest point evaluation dofs in src_mesh - V_dest_vec = firedrake.VectorFunctionSpace(self.dest_mesh, self.dest_element) - f_dest_node_coords = assemble(interpolate(self.dest_mesh.coordinates, V_dest_vec)) - dest_node_coords = f_dest_node_coords.dat.data_ro.reshape(-1, self.dest_mesh.geometric_dimension()) + V_dest_vec = firedrake.VectorFunctionSpace(self.target_mesh, self.dest_element) + f_dest_node_coords = assemble(interpolate(self.target_mesh.coordinates, V_dest_vec)) + dest_node_coords = f_dest_node_coords.dat.data_ro.reshape(-1, self.target_mesh.geometric_dimension()) try: self.vom = firedrake.VertexOnlyMesh( - self.src_mesh, + self.source_mesh, dest_node_coords, redundant=False, missing_points_behaviour=self.missing_points_behaviour, ) except VertexOnlyMeshMissingPointsError: - raise DofNotDefinedError(self.src_mesh, self.dest_mesh) + raise DofNotDefinedError(self.source_mesh, self.target_mesh) # Get the correct type of function space - shape = self.V.ufl_function_space().value_shape + shape = self.V_dest.ufl_function_space().value_shape if len(shape) == 0: fs_type = firedrake.FunctionSpace elif len(shape) == 1: @@ -390,9 +389,9 @@ def _get_symbolic_expressions(self): # Get expression for point evaluation at the dest_node_coords P0DG_vom = fs_type(self.vom, "DG", 0) - self.point_eval = interpolate(self.expr, P0DG_vom) + self.point_eval = interpolate(self.operand, P0DG_vom) - if self.nargs == 2: + if self.rank == 2: # If assembling the operator, we need the concrete permutation matrix self.matfree = False @@ -405,42 +404,19 @@ def _interpolate(self, output=None): """Compute the interpolation. """ from firedrake.assemble import assemble - expr_args = self.ufl_interpolate.arguments() - nargs = len(expr_args) - adjoint = self.ufl_interpolate.is_adjoint - if adjoint and not nargs: - raise ValueError("Can currently only apply adjoint interpolation with arguments.") - if adjoint: - f_src = self.ufl_interpolate.argument_slots()[0] - try: - V_dest = self.expr.function_space().dual() - except AttributeError: - if nargs: - V_dest = expr_args[-1].function_space().dual() - else: - coeffs = extract_coefficients(self.expr) - if len(coeffs): - V_dest = coeffs[0].function_space().dual() - else: - raise ValueError( - "Can't adjoint interpolate an expression with no coefficients or arguments." - ) + adjoint = self.expr.is_adjoint + if adjoint: + f_src = self.dual_arg + V_dest = self.expr_args[0].function_space().dual() else: - f_src = self.expr - V_dest = self.V + f_src = self.operand + V_dest = self.V_dest - if output: - if output.function_space() != V_dest: - raise ValueError("Given output has the wrong function space!") - else: - output = firedrake.Function(V_dest) + output = output or Function(V_dest) if not adjoint: - if f_src is self.expr: - # f_src is already contained in self.point_eval_interpolate - f_src_at_dest_node_coords_src_mesh_decomp = assemble(self.point_eval) - else: - f_src_at_dest_node_coords_src_mesh_decomp = assemble(action(self.point_eval, f_src)) + # f_src is already contained in self.point_eval_interpolate + f_src_at_dest_node_coords_src_mesh_decomp = assemble(self.point_eval) f_src_at_dest_node_coords_dest_mesh_decomp = firedrake.Function( self.point_eval_input_ordering.function_space() ) @@ -534,7 +510,7 @@ def __init__(self, expr, bcs=None): operand, = expr.ufl_operands else: operand = expr - target_mesh = as_domain(self.V) + target_mesh = as_domain(self.V_dest) source_mesh = extract_unique_domain(operand) or target_mesh target = target_mesh.topology source = source_mesh.topology @@ -553,26 +529,19 @@ def __init__(self, expr, bcs=None): # Do not need subset as target <= source. pass self.subset = subset + try: self.callable = self._get_callable() except FIAT.hdiv_trace.TraceError: raise NotImplementedError("Can't interpolate onto traces.") - self.arguments = self.ufl_interpolate.arguments() - def _get_tensor(self): - expr = self.ufl_interpolate - dual_arg, operand = expr.argument_slots() - target_mesh = as_domain(dual_arg) - source_mesh = extract_unique_domain(operand) or target_mesh - arguments = expr.arguments() - rank = len(arguments) - if rank == 0: - R = firedrake.FunctionSpace(target_mesh, "Real", 0) - f = firedrake.Function(R, dtype=utils.ScalarType) - tensor = f.dat - elif rank == 1: - V_dest = arguments[0].function_space().dual() - f = firedrake.Function(V_dest) + def _get_tensor(self) -> op2.Mat | Function | Cofunction: + if self.rank == 0: + R = firedrake.FunctionSpace(self.target_mesh, "Real", 0) + f = Function(R, dtype=utils.ScalarType) + elif self.rank == 1: + V_dest = self.expr_args[0].function_space().dual() + f = Function(V_dest) if self.access in {firedrake.MIN, firedrake.MAX}: finfo = numpy.finfo(f.dat.dtype) if self.access == firedrake.MIN: @@ -580,53 +549,50 @@ def _get_tensor(self): else: val = firedrake.Constant(finfo.min) f.assign(val) - tensor = f.dat - elif rank == 2: - Vrow = arguments[0].function_space() - Vcol = arguments[1].function_space() + elif self.rank == 2: + Vrow = self.expr_args[0].function_space() + Vcol = self.expr_args[1].function_space() if len(Vrow) > 1 or len(Vcol) > 1: raise NotImplementedError("Interpolation matrix with MixedFunctionSpace requires MixedInterpolator") - Vrow_map = get_interp_node_map(source_mesh, target_mesh, Vrow) - Vcol_map = get_interp_node_map(source_mesh, target_mesh, Vcol) + Vrow_map = get_interp_node_map(self.source_mesh, self.target_mesh, Vrow) + Vcol_map = get_interp_node_map(self.source_mesh, self.target_mesh, Vcol) sparsity = op2.Sparsity((Vrow.dof_dset, Vcol.dof_dset), [(Vrow_map, Vcol_map, None)], # non-mixed name=f"{Vrow.name}_{Vcol.name}_sparsity", nest=False, block_sparse=True) - tensor = op2.Mat(sparsity) - f = tensor + f = op2.Mat(sparsity) else: - raise ValueError(f"Cannot interpolate an expression with {rank} arguments") - return f, tensor + raise ValueError(f"Cannot interpolate an expression with {self.rank} arguments") + return f def _get_callable(self): - expr = self.ufl_interpolate - dual_arg, operand = expr.argument_slots() - arguments = expr.arguments() - rank = len(arguments) - f, tensor = self._get_tensor() + f = self._get_tensor() + if isinstance(f, op2.Mat): + tensor = f + else: + tensor = f.dat loops = [] if self.access == op2.INC: loops.append(tensor.zero) - dual_arg, operand = expr.argument_slots() # Arguments in the operand are allowed to be from a MixedFunctionSpace # We need to split the target space V and generate separate kernels - if len(arguments) == 2: - expressions = {(0,): expr} - elif isinstance(dual_arg, Coargument): + if self.rank == 2: + expressions = {(0,): self.expr} + elif isinstance(self.dual_arg, Coargument): # Split in the coargument - expressions = dict(firedrake.formmanipulation.split_form(expr)) + expressions = dict(firedrake.formmanipulation.split_form(self.expr)) else: # Split in the cofunction: split_form can only split in the coargument # Replace the cofunction with a coargument to construct the Jacobian - interp = expr._ufl_expr_reconstruct_(operand, self.V) + interp = self.expr._ufl_expr_reconstruct_(self.operand, self.V_dest) # Split the Jacobian into blocks interp_split = dict(firedrake.formmanipulation.split_form(interp)) # Split the cofunction - dual_split = dict(firedrake.formmanipulation.split_form(dual_arg)) + dual_split = dict(firedrake.formmanipulation.split_form(self.dual_arg)) # Combine the splits by taking their action expressions = {i: action(interp_split[i], dual_split[i[-1:]]) for i in interp_split} @@ -636,10 +602,10 @@ def _get_callable(self): continue arguments = sub_expr.arguments() sub_space = sub_expr.argument_slots()[0].function_space().dual() - sub_tensor = tensor[indices[0]] if rank == 1 else tensor + sub_tensor = tensor[indices[0]] if self.rank == 1 else tensor loops.extend(_interpolator(sub_space, sub_tensor, sub_expr, self.subset, arguments, self.access, bcs=self.bcs)) - if self.bcs and rank == 1: + if self.bcs and self.rank == 1: loops.extend(partial(bc.apply, f) for bc in self.bcs) def callable(loops, f): @@ -655,43 +621,16 @@ def _interpolate(self, output=None): For arguments, see :class:`.Interpolator`. """ + assert self.rank < 2 assembled_interpolator = self.callable() - adjoint = self.ufl_interpolate.is_adjoint - dual_arg, operand = self.ufl_interpolate.argument_slots() - if adjoint: - function = dual_arg - else: - function = operand - if len(self.arguments) == 2: - if not hasattr(function, "dat"): - raise ValueError("The expression had arguments: we therefore need to be given a Function (not an expression) to interpolate!") - if adjoint: - mul = assembled_interpolator.handle.multHermitian - col, row = self.arguments - else: - mul = assembled_interpolator.handle.mult - row, col = self.arguments - V = row.function_space().dual() - assert function.function_space() == col.function_space() - - result = output or firedrake.Function(V) - with function.dat.vec_ro as x, result.dat.vec_wo as out: - if x is not out: - mul(x, out) - else: - out_ = out.duplicate() - mul(x, out_) - out_.copy(result=out) - return result + if output: + output.assign(assembled_interpolator) + return output + if self.rank == 0: + return assembled_interpolator.dat.data.item() else: - if output: - output.assign(assembled_interpolator) - return output - if len(self.arguments) == 0: - return assembled_interpolator.dat.data.item() - else: - return assembled_interpolator + return assembled_interpolator class VomOntoVomInterpolator(SameMeshInterpolator): @@ -700,45 +639,41 @@ def __init__(self, expr: Interpolate, bcs=None): super().__init__(expr, bcs=bcs) def _get_callable(self): - expr = self.ufl_interpolate - dual_arg, operand = expr.argument_slots() - target_mesh = as_domain(dual_arg) - source_mesh = extract_unique_domain(operand) or target_mesh - arguments = expr.arguments() + target_mesh = as_domain(self.dual_arg) + source_mesh = extract_unique_domain(self.operand) or target_mesh - if len(arguments) == 2: + if self.rank == 2: # We make our own linear operator for this case using PETSc SFs tensor = None else: - f, tensor = self._get_tensor() - wrapper = VomOntoVomWrapper(self.V, source_mesh, target_mesh, operand, self.matfree) + f = self._get_tensor() + tensor = f.dat + wrapper = VomOntoVomWrapper(self.V_dest, source_mesh, target_mesh, self.operand, self.matfree) # NOTE: get_dat_mpi_type ensures we get the correct MPI type for the # data, including the correct data size and dimensional information # (so for vector function spaces in 2 dimensions we might need a # concatenation of 2 MPI.DOUBLE types when we are in real mode) if tensor is not None: - assert f.dat is tensor wrapper.mpi_type, _ = get_dat_mpi_type(f.dat) - assert len(arguments) == 1 - if expr.is_adjoint: - assert isinstance(dual_arg, ufl.Cofunction) + assert self.rank == 1 + if self.expr.is_adjoint: + assert isinstance(self.dual_arg, ufl.Cofunction) def callable(): - wrapper.adjoint_operation(dual_arg.dat, f.dat) + wrapper.adjoint_operation(self.dual_arg.dat, f.dat) return f else: def callable(): wrapper.forward_operation(f.dat) return f else: - assert len(arguments) == 2 - assert tensor is None + assert self.rank == 2 # we know we will be outputting either a function or a cofunction, # both of which will use a dat as a data carrier. At present, the # data type does not depend on function space dimension, so we can # safely use the argument function space. NOTE: If this changes # after cofunctions are fully implemented, this will need to be # reconsidered. - temp_source_func = firedrake.Function(arguments[1].function_space()) + temp_source_func = firedrake.Function(self.expr_args[1].function_space()) wrapper.mpi_type, _ = get_dat_mpi_type(temp_source_func.dat) # Leave wrapper inside a callable so we can access the handle @@ -1260,9 +1195,9 @@ def adjoint_operation(self, source_dat, target_dat): Parameters ---------- source_dat : dat - The dat from the cofunction (on the target mesh in forward sense). + The dat from the cofunction. target_dat : dat - The dat to write the result to (on the source mesh in forward sense). + The dat to write the result to. """ with source_dat.vec_ro as source_vec, target_dat.vec_wo as target_vec: self.handle.multHermitian(source_vec, target_vec) @@ -1478,17 +1413,13 @@ class MixedInterpolator(Interpolator): """ def __init__(self, expr, bcs=None): super().__init__(expr, bcs=bcs) - expr = self.ufl_interpolate - self.arguments = expr.arguments() - rank = len(self.arguments) - needs_action = len([a for a in self.arguments if isinstance(a, Coargument)]) == 0 + needs_action = len([a for a in self.expr_args if isinstance(a, Coargument)]) == 0 if needs_action: - dual_arg, operand = expr.argument_slots() # Split the dual argument - dual_split = dict(firedrake.formmanipulation.split_form(dual_arg)) + dual_split = dict(firedrake.formmanipulation.split_form(self.dual_arg)) # Create the Jacobian to be split into blocks - expr = expr._ufl_expr_reconstruct_(operand, self.V) + self.expr = self.expr._ufl_expr_reconstruct_(self.operand, self.V_dest) Isub = {} for indices, form in firedrake.formmanipulation.split_form(expr): @@ -1497,7 +1428,7 @@ def __init__(self, expr, bcs=None): continue vi, _ = form.argument_slots() Vtarget = vi.function_space().dual() - if bcs and rank != 0: + if bcs and self.rank != 0: args = form.arguments() Vsource = args[1 - vi.number()].function_space() sub_bcs = [bc for bc in bcs if bc.function_space() in {Vsource, Vtarget}] @@ -1520,28 +1451,27 @@ def __iter__(self): def _get_callable(self): """Assemble the operator.""" - shape = tuple(len(a.function_space()) for a in self.arguments) + shape = tuple(len(a.function_space()) for a in self.expr_args) blocks = numpy.full(shape, PETSc.Mat(), dtype=object) for i in self: blocks[i] = self[i].callable().handle petscmat = PETSc.Mat().createNest(blocks) - tensor = firedrake.AssembledMatrix(self.arguments, self.bcs, petscmat) + tensor = firedrake.AssembledMatrix(self.expr_args, self.bcs, petscmat) return tensor.M def _interpolate(self, output=None, **kwargs): """Assemble the action.""" - rank = len(self.arguments) - if rank == 0: + if self.rank == 0: result = sum(self[i].assemble(**kwargs) for i in self) return output.assign(result) if output else result if output is None: - output = firedrake.Function(self.arguments[-1].function_space().dual()) + output = firedrake.Function(self.expr_args[-1].function_space().dual()) - if rank == 1: + if self.rank == 1: for k, sub_tensor in enumerate(output.subfunctions): sub_tensor.assign(sum(self[i].assemble(**kwargs) for i in self if i[0] == k)) - elif rank == 2: + elif self.rank == 2: for k, sub_tensor in enumerate(output.subfunctions): sub_tensor.assign(sum(self[i]._interpolate(**kwargs) for i in self if i[0] == k)) From 21d10686f21e1ff7dba11e555fee0283f17d2116 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Fri, 10 Oct 2025 13:32:34 +0100 Subject: [PATCH 097/125] fixes --- firedrake/interpolation.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index cd8289dd04..bed074640e 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -177,7 +177,7 @@ def __init__(self, expr: Interpolate, bcs=None): self.rank = len(self.expr_args) self.operand = operand self.dual_arg = dual_arg - self.V_dest = self.expr.target_space + self.V_dest = dual_arg.function_space().dual() self.target_mesh = as_domain(self.V_dest) self.source_mesh = extract_unique_domain(operand) or self.target_mesh @@ -1414,7 +1414,8 @@ class MixedInterpolator(Interpolator): def __init__(self, expr, bcs=None): super().__init__(expr, bcs=bcs) - needs_action = len([a for a in self.expr_args if isinstance(a, Coargument)]) == 0 + # We need a Coargument in order to split the Interpolate + needs_action = not any(isinstance(a, Coargument) for a in self.expr_args) if needs_action: # Split the dual argument dual_split = dict(firedrake.formmanipulation.split_form(self.dual_arg)) From 19187a5c72ebcbc65483878b971a968a1d3da3e6 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Fri, 10 Oct 2025 14:52:26 +0100 Subject: [PATCH 098/125] fixed zero-form cross mesh --- firedrake/interpolation.py | 134 +++++++++++++++---------------------- 1 file changed, 53 insertions(+), 81 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index bed074640e..7bfa754c10 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -388,108 +388,80 @@ def _get_symbolic_expressions(self): fs_type = partial(firedrake.TensorFunctionSpace, shape=shape) # Get expression for point evaluation at the dest_node_coords - P0DG_vom = fs_type(self.vom, "DG", 0) - self.point_eval = interpolate(self.operand, P0DG_vom) + self.P0DG_vom = fs_type(self.vom, "DG", 0) + self.point_eval = interpolate(self.operand, self.P0DG_vom) if self.rank == 2: # If assembling the operator, we need the concrete permutation matrix self.matfree = False # Interpolate into the input-ordering VOM - P0DG_vom_i_o = fs_type(self.vom.input_ordering, "DG", 0) - self.point_eval_input_ordering = interpolate(firedrake.TrialFunction(P0DG_vom), P0DG_vom_i_o, matfree=self.matfree) + self.P0DG_vom_input_ordering = fs_type(self.vom.input_ordering, "DG", 0) + self.point_eval_input_ordering = interpolate(firedrake.TrialFunction(self.P0DG_vom), + self.P0DG_vom_input_ordering, matfree=self.matfree) - @PETSc.Log.EventDecorator() def _interpolate(self, output=None): - """Compute the interpolation. - """ from firedrake.assemble import assemble - adjoint = self.expr.is_adjoint - if adjoint: - f_src = self.dual_arg + if self.expr.is_adjoint: + f = self.dual_arg V_dest = self.expr_args[0].function_space().dual() else: - f_src = self.operand V_dest = self.V_dest output = output or Function(V_dest) - if not adjoint: - # f_src is already contained in self.point_eval_interpolate - f_src_at_dest_node_coords_src_mesh_decomp = assemble(self.point_eval) - f_src_at_dest_node_coords_dest_mesh_decomp = firedrake.Function( - self.point_eval_input_ordering.function_space() - ) - # We have to create the Function before interpolating so we can - # set default missing values (if requested). + if not self.expr.is_adjoint: + # We evaluate the operand at the node coordinates of the destination space + f_point_eval = assemble(self.point_eval) + + # We create the input-ordering Function before interpolating so we can + # set default missing values if required. + f_point_eval_input_ordering = Function(self.P0DG_vom_input_ordering) if self.default_missing_val is not None: - f_src_at_dest_node_coords_dest_mesh_decomp.dat.data_wo[ - : - ] = self.default_missing_val + f_point_eval_input_ordering.dat.data_wo[:] = self.default_missing_val elif self.allow_missing_dofs: - # If we have allowed missing points we know we might end up - # with points in the target mesh that are not in the source - # mesh. However, since we haven't specified a default missing - # value we expect the interpolation to leave these points - # unchanged. By setting the dat values to NaN we can later - # identify these points and skip over them when assigning to - # the output function. - f_src_at_dest_node_coords_dest_mesh_decomp.dat.data_wo[:] = numpy.nan - - interp = action(self.point_eval_input_ordering, f_src_at_dest_node_coords_src_mesh_decomp) - assemble(interp, tensor=f_src_at_dest_node_coords_dest_mesh_decomp) - - # we can now confidently assign this to a function on V_dest + # If we allow missing points there may be points in the target + # mesh that are not in the source mesh. If we don't specify a + # default missing value we set these to NaN so we can identify + # them later. + f_point_eval_input_ordering.dat.data_wo[:] = numpy.nan + + assemble(action(self.point_eval_input_ordering, f_point_eval), + tensor=f_point_eval_input_ordering) + + # We assign these values to the output function if self.allow_missing_dofs and self.default_missing_val is None: - indices = numpy.where( - ~numpy.isnan(f_src_at_dest_node_coords_dest_mesh_decomp.dat.data_ro) - )[0] - output.dat.data_wo[ - indices - ] = f_src_at_dest_node_coords_dest_mesh_decomp.dat.data_ro[indices] + indices = numpy.where(~numpy.isnan(f_point_eval_input_ordering.dat.data_ro))[0] + output.dat.data_wo[indices] = f_point_eval_input_ordering.dat.data_ro[indices] else: - output.dat.data_wo[ - : - ] = f_src_at_dest_node_coords_dest_mesh_decomp.dat.data_ro[:] - + output.dat.data_wo[:] = f_point_eval_input_ordering.dat.data_ro[:] + + if self.rank == 0: + # We take the action of the dual_arg on the interpolated function + assert not isinstance(self.dual_arg, ufl.Coargument) + return assemble(action(self.dual_arg, output)) else: - # adjoint interpolation - - # f_src is a cofunction on V_dest.dual as originally specified when - # creating the interpolator. Our first adjoint operation is to - # assign the dat values to a P0DG cofunction on our input ordering - # VOM. This has the parallel decomposition V_dest on our orinally - # specified dest_mesh. We can therefore safely create a P0DG - # cofunction on the input-ordering VOM (which has this parallel - # decomposition and ordering) and assign the dat values. - f_src_at_dest_node_coords_dest_mesh_decomp = firedrake.Cofunction( - self.point_eval_input_ordering.function_space().dual() - ) - f_src_at_dest_node_coords_dest_mesh_decomp.dat.data_wo[ + # f_src is a cofunction on V_dest.dual + assert isinstance(f, Cofunction) + # Our first adjoint operation is to assign the dat values to a + # P0DG cofunction on our input ordering VOM. + f_input_ordering = Cofunction(self.P0DG_vom_input_ordering.dual()) + f_input_ordering.dat.data_wo[ : - ] = f_src.dat.data_ro[:] + ] = f.dat.data_ro[:] - # The rest of the adjoint interpolation is merely the composition - # of the adjoint interpolators in the reverse direction. NOTE: I - # don't have to worry about skipping over missing points here - # because I'm going from the input ordering VOM to the original VOM + # The rest of the adjoint interpolation is the composition + # of the adjoint interpolators in the reverse direction. + # We don't worry about skipping over missing points here + # because we're going from the input ordering VOM to the original VOM # and all points from the input ordering VOM are in the original. - interp = action(expr_adjoint(self.point_eval_input_ordering), f_src_at_dest_node_coords_dest_mesh_decomp) + interp = action(expr_adjoint(self.point_eval_input_ordering), f_input_ordering) f_src_at_src_node_coords = assemble(interp) - # NOTE: if I wanted the default missing value to be applied to - # adjoint interpolation I would have to do it here. However, - # this would require me to implement default missing values for - # adjoint interpolation from a point evaluation interpolator - # which I haven't done. I wonder if it is necessary - perhaps the - # adjoint operator always sets all the values of the resulting - # cofunction? My initial attempt to insert setting the dat values - # prior to performing the multHermitian operation in - # SameMeshInterpolator.interpolate did not effect the result. For - # now, I say in the docstring that it only applies to forward - # interpolation. + + # We don't need to take the adjoint of self.point_eval because + # it was constructed using self.operand interp = action(self.point_eval, f_src_at_src_node_coords) assemble(interp, tensor=output) - return output @@ -1423,7 +1395,7 @@ def __init__(self, expr, bcs=None): self.expr = self.expr._ufl_expr_reconstruct_(self.operand, self.V_dest) Isub = {} - for indices, form in firedrake.formmanipulation.split_form(expr): + for indices, form in firedrake.formmanipulation.split_form(self.expr): if isinstance(form, ufl.ZeroBaseForm): # Ensure block sparsity continue @@ -1460,10 +1432,10 @@ def _get_callable(self): tensor = firedrake.AssembledMatrix(self.expr_args, self.bcs, petscmat) return tensor.M - def _interpolate(self, output=None, **kwargs): + def _interpolate(self, output=None): """Assemble the action.""" if self.rank == 0: - result = sum(self[i].assemble(**kwargs) for i in self) + result = sum(self[i].assemble() for i in self) return output.assign(result) if output else result if output is None: @@ -1471,10 +1443,10 @@ def _interpolate(self, output=None, **kwargs): if self.rank == 1: for k, sub_tensor in enumerate(output.subfunctions): - sub_tensor.assign(sum(self[i].assemble(**kwargs) for i in self if i[0] == k)) + sub_tensor.assign(sum(self[i].assemble() for i in self if i[0] == k)) elif self.rank == 2: for k, sub_tensor in enumerate(output.subfunctions): - sub_tensor.assign(sum(self[i]._interpolate(**kwargs) + sub_tensor.assign(sum(self[i]._interpolate() for i in self if i[0] == k)) return output \ No newline at end of file From d17d04984e6b2eab09f150d5a7521d0f9a3095ce Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Fri, 10 Oct 2025 16:51:59 +0100 Subject: [PATCH 099/125] remove vomontovomwrapper --- firedrake/interpolation.py | 216 +++++++++++++------------------------ 1 file changed, 73 insertions(+), 143 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 7bfa754c10..25212ed460 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -4,7 +4,7 @@ import abc from functools import partial, singledispatch -from typing import Hashable, Literal +from typing import Hashable, Literal, Callable from dataclasses import asdict, dataclass import FIAT @@ -348,9 +348,10 @@ def __init__(self, expr: Interpolate, bcs=None): # The cross-mesh interpolation matrix is the product of the # `self.point_eval_interpolate` and the permutation # given by `self.to_input_ordering_interpolate`. - symbolic = action(self.point_eval_input_ordering, self.point_eval) if self.expr.is_adjoint: - symbolic = expr_adjoint(symbolic) + symbolic = action(self.point_eval, self.point_eval_input_ordering) + else: + symbolic = action(self.point_eval_input_ordering, self.point_eval) self.handle = assemble(symbolic).petscmat self.callable = lambda: self @@ -397,8 +398,9 @@ def _get_symbolic_expressions(self): # Interpolate into the input-ordering VOM self.P0DG_vom_input_ordering = fs_type(self.vom.input_ordering, "DG", 0) - self.point_eval_input_ordering = interpolate(firedrake.TrialFunction(self.P0DG_vom), - self.P0DG_vom_input_ordering, matfree=self.matfree) + + arg = Argument(self.P0DG_vom, 0 if self.expr.is_adjoint else 1) + self.point_eval_input_ordering = interpolate(arg, self.P0DG_vom_input_ordering, matfree=self.matfree) def _interpolate(self, output=None): from firedrake.assemble import assemble @@ -418,13 +420,13 @@ def _interpolate(self, output=None): # set default missing values if required. f_point_eval_input_ordering = Function(self.P0DG_vom_input_ordering) if self.default_missing_val is not None: - f_point_eval_input_ordering.dat.data_wo[:] = self.default_missing_val + f_point_eval_input_ordering.assign(self.default_missing_val) elif self.allow_missing_dofs: # If we allow missing points there may be points in the target # mesh that are not in the source mesh. If we don't specify a # default missing value we set these to NaN so we can identify # them later. - f_point_eval_input_ordering.dat.data_wo[:] = numpy.nan + f_point_eval_input_ordering.assign(numpy.nan) assemble(action(self.point_eval_input_ordering, f_point_eval), tensor=f_point_eval_input_ordering) @@ -446,16 +448,14 @@ def _interpolate(self, output=None): # Our first adjoint operation is to assign the dat values to a # P0DG cofunction on our input ordering VOM. f_input_ordering = Cofunction(self.P0DG_vom_input_ordering.dual()) - f_input_ordering.dat.data_wo[ - : - ] = f.dat.data_ro[:] + f_input_ordering.dat.data_wo[:] = f.dat.data_ro[:] # The rest of the adjoint interpolation is the composition # of the adjoint interpolators in the reverse direction. # We don't worry about skipping over missing points here # because we're going from the input ordering VOM to the original VOM # and all points from the input ordering VOM are in the original. - interp = action(expr_adjoint(self.point_eval_input_ordering), f_input_ordering) + interp = action(self.point_eval_input_ordering, f_input_ordering) f_src_at_src_node_coords = assemble(interp) # We don't need to take the adjoint of self.point_eval because @@ -478,14 +478,8 @@ def __init__(self, expr, bcs=None): super().__init__(expr, bcs=bcs) subset = self.subset if subset is None: - if isinstance(expr, ufl.Interpolate): - operand, = expr.ufl_operands - else: - operand = expr - target_mesh = as_domain(self.V_dest) - source_mesh = extract_unique_domain(operand) or target_mesh - target = target_mesh.topology - source = source_mesh.topology + target = self.target_mesh.topology + source = self.source_mesh.topology if all(isinstance(m, firedrake.mesh.MeshTopology) for m in [target, source]) and target is not source: composed_map, result_integral_type = source.trans_mesh_entity_map(target, "cell", "everywhere", None) if result_integral_type != "cell": @@ -508,6 +502,13 @@ def __init__(self, expr, bcs=None): raise NotImplementedError("Can't interpolate onto traces.") def _get_tensor(self) -> op2.Mat | Function | Cofunction: + """Return the tensor to interpolate into. + + Returns + ------- + op2.Mat | Function | Cofunction + + """ if self.rank == 0: R = firedrake.FunctionSpace(self.target_mesh, "Real", 0) f = Function(R, dtype=utils.ScalarType) @@ -538,7 +539,13 @@ def _get_tensor(self) -> op2.Mat | Function | Cofunction: raise ValueError(f"Cannot interpolate an expression with {self.rank} arguments") return f - def _get_callable(self): + def _get_callable(self) -> Callable: + """Construct the callable that performs the interpolation. + + Returns + ------- + Callable + """ f = self._get_tensor() if isinstance(f, op2.Mat): tensor = f @@ -611,31 +618,31 @@ def __init__(self, expr: Interpolate, bcs=None): super().__init__(expr, bcs=bcs) def _get_callable(self): - target_mesh = as_domain(self.dual_arg) - source_mesh = extract_unique_domain(self.operand) or target_mesh - + self.mat = VomOntoVomMat(self) if self.rank == 2: # We make our own linear operator for this case using PETSc SFs tensor = None else: f = self._get_tensor() tensor = f.dat - wrapper = VomOntoVomWrapper(self.V_dest, source_mesh, target_mesh, self.operand, self.matfree) # NOTE: get_dat_mpi_type ensures we get the correct MPI type for the # data, including the correct data size and dimensional information # (so for vector function spaces in 2 dimensions we might need a # concatenation of 2 MPI.DOUBLE types when we are in real mode) if tensor is not None: - wrapper.mpi_type, _ = get_dat_mpi_type(f.dat) + self.mat.mpi_type, _ = get_dat_mpi_type(f.dat) assert self.rank == 1 if self.expr.is_adjoint: assert isinstance(self.dual_arg, ufl.Cofunction) def callable(): - wrapper.adjoint_operation(self.dual_arg.dat, f.dat) + with self.dual_arg.dat.vec_ro as source_vec, f.dat.vec_wo as target_vec: + self.mat.handle.multHermitian(source_vec, target_vec) return f else: def callable(): - wrapper.forward_operation(f.dat) + coeff = self.mat.expr_as_coeff() + with coeff.dat.vec_ro as coeff_vec, f.dat.vec_wo as target_vec: + self.mat.handle.mult(coeff_vec, target_vec) return f else: assert self.rank == 2 @@ -646,7 +653,7 @@ def callable(): # after cofunctions are fully implemented, this will need to be # reconsidered. temp_source_func = firedrake.Function(self.expr_args[1].function_space()) - wrapper.mpi_type, _ = get_dat_mpi_type(temp_source_func.dat) + self.mat.mpi_type, _ = get_dat_mpi_type(temp_source_func.dat) # Leave wrapper inside a callable so we can access the handle # property. If matfree is True, then the handle is a PETSc SF @@ -654,7 +661,7 @@ def callable(): # will be a PETSc Mat representing the equivalent permutation # matrix def callable(): - return wrapper + return self.mat return callable @@ -1092,132 +1099,55 @@ def __init__(self, glob): self.ufl_domain = lambda: None -class VomOntoVomWrapper(object): - """Utility class for interpolating from one ``VertexOnlyMesh`` to it's - intput ordering ``VertexOnlyMesh``, or vice versa. +class VomOntoVomMat: + """Object that facilitates interpolation between two vertex-only meshes.""" + def __init__(self, interpolator: VomOntoVomInterpolator): + """Initialises the VomOntoVomMat. - Parameters - ---------- - V : `.FunctionSpace` - The P0DG function space (which may be vector or tensor valued) on the - source vertex-only mesh. - source_vom : `.VertexOnlyMesh` - The vertex-only mesh we interpolate from. - target_vom : `.VertexOnlyMesh` - The vertex-only mesh we interpolate to. - expr : `ufl.Expr` - The expression to interpolate. If ``arguments`` is not empty, those - arguments must be present within it. - matfree : bool - If ``False``, the matrix representating the permutation of the points is - constructed and used to perform the interpolation. If ``True``, then the - interpolation is performed using the broadcast and reduce operations on the - PETSc Star Forest. - """ + Parameters + ---------- + interpolator : VomOntoVomInterpolator + A :class:`VomOntoVomInterpolator` object. - def __init__(self, V, source_vom, target_vom, expr, matfree): - arguments = extract_arguments(expr) - reduce = False - if source_vom.input_ordering is target_vom: - reduce = True - original_vom = source_vom - elif target_vom.input_ordering is source_vom: - original_vom = target_vom + Raises + ------ + ValueError + If the source and target vertex-only meshes are not linked by input_ordering. + """ + if interpolator.source_mesh.input_ordering is interpolator.target_mesh: + self.forward_reduce = True + self.original_vom = interpolator.source_mesh + elif interpolator.target_mesh.input_ordering is interpolator.source_mesh: + self.forward_reduce = False + self.original_vom = interpolator.target_mesh else: raise ValueError( "The target vom and source vom must be linked by input ordering!" ) - self.V = V - self.source_vom = source_vom - self.expr = expr - self.arguments = arguments - self.reduce = reduce - # note that interpolation doesn't include halo cells - self.dummy_mat = VomOntoVomDummyMat( - original_vom.input_ordering_without_halos_sf, reduce, V, source_vom, expr, arguments - ) - if matfree: - # If matfree, we use the SF to perform the interpolation - self.handle = self.dummy_mat._wrap_dummy_mat() - else: - # Otherwise we create the permutation matrix - self.handle = self.dummy_mat._create_permutation_mat() + self.sf = self.original_vom.input_ordering_without_halos_sf + self.V = interpolator.V_dest + self.source_vom = interpolator.source_mesh + self.expr = interpolator.operand + self.arguments = extract_arguments(self.expr) - @property - def mpi_type(self): - """ - The MPI type to use for the PETSc SF. - - Should correspond to the underlying data type of the PETSc Vec. - """ - return self.handle.mpi_type - - @mpi_type.setter - def mpi_type(self, val): - self.dummy_mat.mpi_type = val - - def forward_operation(self, target_dat): - coeff = self.dummy_mat.expr_as_coeff() - with coeff.dat.vec_ro as coeff_vec, target_dat.vec_wo as target_vec: - self.handle.mult(coeff_vec, target_vec) - - def adjoint_operation(self, source_dat, target_dat): - """Apply the adjoint interpolation operation. - - Parameters - ---------- - source_dat : dat - The dat from the cofunction. - target_dat : dat - The dat to write the result to. - """ - with source_dat.vec_ro as source_vec, target_dat.vec_wo as target_vec: - self.handle.multHermitian(source_vec, target_vec) - - -class VomOntoVomDummyMat(object): - """Dummy object to stand in for a PETSc ``Mat`` when we are interpolating - between vertex-only meshes. - - Parameters - ---------- - sf: PETSc.sf - The PETSc Star Forest (SF) to use for the operation - forward_reduce : bool - If ``True``, the action of the operator (accessed via the `mult` - method) is to perform a SF reduce from the source vec to the target - vec, whilst the adjoint action (accessed via the `multHermitian` - method) is to perform a SF broadcast from the source vec to the target - vec. If ``False``, the opposite is true. - V : `.FunctionSpace` - The P0DG function space (which may be vector or tensor valued) on the - source vertex-only mesh. - source_vom : `.VertexOnlyMesh` - The vertex-only mesh we interpolate from. - expr : `ufl.Expr` - The expression to interpolate. If ``arguments`` is not empty, those - arguments must be present within it. - arguments : list of `ufl.Argument` - The arguments in the expression. - """ - - def __init__(self, sf, forward_reduce, V, source_vom, expr, arguments): - self.sf = sf - self.forward_reduce = forward_reduce - self.V = V - self.source_vom = source_vom - self.expr = expr - self.arguments = arguments # Calculate correct local and global sizes for the matrix - nroots, leaves, _ = sf.getGraph() + nroots, leaves, _ = self.sf.getGraph() self.nleaves = len(leaves) - self._local_sizes = V.comm.allgather(nroots) + self._local_sizes = self.V.comm.allgather(nroots) self.source_size = (self.V.block_size * nroots, self.V.block_size * sum(self._local_sizes)) self.target_size = ( self.V.block_size * self.nleaves, - self.V.block_size * V.comm.allreduce(self.nleaves, op=MPI.SUM), + self.V.block_size * self.V.comm.allreduce(self.nleaves, op=MPI.SUM), ) + if interpolator.matfree: + # If matfree, we use the SF to perform the interpolation + self.handle = self._wrap_python_mat() + else: + # Otherwise we create the permutation matrix + self.handle = self._create_permutation_mat() + + @property def mpi_type(self): """ @@ -1354,7 +1284,7 @@ def _create_permutation_mat(self): mat.transpose() return mat - def _wrap_dummy_mat(self): + def _wrap_python_mat(self): mat = PETSc.Mat().create(comm=self.V.comm) if self.forward_reduce: mat_size = (self.source_size, self.target_size) @@ -1367,7 +1297,7 @@ def _wrap_dummy_mat(self): return mat def duplicate(self, mat=None, op=None): - return self._wrap_dummy_mat() + return self._wrap_python_mat() class MixedInterpolator(Interpolator): From 4c4cf49c530d51b31d0d820646663543d660721f Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Sat, 11 Oct 2025 15:08:20 +0100 Subject: [PATCH 100/125] fixes --- firedrake/interpolation.py | 111 +++++++++--------- .../regression/test_interpolate_cross_mesh.py | 12 +- 2 files changed, 68 insertions(+), 55 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 25212ed460..eca33c59d6 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -115,6 +115,7 @@ def __init__(self, expr: Expr, V: WithGeometry | ufl.BaseForm, **kwargs): if isinstance(V, WithGeometry): # Need to create a Firedrake Coargument so it has a .function_space() method V = Argument(V.dual(), 1 if self.is_adjoint else 0) + self.target_space = V.arguments()[0].function_space() if expr.ufl_shape != self.target_space.value_shape: raise ValueError(f"Shape mismatch: Expression shape {expr.ufl_shape}, FunctionSpace shape {self.target_space.value_shape}.") @@ -188,17 +189,8 @@ def __init__(self, expr: Interpolate, bcs=None): self.matfree = expr.options.matfree self.bcs = bcs self.callable = None + self.access = expr.options.access - access = expr.options.access - if not isinstance(dual_arg, ufl.Coargument): - # Matrix-free assembly of 0-form or 1-form requires INC access - if access and access != op2.INC: - raise ValueError("Matfree adjoint interpolation requires INC access") - access = op2.INC - elif access is None: - # Default access for forward 1-form or 2-form (forward and adjoint) - access = op2.WRITE - self.access = access @abc.abstractmethod def _interpolate(self, *args, **kwargs): @@ -216,6 +208,7 @@ def assemble(self, tensor=None): # Assembling the operator res = tensor.petscmat if tensor else PETSc.Mat() # Get the interpolation matrix + self._get_callable() op2mat = self.callable() petsc_mat = op2mat.handle if tensor: @@ -300,12 +293,12 @@ class CrossMeshInterpolator(Interpolator): @no_annotations def __init__(self, expr: Interpolate, bcs=None): - from firedrake.assemble import assemble super().__init__(expr, bcs) - if self.access != op2.WRITE: - # raise NotImplementedError( - # "Access other than op2.WRITE not implemented for cross-mesh interpolation." - # ) + if self.access and self.access != op2.WRITE: + raise NotImplementedError( + "Access other than op2.WRITE not implemented for cross-mesh interpolation." + ) + else: self.access = op2.WRITE if self.bcs: raise NotImplementedError("bcs not implemented.") @@ -317,7 +310,6 @@ def __init__(self, expr: Interpolate, bcs=None): raise NotImplementedError( "Can only interpolate into spaces with point evaluation nodes." ) - if self.allow_missing_dofs: self.missing_points_behaviour = MissingPointsBehaviour.IGNORE else: @@ -344,17 +336,6 @@ def __init__(self, expr: Interpolate, bcs=None): self._get_symbolic_expressions() - if self.rank == 2: - # The cross-mesh interpolation matrix is the product of the - # `self.point_eval_interpolate` and the permutation - # given by `self.to_input_ordering_interpolate`. - if self.expr.is_adjoint: - symbolic = action(self.point_eval, self.point_eval_input_ordering) - else: - symbolic = action(self.point_eval_input_ordering, self.point_eval) - self.handle = assemble(symbolic).petscmat - self.callable = lambda: self - def _get_symbolic_expressions(self): """Constructs the symbolic Interpolate expressions for cross-mesh interpolation. @@ -392,15 +373,27 @@ def _get_symbolic_expressions(self): self.P0DG_vom = fs_type(self.vom, "DG", 0) self.point_eval = interpolate(self.operand, self.P0DG_vom) - if self.rank == 2: - # If assembling the operator, we need the concrete permutation matrix - self.matfree = False + # If assembling the operator, we need the concrete permutation matrix + matfree = False if self.rank == 2 else self.matfree # Interpolate into the input-ordering VOM self.P0DG_vom_input_ordering = fs_type(self.vom.input_ordering, "DG", 0) arg = Argument(self.P0DG_vom, 0 if self.expr.is_adjoint else 1) - self.point_eval_input_ordering = interpolate(arg, self.P0DG_vom_input_ordering, matfree=self.matfree) + self.point_eval_input_ordering = interpolate(arg, self.P0DG_vom_input_ordering, matfree=matfree) + + def _get_callable(self): + from firedrake.assemble import assemble + assert self.rank == 2 + # The cross-mesh interpolation matrix is the product of the + # `self.point_eval_interpolate` and the permutation + # given by `self.to_input_ordering_interpolate`. + if self.expr.is_adjoint: + symbolic = action(self.point_eval, self.point_eval_input_ordering) + else: + symbolic = action(self.point_eval_input_ordering, self.point_eval) + self.handle = assemble(symbolic).petscmat + self.callable = lambda: self def _interpolate(self, output=None): from firedrake.assemble import assemble @@ -426,7 +419,7 @@ def _interpolate(self, output=None): # mesh that are not in the source mesh. If we don't specify a # default missing value we set these to NaN so we can identify # them later. - f_point_eval_input_ordering.assign(numpy.nan) + f_point_eval_input_ordering.dat.data_wo[:] = numpy.nan assemble(action(self.point_eval_input_ordering, f_point_eval), tensor=f_point_eval_input_ordering) @@ -496,12 +489,16 @@ def __init__(self, expr, bcs=None): pass self.subset = subset - try: - self.callable = self._get_callable() - except FIAT.hdiv_trace.TraceError: - raise NotImplementedError("Can't interpolate onto traces.") + if not isinstance(self.dual_arg, ufl.Coargument): + # Matrix-free assembly of 0-form or 1-form requires INC access + if self.access and self.access != op2.INC: + raise ValueError("Matfree adjoint interpolation requires INC access") + self.access = op2.INC + elif self.access is None: + # Default access for forward 1-form or 2-form (forward and adjoint) + self.access = op2.WRITE - def _get_tensor(self) -> op2.Mat | Function | Cofunction: + def _get_tensor(self, output=None) -> op2.Mat | Function | Cofunction: """Return the tensor to interpolate into. Returns @@ -513,15 +510,19 @@ def _get_tensor(self) -> op2.Mat | Function | Cofunction: R = firedrake.FunctionSpace(self.target_mesh, "Real", 0) f = Function(R, dtype=utils.ScalarType) elif self.rank == 1: - V_dest = self.expr_args[0].function_space().dual() - f = Function(V_dest) - if self.access in {firedrake.MIN, firedrake.MAX}: - finfo = numpy.finfo(f.dat.dtype) - if self.access == firedrake.MIN: - val = firedrake.Constant(finfo.max) - else: - val = firedrake.Constant(finfo.min) - f.assign(val) + if output: + V_dest = output.function_space() + f = output + else: + V_dest = self.expr_args[0].function_space().dual() + f = Function(V_dest) + if self.access in {firedrake.MIN, firedrake.MAX}: + finfo = numpy.finfo(f.dat.dtype) + if self.access == firedrake.MIN: + val = firedrake.Constant(finfo.max) + else: + val = firedrake.Constant(finfo.min) + f.assign(val) elif self.rank == 2: Vrow = self.expr_args[0].function_space() Vcol = self.expr_args[1].function_space() @@ -539,14 +540,14 @@ def _get_tensor(self) -> op2.Mat | Function | Cofunction: raise ValueError(f"Cannot interpolate an expression with {self.rank} arguments") return f - def _get_callable(self) -> Callable: + def _get_callable(self, output=None) -> Callable: """Construct the callable that performs the interpolation. Returns ------- Callable """ - f = self._get_tensor() + f = self._get_tensor(output=output) if isinstance(f, op2.Mat): tensor = f else: @@ -592,7 +593,7 @@ def callable(loops, f): l() return f - return partial(callable, loops, f) + self.callable = partial(callable, loops, f) @PETSc.Log.EventDecorator() def _interpolate(self, output=None): @@ -601,6 +602,7 @@ def _interpolate(self, output=None): For arguments, see :class:`.Interpolator`. """ assert self.rank < 2 + self._get_callable(output=output) assembled_interpolator = self.callable() if output: output.assign(assembled_interpolator) @@ -617,13 +619,13 @@ class VomOntoVomInterpolator(SameMeshInterpolator): def __init__(self, expr: Interpolate, bcs=None): super().__init__(expr, bcs=bcs) - def _get_callable(self): + def _get_callable(self, output=None): self.mat = VomOntoVomMat(self) if self.rank == 2: # We make our own linear operator for this case using PETSc SFs tensor = None else: - f = self._get_tensor() + f = self._get_tensor(output=output) tensor = f.dat # NOTE: get_dat_mpi_type ensures we get the correct MPI type for the # data, including the correct data size and dimensional information @@ -663,7 +665,7 @@ def callable(): def callable(): return self.mat - return callable + self.callable = callable @utils.known_pyop2_safe @@ -1129,6 +1131,7 @@ def __init__(self, interpolator: VomOntoVomInterpolator): self.source_vom = interpolator.source_mesh self.expr = interpolator.operand self.arguments = extract_arguments(self.expr) + self.is_adjoint = interpolator.expr.is_adjoint # Calculate correct local and global sizes for the matrix nroots, leaves, _ = self.sf.getGraph() @@ -1147,7 +1150,6 @@ def __init__(self, interpolator: VomOntoVomInterpolator): # Otherwise we create the permutation matrix self.handle = self._create_permutation_mat() - @property def mpi_type(self): """ @@ -1280,7 +1282,7 @@ def _create_permutation_mat(self): cols = (self.V.block_size * perm[:, None] + numpy.arange(self.V.block_size, dtype=utils.IntType)[None, :]).reshape(-1) mat.setValuesCSR(rows, cols, numpy.ones_like(cols, dtype=utils.IntType)) mat.assemble() - if self.forward_reduce: + if self.forward_reduce and not self.is_adjoint: mat.transpose() return mat @@ -1357,6 +1359,7 @@ def _get_callable(self): shape = tuple(len(a.function_space()) for a in self.expr_args) blocks = numpy.full(shape, PETSc.Mat(), dtype=object) for i in self: + self[i]._get_callable() blocks[i] = self[i].callable().handle petscmat = PETSc.Mat().createNest(blocks) tensor = firedrake.AssembledMatrix(self.expr_args, self.bcs, petscmat) diff --git a/tests/firedrake/regression/test_interpolate_cross_mesh.py b/tests/firedrake/regression/test_interpolate_cross_mesh.py index e35367e08c..78d3688516 100644 --- a/tests/firedrake/regression/test_interpolate_cross_mesh.py +++ b/tests/firedrake/regression/test_interpolate_cross_mesh.py @@ -339,12 +339,18 @@ def test_exact_refinement(): expr_in_V_fine = x**2 + y**2 + 1 f_fine = Function(V_fine).interpolate(expr_in_V_fine) + # Build interpolation matrices in both directions + coarse_to_fine = assemble(interpolate(TrialFunction(V_coarse), V_fine)) + coarse_to_fine_adjoint = assemble(interpolate(TestFunction(V_coarse), TrialFunction(V_fine.dual()))) + # If we now interpolate f_coarse into V_fine we should get a function # which has no interpolation error versus f_fine because we were able to # exactly represent expr_in_V_coarse in V_coarse and V_coarse is a subset # of V_fine f_coarse_on_fine = assemble(interpolate(f_coarse, V_fine)) assert np.allclose(f_coarse_on_fine.dat.data_ro, f_fine.dat.data_ro) + f_coarse_on_fine_mat = assemble(coarse_to_fine @ f_coarse) + assert np.allclose(f_coarse_on_fine_mat.dat.data_ro, f_fine.dat.data_ro) # Adjoint interpolation takes us from V_fine^* to V_coarse^* so we should # also get an exact result here. @@ -354,6 +360,10 @@ def test_exact_refinement(): assert np.allclose( cofunction_fine_on_coarse.dat.data_ro, cofunction_coarse.dat.data_ro ) + cofunction_fine_on_coarse_mat = assemble(action(coarse_to_fine_adjoint, cofunction_fine)) + assert np.allclose( + cofunction_fine_on_coarse_mat.dat.data_ro, cofunction_coarse.dat.data_ro + ) # Now we test with expressions which are NOT exactly representable in the # function spaces by introducing a cube term. This can't be represented @@ -686,7 +696,7 @@ def test_interpolate_matrix_cross_mesh(): assert f_interp3.function_space() == V assert np.allclose(f_interp3.dat.data_ro, g.dat.data_ro) - +@pytest.mark.parallel([1, 3]) def test_interpolate_matrix_cross_mesh_adjoint(): mesh_fine = UnitSquareMesh(4, 4) mesh_coarse = UnitSquareMesh(2, 2) From 58326de410b224049ebb899c2220843a5c954bae Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Sat, 11 Oct 2025 15:28:17 +0100 Subject: [PATCH 101/125] tidy import --- firedrake/interpolation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index eca33c59d6..034f970e5b 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -10,7 +10,7 @@ import FIAT import ufl import finat.ufl -from ufl.algorithms import extract_arguments, extract_coefficients +from ufl.algorithms import extract_arguments from ufl.domain import as_domain, extract_unique_domain from ufl.classes import Expr @@ -26,7 +26,7 @@ import firedrake from firedrake import tsfc_interface, utils -from firedrake.ufl_expr import Argument, Coargument, action, adjoint as expr_adjoint +from firedrake.ufl_expr import Argument, Coargument, action from firedrake.cofunction import Cofunction from firedrake.function import Function from firedrake.mesh import MissingPointsBehaviour, VertexOnlyMeshMissingPointsError, VertexOnlyMeshTopology From a78b652f1c06db7b415385c08c7ebc9bdb89b8f1 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Sat, 11 Oct 2025 15:41:35 +0100 Subject: [PATCH 102/125] lint --- firedrake/interpolation.py | 22 +++++++++---------- .../regression/test_interpolate_cross_mesh.py | 1 + 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 034f970e5b..3ae7da4108 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -32,7 +32,6 @@ from firedrake.mesh import MissingPointsBehaviour, VertexOnlyMeshMissingPointsError, VertexOnlyMeshTopology from firedrake.petsc import PETSc from firedrake.halo import _get_mtype as get_dat_mpi_type -from firedrake.cofunction import Cofunction from firedrake.functionspaceimpl import WithGeometry from mpi4py import MPI @@ -60,9 +59,9 @@ class InterpolateOptions: the target mesh is a :func:`.VertexOnlyMesh`. access : pyop2.types.access.Access, default op2.WRITE The pyop2 access descriptor for combining updates to shared - DoFs. Possible values include ``WRITE``, ``MIN``, ``MAX``, and ``INC``. - Only ``WRITE`` is supported at present when interpolating across meshes - unless the target mesh is a :func:`.VertexOnlyMesh`. Only ``INC`` is + DoFs. Possible values include ``WRITE``, ``MIN``, ``MAX``, and ``INC``. + Only ``WRITE`` is supported at present when interpolating across meshes + unless the target mesh is a :func:`.VertexOnlyMesh`. Only ``INC`` is supported for the matrix-free adjoint interpolation. allow_missing_dofs : bool, default False For interpolation across meshes: allow degrees of freedom (aka DoFs/nodes) @@ -191,7 +190,6 @@ def __init__(self, expr: Interpolate, bcs=None): self.callable = None self.access = expr.options.access - @abc.abstractmethod def _interpolate(self, *args, **kwargs): """ @@ -397,7 +395,7 @@ def _get_callable(self): def _interpolate(self, output=None): from firedrake.assemble import assemble - if self.expr.is_adjoint: + if self.expr.is_adjoint: f = self.dual_arg V_dest = self.expr_args[0].function_space().dual() else: @@ -416,12 +414,12 @@ def _interpolate(self, output=None): f_point_eval_input_ordering.assign(self.default_missing_val) elif self.allow_missing_dofs: # If we allow missing points there may be points in the target - # mesh that are not in the source mesh. If we don't specify a - # default missing value we set these to NaN so we can identify + # mesh that are not in the source mesh. If we don't specify a + # default missing value we set these to NaN so we can identify # them later. f_point_eval_input_ordering.dat.data_wo[:] = numpy.nan - assemble(action(self.point_eval_input_ordering, f_point_eval), + assemble(action(self.point_eval_input_ordering, f_point_eval), tensor=f_point_eval_input_ordering) # We assign these values to the output function @@ -430,7 +428,7 @@ def _interpolate(self, output=None): output.dat.data_wo[indices] = f_point_eval_input_ordering.dat.data_ro[indices] else: output.dat.data_wo[:] = f_point_eval_input_ordering.dat.data_ro[:] - + if self.rank == 0: # We take the action of the dual_arg on the interpolated function assert not isinstance(self.dual_arg, ufl.Coargument) @@ -438,7 +436,7 @@ def _interpolate(self, output=None): else: # f_src is a cofunction on V_dest.dual assert isinstance(f, Cofunction) - # Our first adjoint operation is to assign the dat values to a + # Our first adjoint operation is to assign the dat values to a # P0DG cofunction on our input ordering VOM. f_input_ordering = Cofunction(self.P0DG_vom_input_ordering.dual()) f_input_ordering.dat.data_wo[:] = f.dat.data_ro[:] @@ -636,6 +634,7 @@ def _get_callable(self, output=None): assert self.rank == 1 if self.expr.is_adjoint: assert isinstance(self.dual_arg, ufl.Cofunction) + def callable(): with self.dual_arg.dat.vec_ro as source_vec, f.dat.vec_wo as target_vec: self.mat.handle.multHermitian(source_vec, target_vec) @@ -1382,4 +1381,3 @@ def _interpolate(self, output=None): sub_tensor.assign(sum(self[i]._interpolate() for i in self if i[0] == k)) return output - \ No newline at end of file diff --git a/tests/firedrake/regression/test_interpolate_cross_mesh.py b/tests/firedrake/regression/test_interpolate_cross_mesh.py index 78d3688516..82d7a7da3c 100644 --- a/tests/firedrake/regression/test_interpolate_cross_mesh.py +++ b/tests/firedrake/regression/test_interpolate_cross_mesh.py @@ -696,6 +696,7 @@ def test_interpolate_matrix_cross_mesh(): assert f_interp3.function_space() == V assert np.allclose(f_interp3.dat.data_ro, g.dat.data_ro) + @pytest.mark.parallel([1, 3]) def test_interpolate_matrix_cross_mesh_adjoint(): mesh_fine = UnitSquareMesh(4, 4) From 387ad5e00c4101e6d30da1217cb410932a2fade2 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Mon, 13 Oct 2025 12:16:59 +0100 Subject: [PATCH 103/125] attempt fix for bc fix bcs --- firedrake/bcs.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/firedrake/bcs.py b/firedrake/bcs.py index 17c199c1d8..ad47415008 100644 --- a/firedrake/bcs.py +++ b/firedrake/bcs.py @@ -6,7 +6,8 @@ import ufl from ufl import as_ufl, as_tensor -from finat.ufl import VectorElement +from finat.ufl import VectorElement, EnrichedElement +from finat.physically_mapped import DirectlyDefinedElement, PhysicallyMappedElement import finat import pyop2 as op2 @@ -357,14 +358,12 @@ def function_arg(self, g): elif isinstance(g, ufl.classes.Expr): if g.ufl_shape != V.value_shape: raise RuntimeError(f"Provided boundary value {g} does not match shape of space") - try: + disallowed_elements = PhysicallyMappedElement | DirectlyDefinedElement | EnrichedElement + if all(not isinstance(element, disallowed_elements) for element in V.ufl_element().sub_elements): self._function_arg = firedrake.Function(V) - # Use `Interpolator` instead of assembling an `Interpolate` form - # as the expression compilation needs to happen at this stage to - # determine if we should use interpolation or projection - # -> e.g. interpolation may not be supported for the element. - self._function_arg_update = firedrake.Interpolator(g, self._function_arg)._interpolate - except (NotImplementedError, AttributeError): + interpolate_expr = firedrake.interpolate(g, V) + self._function_arg_update = lambda: firedrake.assemble(interpolate_expr, tensor=self._function_arg) + else: # Element doesn't implement interpolation self._function_arg = firedrake.Function(V).project(g) self._function_arg_update = firedrake.Projector(g, self._function_arg).project From 7f1d5095dbaf9032bb1703ccdd93b10030714cb3 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Mon, 13 Oct 2025 14:47:35 +0100 Subject: [PATCH 104/125] fix --- firedrake/bcs.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/firedrake/bcs.py b/firedrake/bcs.py index ad47415008..46e19a0964 100644 --- a/firedrake/bcs.py +++ b/firedrake/bcs.py @@ -359,7 +359,11 @@ def function_arg(self, g): if g.ufl_shape != V.value_shape: raise RuntimeError(f"Provided boundary value {g} does not match shape of space") disallowed_elements = PhysicallyMappedElement | DirectlyDefinedElement | EnrichedElement - if all(not isinstance(element, disallowed_elements) for element in V.ufl_element().sub_elements): + if len(V.ufl_element().sub_elements) > 0: + elements = V.ufl_element().sub_elements + else: + elements = [V.ufl_element()] + if all(not isinstance(element, disallowed_elements) for element in elements): self._function_arg = firedrake.Function(V) interpolate_expr = firedrake.interpolate(g, V) self._function_arg_update = lambda: firedrake.assemble(interpolate_expr, tensor=self._function_arg) From 78b058a4b56ecb4389f83c6cbdc4d7935dc0c509 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Mon, 13 Oct 2025 15:04:43 +0100 Subject: [PATCH 105/125] make `_get_interpolator` public --- firedrake/assemble.py | 4 ++-- firedrake/interpolation.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 149a4f3f5a..8ef822a1dd 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -23,7 +23,7 @@ from firedrake.bcs import DirichletBC, EquationBC, EquationBCSplit from firedrake.functionspaceimpl import WithGeometry, FunctionSpace, FiredrakeDualSpace from firedrake.functionspacedata import entity_dofs_key, entity_permutations_key -from firedrake.interpolation import _get_interpolator +from firedrake.interpolation import get_interpolator from firedrake.petsc import PETSc from firedrake.slate import slac, slate from firedrake.slate.slac.kernel_builder import CellFacetKernelArg, LayerCountKernelArg @@ -573,7 +573,7 @@ def base_form_assembly_visitor(self, expr, tensor, *args): if rank > 2: raise ValueError("Cannot assemble an Interpolate with more than two arguments") - interpolator = _get_interpolator(expr) + interpolator = get_interpolator(expr) return interpolator.assemble(tensor=tensor) elif tensor and isinstance(expr, (firedrake.Function, firedrake.Cofunction, firedrake.MatrixBase)): return tensor.assign(expr) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 3ae7da4108..ebbef3f603 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -218,7 +218,7 @@ def assemble(self, tensor=None): return self._interpolate(output=tensor) -def _get_interpolator(expr: Interpolate, bcs=None) -> Interpolator: +def get_interpolator(expr: Interpolate, bcs=None) -> Interpolator: arguments = expr.arguments() has_mixed_arguments = any(len(arg.function_space()) > 1 for arg in arguments) if len(arguments) == 2 and has_mixed_arguments: From c9f4ac06b7f8a03a35a4f643ad36a3b7870ad88d Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Mon, 13 Oct 2025 15:14:03 +0100 Subject: [PATCH 106/125] change bcs --- firedrake/bcs.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/firedrake/bcs.py b/firedrake/bcs.py index 46e19a0964..220082a9e9 100644 --- a/firedrake/bcs.py +++ b/firedrake/bcs.py @@ -23,6 +23,7 @@ from firedrake.formmanipulation import ExtractSubBlock from firedrake.adjoint_utils.dirichletbc import DirichletBCMixin from firedrake.petsc import PETSc +from firedrake.interpolation import get_interpolator __all__ = ['DirichletBC', 'homogenize', 'EquationBC'] @@ -358,16 +359,13 @@ def function_arg(self, g): elif isinstance(g, ufl.classes.Expr): if g.ufl_shape != V.value_shape: raise RuntimeError(f"Provided boundary value {g} does not match shape of space") - disallowed_elements = PhysicallyMappedElement | DirectlyDefinedElement | EnrichedElement - if len(V.ufl_element().sub_elements) > 0: - elements = V.ufl_element().sub_elements - else: - elements = [V.ufl_element()] - if all(not isinstance(element, disallowed_elements) for element in elements): + try: self._function_arg = firedrake.Function(V) - interpolate_expr = firedrake.interpolate(g, V) - self._function_arg_update = lambda: firedrake.assemble(interpolate_expr, tensor=self._function_arg) - else: + interpolator = get_interpolator(firedrake.interpolate(g, V)) + # Call this here to check if the element supports interpolation + interpolator._get_callable() + self._function_arg_update = lambda: interpolator.assemble(tensor=self._function_arg) + except (ValueError, NotImplementedError): # Element doesn't implement interpolation self._function_arg = firedrake.Function(V).project(g) self._function_arg_update = firedrake.Projector(g, self._function_arg).project From 2340cea4d70ed6bd54ae2aace3082c571001885a Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Mon, 13 Oct 2025 18:56:08 +0100 Subject: [PATCH 107/125] fixes --- firedrake/interpolation.py | 8 +++++--- tests/firedrake/regression/test_interpolate.py | 4 +++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index ebbef3f603..ed58360901 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -13,6 +13,7 @@ from ufl.algorithms import extract_arguments from ufl.domain import as_domain, extract_unique_domain from ufl.classes import Expr +from ufl.duals import is_dual from pyop2 import op2 from pyop2.caching import memory_and_disk_cache @@ -109,8 +110,9 @@ def __init__(self, expr: Expr, V: WithGeometry | ufl.BaseForm, **kwargs): Additional interpolation options. See :class:`InterpolateOptions` for available parameters and their descriptions. """ - expr_args = extract_arguments(ufl.as_ufl(expr)) - self.is_adjoint = len(expr_args) and expr_args[0].number() == 0 + expr = ufl.as_ufl(expr) + expr_arg_numbers = {arg.number() for arg in extract_arguments(expr) if not is_dual(arg)} + self.is_adjoint = len(expr_arg_numbers) and expr_arg_numbers == {0} if isinstance(V, WithGeometry): # Need to create a Firedrake Coargument so it has a .function_space() method V = Argument(V.dual(), 1 if self.is_adjoint else 0) @@ -1342,7 +1344,7 @@ def __init__(self, expr, bcs=None): # Take the action of each sub-cofunction against each block form = action(form, dual_split[indices[-1:]]) - Isub[indices] = _get_interpolator(form, bcs=sub_bcs) + Isub[indices] = get_interpolator(form, bcs=sub_bcs) self._sub_interpolators = Isub self.callable = self._get_callable diff --git a/tests/firedrake/regression/test_interpolate.py b/tests/firedrake/regression/test_interpolate.py index 40a831ad47..47b3dc7a6d 100644 --- a/tests/firedrake/regression/test_interpolate.py +++ b/tests/firedrake/regression/test_interpolate.py @@ -346,13 +346,15 @@ def test_adjoint_Pk(rank, mat_type, degree, cell, shape): if rank == 0: operand = Function(Pk).interpolate(expr) + dual_arg = TestFunction(Pkp1.dual()) else: operand = TestFunction(Pk) + dual_arg = TrialFunction(Pkp1.dual()) if mat_type == "matfree": interp = interpolate(operand, v) else: - adj_interp = assemble(interpolate(operand, TrialFunction(Pkp1.dual()))) + adj_interp = assemble(interpolate(operand, dual_arg)) if rank == 0: interp = action(v, adj_interp) else: From 79793553a86d52edf57ad867be5b78685406f423 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Tue, 14 Oct 2025 12:07:01 +0100 Subject: [PATCH 108/125] updates --- firedrake/bcs.py | 4 +- firedrake/interpolation.py | 306 +++++++++++++++++++++++-------------- 2 files changed, 192 insertions(+), 118 deletions(-) diff --git a/firedrake/bcs.py b/firedrake/bcs.py index 220082a9e9..c4fd56007e 100644 --- a/firedrake/bcs.py +++ b/firedrake/bcs.py @@ -23,7 +23,6 @@ from firedrake.formmanipulation import ExtractSubBlock from firedrake.adjoint_utils.dirichletbc import DirichletBCMixin from firedrake.petsc import PETSc -from firedrake.interpolation import get_interpolator __all__ = ['DirichletBC', 'homogenize', 'EquationBC'] @@ -341,6 +340,7 @@ def reconstruct(self, field=None, V=None, g=None, sub_domain=None, use_split=Fal @function_arg.setter def function_arg(self, g): '''Set the value of this boundary condition.''' + from firedrake.interpolation import get_interpolator try: # Clear any previously set update function del self._function_arg_update @@ -363,7 +363,7 @@ def function_arg(self, g): self._function_arg = firedrake.Function(V) interpolator = get_interpolator(firedrake.interpolate(g, V)) # Call this here to check if the element supports interpolation - interpolator._get_callable() + interpolator._build_callable() self._function_arg_update = lambda: interpolator.assemble(tensor=self._function_arg) except (ValueError, NotImplementedError): # Element doesn't implement interpolation diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index ed58360901..ebc8a86294 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -4,8 +4,9 @@ import abc from functools import partial, singledispatch -from typing import Hashable, Literal, Callable +from typing import Hashable, Literal, Callable, Iterable from dataclasses import asdict, dataclass +from numbers import Number import FIAT import ufl @@ -34,17 +35,16 @@ from firedrake.petsc import PETSc from firedrake.halo import _get_mtype as get_dat_mpi_type from firedrake.functionspaceimpl import WithGeometry +from firedrake.matrix import MatrixBase +from firedrake.bcs import BCBase from mpi4py import MPI from pyadjoint import stop_annotating, no_annotations __all__ = ( "interpolate", - "Interpolator", "Interpolate", "DofNotDefinedError", - "CrossMeshInterpolator", - "SameMeshInterpolator", ) @@ -139,7 +139,7 @@ def options(self): @PETSc.Log.EventDecorator() -def interpolate(expr: Expr, V: WithGeometry | ufl.BaseForm, **kwargs): +def interpolate(expr: Expr, V: WithGeometry | ufl.BaseForm, **kwargs) -> Interpolate: """Returns a UFL expression for the interpolation operation of ``expr`` into ``V``. Parameters @@ -161,26 +161,82 @@ def interpolate(expr: Expr, V: WithGeometry | ufl.BaseForm, **kwargs): return Interpolate(expr, V, **kwargs) +def get_interpolator(expr: Interpolate, bcs: Iterable[BCBase] | None = None) -> "Interpolator": + """Create an Interpolator. + + Parameters + ---------- + expr : Interpolate + Symbolic interpolation expression. + bcs : Iterable[BCBase] | None, optional + An optional list of boundary conditions to zero-out in the + output function space. Interpolator rows or columns which are + associated with boundary condition nodes are zeroed out when this is + specified. By default None. + + Returns + ------- + Interpolator + + """ + arguments = expr.arguments() + has_mixed_arguments = any(len(arg.function_space()) > 1 for arg in arguments) + if len(arguments) == 2 and has_mixed_arguments: + return MixedInterpolator(expr, bcs=bcs) + + operand, = expr.ufl_operands + target_mesh = expr.target_space.mesh() + source_mesh = extract_unique_domain(operand) or target_mesh + submesh_interp_implemented = ( + all(isinstance(m.topology, firedrake.mesh.MeshTopology) for m in [target_mesh, source_mesh]) + and target_mesh.submesh_ancesters[-1] is source_mesh.submesh_ancesters[-1] + and target_mesh.topological_dimension() == source_mesh.topological_dimension() + ) + if target_mesh is source_mesh or submesh_interp_implemented: + return SameMeshInterpolator(expr, bcs=bcs) + + target_topology = target_mesh.topology + source_topology = source_mesh.topology + + if isinstance(target_topology, VertexOnlyMeshTopology): + if isinstance(source_topology, VertexOnlyMeshTopology): + return VomOntoVomInterpolator(expr, bcs=bcs) + if target_mesh.geometric_dimension() != source_mesh.geometric_dimension(): + raise ValueError("Cannot interpolate onto a mesh of a different geometric dimension") + if not hasattr(target_mesh, "_parent_mesh") or target_mesh._parent_mesh is not source_mesh: + raise ValueError("Can only interpolate across meshes where the source mesh is the parent of the target") + return SameMeshInterpolator(expr, bcs=bcs) + + if has_mixed_arguments or len(expr.target_space) > 1: + return MixedInterpolator(expr, bcs=bcs) + + return CrossMeshInterpolator(expr, bcs=bcs) + + class Interpolator(abc.ABC): + """Initialise the interpolator. Should not be instantiated directly; use the + :func:`get_interpolator` function. - def __init__(self, expr: Interpolate, bcs=None): - """Initialise Interpolator. + Parameters + ---------- + expr : Interpolate + The symbolic interpolation expression. + bcs : Iterable[BCBase], optional + An optional list of boundary conditions to zero-out in the + output function space. Interpolator rows or columns which are + associated with boundary condition nodes are zeroed out when this is + specified. By default None. - Parameters - ---------- - expr : Interpolate - The symbolic interpolation expression. - bcs : list, optional - List of boundary conditions to zero-out in the output function space. By default None. - """ + """ + def __init__(self, expr: Interpolate, bcs: Iterable[BCBase] | None = None): dual_arg, operand = expr.argument_slots() self.expr = expr self.expr_args = expr.arguments() self.rank = len(self.expr_args) self.operand = operand self.dual_arg = dual_arg - self.V_dest = dual_arg.function_space().dual() - self.target_mesh = as_domain(self.V_dest) + self.target_space = dual_arg.function_space().dual() + self.target_mesh = as_domain(self.target_space) self.source_mesh = extract_unique_domain(operand) or self.target_mesh # Interpolation options @@ -193,22 +249,62 @@ def __init__(self, expr: Interpolate, bcs=None): self.access = expr.options.access @abc.abstractmethod - def _interpolate(self, *args, **kwargs): + def _interpolate( + self, output: Function | Cofunction | MatrixBase | None = None + ) -> Function | Cofunction | Number: + """Compute the interpolation action. + + Parameters + ---------- + tensor : Function | Cofunction | MatrixBase, optional + Tensor to hold the interpolated result. + + Returns + ------- + Function | Cofunction | Number + The function, cofunction, or scalar resulting from the + interpolation. + """ - Compute the interpolation operation of interest. + pass - .. note:: - This method is called when an :class:`Interpolate` object is being assembled. + @abc.abstractmethod + def _build_callable(self) -> None: + """Builds callable to perform interpolation. + Stores the callable in self.callable """ pass - def assemble(self, tensor=None): - """Assemble the operator (or its action).""" + def assemble( + self, tensor: Function | Cofunction | MatrixBase | None = None + ) -> Function | Cofunction | MatrixBase | Number: + """Assemble the interpolation. The result depends on the rank (number of arguments) + of the :class:`Interpolate` expression: + + * rank-2: assemble the operator and return a matrix + * rank-1: assemble the action and return a function or cofunction + * rank-0: assemble the action and return a scalar by applying the dual argument + + Parameters + ---------- + tensor : Function | Cofunction | MatrixBase, optional + Pre-allocated storage to receive the interpolated result. For rank-2 + expressions this is expected to be a + :class:`~firedrake.assemble.AssembledMatrix`-compatible object whose + ``petscmat`` will be populated. For lower-rank expressions this is + a :class:`~firedrake.Function` or :class:`~firedrake.Cofunction`. + + Returns + ------- + Function | Cofunction | MatrixBase | Number + The function, cofunction, matrix, or scalar resulting from the + interpolation. + """ if self.rank == 2: # Assembling the operator res = tensor.petscmat if tensor else PETSc.Mat() # Get the interpolation matrix - self._get_callable() + self._build_callable() op2mat = self.callable() petsc_mat = op2mat.handle if tensor: @@ -220,41 +316,6 @@ def assemble(self, tensor=None): return self._interpolate(output=tensor) -def get_interpolator(expr: Interpolate, bcs=None) -> Interpolator: - arguments = expr.arguments() - has_mixed_arguments = any(len(arg.function_space()) > 1 for arg in arguments) - if len(arguments) == 2 and has_mixed_arguments: - return MixedInterpolator(expr, bcs=bcs) - - operand, = expr.ufl_operands - target_mesh = as_domain(expr.target_space) - source_mesh = extract_unique_domain(operand) or target_mesh - submesh_interp_implemented = ( - all(isinstance(m.topology, firedrake.mesh.MeshTopology) for m in [target_mesh, source_mesh]) - and target_mesh.submesh_ancesters[-1] is source_mesh.submesh_ancesters[-1] - and target_mesh.topological_dimension() == source_mesh.topological_dimension() - ) - if target_mesh is source_mesh or submesh_interp_implemented: - return SameMeshInterpolator(expr, bcs=bcs) - - target_topology = target_mesh.topology - source_topology = source_mesh.topology - - if isinstance(target_topology, VertexOnlyMeshTopology): - if isinstance(source_topology, VertexOnlyMeshTopology): - return VomOntoVomInterpolator(expr, bcs=bcs) - if target_mesh.geometric_dimension() != source_mesh.geometric_dimension(): - raise ValueError("Cannot interpolate onto a mesh of a different geometric dimension") - if not hasattr(target_mesh, "_parent_mesh") or target_mesh._parent_mesh is not source_mesh: - raise ValueError("Can only interpolate across meshes where the source mesh is the parent of the target") - return SameMeshInterpolator(expr, bcs=bcs) - - if has_mixed_arguments or len(expr.target_space) > 1: - return MixedInterpolator(expr, bcs=bcs) - - return CrossMeshInterpolator(expr, bcs=bcs) - - class DofNotDefinedError(Exception): r"""Raised when attempting to interpolate across function spaces where the target function space contains degrees of freedom (i.e. nodes) which cannot @@ -292,7 +353,7 @@ class CrossMeshInterpolator(Interpolator): """ @no_annotations - def __init__(self, expr: Interpolate, bcs=None): + def __init__(self, expr: Interpolate, bcs: Iterable[BCBase] | None = None): super().__init__(expr, bcs) if self.access and self.access != op2.WRITE: raise NotImplementedError( @@ -301,8 +362,8 @@ def __init__(self, expr: Interpolate, bcs=None): else: self.access = op2.WRITE if self.bcs: - raise NotImplementedError("bcs not implemented.") - if self.V_dest.ufl_element().mapping() != "identity": + raise NotImplementedError("bcs not implemented for cross-mesh interpolation.") + if self.target_space.ufl_element().mapping() != "identity": # Identity mapping between reference cell and physical coordinates # implies point evaluation nodes. A more general version would # require finding the global coordinates of all quadrature points @@ -318,25 +379,26 @@ def __init__(self, expr: Interpolate, bcs=None): if self.source_mesh.geometric_dimension() != self.target_mesh.geometric_dimension(): raise ValueError("Geometric dimensions of source and destination meshes must match.") - dest_element = self.V_dest.ufl_element() + dest_element = self.target_space.ufl_element() if isinstance(dest_element, finat.ufl.MixedElement): if isinstance(dest_element, (finat.ufl.VectorElement, finat.ufl.TensorElement)): # In this case all sub elements are equal base_element = dest_element.sub_elements[0] if base_element.reference_value_shape != (): raise NotImplementedError( - "Can't yet cross-mesh interpolate onto function spaces made from VectorElements or TensorElements made from sub elements with value shape other than ()." + "Can't yet cross-mesh interpolate onto function spaces made from VectorElements " + "or TensorElements made from sub elements with value shape other than ()." ) self.dest_element = base_element else: - raise NotImplementedError("Interpolation matrix with MixedFunctionSpace requires MixedInterpolator") + raise NotImplementedError("Interpolation matrix with MixedFunctionSpace requires MixedInterpolator.") else: # scalar fiat/finat element self.dest_element = dest_element - self._get_symbolic_expressions() + self._build_symbolic_expressions() - def _get_symbolic_expressions(self): + def _build_symbolic_expressions(self) -> None: """Constructs the symbolic Interpolate expressions for cross-mesh interpolation. Raises @@ -346,9 +408,9 @@ def _get_symbolic_expressions(self): in the source function space. """ from firedrake.assemble import assemble - # Immerse coordinates of V_dest point evaluation dofs in src_mesh - V_dest_vec = firedrake.VectorFunctionSpace(self.target_mesh, self.dest_element) - f_dest_node_coords = assemble(interpolate(self.target_mesh.coordinates, V_dest_vec)) + # Immerse coordinates of target space point evaluation dofs in src_mesh + target_space_vec = firedrake.VectorFunctionSpace(self.target_mesh, self.dest_element) + f_dest_node_coords = assemble(interpolate(self.target_mesh.coordinates, target_space_vec)) dest_node_coords = f_dest_node_coords.dat.data_ro.reshape(-1, self.target_mesh.geometric_dimension()) try: self.vom = firedrake.VertexOnlyMesh( @@ -361,7 +423,7 @@ def _get_symbolic_expressions(self): raise DofNotDefinedError(self.source_mesh, self.target_mesh) # Get the correct type of function space - shape = self.V_dest.ufl_function_space().value_shape + shape = self.target_space.ufl_function_space().value_shape if len(shape) == 0: fs_type = firedrake.FunctionSpace elif len(shape) == 1: @@ -382,7 +444,7 @@ def _get_symbolic_expressions(self): arg = Argument(self.P0DG_vom, 0 if self.expr.is_adjoint else 1) self.point_eval_input_ordering = interpolate(arg, self.P0DG_vom_input_ordering, matfree=matfree) - def _get_callable(self): + def _build_callable(self): from firedrake.assemble import assemble assert self.rank == 2 # The cross-mesh interpolation matrix is the product of the @@ -395,14 +457,11 @@ def _get_callable(self): self.handle = assemble(symbolic).petscmat self.callable = lambda: self - def _interpolate(self, output=None): + def _interpolate(self, output: Function | Cofunction | None = None) -> Function | Cofunction | Number: from firedrake.assemble import assemble - if self.expr.is_adjoint: - f = self.dual_arg - V_dest = self.expr_args[0].function_space().dual() - else: - V_dest = self.V_dest + # self.expr.function() is None in the 0-form case + V_dest = self.expr.function_space() or self.target_space output = output or Function(V_dest) if not self.expr.is_adjoint: @@ -437,6 +496,7 @@ def _interpolate(self, output=None): return assemble(action(self.dual_arg, output)) else: # f_src is a cofunction on V_dest.dual + f = self.dual_arg assert isinstance(f, Cofunction) # Our first adjoint operation is to assign the dat values to a # P0DG cofunction on our input ordering VOM. @@ -451,8 +511,6 @@ def _interpolate(self, output=None): interp = action(self.point_eval_input_ordering, f_input_ordering) f_src_at_src_node_coords = assemble(interp) - # We don't need to take the adjoint of self.point_eval because - # it was constructed using self.operand interp = action(self.point_eval, f_src_at_src_node_coords) assemble(interp, tensor=output) return output @@ -498,7 +556,7 @@ def __init__(self, expr, bcs=None): # Default access for forward 1-form or 2-form (forward and adjoint) self.access = op2.WRITE - def _get_tensor(self, output=None) -> op2.Mat | Function | Cofunction: + def _get_tensor(self) -> op2.Mat | Function | Cofunction: """Return the tensor to interpolate into. Returns @@ -510,19 +568,14 @@ def _get_tensor(self, output=None) -> op2.Mat | Function | Cofunction: R = firedrake.FunctionSpace(self.target_mesh, "Real", 0) f = Function(R, dtype=utils.ScalarType) elif self.rank == 1: - if output: - V_dest = output.function_space() - f = output - else: - V_dest = self.expr_args[0].function_space().dual() - f = Function(V_dest) - if self.access in {firedrake.MIN, firedrake.MAX}: - finfo = numpy.finfo(f.dat.dtype) - if self.access == firedrake.MIN: - val = firedrake.Constant(finfo.max) - else: - val = firedrake.Constant(finfo.min) - f.assign(val) + f = Function(self.expr.function_space()) + if self.access in {firedrake.MIN, firedrake.MAX}: + finfo = numpy.finfo(f.dat.dtype) + if self.access == firedrake.MIN: + val = firedrake.Constant(finfo.max) + else: + val = firedrake.Constant(finfo.min) + f.assign(val) elif self.rank == 2: Vrow = self.expr_args[0].function_space() Vcol = self.expr_args[1].function_space() @@ -540,18 +593,15 @@ def _get_tensor(self, output=None) -> op2.Mat | Function | Cofunction: raise ValueError(f"Cannot interpolate an expression with {self.rank} arguments") return f - def _get_callable(self, output=None) -> Callable: + def _build_callable(self, output=None) -> None: """Construct the callable that performs the interpolation. Returns ------- Callable """ - f = self._get_tensor(output=output) - if isinstance(f, op2.Mat): - tensor = f - else: - tensor = f.dat + f = output or self._get_tensor() + tensor = f if isinstance(f, op2.Mat) else f.dat loops = [] @@ -568,7 +618,7 @@ def _get_callable(self, output=None) -> Callable: else: # Split in the cofunction: split_form can only split in the coargument # Replace the cofunction with a coargument to construct the Jacobian - interp = self.expr._ufl_expr_reconstruct_(self.operand, self.V_dest) + interp = self.expr._ufl_expr_reconstruct_(self.operand, self.target_space) # Split the Jacobian into blocks interp_split = dict(firedrake.formmanipulation.split_form(interp)) # Split the cofunction @@ -583,7 +633,7 @@ def _get_callable(self, output=None) -> Callable: arguments = sub_expr.arguments() sub_space = sub_expr.argument_slots()[0].function_space().dual() sub_tensor = tensor[indices[0]] if self.rank == 1 else tensor - loops.extend(_interpolator(sub_space, sub_tensor, sub_expr, self.subset, arguments, self.access, bcs=self.bcs)) + loops.extend(build_interpolation_callables(sub_space, sub_tensor, sub_expr, self.subset, arguments, self.access, bcs=self.bcs)) if self.bcs and self.rank == 1: loops.extend(partial(bc.apply, f) for bc in self.bcs) @@ -602,7 +652,7 @@ def _interpolate(self, output=None): For arguments, see :class:`.Interpolator`. """ assert self.rank < 2 - self._get_callable(output=output) + self._build_callable(output=output) assembled_interpolator = self.callable() if output: output.assign(assembled_interpolator) @@ -619,21 +669,22 @@ class VomOntoVomInterpolator(SameMeshInterpolator): def __init__(self, expr: Interpolate, bcs=None): super().__init__(expr, bcs=bcs) - def _get_callable(self, output=None): + def _build_callable(self, output=None): self.mat = VomOntoVomMat(self) if self.rank == 2: # We make our own linear operator for this case using PETSc SFs tensor = None else: - f = self._get_tensor(output=output) + f = output or self._get_tensor() + assert isinstance(f, Function | Cofunction) tensor = f.dat # NOTE: get_dat_mpi_type ensures we get the correct MPI type for the # data, including the correct data size and dimensional information # (so for vector function spaces in 2 dimensions we might need a # concatenation of 2 MPI.DOUBLE types when we are in real mode) if tensor is not None: - self.mat.mpi_type, _ = get_dat_mpi_type(f.dat) assert self.rank == 1 + self.mat.mpi_type = get_dat_mpi_type(f.dat)[0] if self.expr.is_adjoint: assert isinstance(self.dual_arg, ufl.Cofunction) @@ -656,9 +707,8 @@ def callable(): # after cofunctions are fully implemented, this will need to be # reconsidered. temp_source_func = firedrake.Function(self.expr_args[1].function_space()) - self.mat.mpi_type, _ = get_dat_mpi_type(temp_source_func.dat) - - # Leave wrapper inside a callable so we can access the handle + self.mat.mpi_type = get_dat_mpi_type(temp_source_func.dat)[0] + # Leave mat inside a callable so we can access the handle # property. If matfree is True, then the handle is a PETSc SF # pretending to be a PETSc Mat. If matfree is False, then this # will be a PETSc Mat representing the equivalent permutation @@ -670,7 +720,32 @@ def callable(): @utils.known_pyop2_safe -def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): +def build_interpolation_callables(V: WithGeometry, tensor, expr, subset, arguments, access, bcs=None) -> tuple[Callable, ...]: + """Builds callables to perform interpolation. + + Parameters + ---------- + V : WithGeometry + _description_ + tensor : _type_ + _description_ + expr : _type_ + _description_ + subset : _type_ + _description_ + arguments : _type_ + _description_ + access : _type_ + _description_ + bcs : _type_, optional + _description_, by default None + + Returns + ------- + tuple[Callable, ...] + Tuple of callables + + """ if not isinstance(expr, ufl.Interpolate): raise ValueError("Expecting to interpolate a ufl.Interpolate") dual_arg, operand = expr.argument_slots() @@ -679,7 +754,7 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): to_element = create_element(V.ufl_element()) except KeyError: # FInAT only elements - raise NotImplementedError("Don't know how to create FIAT element for %s" % V.ufl_element()) + raise NotImplementedError(f"Don't know how to create FIAT element for {V.ufl_element()}") if access is op2.READ: raise ValueError("Can't have READ access for output function") @@ -1128,7 +1203,7 @@ def __init__(self, interpolator: VomOntoVomInterpolator): "The target vom and source vom must be linked by input ordering!" ) self.sf = self.original_vom.input_ordering_without_halos_sf - self.V = interpolator.V_dest + self.V = interpolator.target_space self.source_vom = interpolator.source_mesh self.expr = interpolator.operand self.arguments = extract_arguments(self.expr) @@ -1325,7 +1400,7 @@ def __init__(self, expr, bcs=None): # Split the dual argument dual_split = dict(firedrake.formmanipulation.split_form(self.dual_arg)) # Create the Jacobian to be split into blocks - self.expr = self.expr._ufl_expr_reconstruct_(self.operand, self.V_dest) + self.expr = self.expr._ufl_expr_reconstruct_(self.operand, self.target_space) Isub = {} for indices, form in firedrake.formmanipulation.split_form(self.expr): @@ -1347,7 +1422,6 @@ def __init__(self, expr, bcs=None): Isub[indices] = get_interpolator(form, bcs=sub_bcs) self._sub_interpolators = Isub - self.callable = self._get_callable def __getitem__(self, item): return self._sub_interpolators[item] @@ -1355,16 +1429,16 @@ def __getitem__(self, item): def __iter__(self): return iter(self._sub_interpolators) - def _get_callable(self): + def _build_callable(self): """Assemble the operator.""" shape = tuple(len(a.function_space()) for a in self.expr_args) blocks = numpy.full(shape, PETSc.Mat(), dtype=object) for i in self: - self[i]._get_callable() + self[i]._build_callable() blocks[i] = self[i].callable().handle petscmat = PETSc.Mat().createNest(blocks) tensor = firedrake.AssembledMatrix(self.expr_args, self.bcs, petscmat) - return tensor.M + self.callable = lambda: tensor.M def _interpolate(self, output=None): """Assemble the action.""" From b3ce8f3fcc1717bd4946f7afd45978910a8c604e Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Tue, 14 Oct 2025 12:40:00 +0100 Subject: [PATCH 109/125] remove _interpolate WIP fixes --- firedrake/interpolation.py | 221 +++++++++++++++++-------------------- 1 file changed, 101 insertions(+), 120 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index ebc8a86294..d8bf04fad3 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -249,29 +249,8 @@ def __init__(self, expr: Interpolate, bcs: Iterable[BCBase] | None = None): self.access = expr.options.access @abc.abstractmethod - def _interpolate( - self, output: Function | Cofunction | MatrixBase | None = None - ) -> Function | Cofunction | Number: - """Compute the interpolation action. - - Parameters - ---------- - tensor : Function | Cofunction | MatrixBase, optional - Tensor to hold the interpolated result. - - Returns - ------- - Function | Cofunction | Number - The function, cofunction, or scalar resulting from the - interpolation. - - """ - pass - - @abc.abstractmethod - def _build_callable(self) -> None: - """Builds callable to perform interpolation. - Stores the callable in self.callable + def _build_callable(self, output=None) -> None: + """Builds callable to perform interpolation. Stored in ``self.callable``. """ pass @@ -300,20 +279,28 @@ def assemble( The function, cofunction, matrix, or scalar resulting from the interpolation. """ + self._build_callable(output=tensor) + assembled_interpolator = self.callable() if self.rank == 2: # Assembling the operator + assert isinstance(tensor, MatrixBase | None) res = tensor.petscmat if tensor else PETSc.Mat() # Get the interpolation matrix - self._build_callable() - op2mat = self.callable() - petsc_mat = op2mat.handle + petsc_mat = assembled_interpolator.handle if tensor: petsc_mat.copy(tensor.petscmat) else: res = petsc_mat return tensor or firedrake.AssembledMatrix(self.expr_args, self.bcs, res) else: - return self._interpolate(output=tensor) + assert isinstance(tensor, Function | Cofunction | None) + if tensor: + tensor.assign(assembled_interpolator) + return tensor + if self.rank == 0: + return assembled_interpolator.dat.data.item() + else: + return assembled_interpolator class DofNotDefinedError(Exception): @@ -399,7 +386,7 @@ def __init__(self, expr: Interpolate, bcs: Iterable[BCBase] | None = None): self._build_symbolic_expressions() def _build_symbolic_expressions(self) -> None: - """Constructs the symbolic Interpolate expressions for cross-mesh interpolation. + """Constructs the symbolic ``Interpolate`` expressions for cross-mesh interpolation. Raises ------ @@ -444,76 +431,76 @@ def _build_symbolic_expressions(self) -> None: arg = Argument(self.P0DG_vom, 0 if self.expr.is_adjoint else 1) self.point_eval_input_ordering = interpolate(arg, self.P0DG_vom_input_ordering, matfree=matfree) - def _build_callable(self): - from firedrake.assemble import assemble - assert self.rank == 2 - # The cross-mesh interpolation matrix is the product of the - # `self.point_eval_interpolate` and the permutation - # given by `self.to_input_ordering_interpolate`. - if self.expr.is_adjoint: - symbolic = action(self.point_eval, self.point_eval_input_ordering) - else: - symbolic = action(self.point_eval_input_ordering, self.point_eval) - self.handle = assemble(symbolic).petscmat - self.callable = lambda: self - - def _interpolate(self, output: Function | Cofunction | None = None) -> Function | Cofunction | Number: + def _build_callable(self, output=None): from firedrake.assemble import assemble - # self.expr.function() is None in the 0-form case V_dest = self.expr.function_space() or self.target_space - output = output or Function(V_dest) - - if not self.expr.is_adjoint: - # We evaluate the operand at the node coordinates of the destination space - f_point_eval = assemble(self.point_eval) - - # We create the input-ordering Function before interpolating so we can - # set default missing values if required. - f_point_eval_input_ordering = Function(self.P0DG_vom_input_ordering) - if self.default_missing_val is not None: - f_point_eval_input_ordering.assign(self.default_missing_val) - elif self.allow_missing_dofs: - # If we allow missing points there may be points in the target - # mesh that are not in the source mesh. If we don't specify a - # default missing value we set these to NaN so we can identify - # them later. - f_point_eval_input_ordering.dat.data_wo[:] = numpy.nan - - assemble(action(self.point_eval_input_ordering, f_point_eval), - tensor=f_point_eval_input_ordering) - - # We assign these values to the output function - if self.allow_missing_dofs and self.default_missing_val is None: - indices = numpy.where(~numpy.isnan(f_point_eval_input_ordering.dat.data_ro))[0] - output.dat.data_wo[indices] = f_point_eval_input_ordering.dat.data_ro[indices] - else: - output.dat.data_wo[:] = f_point_eval_input_ordering.dat.data_ro[:] + f = output or Function(V_dest) - if self.rank == 0: - # We take the action of the dual_arg on the interpolated function - assert not isinstance(self.dual_arg, ufl.Coargument) - return assemble(action(self.dual_arg, output)) + if self.rank == 2: + # The cross-mesh interpolation matrix is the product of the + # `self.point_eval_interpolate` and the permutation + # given by `self.to_input_ordering_interpolate`. + if self.expr.is_adjoint: + symbolic = action(self.point_eval, self.point_eval_input_ordering) + else: + symbolic = action(self.point_eval_input_ordering, self.point_eval) + self.handle = assemble(symbolic).petscmat + self.callable = lambda: self else: - # f_src is a cofunction on V_dest.dual - f = self.dual_arg - assert isinstance(f, Cofunction) - # Our first adjoint operation is to assign the dat values to a - # P0DG cofunction on our input ordering VOM. - f_input_ordering = Cofunction(self.P0DG_vom_input_ordering.dual()) - f_input_ordering.dat.data_wo[:] = f.dat.data_ro[:] - - # The rest of the adjoint interpolation is the composition - # of the adjoint interpolators in the reverse direction. - # We don't worry about skipping over missing points here - # because we're going from the input ordering VOM to the original VOM - # and all points from the input ordering VOM are in the original. - interp = action(self.point_eval_input_ordering, f_input_ordering) - f_src_at_src_node_coords = assemble(interp) - - interp = action(self.point_eval, f_src_at_src_node_coords) - assemble(interp, tensor=output) - return output + if self.expr.is_adjoint: + assert self.rank == 1 + # f_src is a cofunction on V_dest.dual + cofunc = self.dual_arg + assert isinstance(cofunc, Cofunction) + # Our first adjoint operation is to assign the dat values to a + # P0DG cofunction on our input ordering VOM. + f_input_ordering = Cofunction(self.P0DG_vom_input_ordering.dual()) + f_input_ordering.dat.data_wo[:] = cofunc.dat.data_ro[:] + + # The rest of the adjoint interpolation is the composition + # of the adjoint interpolators in the reverse direction. + # We don't worry about skipping over missing points here + # because we're going from the input ordering VOM to the original VOM + # and all points from the input ordering VOM are in the original. + def callable(): + f_src_at_src_node_coords = assemble(action(self.point_eval_input_ordering, f_input_ordering)) + assemble(action(self.point_eval, f_src_at_src_node_coords), tensor=f) + return f + else: + # We evaluate the operand at the node coordinates of the destination space + f_point_eval = assemble(self.point_eval) + + # We create the input-ordering Function before interpolating so we can + # set default missing values if required. + f_point_eval_input_ordering = Function(self.P0DG_vom_input_ordering) + if self.default_missing_val is not None: + f_point_eval_input_ordering.assign(self.default_missing_val) + elif self.allow_missing_dofs: + # If we allow missing points there may be points in the target + # mesh that are not in the source mesh. If we don't specify a + # default missing value we set these to NaN so we can identify + # them later. + f_point_eval_input_ordering.dat.data_wo[:] = numpy.nan + + def callable(): + assemble(action(self.point_eval_input_ordering, f_point_eval), + tensor=f_point_eval_input_ordering) + + # We assign these values to the output function + if self.allow_missing_dofs and self.default_missing_val is None: + indices = numpy.where(~numpy.isnan(f_point_eval_input_ordering.dat.data_ro))[0] + f.dat.data_wo[indices] = f_point_eval_input_ordering.dat.data_ro[indices] + else: + f.dat.data_wo[:] = f_point_eval_input_ordering.dat.data_ro[:] + + if self.rank == 0: + # We take the action of the dual_arg on the interpolated function + assert not isinstance(self.dual_arg, ufl.Coargument) + return assemble(action(self.dual_arg, f)) + else: + return f + self.callable = callable class SameMeshInterpolator(Interpolator): @@ -1429,31 +1416,25 @@ def __getitem__(self, item): def __iter__(self): return iter(self._sub_interpolators) - def _build_callable(self): + def _build_callable(self, output=None): """Assemble the operator.""" - shape = tuple(len(a.function_space()) for a in self.expr_args) - blocks = numpy.full(shape, PETSc.Mat(), dtype=object) - for i in self: - self[i]._build_callable() - blocks[i] = self[i].callable().handle - petscmat = PETSc.Mat().createNest(blocks) - tensor = firedrake.AssembledMatrix(self.expr_args, self.bcs, petscmat) - self.callable = lambda: tensor.M - - def _interpolate(self, output=None): - """Assemble the action.""" - if self.rank == 0: - result = sum(self[i].assemble() for i in self) - return output.assign(result) if output else result - - if output is None: - output = firedrake.Function(self.expr_args[-1].function_space().dual()) - - if self.rank == 1: - for k, sub_tensor in enumerate(output.subfunctions): - sub_tensor.assign(sum(self[i].assemble() for i in self if i[0] == k)) - elif self.rank == 2: - for k, sub_tensor in enumerate(output.subfunctions): - sub_tensor.assign(sum(self[i]._interpolate() - for i in self if i[0] == k)) - return output + f = output or Function(self.expr_args[-1].function_space().dual()) + if self.rank == 2: + shape = tuple(len(a.function_space()) for a in self.expr_args) + blocks = numpy.full(shape, PETSc.Mat(), dtype=object) + for i in self: + self[i]._build_callable() + blocks[i] = self[i].callable().handle + petscmat = PETSc.Mat().createNest(blocks) + tensor = firedrake.AssembledMatrix(self.expr_args, self.bcs, petscmat) + self.callable = lambda: tensor.M + elif self.rank == 1: + def callable(): + for k, sub_tensor in enumerate(f.subfunctions): + sub_tensor.assign(sum(self[i].assemble() for i in self if i[0] == k)) + return f + self.callable = callable + else: + def callable(): + return sum(self[i].assemble() for i in self) + self.callable = callable From ca93fbd8a83a34808680e0fa8b7be81f8bc2973a Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Tue, 14 Oct 2025 13:31:47 +0100 Subject: [PATCH 110/125] fixes --- firedrake/interpolation.py | 70 +++++++++++++++----------------------- 1 file changed, 28 insertions(+), 42 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index d8bf04fad3..1fb3d6f475 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -249,8 +249,18 @@ def __init__(self, expr: Interpolate, bcs: Iterable[BCBase] | None = None): self.access = expr.options.access @abc.abstractmethod - def _build_callable(self, output=None) -> None: + def _build_callable(self, output: Function | Cofunction | MatrixBase | None = None) -> None: """Builds callable to perform interpolation. Stored in ``self.callable``. + + If ``self.rank == 2``, then ``self.callable()`` must return an object with a ``handle`` + attribute that stores a PETSc matrix. If ``self.rank == 1``, then `self.callable()` must + return a ``Function`` or ``Cofunction`` (in the forward and adjoint cases respectively). + If ``self.rank == 0``, then ``self.callable()`` must return a number. + + Parameters + ---------- + output : Function | Cofunction | MatrixBase | None, optional + Optional tensor to store the result in, by default None """ pass @@ -268,8 +278,8 @@ def assemble( ---------- tensor : Function | Cofunction | MatrixBase, optional Pre-allocated storage to receive the interpolated result. For rank-2 - expressions this is expected to be a - :class:`~firedrake.assemble.AssembledMatrix`-compatible object whose + expressions this is expected to be a subclass of + :class:`~firedrake.matrix.MatrixBase` whose ``petscmat`` will be populated. For lower-rank expressions this is a :class:`~firedrake.Function` or :class:`~firedrake.Cofunction`. @@ -280,13 +290,13 @@ def assemble( interpolation. """ self._build_callable(output=tensor) - assembled_interpolator = self.callable() + result = self.callable() if self.rank == 2: # Assembling the operator assert isinstance(tensor, MatrixBase | None) res = tensor.petscmat if tensor else PETSc.Mat() # Get the interpolation matrix - petsc_mat = assembled_interpolator.handle + petsc_mat = result.handle if tensor: petsc_mat.copy(tensor.petscmat) else: @@ -294,13 +304,10 @@ def assemble( return tensor or firedrake.AssembledMatrix(self.expr_args, self.bcs, res) else: assert isinstance(tensor, Function | Cofunction | None) - if tensor: - tensor.assign(assembled_interpolator) + if tensor and isinstance(result, Function | Cofunction): + tensor.assign(result) return tensor - if self.rank == 0: - return assembled_interpolator.dat.data.item() - else: - return assembled_interpolator + return result class DofNotDefinedError(Exception): @@ -485,7 +492,7 @@ def callable(): def callable(): assemble(action(self.point_eval_input_ordering, f_point_eval), - tensor=f_point_eval_input_ordering) + tensor=f_point_eval_input_ordering) # We assign these values to the output function if self.allow_missing_dofs and self.default_missing_val is None: @@ -581,12 +588,6 @@ def _get_tensor(self) -> op2.Mat | Function | Cofunction: return f def _build_callable(self, output=None) -> None: - """Construct the callable that performs the interpolation. - - Returns - ------- - Callable - """ f = output or self._get_tensor() tensor = f if isinstance(f, op2.Mat) else f.dat @@ -628,28 +629,10 @@ def _build_callable(self, output=None) -> None: def callable(loops, f): for l in loops: l() - return f + return f.dat.data.item() if self.rank == 0 else f self.callable = partial(callable, loops, f) - @PETSc.Log.EventDecorator() - def _interpolate(self, output=None): - """Compute the interpolation. - - For arguments, see :class:`.Interpolator`. - """ - assert self.rank < 2 - self._build_callable(output=output) - assembled_interpolator = self.callable() - if output: - output.assign(assembled_interpolator) - return output - - if self.rank == 0: - return assembled_interpolator.dat.data.item() - else: - return assembled_interpolator - class VomOntoVomInterpolator(SameMeshInterpolator): @@ -1418,7 +1401,8 @@ def __iter__(self): def _build_callable(self, output=None): """Assemble the operator.""" - f = output or Function(self.expr_args[-1].function_space().dual()) + V_dest = self.expr.function_space() or self.target_space + f = output or Function(V_dest) if self.rank == 2: shape = tuple(len(a.function_space()) for a in self.expr_args) blocks = numpy.full(shape, PETSc.Mat(), dtype=object) @@ -1427,14 +1411,16 @@ def _build_callable(self, output=None): blocks[i] = self[i].callable().handle petscmat = PETSc.Mat().createNest(blocks) tensor = firedrake.AssembledMatrix(self.expr_args, self.bcs, petscmat) - self.callable = lambda: tensor.M + callable = lambda: tensor.M elif self.rank == 1: def callable(): for k, sub_tensor in enumerate(f.subfunctions): sub_tensor.assign(sum(self[i].assemble() for i in self if i[0] == k)) return f - self.callable = callable else: + assert self.rank == 0 def callable(): - return sum(self[i].assemble() for i in self) - self.callable = callable + result = sum(self[i].assemble() for i in self) + assert isinstance(result, Number) + return result + self.callable = callable From 39aef1493bd322f80fdfa572ee9802dce7cb5aca Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Tue, 14 Oct 2025 13:34:25 +0100 Subject: [PATCH 111/125] add `get_interpolator` to `__all__` --- firedrake/interpolation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 1fb3d6f475..6ea62f0cc8 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -44,6 +44,7 @@ __all__ = ( "interpolate", "Interpolate", + "get_interpolator", "DofNotDefinedError", ) @@ -730,7 +731,7 @@ def build_interpolation_callables(V: WithGeometry, tensor, expr, subset, argumen raise ValueError("Can't have READ access for output function") # NOTE: The par_loop is always over the target mesh cells. - target_mesh = as_domain(V) + target_mesh = V.mesh() source_mesh = extract_unique_domain(operand) or target_mesh if isinstance(target_mesh.topology, VertexOnlyMeshTopology): # For trans-mesh interpolation we use a FInAT QuadratureElement as the From 3658338d6f4049e8c0cf094e59814dc216a8516a Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Tue, 14 Oct 2025 14:30:19 +0100 Subject: [PATCH 112/125] lint; type hints and docstrings --- firedrake/bcs.py | 5 +- firedrake/interpolation.py | 175 +++++++++++++++++++++++++++---------- 2 files changed, 132 insertions(+), 48 deletions(-) diff --git a/firedrake/bcs.py b/firedrake/bcs.py index c4fd56007e..a89387cc62 100644 --- a/firedrake/bcs.py +++ b/firedrake/bcs.py @@ -6,8 +6,7 @@ import ufl from ufl import as_ufl, as_tensor -from finat.ufl import VectorElement, EnrichedElement -from finat.physically_mapped import DirectlyDefinedElement, PhysicallyMappedElement +from finat.ufl import VectorElement import finat import pyop2 as op2 @@ -362,7 +361,7 @@ def function_arg(self, g): try: self._function_arg = firedrake.Function(V) interpolator = get_interpolator(firedrake.interpolate(g, V)) - # Call this here to check if the element supports interpolation + # Call this here to check if the element supports interpolation interpolator._build_callable() self._function_arg_update = lambda: interpolator.assemble(tensor=self._function_arg) except (ValueError, NotImplementedError): diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 6ea62f0cc8..d09102a67a 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -31,7 +31,7 @@ from firedrake.ufl_expr import Argument, Coargument, action from firedrake.cofunction import Cofunction from firedrake.function import Function -from firedrake.mesh import MissingPointsBehaviour, VertexOnlyMeshMissingPointsError, VertexOnlyMeshTopology +from firedrake.mesh import MissingPointsBehaviour, VertexOnlyMeshMissingPointsError, VertexOnlyMeshTopology, MeshGeometry from firedrake.petsc import PETSc from firedrake.halo import _get_mtype as get_dat_mpi_type from firedrake.functionspaceimpl import WithGeometry @@ -268,7 +268,7 @@ def _build_callable(self, output: Function | Cofunction | MatrixBase | None = No def assemble( self, tensor: Function | Cofunction | MatrixBase | None = None ) -> Function | Cofunction | MatrixBase | Number: - """Assemble the interpolation. The result depends on the rank (number of arguments) + """Assemble the interpolation. The result depends on the rank (number of arguments) of the :class:`Interpolate` expression: * rank-2: assemble the operator and return a matrix @@ -279,7 +279,7 @@ def assemble( ---------- tensor : Function | Cofunction | MatrixBase, optional Pre-allocated storage to receive the interpolated result. For rank-2 - expressions this is expected to be a subclass of + expressions this is expected to be a subclass of :class:`~firedrake.matrix.MatrixBase` whose ``petscmat`` will be populated. For lower-rank expressions this is a :class:`~firedrake.Function` or :class:`~firedrake.Cofunction`. @@ -492,7 +492,7 @@ def callable(): f_point_eval_input_ordering.dat.data_wo[:] = numpy.nan def callable(): - assemble(action(self.point_eval_input_ordering, f_point_eval), + assemble(action(self.point_eval_input_ordering, f_point_eval), tensor=f_point_eval_input_ordering) # We assign these values to the output function @@ -619,10 +619,8 @@ def _build_callable(self, output=None) -> None: for indices, sub_expr in expressions.items(): if isinstance(sub_expr, ufl.ZeroBaseForm): continue - arguments = sub_expr.arguments() - sub_space = sub_expr.argument_slots()[0].function_space().dual() sub_tensor = tensor[indices[0]] if self.rank == 1 else tensor - loops.extend(build_interpolation_callables(sub_space, sub_tensor, sub_expr, self.subset, arguments, self.access, bcs=self.bcs)) + loops.extend(build_interpolation_callables(sub_expr, sub_tensor, self.subset, self.access, bcs=self.bcs)) if self.bcs and self.rank == 1: loops.extend(partial(bc.apply, f) for bc in self.bcs) @@ -682,8 +680,8 @@ def callable(): # Leave mat inside a callable so we can access the handle # property. If matfree is True, then the handle is a PETSc SF # pretending to be a PETSc Mat. If matfree is False, then this - # will be a PETSc Mat representing the equivalent permutation - # matrix + # will be a PETSc Mat representing the equivalent permutation matrix + def callable(): return self.mat @@ -691,36 +689,40 @@ def callable(): @utils.known_pyop2_safe -def build_interpolation_callables(V: WithGeometry, tensor, expr, subset, arguments, access, bcs=None) -> tuple[Callable, ...]: - """Builds callables to perform interpolation. +def build_interpolation_callables( + expr: ufl.Interpolate, + tensor: op2.Dat | op2.Mat | op2.Global, + access: Literal[op2.WRITE, op2.MIN, op2.MAX, op2.INC], + subset: op2.Subset | None = None, + bcs: Iterable[BCBase] | None = None +) -> tuple[Callable, ...]: + """Returns tuple of callables which calculate the interpolation. Parameters ---------- - V : WithGeometry - _description_ - tensor : _type_ - _description_ - expr : _type_ - _description_ - subset : _type_ - _description_ - arguments : _type_ - _description_ - access : _type_ - _description_ - bcs : _type_, optional - _description_, by default None + expr : ufl.Interpolate + The symbolic interpolation expression. + tensor : op2.Dat | op2.Mat | op2.Global + Object to hold the result of the interpolation. + access : Literal[op2.WRITE, op2.MIN, op2.MAX, op2.INC] + op2 access descriptor + subset : op2.Subset | None, optional + An optional subset to apply the interpolation over, by default None. + bcs : Iterable[BCBase] | None, optional + An optional list of boundary conditions to zero-out in the + output function space. Interpolator rows or columns which are + associated with boundary condition nodes are zeroed out when this is + specified. By default None, by default None. Returns ------- tuple[Callable, ...] - Tuple of callables - - """ + Tuple of callables which perform the interpolation. + """ if not isinstance(expr, ufl.Interpolate): raise ValueError("Expecting to interpolate a ufl.Interpolate") dual_arg, operand = expr.argument_slots() - + V = dual_arg.function_space().dual() try: to_element = create_element(V.ufl_element()) except KeyError: @@ -825,6 +827,7 @@ def build_interpolation_callables(V: WithGeometry, tensor, expr, subset, argumen copyin = () copyout = () + arguments = expr.arguments() if isinstance(tensor, op2.Global): parloop_args.append(tensor(access)) elif isinstance(tensor, op2.Dat): @@ -897,7 +900,7 @@ def build_interpolation_callables(V: WithGeometry, tensor, expr, subset, argumen return copyin + callables + (parloop_compute_callable, ) + copyout -def get_interp_node_map(source_mesh, target_mesh, fs): +def get_interp_node_map(source_mesh: MeshGeometry, target_mesh: MeshGeometry, fs: WithGeometry) -> op2.Map | None: """Return the map between cells of the target mesh and nodes of the function space. If the function space is defined on the source mesh then the node map is composed @@ -999,7 +1002,7 @@ def rebuild_te(element, expr_cell, rt_var_name): transpose=element._transpose) -def compose_map_and_cache(map1, map2): +def compose_map_and_cache(map1: op2.Map, map2: op2.Map | None) -> op2.ComposedMap | None: """ Retrieve a :class:`pyop2.ComposedMap` map from the cache of map1 using map2 as the cache key. The composed map maps from the iterset @@ -1022,7 +1025,7 @@ def compose_map_and_cache(map1, map2): return cmap -def vom_cell_parent_node_map_extruded(vertex_only_mesh, extruded_cell_node_map): +def vom_cell_parent_node_map_extruded(vertex_only_mesh: MeshGeometry, extruded_cell_node_map: op2.Map) -> op2.Map: """Build a map from the cells of a vertex only mesh to the nodes of the nodes on the source mesh where the source mesh is extruded. @@ -1210,13 +1213,25 @@ def mpi_type(self): def mpi_type(self, val): self._mpi_type = val - def expr_as_coeff(self, source_vec=None): - """ - Return a coefficient that corresponds to the expression used at + def expr_as_coeff(self, source_vec: PETSc.Vec | None = None) -> Function: + """Return a coefficient that corresponds to the expression used at construction, where the expression has been interpolated into the P0DG function space on the source vertex-only mesh. Will fail if there are no arguments. + + Parameters + ---------- + source_vec : PETSc.Vec | None, optional + Optional vector used to replace arguments in the expression. + By default None. + + Returns + ------- + Function + A Function representing the expression as a coefficient on the + source vertex-only mesh. + """ # Since we always output a coefficient when we don't have arguments in # the expression, we should evaluate the expression on the source mesh @@ -1243,7 +1258,16 @@ def expr_as_coeff(self, source_vec=None): coeff = firedrake.Function(P0DG).interpolate(coeff_expr) return coeff - def reduce(self, source_vec, target_vec): + def reduce(self, source_vec: PETSc.Vec, target_vec: PETSc.Vec) -> None: + """Reduce data in source_vec using the PETSc SF. + + Parameters + ---------- + source_vec : PETSc.Vec + The vector to reduce. + target_vec : PETSc.Vec + The vector to store the result in. + """ source_arr = source_vec.getArray(readonly=True) target_arr = target_vec.getArray() self.sf.reduceBegin( @@ -1259,7 +1283,16 @@ def reduce(self, source_vec, target_vec): MPI.REPLACE, ) - def broadcast(self, source_vec, target_vec): + def broadcast(self, source_vec: PETSc.Vec, target_vec: PETSc.Vec) -> None: + """Broadcast data in source_vec using the PETSc SF. + + Parameters + ---------- + source_vec : PETSc.Vec + The vector to broadcast. + target_vec : PETSc.Vec + The vector to store the result in. + """ source_arr = source_vec.getArray(readonly=True) target_arr = target_vec.getArray() self.sf.bcastBegin( @@ -1275,8 +1308,20 @@ def broadcast(self, source_vec, target_vec): MPI.REPLACE, ) - def mult(self, mat, source_vec, target_vec): - # need to evaluate expression before doing mult + def mult(self, mat: PETSc.Mat, source_vec: PETSc.Vec, target_vec: PETSc.Vec) -> None: + """Applies the interpolation operator. + + Parameters + ---------- + mat : PETSc.Mat + Required by petsc4py but unused. + source_vec : PETSc.Vec + The vector to interpolate. + target_vec : PETSc.Vec + The vector to store the result in. + """ + # Need to convert the expression into a coefficient + # so that we can broadcast/reduce it coeff = self.expr_as_coeff(source_vec) with coeff.dat.vec_ro as coeff_vec: if self.forward_reduce: @@ -1284,10 +1329,35 @@ def mult(self, mat, source_vec, target_vec): else: self.broadcast(coeff_vec, target_vec) - def multHermitian(self, mat, source_vec, target_vec): + def multHermitian(self, mat: PETSc.Mat, source_vec: PETSc.Vec, target_vec: PETSc.Vec) -> None: + """Applies the adjoint of the interpolation operator. + Since ``VomOntoVomMat`` represents a permutation, it is + real-valued and thus the adjoint is the transpose. + + Parameters + ---------- + mat : PETSc.Mat + Required by petsc4py but unused. + source_vec : PETSc.Vec + The vector to adjoint interpolate. + target_vec : PETSc.Vec + The vector to store the result in. + """ self.multTranspose(mat, source_vec, target_vec) - def multTranspose(self, mat, source_vec, target_vec): + def multTranspose(self, mat: PETSc.Mat, source_vec: PETSc.Vec, target_vec: PETSc.Vec) -> None: + """Applies the tranpose of the interpolation operator. Called by `self.multHermitian`. + + Parameters + ---------- + mat : PETSc.Mat + Required by petsc4py but unused. + source_vec : PETSc.Vec + The vector to transpose interpolate. + target_vec : PETSc.Vec + The vector to store the result in. + + """ # can only do adjoint if our expression exclusively contains a # single argument, making the application of the adjoint operator # straightforward (haven't worked out how to do this otherwise!) @@ -1314,9 +1384,17 @@ def multTranspose(self, mat, source_vec, target_vec): target_vec.zeroEntries() self.reduce(source_vec, target_vec) - def _create_permutation_mat(self): + def _create_permutation_mat(self) -> PETSc.Mat: """Creates the PETSc matrix that represents the interpolation operator from a vertex-only mesh to - its input ordering vertex-only mesh""" + its input ordering vertex-only mesh. + + Returns + ------- + PETSc.Mat + PETSc seqaij matrix + """ + # To create the permutation matrix we broadcast an array of indices contiguous across + # all ranks and then use these indices to set the values of the matrix directly. mat = PETSc.Mat().createAIJ((self.target_size, self.source_size), nnz=1, comm=self.V.comm) mat.setUp() start = sum(self._local_sizes[:self.V.comm.rank]) @@ -1333,7 +1411,14 @@ def _create_permutation_mat(self): mat.transpose() return mat - def _wrap_python_mat(self): + def _wrap_python_mat(self) -> PETSc.Mat: + """Wraps this object as a PETSc Mat. Used for matfree interpolation. + + Returns + ------- + PETSc.Mat + A PETSc Mat of type python with this object as its context. + """ mat = PETSc.Mat().create(comm=self.V.comm) if self.forward_reduce: mat_size = (self.source_size, self.target_size) @@ -1401,7 +1486,6 @@ def __iter__(self): return iter(self._sub_interpolators) def _build_callable(self, output=None): - """Assemble the operator.""" V_dest = self.expr.function_space() or self.target_space f = output or Function(V_dest) if self.rank == 2: @@ -1420,6 +1504,7 @@ def callable(): return f else: assert self.rank == 0 + def callable(): result = sum(self[i].assemble() for i in self) assert isinstance(result, Number) From b9ab2e0fe74563dfae4c77eb4cad8d0306a8ee53 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Tue, 14 Oct 2025 15:06:45 +0100 Subject: [PATCH 113/125] fix --- firedrake/interpolation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index d09102a67a..e2cd9cc14c 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -620,7 +620,7 @@ def _build_callable(self, output=None) -> None: if isinstance(sub_expr, ufl.ZeroBaseForm): continue sub_tensor = tensor[indices[0]] if self.rank == 1 else tensor - loops.extend(build_interpolation_callables(sub_expr, sub_tensor, self.subset, self.access, bcs=self.bcs)) + loops.extend(_build_interpolation_callables(sub_expr, sub_tensor, self.access, self.subset, self.bcs)) if self.bcs and self.rank == 1: loops.extend(partial(bc.apply, f) for bc in self.bcs) @@ -689,7 +689,7 @@ def callable(): @utils.known_pyop2_safe -def build_interpolation_callables( +def _build_interpolation_callables( expr: ufl.Interpolate, tensor: op2.Dat | op2.Mat | op2.Global, access: Literal[op2.WRITE, op2.MIN, op2.MAX, op2.INC], From 93476e43011186a499f7dd64db680bf2ad37e9a6 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Thu, 16 Oct 2025 13:02:00 +0100 Subject: [PATCH 114/125] suggestions --- firedrake/bcs.py | 6 +++--- firedrake/interpolation.py | 41 ++++++++++++++++++-------------------- 2 files changed, 22 insertions(+), 25 deletions(-) diff --git a/firedrake/bcs.py b/firedrake/bcs.py index a89387cc62..8b70f3d400 100644 --- a/firedrake/bcs.py +++ b/firedrake/bcs.py @@ -1,7 +1,7 @@ # A module implementing strong (Dirichlet) boundary conditions. import numpy as np -import functools +from functools import partial, reduce import itertools import ufl @@ -167,7 +167,7 @@ def hermite_stride(bcnodes): # Edge conditions have only been tested with Lagrange elements. # Need to expand the list. bcnodes1.append(hermite_stride(self._function_space.boundary_nodes(ss))) - bcnodes1 = functools.reduce(np.intersect1d, bcnodes1) + bcnodes1 = reduce(np.intersect1d, bcnodes1) bcnodes.append(bcnodes1) return np.concatenate(bcnodes) @@ -363,7 +363,7 @@ def function_arg(self, g): interpolator = get_interpolator(firedrake.interpolate(g, V)) # Call this here to check if the element supports interpolation interpolator._build_callable() - self._function_arg_update = lambda: interpolator.assemble(tensor=self._function_arg) + self._function_arg_update = partial(interpolator.assemble, tensor=self._function_arg) except (ValueError, NotImplementedError): # Element doesn't implement interpolation self._function_arg = firedrake.Function(V).project(g) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index e2cd9cc14c..79beb99dbf 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -305,10 +305,7 @@ def assemble( return tensor or firedrake.AssembledMatrix(self.expr_args, self.bcs, res) else: assert isinstance(tensor, Function | Cofunction | None) - if tensor and isinstance(result, Function | Cofunction): - tensor.assign(result) - return tensor - return result + return tensor.assign(result) if tensor else result class DofNotDefinedError(Exception): @@ -454,13 +451,15 @@ def _build_callable(self, output=None): else: symbolic = action(self.point_eval_input_ordering, self.point_eval) self.handle = assemble(symbolic).petscmat - self.callable = lambda: self + def callable() -> CrossMeshInterpolator: + return self else: if self.expr.is_adjoint: assert self.rank == 1 # f_src is a cofunction on V_dest.dual cofunc = self.dual_arg assert isinstance(cofunc, Cofunction) + # Our first adjoint operation is to assign the dat values to a # P0DG cofunction on our input ordering VOM. f_input_ordering = Cofunction(self.P0DG_vom_input_ordering.dual()) @@ -471,7 +470,7 @@ def _build_callable(self, output=None): # We don't worry about skipping over missing points here # because we're going from the input ordering VOM to the original VOM # and all points from the input ordering VOM are in the original. - def callable(): + def callable() -> Cofunction: f_src_at_src_node_coords = assemble(action(self.point_eval_input_ordering, f_input_ordering)) assemble(action(self.point_eval, f_src_at_src_node_coords), tensor=f) return f @@ -491,7 +490,7 @@ def callable(): # them later. f_point_eval_input_ordering.dat.data_wo[:] = numpy.nan - def callable(): + def callable() -> Function | Number: assemble(action(self.point_eval_input_ordering, f_point_eval), tensor=f_point_eval_input_ordering) @@ -508,7 +507,7 @@ def callable(): return assemble(action(self.dual_arg, f)) else: return f - self.callable = callable + self.callable = callable class SameMeshInterpolator(Interpolator): @@ -645,7 +644,6 @@ def _build_callable(self, output=None): tensor = None else: f = output or self._get_tensor() - assert isinstance(f, Function | Cofunction) tensor = f.dat # NOTE: get_dat_mpi_type ensures we get the correct MPI type for the # data, including the correct data size and dimensional information @@ -656,13 +654,16 @@ def _build_callable(self, output=None): self.mat.mpi_type = get_dat_mpi_type(f.dat)[0] if self.expr.is_adjoint: assert isinstance(self.dual_arg, ufl.Cofunction) + assert isinstance(f, Cofunction) - def callable(): + def callable() -> Cofunction: with self.dual_arg.dat.vec_ro as source_vec, f.dat.vec_wo as target_vec: self.mat.handle.multHermitian(source_vec, target_vec) return f else: - def callable(): + assert isinstance(f, Function) + + def callable() -> Function: coeff = self.mat.expr_as_coeff() with coeff.dat.vec_ro as coeff_vec, f.dat.vec_wo as target_vec: self.mat.handle.mult(coeff_vec, target_vec) @@ -682,7 +683,7 @@ def callable(): # pretending to be a PETSc Mat. If matfree is False, then this # will be a PETSc Mat representing the equivalent permutation matrix - def callable(): + def callable() -> VomOntoVomMat: return self.mat self.callable = callable @@ -1494,19 +1495,15 @@ def _build_callable(self, output=None): for i in self: self[i]._build_callable() blocks[i] = self[i].callable().handle - petscmat = PETSc.Mat().createNest(blocks) - tensor = firedrake.AssembledMatrix(self.expr_args, self.bcs, petscmat) - callable = lambda: tensor.M + self.handle = PETSc.Mat().createNest(blocks) + def callable() -> MixedInterpolator: + return self elif self.rank == 1: - def callable(): + def callable() -> Function | Cofunction: for k, sub_tensor in enumerate(f.subfunctions): sub_tensor.assign(sum(self[i].assemble() for i in self if i[0] == k)) return f else: - assert self.rank == 0 - - def callable(): - result = sum(self[i].assemble() for i in self) - assert isinstance(result, Number) - return result + def callable() -> Number: + return sum(self[i].assemble() for i in self) self.callable = callable From 0934e61c21c5a2d9bb43d333318ac32de204b1db Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Thu, 16 Oct 2025 13:32:03 +0100 Subject: [PATCH 115/125] fixes --- firedrake/assemble.py | 268 ++++++++++++++++++++++++++++++++---------- 1 file changed, 204 insertions(+), 64 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 8ef822a1dd..6f573e1a29 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -19,7 +19,7 @@ from firedrake import (extrusion_utils as eutils, matrix, parameters, solving, tsfc_interface, utils) from firedrake.adjoint_utils import annotate_assemble -from firedrake.ufl_expr import extract_unique_domain +from firedrake.ufl_expr import extract_domains from firedrake.bcs import DirichletBC, EquationBC, EquationBCSplit from firedrake.functionspaceimpl import WithGeometry, FunctionSpace, FiredrakeDualSpace from firedrake.functionspacedata import entity_dofs_key, entity_permutations_key @@ -61,17 +61,23 @@ def assemble(expr, *args, **kwargs): `ufl.classes.Measure` in the form. For example, if a ``quadrature_degree`` of 4 is specified in this argument, but a degree of 3 is requested in the measure, the latter will be used. - mat_type : str + mat_type : str | None String indicating how a 2-form (matrix) should be assembled -- either as a monolithic matrix (``"aij"`` or ``"baij"``), - a block matrix (``"nest"``), or left as a `matrix.ImplicitMatrix` giving + a block matrix (``"nest"``), or left as a :class:`firedrake.matrix.ImplicitMatrix` giving matrix-free actions (``'matfree'``). If not supplied, the default value in - ``parameters["default_matrix_type"]`` is used. BAIJ differs - from AIJ in that only the block sparsity rather than the dof + ``parameters["default_matrix_type"]`` is used. ``"baij"``` differs + from ``"aij"`` in that only the block sparsity rather than the DoF sparsity is constructed. This can result in some memory savings, but does not work with all PETSc preconditioners. - BAIJ matrices only make sense for non-mixed matrices. - sub_mat_type : str + ``"baij"`` matrices only make sense for non-mixed matrices with arguments + on a :func:`firedrake.functionspace.VectorFunctionSpace`. + + NOTE + ---- + For the assembly of a 0-form or 1-form arising from the action of a 2-form, + the default matrix type is ``"matfree"``. + sub_mat_type : str | None String indicating the matrix type to use *inside* a nested block matrix. Only makes sense if ``mat_type`` is ``nest``. May be one of ``"aij"`` or ``"baij"``. If @@ -155,7 +161,10 @@ def get_assembler(form, *args, **kwargs): is_base_form_preprocessed = kwargs.pop('is_base_form_preprocessed', False) fc_params = kwargs.get('form_compiler_parameters', None) if isinstance(form, ufl.form.BaseForm) and not is_base_form_preprocessed: - mat_type = kwargs.get('mat_type', None) + # If not assembling a matrix, internal BaseForm nodes are matfree by default + # Otherwise, the default matrix type is firedrake.parameters["default_matrix_type"] + default_mat_type = "matfree" if len(form.arguments()) < 2 else None + mat_type = kwargs.get('mat_type', default_mat_type) # Preprocess the DAG and restructure the DAG # Only pre-process `form` once beforehand to avoid pre-processing for each assembly call form = BaseFormAssembler.preprocess_base_form(form, mat_type=mat_type, form_compiler_parameters=fc_params) @@ -556,8 +565,6 @@ def base_form_assembly_visitor(self, expr, tensor, *args): result = expr.assemble(assembly_opts=opts) return tensor.assign(result) if tensor else result elif isinstance(expr, ufl.Interpolate): - if not isinstance(expr, firedrake.Interpolate): - expr = firedrake.Interpolate(*reversed(expr.dual_args())) # Replace assembled children _, operand = expr.argument_slots() v, *assembled_operand = args @@ -1022,7 +1029,7 @@ def parloops(self, tensor): self._bcs, local_kernel, subdomain_id, - self.all_integer_subdomain_ids[local_kernel.indices], + self.all_integer_subdomain_ids[local_kernel.indices][local_kernel.kinfo.domain_number], diagonal=self.diagonal, ) pyop2_tensor = self._as_pyop2_type(tensor, local_kernel.indices) @@ -1044,14 +1051,15 @@ def local_kernels(self): """ try: - topology, = set(d.topology for d in self._form.ufl_domains()) + topology, = set(d.topology.submesh_ancesters[-1] for d in self._form.ufl_domains()) except ValueError: raise NotImplementedError("All integration domains must share a mesh topology") for o in itertools.chain(self._form.arguments(), self._form.coefficients()): - domain = extract_unique_domain(o) - if domain is not None and domain.topology != topology: - raise NotImplementedError("Assembly with multiple meshes is not supported") + domains = extract_domains(o) + for domain in domains: + if domain is not None and domain.topology.submesh_ancesters[-1] != topology: + raise NotImplementedError("Assembly with multiple meshes is not supported") if isinstance(self._form, ufl.Form): kernels = tsfc_interface.compile_form( @@ -1363,12 +1371,12 @@ def _make_maps_and_regions(self): else: maps_and_regions = defaultdict(lambda: defaultdict(set)) for assembler in self._all_assemblers: - all_meshes = assembler._form.ufl_domains() + all_meshes = extract_domains(assembler._form) for local_kernel, subdomain_id in assembler.local_kernels: i, j = local_kernel.indices mesh = all_meshes[local_kernel.kinfo.domain_number] # integration domain integral_type = local_kernel.kinfo.integral_type - all_subdomain_ids = assembler.all_integer_subdomain_ids[local_kernel.indices] + all_subdomain_ids = assembler.all_integer_subdomain_ids[local_kernel.indices][local_kernel.kinfo.domain_number] # Make Sparsity independent of the subdomain of integration for better reusability; # subdomain_id is passed here only to determine the integration_type on the target domain # (see ``entity_node_map``). @@ -1544,6 +1552,10 @@ def _global_kernel_cache_key(form, local_knl, subdomain_id, all_integer_subdomai # N.B. Generating the global kernel is not a collective operation so the # communicator does not need to be a part of this cache key. + # Maps in the cached global kernel depend on concrete mesh data. + all_meshes = extract_domains(form) + domain_ids = tuple(mesh.ufl_id() for mesh in all_meshes) + if isinstance(form, ufl.Form): sig = form.signature() elif isinstance(form, slate.TensorBase): @@ -1563,7 +1575,8 @@ def _global_kernel_cache_key(form, local_knl, subdomain_id, all_integer_subdomai else: subdomain_key.append((k, i)) - return ((sig, subdomain_id) + return (domain_ids + + (sig, subdomain_id) + tuple(subdomain_key) + tuplify(all_integer_subdomain_ids) + cachetools.keys.hashkey(local_knl, **kwargs)) @@ -1600,8 +1613,15 @@ def __init__(self, form, local_knl, subdomain_id, all_integer_subdomain_ids, dia self._diagonal = diagonal self._unroll = unroll + self._active_coordinates = _FormHandler.iter_active_coordinates(form, local_knl.kinfo) + self._active_cell_orientations = _FormHandler.iter_active_cell_orientations(form, local_knl.kinfo) + self._active_cell_sizes = _FormHandler.iter_active_cell_sizes(form, local_knl.kinfo) self._active_coefficients = _FormHandler.iter_active_coefficients(form, local_knl.kinfo) self._constants = _FormHandler.iter_constants(form, local_knl.kinfo) + self._active_exterior_facets = _FormHandler.iter_active_exterior_facets(form, local_knl.kinfo) + self._active_interior_facets = _FormHandler.iter_active_interior_facets(form, local_knl.kinfo) + self._active_orientations_exterior_facet = _FormHandler.iter_active_orientations_exterior_facet(form, local_knl.kinfo) + self._active_orientations_interior_facet = _FormHandler.iter_active_orientations_interior_facet(form, local_knl.kinfo) self._map_arg_cache = {} # Cache for holding :class:`op2.MapKernelArg` instances. @@ -1615,8 +1635,15 @@ def build(self): for arg in self._kinfo.arguments] # we should use up all of the coefficients and constants + assert_empty(self._active_coordinates) + assert_empty(self._active_cell_orientations) + assert_empty(self._active_cell_sizes) assert_empty(self._active_coefficients) assert_empty(self._constants) + assert_empty(self._active_exterior_facets) + assert_empty(self._active_interior_facets) + assert_empty(self._active_orientations_exterior_facet) + assert_empty(self._active_orientations_interior_facet) iteration_regions = {"exterior_facet_top": op2.ON_TOP, "exterior_facet_bottom": op2.ON_BOTTOM, @@ -1641,7 +1668,8 @@ def _integral_type(self): @cached_property def _mesh(self): - return self._form.ufl_domains()[self._kinfo.domain_number] + all_meshes = extract_domains(self._form) + return all_meshes[self._kinfo.domain_number] @cached_property def _needs_subset(self): @@ -1746,7 +1774,22 @@ def _as_global_kernel_arg_output(_, self): @_as_global_kernel_arg.register(kernel_args.CoordinatesKernelArg) def _as_global_kernel_arg_coordinates(_, self): - V = self._mesh.coordinates.function_space() + coord = next(self._active_coordinates) + V = coord.function_space() + return self._make_dat_global_kernel_arg(V) + + +@_as_global_kernel_arg.register(kernel_args.CellOrientationsKernelArg) +def _as_global_kernel_arg_cell_orientations(_, self): + c = next(self._active_cell_orientations) + V = c.function_space() + return self._make_dat_global_kernel_arg(V) + + +@_as_global_kernel_arg.register(kernel_args.CellSizesKernelArg) +def _as_global_kernel_arg_cell_sizes(_, self): + c = next(self._active_cell_sizes) + V = c.function_space() return self._make_dat_global_kernel_arg(V) @@ -1774,30 +1817,48 @@ def _as_global_kernel_arg_constant(_, self): return op2.GlobalKernelArg((value_size,)) -@_as_global_kernel_arg.register(kernel_args.CellSizesKernelArg) -def _as_global_kernel_arg_cell_sizes(_, self): - V = self._mesh.cell_sizes.function_space() - return self._make_dat_global_kernel_arg(V) - - @_as_global_kernel_arg.register(kernel_args.ExteriorFacetKernelArg) def _as_global_kernel_arg_exterior_facet(_, self): - return op2.DatKernelArg((1,)) + mesh = next(self._active_exterior_facets) + if mesh is self._mesh: + return op2.DatKernelArg((1,)) + else: + m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids) + assert integral_type == "exterior_facet" + return op2.DatKernelArg((1,), m._global_kernel_arg) @_as_global_kernel_arg.register(kernel_args.InteriorFacetKernelArg) def _as_global_kernel_arg_interior_facet(_, self): - return op2.DatKernelArg((2,)) + mesh = next(self._active_interior_facets) + if mesh is self._mesh: + return op2.DatKernelArg((2,)) + else: + m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids) + assert integral_type == "interior_facet" + return op2.DatKernelArg((2,), m._global_kernel_arg) -@_as_global_kernel_arg.register(kernel_args.ExteriorFacetOrientationKernelArg) -def _as_global_kernel_arg_exterior_facet_orientation(_, self): - return op2.DatKernelArg((1,)) +@_as_global_kernel_arg.register(kernel_args.OrientationsExteriorFacetKernelArg) +def _(_, self): + mesh = next(self._active_orientations_exterior_facet) + if mesh is self._mesh: + return op2.DatKernelArg((1,)) + else: + m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids) + assert integral_type == "exterior_facet" + return op2.DatKernelArg((1,), m._global_kernel_arg) -@_as_global_kernel_arg.register(kernel_args.InteriorFacetOrientationKernelArg) -def _as_global_kernel_arg_interior_facet_orientation(_, self): - return op2.DatKernelArg((2,)) +@_as_global_kernel_arg.register(kernel_args.OrientationsInteriorFacetKernelArg) +def _(_, self): + mesh = next(self._active_orientations_interior_facet) + if mesh is self._mesh: + return op2.DatKernelArg((2,)) + else: + m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids) + assert integral_type == "interior_facet" + return op2.DatKernelArg((2,), m._global_kernel_arg) @_as_global_kernel_arg.register(CellFacetKernelArg) @@ -1809,12 +1870,6 @@ def _as_global_kernel_arg_cell_facet(_, self): return op2.DatKernelArg((num_facets, 2)) -@_as_global_kernel_arg.register(kernel_args.CellOrientationsKernelArg) -def _as_global_kernel_arg_cell_orientations(_, self): - V = self._mesh.cell_orientations().function_space() - return self._make_dat_global_kernel_arg(V) - - @_as_global_kernel_arg.register(LayerCountKernelArg) def _as_global_kernel_arg_layer_count(_, self): return op2.GlobalKernelArg((1,)) @@ -1848,8 +1903,15 @@ def __init__(self, form, bcs, local_knl, subdomain_id, self._diagonal = diagonal self._bcs = bcs + self._active_coordinates = _FormHandler.iter_active_coordinates(form, local_knl.kinfo) + self._active_cell_orientations = _FormHandler.iter_active_cell_orientations(form, local_knl.kinfo) + self._active_cell_sizes = _FormHandler.iter_active_cell_sizes(form, local_knl.kinfo) self._active_coefficients = _FormHandler.iter_active_coefficients(form, local_knl.kinfo) self._constants = _FormHandler.iter_constants(form, local_knl.kinfo) + self._active_exterior_facets = _FormHandler.iter_active_exterior_facets(form, local_knl.kinfo) + self._active_interior_facets = _FormHandler.iter_active_interior_facets(form, local_knl.kinfo) + self._active_orientations_exterior_facet = _FormHandler.iter_active_orientations_exterior_facet(form, local_knl.kinfo) + self._active_orientations_interior_facet = _FormHandler.iter_active_orientations_interior_facet(form, local_knl.kinfo) def build(self, tensor: op2.Global | op2.Dat | op2.Mat) -> op2.Parloop: """Construct the parloop. @@ -1983,7 +2045,8 @@ def _indexed_function_spaces(self): @cached_property def _mesh(self): - return self._form.ufl_domains()[self._kinfo.domain_number] + all_meshes = extract_domains(self._form) + return all_meshes[self._kinfo.domain_number] @cached_property def _iterset(self): @@ -2055,7 +2118,21 @@ def _as_parloop_arg_output(_, self): @_as_parloop_arg.register(kernel_args.CoordinatesKernelArg) def _as_parloop_arg_coordinates(_, self): - func = self._mesh.coordinates + func = next(self._active_coordinates) + map_ = self._get_map(func.function_space()) + return op2.DatParloopArg(func.dat, map_) + + +@_as_parloop_arg.register(kernel_args.CellOrientationsKernelArg) +def _as_parloop_arg_cell_orientations(_, self): + func = next(self._active_cell_orientations) + map_ = self._get_map(func.function_space()) + return op2.DatParloopArg(func.dat, map_) + + +@_as_parloop_arg.register(kernel_args.CellSizesKernelArg) +def _as_parloop_arg_cell_sizes(_, self): + func = next(self._active_cell_sizes) map_ = self._get_map(func.function_space()) return op2.DatParloopArg(func.dat, map_) @@ -2076,38 +2153,48 @@ def _as_parloop_arg_constant(arg, self): return op2.GlobalParloopArg(const.dat) -@_as_parloop_arg.register(kernel_args.CellOrientationsKernelArg) -def _as_parloop_arg_cell_orientations(_, self): - func = self._mesh.cell_orientations() - m = self._get_map(func.function_space()) - return op2.DatParloopArg(func.dat, m) - - -@_as_parloop_arg.register(kernel_args.CellSizesKernelArg) -def _as_parloop_arg_cell_sizes(_, self): - func = self._mesh.cell_sizes - m = self._get_map(func.function_space()) - return op2.DatParloopArg(func.dat, m) - - @_as_parloop_arg.register(kernel_args.ExteriorFacetKernelArg) def _as_parloop_arg_exterior_facet(_, self): - return op2.DatParloopArg(self._mesh.exterior_facets.local_facet_dat) + mesh = next(self._active_exterior_facets) + if mesh is self._mesh: + m = None + else: + m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids) + assert integral_type == "exterior_facet" + return op2.DatParloopArg(mesh.exterior_facets.local_facet_dat, m) @_as_parloop_arg.register(kernel_args.InteriorFacetKernelArg) def _as_parloop_arg_interior_facet(_, self): - return op2.DatParloopArg(self._mesh.interior_facets.local_facet_dat) + mesh = next(self._active_interior_facets) + if mesh is self._mesh: + m = None + else: + m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids) + assert integral_type == "interior_facet" + return op2.DatParloopArg(mesh.interior_facets.local_facet_dat, m) -@_as_parloop_arg.register(kernel_args.ExteriorFacetOrientationKernelArg) -def _as_parloop_arg_exterior_facet_orientation(_, self): - return op2.DatParloopArg(self._mesh.exterior_facets.local_facet_orientation_dat) +@_as_parloop_arg.register(kernel_args.OrientationsExteriorFacetKernelArg) +def _(_, self): + mesh = next(self._active_orientations_exterior_facet) + if mesh is self._mesh: + m = None + else: + m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids) + assert integral_type == "exterior_facet" + return op2.DatParloopArg(mesh.exterior_facets.local_facet_orientation_dat, m) -@_as_parloop_arg.register(kernel_args.InteriorFacetOrientationKernelArg) -def _as_parloop_arg_interior_facet_orientation(_, self): - return op2.DatParloopArg(self._mesh.interior_facets.local_facet_orientation_dat) +@_as_parloop_arg.register(kernel_args.OrientationsInteriorFacetKernelArg) +def _(_, self): + mesh = next(self._active_orientations_interior_facet) + if mesh is self._mesh: + m = None + else: + m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids) + assert integral_type == "interior_facet" + return op2.DatParloopArg(mesh.interior_facets.local_facet_orientation_dat, m) @_as_parloop_arg.register(CellFacetKernelArg) @@ -2129,6 +2216,27 @@ def _as_parloop_arg_layer_count(_, self): class _FormHandler: """Utility class for inspecting forms and local kernels.""" + @staticmethod + def iter_active_coordinates(form, kinfo): + """Yield the form coordinates referenced in ``kinfo``.""" + all_meshes = extract_domains(form) + for i in kinfo.active_domain_numbers.coordinates: + yield all_meshes[i].coordinates + + @staticmethod + def iter_active_cell_orientations(form, kinfo): + """Yield the form cell orientations referenced in ``kinfo``.""" + all_meshes = extract_domains(form) + for i in kinfo.active_domain_numbers.cell_orientations: + yield all_meshes[i].cell_orientations() + + @staticmethod + def iter_active_cell_sizes(form, kinfo): + """Yield the form cell sizes referenced in ``kinfo``.""" + all_meshes = extract_domains(form) + for i in kinfo.active_domain_numbers.cell_sizes: + yield all_meshes[i].cell_sizes + @staticmethod def iter_active_coefficients(form, kinfo): """Yield the form coefficients referenced in ``kinfo``.""" @@ -2147,6 +2255,38 @@ def iter_constants(form, kinfo): for constant_index in kinfo.constant_numbers: yield all_constants[constant_index] + @staticmethod + def iter_active_exterior_facets(form, kinfo): + """Yield the form exterior facets referenced in ``kinfo``.""" + all_meshes = extract_domains(form) + for i in kinfo.active_domain_numbers.exterior_facets: + mesh = all_meshes[i] + yield mesh + + @staticmethod + def iter_active_interior_facets(form, kinfo): + """Yield the form interior facets referenced in ``kinfo``.""" + all_meshes = extract_domains(form) + for i in kinfo.active_domain_numbers.interior_facets: + mesh = all_meshes[i] + yield mesh + + @staticmethod + def iter_active_orientations_exterior_facet(form, kinfo): + """Yield the form exterior facet orientations referenced in ``kinfo``.""" + all_meshes = extract_domains(form) + for i in kinfo.active_domain_numbers.orientations_exterior_facet: + mesh = all_meshes[i] + yield mesh + + @staticmethod + def iter_active_orientations_interior_facet(form, kinfo): + """Yield the form interior facet orientations referenced in ``kinfo``.""" + all_meshes = extract_domains(form) + for i in kinfo.active_domain_numbers.orientations_interior_facet: + mesh = all_meshes[i] + yield mesh + @staticmethod def index_function_spaces(form, indices): """Return the function spaces of the form's arguments, indexed From 6b87a6b567ef3b1b59e96ec749113530116989cc Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Thu, 16 Oct 2025 13:42:16 +0100 Subject: [PATCH 116/125] lint --- firedrake/interpolation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 79beb99dbf..31db22b4dc 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -451,6 +451,7 @@ def _build_callable(self, output=None): else: symbolic = action(self.point_eval_input_ordering, self.point_eval) self.handle = assemble(symbolic).petscmat + def callable() -> CrossMeshInterpolator: return self else: @@ -1496,6 +1497,7 @@ def _build_callable(self, output=None): self[i]._build_callable() blocks[i] = self[i].callable().handle self.handle = PETSc.Mat().createNest(blocks) + def callable() -> MixedInterpolator: return self elif self.rank == 1: From 15bdd53a671cd9c228a87c8bd9ce796bf3710f2c Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Thu, 16 Oct 2025 15:51:51 +0100 Subject: [PATCH 117/125] pass bcs to `interpolate`, zero cofunction fix fix --- firedrake/interpolation.py | 68 ++++++++++++--------------- firedrake/preconditioners/hiptmair.py | 4 +- 2 files changed, 33 insertions(+), 39 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 31db22b4dc..3d8e7d05cc 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -87,12 +87,18 @@ class InterpolateOptions: If ``False``, then construct the permutation matrix for interpolating between a VOM and its input ordering. Defaults to ``True`` which uses SF broadcast and reduce operations. + bcs : Iterable[BCBase] | None, optional + An optional list of boundary conditions to zero-out in the + output function space. Interpolator rows or columns which are + associated with boundary condition nodes are zeroed out when this is + specified. By default None. """ subset: op2.Subset | None = None access: Literal[op2.WRITE, op2.MIN, op2.MAX, op2.INC] | None = None allow_missing_dofs: bool = False default_missing_val: float | None = None matfree: bool = True + bcs: Iterable[BCBase] | None = None class Interpolate(ufl.Interpolate): @@ -162,18 +168,13 @@ def interpolate(expr: Expr, V: WithGeometry | ufl.BaseForm, **kwargs) -> Interpo return Interpolate(expr, V, **kwargs) -def get_interpolator(expr: Interpolate, bcs: Iterable[BCBase] | None = None) -> "Interpolator": +def get_interpolator(expr: Interpolate) -> "Interpolator": """Create an Interpolator. Parameters ---------- expr : Interpolate Symbolic interpolation expression. - bcs : Iterable[BCBase] | None, optional - An optional list of boundary conditions to zero-out in the - output function space. Interpolator rows or columns which are - associated with boundary condition nodes are zeroed out when this is - specified. By default None. Returns ------- @@ -183,7 +184,7 @@ def get_interpolator(expr: Interpolate, bcs: Iterable[BCBase] | None = None) -> arguments = expr.arguments() has_mixed_arguments = any(len(arg.function_space()) > 1 for arg in arguments) if len(arguments) == 2 and has_mixed_arguments: - return MixedInterpolator(expr, bcs=bcs) + return MixedInterpolator(expr) operand, = expr.ufl_operands target_mesh = expr.target_space.mesh() @@ -194,24 +195,24 @@ def get_interpolator(expr: Interpolate, bcs: Iterable[BCBase] | None = None) -> and target_mesh.topological_dimension() == source_mesh.topological_dimension() ) if target_mesh is source_mesh or submesh_interp_implemented: - return SameMeshInterpolator(expr, bcs=bcs) + return SameMeshInterpolator(expr) target_topology = target_mesh.topology source_topology = source_mesh.topology if isinstance(target_topology, VertexOnlyMeshTopology): if isinstance(source_topology, VertexOnlyMeshTopology): - return VomOntoVomInterpolator(expr, bcs=bcs) + return VomOntoVomInterpolator(expr) if target_mesh.geometric_dimension() != source_mesh.geometric_dimension(): raise ValueError("Cannot interpolate onto a mesh of a different geometric dimension") if not hasattr(target_mesh, "_parent_mesh") or target_mesh._parent_mesh is not source_mesh: raise ValueError("Can only interpolate across meshes where the source mesh is the parent of the target") - return SameMeshInterpolator(expr, bcs=bcs) + return SameMeshInterpolator(expr) if has_mixed_arguments or len(expr.target_space) > 1: - return MixedInterpolator(expr, bcs=bcs) + return MixedInterpolator(expr) - return CrossMeshInterpolator(expr, bcs=bcs) + return CrossMeshInterpolator(expr) class Interpolator(abc.ABC): @@ -222,14 +223,9 @@ class Interpolator(abc.ABC): ---------- expr : Interpolate The symbolic interpolation expression. - bcs : Iterable[BCBase], optional - An optional list of boundary conditions to zero-out in the - output function space. Interpolator rows or columns which are - associated with boundary condition nodes are zeroed out when this is - specified. By default None. """ - def __init__(self, expr: Interpolate, bcs: Iterable[BCBase] | None = None): + def __init__(self, expr: Interpolate): dual_arg, operand = expr.argument_slots() self.expr = expr self.expr_args = expr.arguments() @@ -245,7 +241,7 @@ def __init__(self, expr: Interpolate, bcs: Iterable[BCBase] | None = None): self.allow_missing_dofs = expr.options.allow_missing_dofs self.default_missing_val = expr.options.default_missing_val self.matfree = expr.options.matfree - self.bcs = bcs + self.bcs = expr.options.bcs self.callable = None self.access = expr.options.access @@ -345,8 +341,8 @@ class CrossMeshInterpolator(Interpolator): """ @no_annotations - def __init__(self, expr: Interpolate, bcs: Iterable[BCBase] | None = None): - super().__init__(expr, bcs) + def __init__(self, expr: Interpolate): + super().__init__(expr) if self.access and self.access != op2.WRITE: raise NotImplementedError( "Access other than op2.WRITE not implemented for cross-mesh interpolation." @@ -520,8 +516,8 @@ class SameMeshInterpolator(Interpolator): """ @no_annotations - def __init__(self, expr, bcs=None): - super().__init__(expr, bcs=bcs) + def __init__(self, expr): + super().__init__(expr) subset = self.subset if subset is None: target = self.target_mesh.topology @@ -594,9 +590,6 @@ def _build_callable(self, output=None) -> None: loops = [] - if self.access == op2.INC: - loops.append(tensor.zero) - # Arguments in the operand are allowed to be from a MixedFunctionSpace # We need to split the target space V and generate separate kernels if self.rank == 2: @@ -635,8 +628,8 @@ def callable(loops, f): class VomOntoVomInterpolator(SameMeshInterpolator): - def __init__(self, expr: Interpolate, bcs=None): - super().__init__(expr, bcs=bcs) + def __init__(self, expr: Interpolate): + super().__init__(expr) def _build_callable(self, output=None): self.mat = VomOntoVomMat(self) @@ -899,7 +892,10 @@ def _build_interpolation_callables( if isinstance(tensor, op2.Mat): return parloop_compute_callable, tensor.assemble else: - return copyin + callables + (parloop_compute_callable, ) + copyout + extra = copyin + callables + if access == op2.INC: + extra += (tensor.zero,) + return extra + (parloop_compute_callable, ) + copyout def get_interp_node_map(source_mesh: MeshGeometry, target_mesh: MeshGeometry, fs: WithGeometry) -> op2.Map | None: @@ -1446,11 +1442,9 @@ class MixedInterpolator(Interpolator): V The :class:`.FunctionSpace` or :class:`.Function` to interpolate into. - bcs - A list of boundary conditions. """ - def __init__(self, expr, bcs=None): - super().__init__(expr, bcs=bcs) + def __init__(self, expr): + super().__init__(expr) # We need a Coargument in order to split the Interpolate needs_action = not any(isinstance(a, Coargument) for a in self.expr_args) @@ -1467,17 +1461,17 @@ def __init__(self, expr, bcs=None): continue vi, _ = form.argument_slots() Vtarget = vi.function_space().dual() - if bcs and self.rank != 0: + if self.bcs and self.rank != 0: args = form.arguments() Vsource = args[1 - vi.number()].function_space() - sub_bcs = [bc for bc in bcs if bc.function_space() in {Vsource, Vtarget}] + sub_bcs = [bc for bc in self.bcs if bc.function_space() in {Vsource, Vtarget}] else: sub_bcs = None if needs_action: # Take the action of each sub-cofunction against each block form = action(form, dual_split[indices[-1:]]) - - Isub[indices] = get_interpolator(form, bcs=sub_bcs) + form.options.bcs = sub_bcs + Isub[indices] = get_interpolator(form) self._sub_interpolators = Isub diff --git a/firedrake/preconditioners/hiptmair.py b/firedrake/preconditioners/hiptmair.py index 14ec77fe1a..6c24b9cd84 100644 --- a/firedrake/preconditioners/hiptmair.py +++ b/firedrake/preconditioners/hiptmair.py @@ -10,7 +10,7 @@ from firedrake.preconditioners.hypre_ams import chop from firedrake.preconditioners.facet_split import restrict from firedrake.parameters import parameters -from firedrake.interpolation import Interpolator +from firedrake.interpolation import interpolate from ufl.algorithms.ad import expand_derivatives import firedrake.dmhooks as dmhooks import firedrake.utils as utils @@ -202,7 +202,7 @@ def coarsen(self, pc): coarse_space_bcs = tuple(coarse_space_bcs) if G_callback is None: - interp_petscmat = chop(Interpolator(dminus(trial), V, bcs=bcs + coarse_space_bcs).callable().handle) + interp_petscmat = chop(assemble(interpolate(dminus(trial), V, bcs=bcs + coarse_space_bcs)).mat()) else: interp_petscmat = G_callback(coarse_space, V, coarse_space_bcs, bcs) From d54994e26ea8e9d7c55c8c0c134bd10784664e7a Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 16 Oct 2025 16:13:50 +0100 Subject: [PATCH 118/125] Interpolate: map dual argument to reference values for codegen --- tsfc/driver.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/tsfc/driver.py b/tsfc/driver.py index eda810bdf7..effa1bd9bb 100644 --- a/tsfc/driver.py +++ b/tsfc/driver.py @@ -215,20 +215,28 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, if isinstance(to_element, (PhysicallyMappedElement, DirectlyDefinedElement)): raise NotImplementedError("Don't know how to interpolate onto zany spaces, sorry") + if domain is None: + domain = extract_unique_domain(expression) + assert domain is not None + orig_coefficients = extract_coefficients(expression) - if isinstance(expression, ufl.Interpolate): - v, operand = expression.argument_slots() - else: - operand = expression - v = ufl.FunctionSpace(extract_unique_domain(operand), ufl_element) + v, operand = expression.argument_slots() - # Map into reference space + # Map v into reference space + if ufl_element.mapping() != "identity": + Vref = ufl.FunctionSpace(domain, finat.ufl.WithMapping(ufl_element, "identity")) + if isinstance(v, ufl.Cofunction): + v = ufl.Cofunction(Vref.dual()) + else: + v = ufl.Coargument(Vref.dual(), number=v.number()) + + # Map operand into reference space operand = apply_mapping(operand, ufl_element, domain) # Apply UFL preprocessing operand = ufl_utils.preprocess_expression(operand, complex_mode=complex_mode) - # Reconstructed Interpolate with mapped operand + # Reconstructed Interpolate in the reference space expression = ufl.Interpolate(operand, v) # Initialise kernel builder @@ -243,9 +251,6 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, assert len(argument_multiindices) == len(arguments) # Replace coordinates (if any) unless otherwise specified by kwarg - if domain is None: - domain = extract_unique_domain(expression) - assert domain is not None builder._domain_integral_type_map = {domain: "cell"} # Collect required coefficients and determine numbering @@ -259,6 +264,7 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, elements = [f.ufl_element() for f in (*coefficients, *arguments)] + # Replace coordinates (if any) unless otherwise specified by kwarg needs_external_coords = False if has_type(expression, GeometricQuantity) or any(map(fem.needs_coordinate_mapping, elements)): # Create a fake coordinate coefficient for a domain. From 8667567885a7f7eec047a15c5f190ddc54391edb Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Thu, 16 Oct 2025 18:39:45 +0100 Subject: [PATCH 119/125] add zero form optimisation back in --- firedrake/interpolation.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 3d8e7d05cc..e84e2e43db 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -291,14 +291,12 @@ def assemble( if self.rank == 2: # Assembling the operator assert isinstance(tensor, MatrixBase | None) - res = tensor.petscmat if tensor else PETSc.Mat() # Get the interpolation matrix petsc_mat = result.handle if tensor: petsc_mat.copy(tensor.petscmat) - else: - res = petsc_mat - return tensor or firedrake.AssembledMatrix(self.expr_args, self.bcs, res) + return tensor + return firedrake.AssembledMatrix(self.expr_args, self.bcs, petsc_mat) else: assert isinstance(tensor, Function | Cofunction | None) return tensor.assign(result) if tensor else result @@ -610,8 +608,6 @@ def _build_callable(self, output=None) -> None: # Interpolate each sub expression into each function space for indices, sub_expr in expressions.items(): - if isinstance(sub_expr, ufl.ZeroBaseForm): - continue sub_tensor = tensor[indices[0]] if self.rank == 1 else tensor loops.extend(_build_interpolation_callables(sub_expr, sub_tensor, self.access, self.subset, self.bcs)) @@ -670,7 +666,7 @@ def callable() -> Function: # safely use the argument function space. NOTE: If this changes # after cofunctions are fully implemented, this will need to be # reconsidered. - temp_source_func = firedrake.Function(self.expr_args[1].function_space()) + temp_source_func = Function(self.expr_args[1].function_space()) self.mat.mpi_type = get_dat_mpi_type(temp_source_func.dat)[0] # Leave mat inside a callable so we can access the handle # property. If matfree is True, then the handle is a PETSc SF @@ -685,7 +681,7 @@ def callable() -> VomOntoVomMat: @utils.known_pyop2_safe def _build_interpolation_callables( - expr: ufl.Interpolate, + expr: ufl.Interpolate | ufl.ZeroBaseForm, tensor: op2.Dat | op2.Mat | op2.Global, access: Literal[op2.WRITE, op2.MIN, op2.MAX, op2.INC], subset: op2.Subset | None = None, @@ -695,8 +691,9 @@ def _build_interpolation_callables( Parameters ---------- - expr : ufl.Interpolate - The symbolic interpolation expression. + expr : ufl.Interpolate | ufl.ZeroBaseForm + The symbolic interpolation expression, or a zero form. Zero forms + are simplified here to avoid code generation when access is WRITE or INC. tensor : op2.Dat | op2.Mat | op2.Global Object to hold the result of the interpolation. access : Literal[op2.WRITE, op2.MIN, op2.MAX, op2.INC] @@ -714,6 +711,16 @@ def _build_interpolation_callables( tuple[Callable, ...] Tuple of callables which perform the interpolation. """ + if isinstance(expr, ufl.ZeroBaseForm): + # Zero simplification, avoid code-generation + if access is op2.INC: + return () + elif access is op2.WRITE: + return (partial(tensor.zero, subset=subset),) + # Unclear how to avoid codegen for MIN and MAX + # Reconstruct the expression as an Interpolate + V = expr.arguments()[-1].function_space().dual() + expr = interpolate(ufl.zero(V.value_shape), V) if not isinstance(expr, ufl.Interpolate): raise ValueError("Expecting to interpolate a ufl.Interpolate") dual_arg, operand = expr.argument_slots() From 59bcb4d384a40a812787cea7b44078ee2352e4fe Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Thu, 16 Oct 2025 18:48:08 +0100 Subject: [PATCH 120/125] conjugate test function --- tests/firedrake/regression/test_interpolate_cross_mesh.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/firedrake/regression/test_interpolate_cross_mesh.py b/tests/firedrake/regression/test_interpolate_cross_mesh.py index 82d7a7da3c..65974c5e54 100644 --- a/tests/firedrake/regression/test_interpolate_cross_mesh.py +++ b/tests/firedrake/regression/test_interpolate_cross_mesh.py @@ -705,7 +705,7 @@ def test_interpolate_matrix_cross_mesh_adjoint(): V_coarse = FunctionSpace(mesh_coarse, "CG", 1) V_fine = FunctionSpace(mesh_fine, "CG", 1) - cofunc_fine = assemble(TestFunction(V_fine) * dx) + cofunc_fine = assemble(conj(TestFunction(V_fine)) * dx) interp = assemble(interpolate(TestFunction(V_coarse), TrialFunction(V_fine.dual()))) cofunc_coarse = assemble(Action(interp, cofunc_fine)) @@ -713,7 +713,7 @@ def test_interpolate_matrix_cross_mesh_adjoint(): assert cofunc_coarse.function_space() == V_coarse.dual() # Compare cofunc_fine with direct interpolation - cofunc_coarse_direct = assemble(TestFunction(V_coarse) * dx) + cofunc_coarse_direct = assemble(conj(TestFunction(V_coarse)) * dx) assert np.allclose(cofunc_coarse.dat.data_ro, cofunc_coarse_direct.dat.data_ro) From 86ea04925c589760ed8a8c751443c24cfb25915b Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Fri, 17 Oct 2025 12:36:48 +0100 Subject: [PATCH 121/125] fixes --- firedrake/bcs.py | 5 ++- firedrake/function.py | 2 +- .../firedrake/regression/test_interp_dual.py | 30 ++++++++--------- .../firedrake/regression/test_interpolate.py | 33 +++++++++++++++++++ .../submesh/test_submesh_interpolate.py | 6 ++-- 5 files changed, 53 insertions(+), 23 deletions(-) diff --git a/firedrake/bcs.py b/firedrake/bcs.py index 8b70f3d400..c1aaf424fb 100644 --- a/firedrake/bcs.py +++ b/firedrake/bcs.py @@ -339,7 +339,6 @@ def reconstruct(self, field=None, V=None, g=None, sub_domain=None, use_split=Fal @function_arg.setter def function_arg(self, g): '''Set the value of this boundary condition.''' - from firedrake.interpolation import get_interpolator try: # Clear any previously set update function del self._function_arg_update @@ -360,11 +359,11 @@ def function_arg(self, g): raise RuntimeError(f"Provided boundary value {g} does not match shape of space") try: self._function_arg = firedrake.Function(V) - interpolator = get_interpolator(firedrake.interpolate(g, V)) + interpolator = firedrake.get_interpolator(firedrake.interpolate(g, V)) # Call this here to check if the element supports interpolation interpolator._build_callable() self._function_arg_update = partial(interpolator.assemble, tensor=self._function_arg) - except (ValueError, NotImplementedError): + except (NotImplementedError, AttributeError): # Element doesn't implement interpolation self._function_arg = firedrake.Function(V).project(g) self._function_arg_update = firedrake.Projector(g, self._function_arg).project diff --git a/firedrake/function.py b/firedrake/function.py index bea8e965de..ce9b4b1538 100644 --- a/firedrake/function.py +++ b/firedrake/function.py @@ -716,7 +716,7 @@ def __init__(self, mesh: MeshGeometry, points: np.ndarray | list, tolerance: flo The mesh on which to embed the points. points : numpy.ndarray | list Array or list of points to evaluate at. - tolerance : Optional[float] + tolerance : float | None Tolerance to use when checking if a point is in a cell. If ``None`` (the default), the ``tolerance`` of the ``mesh`` is used. missing_points_behaviour : str diff --git a/tests/firedrake/regression/test_interp_dual.py b/tests/firedrake/regression/test_interp_dual.py index ccd4de13f0..b58eb3c0e1 100644 --- a/tests/firedrake/regression/test_interp_dual.py +++ b/tests/firedrake/regression/test_interp_dual.py @@ -54,7 +54,7 @@ def test_assemble_interp_adjoint_tensor(mesh, V1, f1): def test_assemble_interp_operator(V2, f1): # Check type - If1 = Interpolate(f1, Argument(V2.dual(), 0)) + If1 = interpolate(f1, V2) assert isinstance(If1, ufl.Interpolate) # -- I(f1, V2) -- # @@ -89,7 +89,7 @@ def test_assemble_interp_matrix(V1, V2, f1): def test_assemble_interp_tlm(V1, V2, f1): # -- Action(I(v1, V2), f1) -- # v1 = TrialFunction(V1) - Iv1 = Interpolate(v1, Argument(V2.dual(), 0)) + Iv1 = interpolate(v1, V2) b = assemble(interpolate(f1, V2)) assembled_action_Iv1 = assemble(action(Iv1, f1)) @@ -99,7 +99,7 @@ def test_assemble_interp_tlm(V1, V2, f1): def test_assemble_interp_adjoint_matrix(V1, V2): # -- Adjoint(I(v1, V2)) -- # v1 = TrialFunction(V1) - Iv1 = Interpolate(v1, Argument(V2.dual(), 0)) + Iv1 = interpolate(v1, V2) v2 = TestFunction(V2) c2 = assemble(conj(v2) * dx) @@ -120,11 +120,11 @@ def test_assemble_interp_adjoint_matrix(V1, V2): def test_assemble_interp_adjoint_model(V1, V2): # -- Action(Adjoint(I(v1, v2)), fstar) -- # v1 = TrialFunction(V1) - Iv1 = Interpolate(v1, Argument(V2.dual(), 0)) + Iv1 = interpolate(v1, V2) fstar = Cofunction(V2.dual()) v = Argument(V1, 0) - Ivfstar = assemble(Interpolate(v, fstar)) + Ivfstar = assemble(interpolate(v, fstar)) # Action(Adjoint(I(v1, v2)), fstar) <=> I(v, fstar) res = assemble(action(adjoint(Iv1), fstar)) assert np.allclose(res.dat.data, Ivfstar.dat.data) @@ -167,9 +167,9 @@ def test_assemble_base_form_operator_expressions(mesh): f2 = Function(V1).interpolate(sin(2*pi*y)) f3 = Function(V1).interpolate(cos(2*pi*x)) - If1 = Interpolate(f1, Argument(V2.dual(), 0)) - If2 = Interpolate(f2, Argument(V2.dual(), 0)) - If3 = Interpolate(f3, Argument(V2.dual(), 0)) + If1 = interpolate(f1, V2) + If2 = interpolate(f2, V2) + If3 = interpolate(f3, V2) # Sum of BaseFormOperators (1-form) res = assemble(If1 + If2 + If3) @@ -184,8 +184,8 @@ def test_assemble_base_form_operator_expressions(mesh): # Sum of BaseFormOperator (2-form) v1 = TrialFunction(V1) - Iv1 = Interpolate(v1, V2) - Iv2 = Interpolate(v1, V2) + Iv1 = interpolate(v1, V2) + Iv2 = interpolate(v1, V2) res = assemble(Iv1 + Iv2) mat_Iv1 = assemble(Iv1) mat_Iv2 = assemble(Iv2) @@ -210,7 +210,7 @@ def test_check_identity(mesh): V1 = FunctionSpace(mesh, "CG", 1) v2 = TestFunction(V2) v1 = TestFunction(V1) - a = assemble(Interpolate(v1, conj(v2)*dx)) + a = assemble(interpolate(v1, conj(v2)*dx)) b = assemble(conj(v1)*dx) assert np.allclose(a.dat.data, b.dat.data) @@ -234,7 +234,7 @@ def test_solve_interp_f(mesh): # -- Solution where the source term is interpolated via `ufl.Interpolate` u2 = Function(V1) - If = Interpolate(f1, Argument(V2.dual(), 0)) + If = interpolate(f1, V2) # This requires assembling If F2 = inner(grad(u2), grad(w))*dx + inner(u2, w)*dx - inner(If, w)*dx solve(F2 == 0, u2) @@ -267,7 +267,7 @@ def test_solve_interp_u(mesh): # -- Solution where u2 is interpolated via `ufl.Interpolate` (mat-free) u2 = Function(V1) # Iu is the identity - Iu = Interpolate(u2, Argument(V1.dual(), 0)) + Iu = interpolate(u2, V1) # This requires assembling the action the Jacobian of Iu F2 = inner(grad(u2), grad(w))*dx + inner(Iu, w)*dx - inner(f, w)*dx solve(F2 == 0, u2, solver_parameters={"mat_type": "matfree", @@ -278,7 +278,7 @@ def test_solve_interp_u(mesh): # Same problem with grad(Iu) instead of grad(Iu) u2 = Function(V1) # Iu is the identity - Iu = Interpolate(u2, Argument(V1.dual(), 0)) + Iu = interpolate(u2, V1) # This requires assembling the action the Jacobian of Iu F2 = inner(grad(Iu), grad(w))*dx + inner(Iu, w)*dx - inner(f, w)*dx solve(F2 == 0, u2, solver_parameters={"mat_type": "matfree", @@ -341,7 +341,7 @@ def test_interp_dual_mixed(source_space, target_space): expected = assemble(F_target) F_source = inner(b, v)*dx - I_source = Interpolate(expr, F_source) + I_source = interpolate(expr, F_source) c = Cofunction(W.dual()) c.assign(99) diff --git a/tests/firedrake/regression/test_interpolate.py b/tests/firedrake/regression/test_interpolate.py index 47b3dc7a6d..ec925d312a 100644 --- a/tests/firedrake/regression/test_interpolate.py +++ b/tests/firedrake/regression/test_interpolate.py @@ -327,6 +327,7 @@ def test_trace(): assert np.allclose(x_tr_cg.dat.data, x_tr_dir.dat.data) +@pytest.mark.parallel([1, 3]) @pytest.mark.parametrize("rank", (0, 1)) @pytest.mark.parametrize("mat_type", ("matfree", "aij")) @pytest.mark.parametrize("degree", (1, 3)) @@ -566,3 +567,35 @@ def test_mixed_matrix(mode): result_explicit = assemble(action(a, u)) for x, y in zip(result_explicit.subfunctions, result_matfree.subfunctions): assert np.allclose(x.dat.data, y.dat.data) + + +@pytest.mark.parallel(2) +@pytest.mark.parametrize("mode", ["forward", "adjoint"]) +@pytest.mark.parametrize("family,degree", [("CG", 1), ("DG", 0)]) +def test_interpolator_reuse(family, degree, mode): + mesh = UnitSquareMesh(1, 1) + V = FunctionSpace(mesh, family, degree) + rg = RandomGenerator(PCG64(seed=123456789)) + if mode == "forward": + u = Function(V) + expr = interpolate(u, V) + + elif mode == "adjoint": + u = Function(V.dual()) + expr = interpolate(TestFunction(V), u) + + I = get_interpolator(expr) + + for k in range(3): + u.assign(rg.uniform(u.function_space())) + expected = u.dat.data.copy() + + tensor = Function(expr.function_space()) + result = I.assemble(tensor=tensor) + assert result is tensor + + # Test that the input was not modified + assert np.allclose(u.dat.data, expected) + + # Test for correctness + assert np.allclose(result.dat.data, expected) diff --git a/tests/firedrake/submesh/test_submesh_interpolate.py b/tests/firedrake/submesh/test_submesh_interpolate.py index 5c29aa917c..92e422cf6a 100644 --- a/tests/firedrake/submesh/test_submesh_interpolate.py +++ b/tests/firedrake/submesh/test_submesh_interpolate.py @@ -1,7 +1,6 @@ import pytest from firedrake import * import numpy as np -from mpi4py import MPI from ufl.conditional import GT, LT from os.path import abspath, dirname, join @@ -328,12 +327,11 @@ def test_submesh_interpolate_adjoint(fe_fesub): expected_primal = assemble(action(I, u1)) test1 = np.allclose(Iu1.dat.data, expected_primal.dat.data) - test1 = V2.comm.allreduce(test1, MPI.LAND) - assert test1 == expected_to_pass + assert test1 or not expected_to_pass result_forward_1 = assemble(action(ustar2, Iu1)) test0 = np.isclose(result_forward_1, expected) - assert test0 == expected_to_pass + assert test0 or not expected_to_pass # Test adjoint 1-form ustar2I = assemble(interpolate(TestFunction(V1), ustar2, allow_missing_dofs=True)) From 0e2e10791a22fcb651c7ceccb1903ae3b0c3cce8 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Sat, 18 Oct 2025 16:00:20 +0100 Subject: [PATCH 122/125] suggestions / tidy --- firedrake/interpolation.py | 184 ++++++++++++++++++------------------- 1 file changed, 92 insertions(+), 92 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index e84e2e43db..16dc6591d4 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -12,7 +12,7 @@ import ufl import finat.ufl from ufl.algorithms import extract_arguments -from ufl.domain import as_domain, extract_unique_domain +from ufl.domain import extract_unique_domain from ufl.classes import Expr from ufl.duals import is_dual @@ -119,7 +119,7 @@ def __init__(self, expr: Expr, V: WithGeometry | ufl.BaseForm, **kwargs): """ expr = ufl.as_ufl(expr) expr_arg_numbers = {arg.number() for arg in extract_arguments(expr) if not is_dual(arg)} - self.is_adjoint = len(expr_arg_numbers) and expr_arg_numbers == {0} + self.is_adjoint = expr_arg_numbers == {0} if isinstance(V, WithGeometry): # Need to create a Firedrake Coargument so it has a .function_space() method V = Argument(V.dual(), 1 if self.is_adjoint else 0) @@ -233,7 +233,7 @@ def __init__(self, expr: Interpolate): self.operand = operand self.dual_arg = dual_arg self.target_space = dual_arg.function_space().dual() - self.target_mesh = as_domain(self.target_space) + self.target_mesh = self.target_space.mesh() self.source_mesh = extract_unique_domain(operand) or self.target_mesh # Interpolation options @@ -246,7 +246,7 @@ def __init__(self, expr: Interpolate): self.access = expr.options.access @abc.abstractmethod - def _build_callable(self, output: Function | Cofunction | MatrixBase | None = None) -> None: + def _build_callable(self, tensor: Function | Cofunction | MatrixBase | None = None) -> None: """Builds callable to perform interpolation. Stored in ``self.callable``. If ``self.rank == 2``, then ``self.callable()`` must return an object with a ``handle`` @@ -256,8 +256,8 @@ def _build_callable(self, output: Function | Cofunction | MatrixBase | None = No Parameters ---------- - output : Function | Cofunction | MatrixBase | None, optional - Optional tensor to store the result in, by default None + tensor : Function | Cofunction | MatrixBase | None, optional + Optional tensor to store the result in, by default None. """ pass @@ -274,11 +274,11 @@ def assemble( Parameters ---------- tensor : Function | Cofunction | MatrixBase, optional - Pre-allocated storage to receive the interpolated result. For rank-2 + Optional tensor to store the interpolated result. For rank-2 expressions this is expected to be a subclass of - :class:`~firedrake.matrix.MatrixBase` whose - ``petscmat`` will be populated. For lower-rank expressions this is - a :class:`~firedrake.Function` or :class:`~firedrake.Cofunction`. + :class:`~firedrake.matrix.MatrixBase`. For lower-rank expressions + this is a :class:`~firedrake.Function` or :class:`~firedrake.Cofunction`, + for forward and adjoint interpolation respectively. Returns ------- @@ -286,7 +286,7 @@ def assemble( The function, cofunction, matrix, or scalar resulting from the interpolation. """ - self._build_callable(output=tensor) + self._build_callable(tensor=tensor) result = self.callable() if self.rank == 2: # Assembling the operator @@ -367,7 +367,7 @@ def __init__(self, expr: Interpolate): dest_element = self.target_space.ufl_element() if isinstance(dest_element, finat.ufl.MixedElement): - if isinstance(dest_element, (finat.ufl.VectorElement, finat.ufl.TensorElement)): + if isinstance(dest_element, finat.ufl.VectorElement | finat.ufl.TensorElement): # In this case all sub elements are equal base_element = dest_element.sub_elements[0] if base_element.reference_value_shape != (): @@ -377,7 +377,7 @@ def __init__(self, expr: Interpolate): ) self.dest_element = base_element else: - raise NotImplementedError("Interpolation matrix with MixedFunctionSpace requires MixedInterpolator.") + raise NotImplementedError("Interpolation with MixedFunctionSpace requires MixedInterpolator.") else: # scalar fiat/finat element self.dest_element = dest_element @@ -430,11 +430,11 @@ def _build_symbolic_expressions(self) -> None: arg = Argument(self.P0DG_vom, 0 if self.expr.is_adjoint else 1) self.point_eval_input_ordering = interpolate(arg, self.P0DG_vom_input_ordering, matfree=matfree) - def _build_callable(self, output=None): + def _build_callable(self, tensor=None): from firedrake.assemble import assemble - # self.expr.function() is None in the 0-form case + # self.expr.function_space() is None in the 0-form case V_dest = self.expr.function_space() or self.target_space - f = output or Function(V_dest) + f = tensor or Function(V_dest) if self.rank == 2: # The cross-mesh interpolation matrix is the product of the @@ -448,60 +448,60 @@ def _build_callable(self, output=None): def callable() -> CrossMeshInterpolator: return self + elif self.expr.is_adjoint: + assert self.rank == 1 + # f_src is a cofunction on V_dest.dual + cofunc = self.dual_arg + assert isinstance(cofunc, Cofunction) + + # Our first adjoint operation is to assign the dat values to a + # P0DG cofunction on our input ordering VOM. + f_input_ordering = Cofunction(self.P0DG_vom_input_ordering.dual()) + f_input_ordering.dat.data_wo[:] = cofunc.dat.data_ro[:] + + # The rest of the adjoint interpolation is the composition + # of the adjoint interpolators in the reverse direction. + # We don't worry about skipping over missing points here + # because we're going from the input ordering VOM to the original VOM + # and all points from the input ordering VOM are in the original. + def callable() -> Cofunction: + f_src_at_src_node_coords = assemble(action(self.point_eval_input_ordering, f_input_ordering)) + assemble(action(self.point_eval, f_src_at_src_node_coords), tensor=f) + return f else: - if self.expr.is_adjoint: - assert self.rank == 1 - # f_src is a cofunction on V_dest.dual - cofunc = self.dual_arg - assert isinstance(cofunc, Cofunction) - - # Our first adjoint operation is to assign the dat values to a - # P0DG cofunction on our input ordering VOM. - f_input_ordering = Cofunction(self.P0DG_vom_input_ordering.dual()) - f_input_ordering.dat.data_wo[:] = cofunc.dat.data_ro[:] - - # The rest of the adjoint interpolation is the composition - # of the adjoint interpolators in the reverse direction. - # We don't worry about skipping over missing points here - # because we're going from the input ordering VOM to the original VOM - # and all points from the input ordering VOM are in the original. - def callable() -> Cofunction: - f_src_at_src_node_coords = assemble(action(self.point_eval_input_ordering, f_input_ordering)) - assemble(action(self.point_eval, f_src_at_src_node_coords), tensor=f) + assert self.rank in {0, 1} + # We evaluate the operand at the node coordinates of the destination space + f_point_eval = assemble(self.point_eval) + + # We create the input-ordering Function before interpolating so we can + # set default missing values if required. + f_point_eval_input_ordering = Function(self.P0DG_vom_input_ordering) + if self.default_missing_val is not None: + f_point_eval_input_ordering.assign(self.default_missing_val) + elif self.allow_missing_dofs: + # If we allow missing points there may be points in the target + # mesh that are not in the source mesh. If we don't specify a + # default missing value we set these to NaN so we can identify + # them later. + f_point_eval_input_ordering.dat.data_wo[:] = numpy.nan + + def callable() -> Function | Number: + assemble(action(self.point_eval_input_ordering, f_point_eval), + tensor=f_point_eval_input_ordering) + + # We assign these values to the output function + if self.allow_missing_dofs and self.default_missing_val is None: + indices = numpy.where(~numpy.isnan(f_point_eval_input_ordering.dat.data_ro))[0] + f.dat.data_wo[indices] = f_point_eval_input_ordering.dat.data_ro[indices] + else: + f.dat.data_wo[:] = f_point_eval_input_ordering.dat.data_ro[:] + + if self.rank == 0: + # We take the action of the dual_arg on the interpolated function + assert not isinstance(self.dual_arg, ufl.Coargument) + return assemble(action(self.dual_arg, f)) + else: return f - else: - # We evaluate the operand at the node coordinates of the destination space - f_point_eval = assemble(self.point_eval) - - # We create the input-ordering Function before interpolating so we can - # set default missing values if required. - f_point_eval_input_ordering = Function(self.P0DG_vom_input_ordering) - if self.default_missing_val is not None: - f_point_eval_input_ordering.assign(self.default_missing_val) - elif self.allow_missing_dofs: - # If we allow missing points there may be points in the target - # mesh that are not in the source mesh. If we don't specify a - # default missing value we set these to NaN so we can identify - # them later. - f_point_eval_input_ordering.dat.data_wo[:] = numpy.nan - - def callable() -> Function | Number: - assemble(action(self.point_eval_input_ordering, f_point_eval), - tensor=f_point_eval_input_ordering) - - # We assign these values to the output function - if self.allow_missing_dofs and self.default_missing_val is None: - indices = numpy.where(~numpy.isnan(f_point_eval_input_ordering.dat.data_ro))[0] - f.dat.data_wo[indices] = f_point_eval_input_ordering.dat.data_ro[indices] - else: - f.dat.data_wo[:] = f_point_eval_input_ordering.dat.data_ro[:] - - if self.rank == 0: - # We take the action of the dual_arg on the interpolated function - assert not isinstance(self.dual_arg, ufl.Coargument) - return assemble(action(self.dual_arg, f)) - else: - return f self.callable = callable @@ -582,9 +582,9 @@ def _get_tensor(self) -> op2.Mat | Function | Cofunction: raise ValueError(f"Cannot interpolate an expression with {self.rank} arguments") return f - def _build_callable(self, output=None) -> None: - f = output or self._get_tensor() - tensor = f if isinstance(f, op2.Mat) else f.dat + def _build_callable(self, tensor=None) -> None: + f = tensor or self._get_tensor() + op2_tensor = f if isinstance(f, op2.Mat) else f.dat loops = [] @@ -596,6 +596,7 @@ def _build_callable(self, output=None) -> None: # Split in the coargument expressions = dict(firedrake.formmanipulation.split_form(self.expr)) else: + assert isinstance(self.dual_arg, Cofunction) # Split in the cofunction: split_form can only split in the coargument # Replace the cofunction with a coargument to construct the Jacobian interp = self.expr._ufl_expr_reconstruct_(self.operand, self.target_space) @@ -608,8 +609,8 @@ def _build_callable(self, output=None) -> None: # Interpolate each sub expression into each function space for indices, sub_expr in expressions.items(): - sub_tensor = tensor[indices[0]] if self.rank == 1 else tensor - loops.extend(_build_interpolation_callables(sub_expr, sub_tensor, self.access, self.subset, self.bcs)) + sub_op2_tensor = op2_tensor[indices[0]] if self.rank == 1 else op2_tensor + loops.extend(_build_interpolation_callables(sub_expr, sub_op2_tensor, self.access, self.subset, self.bcs)) if self.bcs and self.rank == 1: loops.extend(partial(bc.apply, f) for bc in self.bcs) @@ -627,19 +628,19 @@ class VomOntoVomInterpolator(SameMeshInterpolator): def __init__(self, expr: Interpolate): super().__init__(expr) - def _build_callable(self, output=None): + def _build_callable(self, tensor=None): self.mat = VomOntoVomMat(self) if self.rank == 2: # We make our own linear operator for this case using PETSc SFs - tensor = None + op2_tensor = None else: - f = output or self._get_tensor() - tensor = f.dat + f = tensor or self._get_tensor() + op2_tensor = f.dat # NOTE: get_dat_mpi_type ensures we get the correct MPI type for the # data, including the correct data size and dimensional information # (so for vector function spaces in 2 dimensions we might need a # concatenation of 2 MPI.DOUBLE types when we are in real mode) - if tensor is not None: + if op2_tensor is not None: assert self.rank == 1 self.mat.mpi_type = get_dat_mpi_type(f.dat)[0] if self.expr.is_adjoint: @@ -1440,17 +1441,16 @@ def duplicate(self, mat=None, op=None): class MixedInterpolator(Interpolator): - """A reusable interpolation object between MixedFunctionSpaces. - - Parameters - ---------- - expr - The underlying ufl.Interpolate or the operand to the ufl.Interpolate. - V - The :class:`.FunctionSpace` or :class:`.Function` to - interpolate into. + """Interpolator between MixedFunctionSpaces. """ - def __init__(self, expr): + def __init__(self, expr: Interpolate): + """Initialise MixedInterpolator. Should not be called directly; use `get_interpolator` + + Parameters + ---------- + expr : Interpolate + Symbolic Interpolate expression. + """ super().__init__(expr) # We need a Coargument in order to split the Interpolate @@ -1461,7 +1461,7 @@ def __init__(self, expr): # Create the Jacobian to be split into blocks self.expr = self.expr._ufl_expr_reconstruct_(self.operand, self.target_space) - Isub = {} + Isub: dict[tuple[int, int], Interpolator] = {} for indices, form in firedrake.formmanipulation.split_form(self.expr): if isinstance(form, ufl.ZeroBaseForm): # Ensure block sparsity @@ -1482,15 +1482,15 @@ def __init__(self, expr): self._sub_interpolators = Isub - def __getitem__(self, item): + def __getitem__(self, item: tuple[int, int]) -> Interpolator: return self._sub_interpolators[item] def __iter__(self): return iter(self._sub_interpolators) - def _build_callable(self, output=None): + def _build_callable(self, tensor=None): V_dest = self.expr.function_space() or self.target_space - f = output or Function(V_dest) + f = tensor or Function(V_dest) if self.rank == 2: shape = tuple(len(a.function_space()) for a in self.expr_args) blocks = numpy.full(shape, PETSc.Mat(), dtype=object) From 7694538ac7350225ef93d9c629e3fbd30ee8e5d2 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Mon, 20 Oct 2025 17:30:32 +0100 Subject: [PATCH 123/125] fixes docs --- firedrake/interpolation.py | 194 ++++++++++++++++++++----------------- 1 file changed, 106 insertions(+), 88 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 16dc6591d4..5dedefcec1 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -36,7 +36,7 @@ from firedrake.halo import _get_mtype as get_dat_mpi_type from firedrake.functionspaceimpl import WithGeometry from firedrake.matrix import MatrixBase -from firedrake.bcs import BCBase +from firedrake.bcs import DirichletBC from mpi4py import MPI from pyadjoint import stop_annotating, no_annotations @@ -46,6 +46,8 @@ "Interpolate", "get_interpolator", "DofNotDefinedError", + "InterpolateOptions", + "Interpolator" ) @@ -53,19 +55,19 @@ class InterpolateOptions: """Options for interpolation operations. - Attributes + Parameters ---------- - subset : pyop2.types.set.Subset, optional + subset : pyop2.types.set.Subset or None An optional subset to apply the interpolation over. Cannot, at present, be used when interpolating across meshes unless the target mesh is a :func:`.VertexOnlyMesh`. - access : pyop2.types.access.Access, default op2.WRITE + access : pyop2.types.access.Access or None The pyop2 access descriptor for combining updates to shared DoFs. Possible values include ``WRITE``, ``MIN``, ``MAX``, and ``INC``. Only ``WRITE`` is supported at present when interpolating across meshes unless the target mesh is a :func:`.VertexOnlyMesh`. Only ``INC`` is supported for the matrix-free adjoint interpolation. - allow_missing_dofs : bool, default False + allow_missing_dofs : bool For interpolation across meshes: allow degrees of freedom (aka DoFs/nodes) in the target mesh that cannot be defined on the source mesh. For example, where nodes are point evaluations, points in the target mesh @@ -77,17 +79,17 @@ class InterpolateOptions: This does not affect adjoint interpolation. Ignored if interpolating within the same mesh or onto a :func:`.VertexOnlyMesh` (the behaviour of a :func:`.VertexOnlyMesh` in this scenario is, at present, set when it is created). - default_missing_val : float, optional + default_missing_val : float or None For interpolation across meshes: the optional value to assign to DoFs in the target mesh that are outside the source mesh. If this is not set then the values are either (a) unchanged if some ``output`` is given to the :meth:`interpolate` method or (b) set to zero. Ignored if interpolating within the same mesh or onto a :func:`.VertexOnlyMesh`. - matfree : bool, default True + matfree : bool If ``False``, then construct the permutation matrix for interpolating between a VOM and its input ordering. Defaults to ``True`` which uses SF broadcast and reduce operations. - bcs : Iterable[BCBase] | None, optional + bcs : Iterable[DirichletBC] or None An optional list of boundary conditions to zero-out in the output function space. Interpolator rows or columns which are associated with boundary condition nodes are zeroed out when this is @@ -98,7 +100,7 @@ class InterpolateOptions: allow_missing_dofs: bool = False default_missing_val: float | None = None matfree: bool = True - bcs: Iterable[BCBase] | None = None + bcs: Iterable[DirichletBC] | None = None class Interpolate(ufl.Interpolate): @@ -141,7 +143,14 @@ def _ufl_expr_reconstruct_( return ufl.Interpolate._ufl_expr_reconstruct_(self, expr, v=v, **interp_data) @property - def options(self): + def options(self) -> InterpolateOptions: + """Access the interpolation options. + + Returns + ------- + InterpolateOptions + An :class:`InterpolateOptions` instance containing the interpolation options. + """ return self._options @@ -168,55 +177,8 @@ def interpolate(expr: Expr, V: WithGeometry | ufl.BaseForm, **kwargs) -> Interpo return Interpolate(expr, V, **kwargs) -def get_interpolator(expr: Interpolate) -> "Interpolator": - """Create an Interpolator. - - Parameters - ---------- - expr : Interpolate - Symbolic interpolation expression. - - Returns - ------- - Interpolator - - """ - arguments = expr.arguments() - has_mixed_arguments = any(len(arg.function_space()) > 1 for arg in arguments) - if len(arguments) == 2 and has_mixed_arguments: - return MixedInterpolator(expr) - - operand, = expr.ufl_operands - target_mesh = expr.target_space.mesh() - source_mesh = extract_unique_domain(operand) or target_mesh - submesh_interp_implemented = ( - all(isinstance(m.topology, firedrake.mesh.MeshTopology) for m in [target_mesh, source_mesh]) - and target_mesh.submesh_ancesters[-1] is source_mesh.submesh_ancesters[-1] - and target_mesh.topological_dimension() == source_mesh.topological_dimension() - ) - if target_mesh is source_mesh or submesh_interp_implemented: - return SameMeshInterpolator(expr) - - target_topology = target_mesh.topology - source_topology = source_mesh.topology - - if isinstance(target_topology, VertexOnlyMeshTopology): - if isinstance(source_topology, VertexOnlyMeshTopology): - return VomOntoVomInterpolator(expr) - if target_mesh.geometric_dimension() != source_mesh.geometric_dimension(): - raise ValueError("Cannot interpolate onto a mesh of a different geometric dimension") - if not hasattr(target_mesh, "_parent_mesh") or target_mesh._parent_mesh is not source_mesh: - raise ValueError("Can only interpolate across meshes where the source mesh is the parent of the target") - return SameMeshInterpolator(expr) - - if has_mixed_arguments or len(expr.target_space) > 1: - return MixedInterpolator(expr) - - return CrossMeshInterpolator(expr) - - class Interpolator(abc.ABC): - """Initialise the interpolator. Should not be instantiated directly; use the + """Base class for calculating interpolation. Should not be instantiated directly; use the :func:`get_interpolator` function. Parameters @@ -256,7 +218,7 @@ def _build_callable(self, tensor: Function | Cofunction | MatrixBase | None = No Parameters ---------- - tensor : Function | Cofunction | MatrixBase | None, optional + tensor Optional tensor to store the result in, by default None. """ pass @@ -273,16 +235,16 @@ def assemble( Parameters ---------- - tensor : Function | Cofunction | MatrixBase, optional + tensor : Function | Cofunction | MatrixBase Optional tensor to store the interpolated result. For rank-2 expressions this is expected to be a subclass of :class:`~firedrake.matrix.MatrixBase`. For lower-rank expressions - this is a :class:`~firedrake.Function` or :class:`~firedrake.Cofunction`, + this is a :class:`~firedrake.function.Function` or :class:`~firedrake.cofunction.Cofunction`, for forward and adjoint interpolation respectively. Returns ------- - Function | Cofunction | MatrixBase | Number + Function | Cofunction | MatrixBase | numbers.Number The function, cofunction, matrix, or scalar resulting from the interpolation. """ @@ -302,6 +264,54 @@ def assemble( return tensor.assign(result) if tensor else result +def get_interpolator(expr: Interpolate) -> Interpolator: + """Create an Interpolator. + + Parameters + ---------- + expr : Interpolate + Symbolic interpolation expression. + + Returns + ------- + Interpolator + An appropriate :class:`Interpolator` subclass for the given + interpolation expression. + """ + arguments = expr.arguments() + has_mixed_arguments = any(len(arg.function_space()) > 1 for arg in arguments) + if len(arguments) == 2 and has_mixed_arguments: + return MixedInterpolator(expr) + + operand, = expr.ufl_operands + target_mesh = expr.target_space.mesh() + source_mesh = extract_unique_domain(operand) or target_mesh + submesh_interp_implemented = ( + all(isinstance(m.topology, firedrake.mesh.MeshTopology) for m in [target_mesh, source_mesh]) + and target_mesh.submesh_ancesters[-1] is source_mesh.submesh_ancesters[-1] + and target_mesh.topological_dimension() == source_mesh.topological_dimension() + ) + if target_mesh is source_mesh or submesh_interp_implemented: + return SameMeshInterpolator(expr) + + target_topology = target_mesh.topology + source_topology = source_mesh.topology + + if isinstance(target_topology, VertexOnlyMeshTopology): + if isinstance(source_topology, VertexOnlyMeshTopology): + return VomOntoVomInterpolator(expr) + if target_mesh.geometric_dimension() != source_mesh.geometric_dimension(): + raise ValueError("Cannot interpolate onto a mesh of a different geometric dimension") + if not hasattr(target_mesh, "_parent_mesh") or target_mesh._parent_mesh is not source_mesh: + raise ValueError("Can only interpolate across meshes where the source mesh is the parent of the target") + return SameMeshInterpolator(expr) + + if has_mixed_arguments or len(expr.target_space) > 1: + return MixedInterpolator(expr) + + return CrossMeshInterpolator(expr) + + class DofNotDefinedError(Exception): r"""Raised when attempting to interpolate across function spaces where the target function space contains degrees of freedom (i.e. nodes) which cannot @@ -551,7 +561,7 @@ def _get_tensor(self) -> op2.Mat | Function | Cofunction: Returns ------- op2.Mat | Function | Cofunction - + The tensor to interpolate into. """ if self.rank == 0: R = firedrake.FunctionSpace(self.target_mesh, "Real", 0) @@ -686,7 +696,7 @@ def _build_interpolation_callables( tensor: op2.Dat | op2.Mat | op2.Global, access: Literal[op2.WRITE, op2.MIN, op2.MAX, op2.INC], subset: op2.Subset | None = None, - bcs: Iterable[BCBase] | None = None + bcs: Iterable[DirichletBC] | None = None ) -> tuple[Callable, ...]: """Returns tuple of callables which calculate the interpolation. @@ -699,9 +709,9 @@ def _build_interpolation_callables( Object to hold the result of the interpolation. access : Literal[op2.WRITE, op2.MIN, op2.MAX, op2.INC] op2 access descriptor - subset : op2.Subset | None, optional + subset : op2.Subset | None An optional subset to apply the interpolation over, by default None. - bcs : Iterable[BCBase] | None, optional + bcs : Iterable[DirichletBC] | None An optional list of boundary conditions to zero-out in the output function space. Interpolator rows or columns which are associated with boundary condition nodes are zeroed out when this is @@ -1338,7 +1348,7 @@ def mult(self, mat: PETSc.Mat, source_vec: PETSc.Vec, target_vec: PETSc.Vec) -> def multHermitian(self, mat: PETSc.Mat, source_vec: PETSc.Vec, target_vec: PETSc.Vec) -> None: """Applies the adjoint of the interpolation operator. Since ``VomOntoVomMat`` represents a permutation, it is - real-valued and thus the adjoint is the transpose. + real-valued and thus the Hermitian adjoint is the transpose. Parameters ---------- @@ -1399,8 +1409,8 @@ def _create_permutation_mat(self) -> PETSc.Mat: PETSc.Mat PETSc seqaij matrix """ - # To create the permutation matrix we broadcast an array of indices contiguous across - # all ranks and then use these indices to set the values of the matrix directly. + # To create the permutation matrix we broadcast an array of indices which are contiguous + # across all ranks and then use these indices to set the values of the matrix directly. mat = PETSc.Mat().createAIJ((self.target_size, self.source_size), nnz=1, comm=self.V.comm) mat.setUp() start = sum(self._local_sizes[:self.V.comm.rank]) @@ -1436,15 +1446,28 @@ def _wrap_python_mat(self) -> PETSc.Mat: mat.setUp() return mat - def duplicate(self, mat=None, op=None): + def duplicate(self, mat: PETSc.Mat | None = None, op: PETSc.Mat.DuplicateOption | None = None) -> PETSc.Mat: + """Duplicates the matrix. Needed to wrap as a PETSc Python Mat. + + Parameters + ---------- + mat : PETSc.Mat | None, optional + Unused, by default None + op : PETSc.Mat.DuplicateOption | None, optional + Unused, by default None + + Returns + ------- + PETSc.Mat + VomOntoVomMat wrapped as a PETSc Mat of type python. + """ return self._wrap_python_mat() class MixedInterpolator(Interpolator): - """Interpolator between MixedFunctionSpaces. - """ + """Interpolator between MixedFunctionSpaces.""" def __init__(self, expr: Interpolate): - """Initialise MixedInterpolator. Should not be called directly; use `get_interpolator` + """Initialise MixedInterpolator. Should not be called directly; use `get_interpolator`. Parameters ---------- @@ -1461,7 +1484,8 @@ def __init__(self, expr: Interpolate): # Create the Jacobian to be split into blocks self.expr = self.expr._ufl_expr_reconstruct_(self.operand, self.target_space) - Isub: dict[tuple[int, int], Interpolator] = {} + # Get sub-interpolators for each block + self.Isub: dict[tuple[int, int], Interpolator] = {} for indices, form in firedrake.formmanipulation.split_form(self.expr): if isinstance(form, ufl.ZeroBaseForm): # Ensure block sparsity @@ -1478,15 +1502,7 @@ def __init__(self, expr: Interpolate): # Take the action of each sub-cofunction against each block form = action(form, dual_split[indices[-1:]]) form.options.bcs = sub_bcs - Isub[indices] = get_interpolator(form) - - self._sub_interpolators = Isub - - def __getitem__(self, item: tuple[int, int]) -> Interpolator: - return self._sub_interpolators[item] - - def __iter__(self): - return iter(self._sub_interpolators) + self.Isub[indices] = get_interpolator(form) def _build_callable(self, tensor=None): V_dest = self.expr.function_space() or self.target_space @@ -1494,9 +1510,9 @@ def _build_callable(self, tensor=None): if self.rank == 2: shape = tuple(len(a.function_space()) for a in self.expr_args) blocks = numpy.full(shape, PETSc.Mat(), dtype=object) - for i in self: - self[i]._build_callable() - blocks[i] = self[i].callable().handle + for indices, interp in self.Isub.items(): + interp._build_callable() + blocks[indices] = interp.callable().handle self.handle = PETSc.Mat().createNest(blocks) def callable() -> MixedInterpolator: @@ -1504,9 +1520,11 @@ def callable() -> MixedInterpolator: elif self.rank == 1: def callable() -> Function | Cofunction: for k, sub_tensor in enumerate(f.subfunctions): - sub_tensor.assign(sum(self[i].assemble() for i in self if i[0] == k)) + sub_tensor.assign(sum( + interp.assemble() for indices, interp in self.Isub.items() if indices[0] == k + )) return f else: def callable() -> Number: - return sum(self[i].assemble() for i in self) + return sum(interp.assemble() for interp in self.Isub.values()) self.callable = callable From 0cfc6e90806697792a9580ed48bfb936ba2cc8ce Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Tue, 21 Oct 2025 11:46:27 +0100 Subject: [PATCH 124/125] lint --- firedrake/interpolation.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 5dedefcec1..7ea703c74e 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -238,7 +238,7 @@ def assemble( tensor : Function | Cofunction | MatrixBase Optional tensor to store the interpolated result. For rank-2 expressions this is expected to be a subclass of - :class:`~firedrake.matrix.MatrixBase`. For lower-rank expressions + :class:`~firedrake.matrix.MatrixBase`. For lower-rank expressions this is a :class:`~firedrake.function.Function` or :class:`~firedrake.cofunction.Cofunction`, for forward and adjoint interpolation respectively. @@ -496,9 +496,7 @@ def callable() -> Cofunction: f_point_eval_input_ordering.dat.data_wo[:] = numpy.nan def callable() -> Function | Number: - assemble(action(self.point_eval_input_ordering, f_point_eval), - tensor=f_point_eval_input_ordering) - + assemble(action(self.point_eval_input_ordering, f_point_eval), tensor=f_point_eval_input_ordering) # We assign these values to the output function if self.allow_missing_dofs and self.default_missing_val is None: indices = numpy.where(~numpy.isnan(f_point_eval_input_ordering.dat.data_ro))[0] @@ -1409,7 +1407,7 @@ def _create_permutation_mat(self) -> PETSc.Mat: PETSc.Mat PETSc seqaij matrix """ - # To create the permutation matrix we broadcast an array of indices which are contiguous + # To create the permutation matrix we broadcast an array of indices which are contiguous # across all ranks and then use these indices to set the values of the matrix directly. mat = PETSc.Mat().createAIJ((self.target_size, self.source_size), nnz=1, comm=self.V.comm) mat.setUp() From 2e7df2faa4aeae441919e761a1a912c698a17a49 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 22 Oct 2025 13:33:29 +0100 Subject: [PATCH 125/125] Fix TSFC tests --- tests/tsfc/test_interpolation_factorisation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tsfc/test_interpolation_factorisation.py b/tests/tsfc/test_interpolation_factorisation.py index 4355c24b1f..e0352a48be 100644 --- a/tests/tsfc/test_interpolation_factorisation.py +++ b/tests/tsfc/test_interpolation_factorisation.py @@ -2,7 +2,7 @@ import numpy import pytest -from ufl import (Mesh, FunctionSpace, Coefficient, +from ufl import (Interpolate, Mesh, FunctionSpace, Coefficient, interval, quadrilateral, hexahedron) from finat.ufl import FiniteElement, VectorElement, TensorElement @@ -30,7 +30,7 @@ def flop_count(mesh, source, target): Vtarget = FunctionSpace(mesh, target) Vsource = FunctionSpace(mesh, source) to_element = create_element(Vtarget.ufl_element()) - expr = Coefficient(Vsource) + expr = Interpolate(Coefficient(Vsource), Vtarget) kernel = compile_expression_dual_evaluation(expr, to_element, Vtarget.ufl_element()) return kernel.flop_count