From 28e0b9189547a08b436a657abf22d2ec5dd9b8ec Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 12 Feb 2026 09:16:03 -0500 Subject: [PATCH] api: fix interpolate with complex dtype --- devito/passes/clusters/cse.py | 14 +++++++++++++- tests/test_interpolation.py | 15 +++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/devito/passes/clusters/cse.py b/devito/passes/clusters/cse.py index d4d7f0a8b8..ba2c089ee9 100644 --- a/devito/passes/clusters/cse.py +++ b/devito/passes/clusters/cse.py @@ -36,6 +36,18 @@ def retrieve_ctemps(exprs, mode='all'): return search(exprs, lambda expr: isinstance(expr, CTemp), mode, 'dfs') +def cse_dtype(exprdtype, cdtype): + """ + Return the dtype of a CSE temporary given the dtype of the expression to be + captured and the cluster's dtype. + """ + if np.issubdtype(cdtype, np.complexfloating): + return np.promote_types(exprdtype, cdtype(0).real.__class__).type + else: + # Real cluster, can safely promote to the largest precision + return np.promote_types(exprdtype, cdtype).type + + @cluster_pass def cse(cluster, sregistry=None, options=None, **kwargs): """ @@ -86,7 +98,7 @@ def cse(cluster, sregistry=None, options=None, **kwargs): if cluster.is_fence: return cluster - make_dtype = lambda e: np.promote_types(e.dtype, dtype).type + make_dtype = lambda e: cse_dtype(e.dtype, dtype) make = lambda e: CTemp(name=sregistry.make_name(), dtype=make_dtype(e)) exprs = _cse(cluster, make, min_cost=min_cost, mode=mode) diff --git a/tests/test_interpolation.py b/tests/test_interpolation.py index 1cde2e13e2..8f59ca0076 100644 --- a/tests/test_interpolation.py +++ b/tests/test_interpolation.py @@ -855,6 +855,21 @@ def test_point_symbol_types(dtype, expected): assert point_symbol.dtype is expected +@pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) +def test_interp_complex(dtype): + grid = Grid((11, 11, 11)) + + sc = SparseFunction(name="sc", grid=grid, npoint=1, dtype=dtype) + sc.coordinates.data[:] = [.5, .5, .5] + + fc = Function(name="fc", grid=grid, npoint=2, dtype=dtype) + fc.data[:] = np.random.randn(*grid.shape) + 1j * np.random.randn(*grid.shape) + opC = Operator([sc.interpolate(expr=fc)], name="OpC") + opC() + + assert np.isclose(sc.data[0], fc.data[5, 5, 5]) + + class SD0(SubDomain): name = 'sd0'