We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
2 parents fc8777f + ea068a0 commit d90d81fCopy full SHA for d90d81f
4 files changed
array_api_compat/cupy/_aliases.py
@@ -125,6 +125,20 @@ def astype(
125
return out.copy() if copy and out is x else out
126
127
128
+# cupy.count_nonzero does not have keepdims
129
+def count_nonzero(
130
+ x: ndarray,
131
+ axis=None,
132
+ keepdims=False
133
+) -> ndarray:
134
+ result = cp.count_nonzero(x, axis)
135
+ if keepdims:
136
+ if axis is None:
137
+ return cp.reshape(result, [1]*x.ndim)
138
+ return cp.expand_dims(result, axis)
139
+ return result
140
+
141
142
# These functions are completely new here. If the library already has them
143
# (i.e., numpy 2.0), use the library version instead of our wrapper.
144
if hasattr(cp, 'vecdot'):
@@ -146,6 +160,6 @@ def astype(
146
160
'acos', 'acosh', 'asin', 'asinh', 'atan',
147
161
'atan2', 'atanh', 'bitwise_left_shift',
148
162
'bitwise_invert', 'bitwise_right_shift',
149
- 'bool', 'concat', 'pow', 'sign']
163
+ 'bool', 'concat', 'count_nonzero', 'pow', 'sign']
150
164
151
165
_all_ignore = ['cp', 'get_xp']
array_api_compat/dask/array/_aliases.py
@@ -335,6 +335,21 @@ def argsort(
335
return restore(x)
336
337
338
+# dask.array.count_nonzero does not have keepdims
339
340
+ x: Array,
341
342
343
+) -> Array:
344
+ result = da.count_nonzero(x, axis)
345
346
347
+ return da.reshape(result, [1]*x.ndim)
348
+ return da.expand_dims(result, axis)
349
350
351
352
353
__all__ = _aliases.__all__ + [
354
'__array_namespace_info__', 'asarray', 'astype', 'acos',
355
'acosh', 'asin', 'asinh', 'atan', 'atan2',
@@ -343,6 +358,6 @@ def argsort(
358
'result_type', 'bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64',
359
'uint8', 'uint16', 'uint32', 'uint64',
360
'complex64', 'complex128', 'iinfo', 'finfo',
- 'can_cast', 'result_type']
361
+ 'can_cast', 'count_nonzero', 'result_type']
362
363
_all_ignore = ["Callable", "array_namespace", "get_xp", "da", "np"]
array_api_compat/numpy/_aliases.py
@@ -127,6 +127,19 @@ def astype(
return x.astype(dtype=dtype, copy=copy)
+# count_nonzero returns a python int for axis=None and keepdims=False
+# https://github.com/numpy/numpy/issues/17562
+ x : ndarray,
+ result = np.count_nonzero(x, axis=axis, keepdims=keepdims)
+ if axis is None and not keepdims:
+ return np.asarray(result)
145
if hasattr(np, 'vecdot'):
@@ -148,6 +161,6 @@ def astype(
- 'bool', 'concat', 'pow']
+ 'bool', 'concat', 'count_nonzero', 'pow']
152
153
166
_all_ignore = ['np', 'get_xp']
array_api_compat/torch/_aliases.py
@@ -521,15 +521,22 @@ def diff(
521
return torch.diff(x, dim=axis, n=n, prepend=prepend, append=append)
522
523
524
-# torch uses `dim` instead of `axis`
+# torch uses `dim` instead of `axis`, does not have keepdims
525
def count_nonzero(
526
x: array,
527
/,
528
*,
529
axis: Optional[Union[int, Tuple[int, ...]]] = None,
530
keepdims: bool = False,
531
) -> array:
532
- return torch.count_nonzero(x, dim=axis, keepdims=keepdims)
+ result = torch.count_nonzero(x, dim=axis)
533
534
+ if axis is not None:
535
+ return result.unsqueeze(axis)
536
+ return _axis_none_keepdims(result, x.ndim, keepdims)
537
+ else:
538
539
540
541
542
def where(condition: array, x1: array, x2: array, /) -> array:
0 commit comments