Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 37 additions & 7 deletions src/csrc/umath/binary_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,7 @@ quad_ldexp_resolve_descriptors(PyObject *self, PyArray_DTypeMeta *const dtypes[]
Py_INCREF(given_descrs[0]);
loop_descrs[0] = given_descrs[0];

// Input 1: Use NPY_INTP (int64 on 64-bit, int32 on 32-bit) to match platform integer size
// This ensures we can handle the full range of PyArray_PyLongDType without data loss
// Input 1: Use NPY_INTP to match the registered PyArray_IntpDType
loop_descrs[1] = PyArray_DescrFromType(NPY_INTP);

// Output: QuadPrecDType with same backend as input
Expand Down Expand Up @@ -408,7 +407,7 @@ create_quad_ldexp_ufunc(PyObject *numpy, const char *ufunc_name)
return -1;
}

PyArray_DTypeMeta *dtypes[3] = {&QuadPrecDType, &PyArray_PyLongDType, &QuadPrecDType};
PyArray_DTypeMeta *dtypes[3] = {&QuadPrecDType, &PyArray_IntpDType, &QuadPrecDType};

PyType_Slot slots[] = {
{NPY_METH_resolve_descriptors, (void *)&quad_ldexp_resolve_descriptors},
Expand All @@ -433,12 +432,12 @@ create_quad_ldexp_ufunc(PyObject *numpy, const char *ufunc_name)
}

PyObject *promoter_capsule =
PyCapsule_New((void *)&quad_ufunc_promoter, "numpy._ufunc_promoter", NULL);
PyCapsule_New((void *)&quad_ldexp_promoter, "numpy._ufunc_promoter", NULL);
if (promoter_capsule == NULL) {
return -1;
}

PyObject *DTypes = PyTuple_Pack(3, &PyArrayDescr_Type, &PyArray_PyLongDType, &PyArrayDescr_Type);
PyObject *DTypes = PyTuple_Pack(3, &QuadPrecDType, &PyArrayDescr_Type, &PyArrayDescr_Type);
if (DTypes == 0) {
Py_DECREF(promoter_capsule);
return -1;
Expand Down Expand Up @@ -495,13 +494,29 @@ create_quad_binary_2out_ufunc(PyObject *numpy, const char *ufunc_name)
return -1;
}

PyObject *DTypes = PyTuple_Pack(4, &PyArrayDescr_Type, &PyArrayDescr_Type,
// Register promoter for (QuadPrecDType, Any, Any, Any)
PyObject *DTypes = PyTuple_Pack(4, &QuadPrecDType, &PyArrayDescr_Type,
&PyArrayDescr_Type, &PyArrayDescr_Type);
if (DTypes == 0) {
Py_DECREF(promoter_capsule);
return -1;
}

if (PyUFunc_AddPromoter(ufunc, DTypes, promoter_capsule) < 0) {
Py_DECREF(promoter_capsule);
Py_DECREF(DTypes);
return -1;
}
Py_DECREF(DTypes);

// Register promoter for (Any, QuadPrecDType, Any, Any)
DTypes = PyTuple_Pack(4, &PyArrayDescr_Type, &QuadPrecDType,
&PyArrayDescr_Type, &PyArrayDescr_Type);
if (DTypes == 0) {
Py_DECREF(promoter_capsule);
return -1;
}

if (PyUFunc_AddPromoter(ufunc, DTypes, promoter_capsule) < 0) {
Py_DECREF(promoter_capsule);
Py_DECREF(DTypes);
Expand Down Expand Up @@ -551,7 +566,22 @@ create_quad_binary_ufunc(PyObject *numpy, const char *ufunc_name)
return -1;
}

PyObject *DTypes = PyTuple_Pack(3, &PyArrayDescr_Type, &PyArrayDescr_Type, &PyArrayDescr_Type);
// Register promoter for (QuadPrecDType, Any, Any)
PyObject *DTypes = PyTuple_Pack(3, &QuadPrecDType, &PyArrayDescr_Type, &PyArrayDescr_Type);
if (DTypes == 0) {
Py_DECREF(promoter_capsule);
return -1;
}

if (PyUFunc_AddPromoter(ufunc, DTypes, promoter_capsule) < 0) {
Py_DECREF(promoter_capsule);
Py_DECREF(DTypes);
return -1;
}
Py_DECREF(DTypes);

// Register promoter for (Any, QuadPrecDType, Any)
DTypes = PyTuple_Pack(3, &PyArrayDescr_Type, &QuadPrecDType, &PyArrayDescr_Type);
if (DTypes == 0) {
Py_DECREF(promoter_capsule);
return -1;
Expand Down
16 changes: 14 additions & 2 deletions src/csrc/umath/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,8 @@ init_matmul_ops(PyObject *numpy)
return -1;
}

PyObject *DTypes = PyTuple_Pack(3, &PyArrayDescr_Type, &PyArrayDescr_Type, &PyArrayDescr_Type);
// Register promoter for (QuadPrecDType, Any, Any)
PyObject *DTypes = PyTuple_Pack(3, &QuadPrecDType, &PyArrayDescr_Type, &PyArrayDescr_Type);
if (DTypes == NULL) {
Py_DECREF(promoter_capsule);
Py_DECREF(ufunc);
Expand All @@ -441,10 +442,21 @@ init_matmul_ops(PyObject *numpy)
if (PyUFunc_AddPromoter(ufunc, DTypes, promoter_capsule) < 0) {
PyErr_Clear();
}
else {
Py_DECREF(DTypes);

// Register promoter for (Any, QuadPrecDType, Any)
DTypes = PyTuple_Pack(3, &PyArrayDescr_Type, &QuadPrecDType, &PyArrayDescr_Type);
if (DTypes == NULL) {
Py_DECREF(promoter_capsule);
Py_DECREF(ufunc);
return -1;
}

if (PyUFunc_AddPromoter(ufunc, DTypes, promoter_capsule) < 0) {
PyErr_Clear();
}
Py_DECREF(DTypes);

Py_DECREF(promoter_capsule);
Py_DECREF(ufunc);

Expand Down
23 changes: 23 additions & 0 deletions src/include/umath/promoters.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,27 @@ quad_ufunc_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
}


inline int
quad_ldexp_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
PyArray_DTypeMeta *signature[], PyArray_DTypeMeta *new_op_dtypes[])
{
Py_INCREF(&QuadPrecDType);
new_op_dtypes[0] = &QuadPrecDType;

// Promote the exponent to PyArray_IntpDType (unless signature specifies otherwise)
if (signature[1] != NULL) {
Py_INCREF(signature[1]);
new_op_dtypes[1] = signature[1];
}
else {
Py_INCREF(&PyArray_IntpDType);
new_op_dtypes[1] = &PyArray_IntpDType;
}

Py_INCREF(&QuadPrecDType);
new_op_dtypes[2] = &QuadPrecDType;

return 0;
}

#endif
56 changes: 56 additions & 0 deletions tests/test_quaddtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -5880,6 +5880,62 @@ def test_logical_reduce_on_non_quad_arrays():
assert result == True


class TestPromoterNoInterference:
"""Regression tests for overly broad promoter registration.

Prior to the fix, promoters were registered with PyArrayDescr_Type in
all slots, matching ANY dtype combination. This caused the quaddtype
promoter to intercept operations on unrelated NumPy types (timedelta64,
float64, etc.), breaking normal NumPy functionality.

See https://github.com/numpy/numpy-quaddtype/issues/76
"""

def test_timedelta_modulus_raises_typeerror(self):
"""timedelta64 % int must raise TypeError, not be silently handled."""
with pytest.raises(TypeError, match="remainder"):
np.remainder(np.timedelta64(7, 'Y'), 15)

def test_timedelta_divide_preserves_dtype(self):
"""timedelta64 / int must return timedelta64, not float64."""
a = np.arange(1000, dtype="m8[s]")
result = a.sum() / len(a)
assert result.dtype.kind == 'm', (
f"Expected timedelta64 dtype, got {result.dtype}")

def test_timedelta_mean_correct(self):
"""timedelta mean must use timedelta division, not float promotion."""
a = np.arange(1000, dtype="m8[s]")
mean_val = a.mean()
sum_div = a.sum() / len(a)
np.testing.assert_array_equal(mean_val, sum_div)

@pytest.mark.parametrize("op", [
np.add, np.subtract, np.multiply, np.divide,
np.floor_divide, np.power, np.mod,
])
def test_binary_ufunc_float64_preserves_dtype(self, op):
"""Builtin float64 ops must not be affected by quad promoters."""
a = np.array([1.0, 2.0, 3.0], dtype=np.float64)
b = np.array([4.0, 5.0, 6.0], dtype=np.float64)
result = op(a, b)
assert result.dtype == np.float64

def test_matmul_float64_preserves_dtype(self):
a = np.eye(3, dtype=np.float64)
b = np.ones((3, 2), dtype=np.float64)
result = np.matmul(a, b)
assert result.dtype == np.float64
np.testing.assert_array_equal(result, b)

def test_divmod_float64_preserves_dtype(self):
a = np.array([7.0, 8.0], dtype=np.float64)
b = np.array([3.0, 3.0], dtype=np.float64)
q, r = np.divmod(a, b)
assert q.dtype == np.float64
assert r.dtype == np.float64


def test_sleef_purecfma_symbols():
"""Test that SLEEF PURECFMA symbols are present in the compiled module.

Expand Down
Loading