From c7e4731334c2088bf33cc4dec4fa77be8cb8b910 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Thu, 12 Mar 2026 14:56:44 -0500 Subject: [PATCH 1/6] Upsample as early as possible, to avoid aliasing error Co-authored-by: Shawn Lin --- pytential/symbolic/mappers.py | 57 +++++++++- test/test_normal_interpolation.py | 166 ++++++++++++++++++++++++++++++ 2 files changed, 220 insertions(+), 3 deletions(-) create mode 100644 test/test_normal_interpolation.py diff --git a/pytential/symbolic/mappers.py b/pytential/symbolic/mappers.py index d437b9113..2e1dbe655 100644 --- a/pytential/symbolic/mappers.py +++ b/pytential/symbolic/mappers.py @@ -740,6 +740,55 @@ def map_int_g(self, expr: pp.IntG): # {{{ InterpolationPreprocessor +@dataclass +class EarlyInterpolationAdder( + # This is deliberately inheriting from the pymbolic mapper, + # based on the assumption that all the pymbolic-defined operations + # will apply elementwise. Pytential nodes will end up in + # handle_unsupported_expression below. + IdentityMapperBase[[]], + CSECachingMapperMixin[Expression, []]): + """Used from within :class:`InterpolationPreprocessor`. Rather than + interpolate the result of a computation, push interpolation as far + 'upstream' as possible, to minimize aliasing error. + """ + from_dd: DOFDescriptor + to_dd: DOFDescriptor + + @override + def map_variable(self, expr: p.Variable): + return pp.interpolate(expr, self.from_dd, self.to_dd) + + @override + def map_call(self, + expr: p.Call, + ) -> Expression: + parameters = tuple(self.rec(child) for child in expr.parameters) + if all(child is orig_child for child, orig_child in + zip(expr.parameters, parameters, strict=True)): + return expr + + return type(expr)(expr.function, parameters) + + @override + def handle_unsupported_expression(self, expr: p.ExpressionNode) -> Expression: + return pp.interpolate(expr, self.from_dd, self.to_dd) + + @override + def map_common_subexpression_uncached(self, + expr: p.CommonSubexpression, /, + ) -> Expression: + result = self.rec(expr.child) + if result is expr.child: + return expr + + return type(expr)( + result, + expr.prefix, + expr.scope, + **expr.get_extra_properties()) + + class InterpolationPreprocessor(IdentityMapper): """Handle expressions that require upsampling or downsampling by inserting a :class:`~pytential.symbolic.primitives.Interpolation`. This is used to @@ -801,16 +850,18 @@ def map_int_g(self, expr: pp.IntG): from_dd = expr.source.to_stage1() to_dd = from_dd.to_quad_stage2() + interp_adder = EarlyInterpolationAdder(from_dd, to_dd) densities = tuple( - pp.interpolate(self.rec_arith(density), from_dd, to_dd) + interp_adder.rec_arith(self.rec_arith(density)) for density in expr.densities) from_dd = from_dd.copy(discr_stage=self.from_discr_stage) + interp_adder = EarlyInterpolationAdder(from_dd, to_dd) kernel_arguments = constantdict({ name: componentwise( - lambda aexpr: pp.interpolate( + lambda aexpr: interp_adder.rec_arith( self.rec_arith( - self.tagger.rec_arith(aexpr)), from_dd, to_dd), + self.tagger.rec_arith(aexpr))), arg_expr) for name, arg_expr in expr.kernel_arguments.items()}) diff --git a/test/test_normal_interpolation.py b/test/test_normal_interpolation.py new file mode 100644 index 000000000..9e8a07f7b --- /dev/null +++ b/test/test_normal_interpolation.py @@ -0,0 +1,166 @@ +from __future__ import annotations + + +__copyright__ = "Copyright (C) 2026 Shawn/Chaoqi Lin" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from typing import TYPE_CHECKING + +import numpy as np + +from arraycontext import flatten, pytest_generate_tests_for_array_contexts +from meshmode.discretization import Discretization +from meshmode.discretization.poly_element import InterpolatoryQuadratureGroupFactory +from meshmode.mesh import TensorProductElementGroup +from meshmode.mesh.generation import generate_sphere +from sumpy.expansion.local import LineTaylorLocalExpansion +from sumpy.kernel import DirectionalSourceDerivative, LaplaceKernel +from sumpy.qbx import LayerPotential + +from pytential import GeometryCollection, bind, sym +from pytential.array_context import PytestPyOpenCLArrayContextFactory +from pytential.qbx import QBXLayerPotentialSource + + +if TYPE_CHECKING: + from arraycontext.context import ArrayContextFactory + + +pytest_generate_tests = pytest_generate_tests_for_array_contexts([ + PytestPyOpenCLArrayContextFactory, +]) + + +def test_no_aliasing_in_kernel_arguments(actx_factory: ArrayContextFactory): + """Disagreement between sumpy and pytential would previously arise because + pytential would upsample an already-computed normal, as opposed to recomputing + the normal on the upsampled grid. Doing the former leads to avoidable + aliasing error that would limit pytential to attaining about 5 digits + in this specific test. + """ + actx = actx_factory() + + order = 4 + qbx_order = 4 + level = 3 + ambient_dim = 3 + base_knl = LaplaceKernel(3) + dlp_knl = DirectionalSourceDerivative(base_knl, dir_vec_name="dsource_vec") + + mesh = generate_sphere(1.0, order, uniform_refinement_rounds=level, + group_cls=TensorProductElementGroup) + pre_density_discr = Discretization( + actx, mesh, InterpolatoryQuadratureGroupFactory(order)) + + fine_orders = [order, order*2, order*4] + rows = [] + + for fine_order in fine_orders: + qbx = QBXLayerPotentialSource( + pre_density_discr, fine_order=fine_order, + qbx_order=qbx_order, fmm_order=False) + places = GeometryCollection({"qbx": qbx}, auto_where="qbx") + + target_discr = places.get_discretization("qbx", sym.QBX_SOURCE_STAGE1) + source_discr = places.get_discretization( + "qbx", sym.QBX_SOURCE_QUAD_STAGE2) + source_dd = sym.DOFDescriptor("qbx", sym.QBX_SOURCE_QUAD_STAGE2) + + # --- pytential sym.D --- + sigma = target_discr.zeros(actx) + 1 + bound_op = bind(places, sym.D(base_knl, sym.var("sigma"), + qbx_forced_limit=-1)) + result_pyt = bound_op(actx, sigma=sigma) + err_pyt = float(np.max(np.abs( + actx.to_numpy(result_pyt[0]).ravel() - (-1.0)))) + + # --- sumpy direct --- + expn = LineTaylorLocalExpansion(base_knl, qbx_order) + lpot = LayerPotential( + expansion=expn, + source_kernels=(dlp_knl,), + target_kernels=(dlp_knl,), + ) + + targets = actx.thaw(target_discr.nodes()) + sources = actx.thaw(source_discr.nodes()) + normals_src = bind(places, sym.normal( + ambient_dim, dofdesc=source_dd))(actx).as_vector(object) + waa = bind(places, sym.weights_and_area_elements( + ambient_dim=ambient_dim, dim=ambient_dim - 1, + dofdesc=source_dd))(actx) + expansion_radii = bind( + places, sym.expansion_radii(ambient_dim))(actx) + centers_in = bind( + places, sym.expansion_centers(ambient_dim, -1))(actx) + + targets_h = actx.to_numpy( + flatten(targets, actx)).reshape(ambient_dim, -1) + sources_h = actx.to_numpy( + flatten(sources, actx)).reshape(ambient_dim, -1) + centers_h = actx.to_numpy( + flatten(centers_in, actx)).reshape(ambient_dim, -1) + radii_h = actx.to_numpy(flatten(expansion_radii, actx)) + waa_h = actx.to_numpy(flatten(waa, actx)) + normals_h = actx.to_numpy( + flatten(normals_src, actx)).reshape(ambient_dim, -1) + + result_sumpy = lpot(actx, + targets=actx.from_numpy(targets_h), + sources=actx.from_numpy(sources_h), + centers=actx.from_numpy(centers_h), + strengths=(actx.from_numpy(waa_h),), + expansion_radii=actx.from_numpy(radii_h), + dsource_vec=actx.from_numpy(normals_h), + ) + err_sumpy = float(np.max(np.abs( + actx.to_numpy(result_sumpy[0]).ravel() - (-1.0)))) + + pyt_flat = actx.to_numpy(result_pyt[0]).ravel() + sumpy_flat = actx.to_numpy(result_sumpy[0]).ravel() + diff = float(np.max(np.abs(pyt_flat - sumpy_flat))) + + rows.append((fine_order, err_sumpy, err_pyt, diff)) + + if diff >= 1e-14: + header = (f" {'fine_order':>10s} {'err sumpy':>12s} " + f"{'err pytential':>14s} {'|sumpy-pyt|':>12s}") + lines = [header, f" {'-'*54}"] + for fine_order, err_sumpy, err_pytential, diff in rows: + lines.append( + f" {fine_order:10d} {err_sumpy:12.2e} " + f"{err_pytential:14.2e} {diff:12.2e}") + table = "\n".join(lines) + raise AssertionError( + f"DLP results disagree at fine_order={fine_order}\n{table}") + + +if __name__ == "__main__": + import sys + + from pytential.array_context import _acf # noqa: F401 + + if len(sys.argv) > 1: + exec(sys.argv[1]) + else: + from pytest import main + main([__file__]) From de34340cb02fe04c14303ee3249e25cd504025e8 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Thu, 12 Mar 2026 16:30:12 -0500 Subject: [PATCH 2/6] A few annotations in GraphvizMapper --- pytential/symbolic/mappers.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/pytential/symbolic/mappers.py b/pytential/symbolic/mappers.py index 2e1dbe655..5e2305188 100644 --- a/pytential/symbolic/mappers.py +++ b/pytential/symbolic/mappers.py @@ -1076,9 +1076,6 @@ class PrettyStringifyMapper( # {{{ graphviz class GraphvizMapper(GraphvizMapperBase): - def __init__(self): - super().__init__() - def map_pytential_leaf(self, expr): self.lines.append( '{} [label="{}", shape=box];'.format( @@ -1090,7 +1087,7 @@ def map_pytential_leaf(self, expr): map_ones = map_pytential_leaf - def map_map_node_sum(self, expr): + def map_map_node_sum(self, expr: pp.NodeSum): self.lines.append( '{} [label="{}",shape=circle];'.format( self.get_id(expr), type(expr).__name__)) @@ -1106,7 +1103,7 @@ def map_map_node_sum(self, expr): map_q_weight = map_pytential_leaf - def map_int_g(self, expr): + def map_int_g(self, expr: pp.IntG): descr = "Int[%s->%s]@(%d) (%s)" % ( stringify_where(expr.source), stringify_where(expr.target), From 432328fa9db666abb942fff538faa784323dedd7 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Thu, 12 Mar 2026 16:31:03 -0500 Subject: [PATCH 3/6] Add missing type arg in StringifyMapper --- pytential/symbolic/mappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytential/symbolic/mappers.py b/pytential/symbolic/mappers.py index 5e2305188..d82b63d69 100644 --- a/pytential/symbolic/mappers.py +++ b/pytential/symbolic/mappers.py @@ -940,7 +940,7 @@ def stringify_where(where: DOFDescriptorLike): return str(pp.as_dofdesc(where)) -class StringifyMapper(BaseStringifyMapper): +class StringifyMapper(BaseStringifyMapper[[]]): def map_ones(self, expr: pp.Ones, enclosing_prec: int): return "Ones[%s]" % stringify_where(expr.dofdesc) From 345f6fd1ad9a18fce4cfd127d47eba4ea035096d Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Thu, 12 Mar 2026 16:32:15 -0500 Subject: [PATCH 4/6] Break out discretization-level CSEs in bound operator code printing --- pytential/symbolic/compiler.py | 59 +++++++++++++++++++++++----------- 1 file changed, 41 insertions(+), 18 deletions(-) diff --git a/pytential/symbolic/compiler.py b/pytential/symbolic/compiler.py index ae8e12b06..f3f37c997 100644 --- a/pytential/symbolic/compiler.py +++ b/pytential/symbolic/compiler.py @@ -22,7 +22,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ - +from abc import ABC, abstractmethod from dataclasses import dataclass from typing import TYPE_CHECKING, Generic, Literal, TypeVar, cast @@ -33,7 +33,12 @@ from pymbolic import ArithmeticExpression from pytools.obj_array import ObjectArray, ObjectArray1D, ShapeT, from_numpy -from pytential.symbolic.mappers import CachedIdentityMapper, DependencyMapper +from pytential.symbolic.mappers import ( + CachedIdentityMapper, + DependencyMapper, + PrettyStringifyMapper, + StringifyMapper, +) from pytential.symbolic.primitives import ( DOFDescriptor, IntG, @@ -44,6 +49,7 @@ if TYPE_CHECKING: from collections.abc import ( + Callable, Collection, Hashable, Iterator, @@ -76,7 +82,7 @@ # {{{ statements @dataclass(frozen=True, eq=False) -class Statement: +class Statement(ABC): """ .. autoattribute:: names .. autoattribute:: exprs @@ -93,23 +99,25 @@ class Statement: priority: int """The priority of the statement.""" + @abstractmethod def get_assignees(self) -> set[str]: """ :returns: names of variables that are assigned to in this statement. """ - raise NotImplementedError( - f"get_assignees for '{self.__class__.__name__}'") + @abstractmethod def get_dependencies(self, dep_mapper: DependencyMapper) -> set[prim.Variable]: """ :returns: variables that are dependencies of the assignees. """ - raise NotImplementedError( - f"get_dependencies for '{self.__class__.__name__}'") + + @abstractmethod + def stringify(self, expr_mapper: Callable[[Expression | Operand], str]) -> str: + ... @override - def __str__(self) -> str: - raise NotImplementedError + def __str__(self): + return self.stringify(StringifyMapper()) @dataclass(frozen=True, eq=False) @@ -152,14 +160,17 @@ def get_dependencies(self, dep_mapper: DependencyMapper) -> set[prim.Variable]: return deps @override - def __str__(self) -> str: + def stringify(self, expr_mapper: Callable[[Expression | Operand], str]) -> str: comment = self.comment if len(self.names) == 1: if comment: comment = f"/* {comment} */ " - return "{} <- {}{}".format(self.names[0], comment, self.exprs[0]) + return "{} <- {}{}".format( + self.names[0], + comment, + expr_mapper(self.exprs[0])) else: do_not_return = self.do_not_return if do_not_return is None: @@ -176,7 +187,7 @@ def __str__(self) -> str: else: dnr_indicator = "" - lines.append(f" {n} <{dnr_indicator}- {e}") + lines.append(f" {n} <{dnr_indicator}- {expr_mapper(e)}") lines.append("}") return "\n".join(lines) @@ -266,14 +277,12 @@ def get_dependencies(self, dep_mapper: DependencyMapper) -> set[prim.Variable]: return result @override - def __str__(self) -> str: + def stringify(self, expr_mapper: Callable[[Expression | Operand], str]) -> str: args = [f"source={self.source}"] for i, density in enumerate(self.densities): args.append(f"density{i}={density}") - from pytential.symbolic.mappers import StringifyMapper, stringify_where - strify = StringifyMapper() - + from pytential.symbolic.mappers import stringify_where lines: list[str] = [] for o in self.outputs: if o.target_name != self.source: @@ -308,7 +317,7 @@ def __str__(self) -> str: lines.append(line) for arg_name, arg_expr in self.kernel_arguments.items(): - arg_expr_lines = strify(arg_expr).split("\n") + arg_expr_lines = expr_mapper(arg_expr).split("\n") lines.append(" {} = {}".format(arg_name, arg_expr_lines[0])) lines.extend(" " + s for s in arg_expr_lines[1:]) @@ -417,9 +426,23 @@ def statements(self) -> list[Statement]: @override def __str__(self) -> str: + strify_mapper = PrettyStringifyMapper() lines: list[str] = [] for insn in self.statements: - lines.extend(str(insn).split("\n")) + lines.extend(insn.stringify(strify_mapper).split("\n")) + + if strify_mapper.cse_name_list: + # FIXME: There's potential here for name clashes between the 'code' + # and 'discretization CSE' parts. It's just presentation, so if it's + # bothersome, near here is the place to fix it. + lines = [ + "DISCRETIZATION-LEVEL COMMON SUBEXPRESSIONS:", + *[ + f"{name} <- {cse_expr_str}" + for name, cse_expr_str in strify_mapper.cse_name_list], + "-"*75, + *lines] + lines.append(f"RESULT: {self.result}") return "\n".join(lines) From f1730875e0a317523cdf177271fffa486cc18f3d Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Thu, 12 Mar 2026 17:05:31 -0500 Subject: [PATCH 5/6] Annotate DiscretizationStageTagger --- pytential/symbolic/mappers.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/pytential/symbolic/mappers.py b/pytential/symbolic/mappers.py index d82b63d69..43dbb85a4 100644 --- a/pytential/symbolic/mappers.py +++ b/pytential/symbolic/mappers.py @@ -571,7 +571,9 @@ class DiscretizationStageTagger(IdentityMapper): :attr:`~pytential.symbolic.dof_desc.DOFDescriptor.discr_stage`. """ - def __init__(self, discr_stage): + discr_stage: DiscretizationStage + + def __init__(self, discr_stage: DiscretizationStage): if discr_stage not in { pp.QBX_SOURCE_STAGE1, pp.QBX_SOURCE_STAGE2, @@ -581,7 +583,8 @@ def __init__(self, discr_stage): self.discr_stage = discr_stage - def map_node_coordinate_component(self, expr): + @override + def map_node_coordinate_component(self, expr: pp.NodeCoordinateComponent): dofdesc = expr.dofdesc if dofdesc.discr_stage == self.discr_stage: return expr @@ -590,7 +593,8 @@ def map_node_coordinate_component(self, expr): expr.ambient_axis, dofdesc.copy(discr_stage=self.discr_stage)) - def map_num_reference_derivative(self, expr): + @override + def map_num_reference_derivative(self, expr: pp.NumReferenceDerivative): dofdesc = expr.dofdesc if dofdesc.discr_stage == self.discr_stage: return expr From 523cbde3221ddcf1eb6bf2ffafed43775fe6b859 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Thu, 12 Mar 2026 15:00:33 -0500 Subject: [PATCH 6/6] Update baseline --- .basedpyright/baseline.json | 80 ++++--------------------------------- 1 file changed, 8 insertions(+), 72 deletions(-) diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index 15007cfb5..269c975cf 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -12121,22 +12121,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 24, - "endColumn": 27, - "lineCount": 1 - } - }, - { - "code": "reportArgumentType", - "range": { - "startColumn": 41, - "endColumn": 76, - "lineCount": 1 - } - }, { "code": "reportAny", "range": { @@ -12209,38 +12193,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 35, - "endColumn": 45, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 35, - "endColumn": 45, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 34, - "endColumn": 43, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 34, - "endColumn": 43, - "lineCount": 1 - } - }, { "code": "reportUnknownArgumentType", "range": { @@ -16885,14 +16837,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 19, - "endColumn": 41, - "lineCount": 1 - } - }, { "code": "reportUnknownVariableType", "range": { @@ -17157,6 +17101,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownVariableType", + "range": { + "startColumn": 8, + "endColumn": 19, + "lineCount": 1 + } + }, { "code": "reportUnknownMemberType", "range": { @@ -17189,14 +17141,6 @@ "lineCount": 1 } }, - { - "code": "reportAttributeAccessIssue", - "range": { - "startColumn": 58, - "endColumn": 74, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -17213,14 +17157,6 @@ "lineCount": 1 } }, - { - "code": "reportAttributeAccessIssue", - "range": { - "startColumn": 36, - "endColumn": 54, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": {