-
Notifications
You must be signed in to change notification settings - Fork 54
RFC: add support for a tuple of axes in expand_dims #760
Copy link
Copy link
Closed
Labels
API changeChanges to existing functions or objects in the API.Changes to existing functions or objects in the API.AcceptedRFC feature request which has been accepted.RFC feature request which has been accepted.RFCRequest for comments. Feature requests and proposed changes.Request for comments. Feature requests and proposed changes.topic: ManipulationArray manipulation and transformation.Array manipulation and transformation.
Milestone
Metadata
Metadata
Assignees
Labels
API changeChanges to existing functions or objects in the API.Changes to existing functions or objects in the API.AcceptedRFC feature request which has been accepted.RFC feature request which has been accepted.RFCRequest for comments. Feature requests and proposed changes.Request for comments. Feature requests and proposed changes.topic: ManipulationArray manipulation and transformation.Array manipulation and transformation.
Type
Projects
Status
Stage 0
Hello all! I raised this issue on array-api-compat earlier (data-apis/array-api-compat#105), but I think it might be more properly directed here.
In the array API,
expand_dimssupports only a single axis (https://data-apis.org/array-api/latest/API_specification/generated/array_api.expand_dims.html) as opposed to a tuple of axes. This is different to NumPy, CuPy, and Jax, which support a tuple of axes. PyTorch, however, supports only a single axis. I don't know the justification for why the array API supports only a single axis as opposed to a tuple, but what it means is thatexpand_dimsno longer works in many places when adopting the array API.In practice, expand_dims is just a light wrapper for reshape, see https://github.com/numpy/numpy/blob/3b246c6488cf246d488bbe5726ca58dc26b6ea74/numpy/lib/_shape_base_impl.py#L594. But it's not great to force users to write their own version of
expand_dimsin every library now. Is the array API willing to updateexpand_dimsto support a tuple of axes? If not, and ifexpand_dimswill only support a single axis going forward, that effectively makes all users ofexpand_dimscopy and paste the NumPy implementation.@lucascolley Pointed out to me that when
expand_dimswas added to the array API, only NumPy supported a tuple of axes. See #42. That was 4 years ago and the situation has changed, as above.