From c4dcf2d9720292c96241cd5ff6b40cbfa541bb48 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Sun, 15 Mar 2026 20:31:53 +0200 Subject: [PATCH] fix: allow sequences for center --- .basedpyright/baseline.json | 16 ---------------- doc/conf.py | 1 + sumpy/point_calculus.py | 18 +++++++++--------- 3 files changed, 10 insertions(+), 25 deletions(-) diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index 4376065f..d183da00 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -13687,22 +13687,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 24, - "endColumn": 50, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 25, - "endColumn": 52, - "lineCount": 1 - } - }, { "code": "reportUnknownArgumentType", "range": { diff --git a/doc/conf.py b/doc/conf.py index 8f13feca..9013483a 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -37,6 +37,7 @@ "Array1D": "class:numpy.ndarray", "Array2D": "class:numpy.ndarray", "ArrayND": "class:numpy.ndarray", + "ToArray1D": "class:numpy.ndarray", "np.floating": "class:numpy.floating", "np.complexfloating": "class:numpy.complexfloating", "np.inexact": "class:numpy.inexact", diff --git a/sumpy/point_calculus.py b/sumpy/point_calculus.py index a83eef26..fb177185 100644 --- a/sumpy/point_calculus.py +++ b/sumpy/point_calculus.py @@ -35,7 +35,7 @@ if TYPE_CHECKING: from collections.abc import Callable, Sequence - from optype.numpy import Array1D, Array2D, ArrayND + from optype.numpy import Array1D, Array2D, ArrayND, ToArray1D __doc__ = """ @@ -99,28 +99,28 @@ class CalculusPatch: _pshape: tuple[int, ...] def __init__(self, - center: Array1D[np.floating[Any]], + center: ToArray1D[np.floating[Any]], h: float = 1e-1, order: int = 4, nodes: NodesKind = "chebyshev") -> None: - self.center = center - dtype = center.dtype + center = np.asarray(center) + assert center.ndim == 1 npoints = order + 1 if nodes == "equispaced": - points_1d = np.linspace(-h/2, h/2, npoints, dtype=dtype) + points_1d = np.linspace(-h/2, h/2, npoints) weights_1d = None elif nodes == "chebyshev": - a = np.arange(npoints, dtype=dtype) + a = np.arange(npoints) points_1d = (h/2)*np.cos((2*(a+1)-1)/(2*npoints)*np.pi) weights_1d = None elif nodes == "legendre": from scipy.special import legendre points_1d, weights_1d, _ = legendre(npoints).weights.T - points_1d = (points_1d * (h/2)).astype(dtype) - weights_1d = (weights_1d * (h/2)).astype(dtype) + points_1d = points_1d * (h/2) + weights_1d = weights_1d * (h/2) else: raise ValueError(f"invalid node set: {nodes}") @@ -130,8 +130,8 @@ def __init__(self, self._points_1d = points_1d self._weights_1d = weights_1d - self.dim = dim = len(self.center) self.center = center + self.dim = dim = len(self.center) points_shaped = np.array(np.meshgrid( *[center[i] + points_1d for i in range(dim)],