Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions mlx/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2546,13 +2546,15 @@ array blackman(int M, StreamOrDevice s /* = {} */) {
return add(subtract(alpha, term1, s), term2, s);
}

/** Returns a sorted copy of the flattened array. */
/** Returns a sorted copy of the flattened array.
* The sort is stable and NaN values are placed at the end. */
array sort(const array& a, StreamOrDevice s /* = {} */) {
int size = a.size();
return sort(reshape(a, {size}, s), 0, s);
}

/** Returns a sorted copy of the array along a given axis. */
/** Returns a sorted copy of the array along a given axis.
* The sort is stable and NaN values are placed at the end. */
array sort(const array& a, int axis, StreamOrDevice s /* = {} */) {
// Check for valid axis
if (axis + static_cast<int>(a.ndim()) < 0 ||
Expand All @@ -2567,13 +2569,15 @@ array sort(const array& a, int axis, StreamOrDevice s /* = {} */) {
a.shape(), a.dtype(), std::make_shared<Sort>(to_stream(s), axis), {a});
}

/** Returns indices that sort the flattened array. */
/** Returns indices that sort the flattened array.
* The sort is stable and NaN values are placed at the end. */
array argsort(const array& a, StreamOrDevice s /* = {} */) {
int size = a.size();
return argsort(reshape(a, {size}, s), 0, s);
}

/** Returns indices that sort the array along a given axis. */
/** Returns indices that sort the array along a given axis.
* The sort is stable and NaN values are placed at the end. */
array argsort(const array& a, int axis, StreamOrDevice s /* = {} */) {
// Check for valid axis
if (axis + static_cast<int>(a.ndim()) < 0 ||
Expand Down
20 changes: 16 additions & 4 deletions mlx/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -770,16 +770,28 @@ inline array argmax(const array& a, StreamOrDevice s = {}) {
MLX_API array
argmax(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {});

/** Returns a sorted copy of the flattened array. */
/**
* Returns a sorted copy of the flattened array.
* The sort is stable and NaN values are placed at the end.
*/
MLX_API array sort(const array& a, StreamOrDevice s = {});

/** Returns a sorted copy of the array along a given axis. */
/**
* Returns a sorted copy of the array along a given axis.
* The sort is stable and NaN values are placed at the end.
*/
MLX_API array sort(const array& a, int axis, StreamOrDevice s = {});

/** Returns indices that sort the flattened array. */
/**
* Returns indices that sort the flattened array.
* The sort is stable and NaN values are placed at the end.
*/
MLX_API array argsort(const array& a, StreamOrDevice s = {});

/** Returns indices that sort the array along a given axis. */
/**
* Returns indices that sort the array along a given axis.
* The sort is stable and NaN values are placed at the end.
*/
MLX_API array argsort(const array& a, int axis, StreamOrDevice s = {});

/**
Expand Down
6 changes: 6 additions & 0 deletions python/src/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2821,6 +2821,9 @@ void init_ops(nb::module_& m) {
R"pbdoc(
Returns a sorted copy of the array.

The sort is stable, meaning equal elements preserve their relative
order. ``NaN`` values are placed at the end.

Args:
a (array): Input array.
axis (int or None, optional): Optional axis to sort over.
Expand Down Expand Up @@ -2848,6 +2851,9 @@ void init_ops(nb::module_& m) {
R"pbdoc(
Returns the indices that sort the array.

The sort is stable, meaning equal elements preserve their relative
order. ``NaN`` values are placed at the end.

Args:
a (array): Input array.
axis (int or None, optional): Optional axis to sort over.
Expand Down