From d5ccfac4b81dfa37ab5001a9a0fe9b16c4cbf45c Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 25 Feb 2026 15:26:38 -0500 Subject: [PATCH] Apply Runic formatting --- ext/StridedViewsCUDAExt.jl | 8 ++-- src/StridedViews.jl | 2 +- src/auxiliary.jl | 34 ++++++++----- src/stridedview.jl | 98 +++++++++++++++++++++----------------- test/runtests.jl | 6 +-- 5 files changed, 83 insertions(+), 65 deletions(-) diff --git a/ext/StridedViewsCUDAExt.jl b/ext/StridedViewsCUDAExt.jl index b96632d..503e750 100644 --- a/ext/StridedViewsCUDAExt.jl +++ b/ext/StridedViewsCUDAExt.jl @@ -4,11 +4,13 @@ using StridedViews using CUDA using CUDA: Adapt, CuPtr -const CuStridedView{T,N,A<:CuArray{T}} = StridedView{T,N,A} +const CuStridedView{T, N, A <: CuArray{T}} = StridedView{T, N, A} function Adapt.adapt_structure(::Type{T}, A::StridedView) where {T} - return StridedView(Adapt.adapt_structure(T, parent(A)), - A.size, A.strides, A.offset, A.op) + return StridedView( + Adapt.adapt_structure(T, parent(A)), + A.size, A.strides, A.offset, A.op + ) end function Base.pointer(x::CuStridedView{T}) where {T} diff --git a/src/StridedViews.jl b/src/StridedViews.jl index f863979..4235d3e 100644 --- a/src/StridedViews.jl +++ b/src/StridedViews.jl @@ -2,7 +2,7 @@ module StridedViews import Base: parent, size, strides, tail, setindex using Base: @propagate_inbounds, RangeIndex, Dims -const SliceIndex = Union{RangeIndex,Colon} +const SliceIndex = Union{RangeIndex, Colon} using LinearAlgebra export StridedView, sreshape, sview, isstrided diff --git a/src/auxiliary.jl b/src/auxiliary.jl index 64d1361..0b19941 100644 --- a/src/auxiliary.jl +++ b/src/auxiliary.jl @@ -2,11 +2,11 @@ #--------------------- # Check whether p is a valid permutation of length N _isperm(N::Integer, p::AbstractVector) = (length(p) == N && isperm(p)) -_isperm(N::Integer, p::NTuple{M,Integer}) where {M} = (M == N && isperm(p)) +_isperm(N::Integer, p::NTuple{M, Integer}) where {M} = (M == N && isperm(p)) _isperm(N::Integer, p) = false # Compute the memory index given a list of cartesian indices and corresponding strides -@inline function _computeind(indices::NTuple{N,Int}, strides::NTuple{N,Int}) where {N} +@inline function _computeind(indices::NTuple{N, Int}, strides::NTuple{N, Int}) where {N} return (indices[1] - 1) * strides[1] + _computeind(tail(indices), tail(strides)) end _computeind(indices::Tuple{}, strides::Tuple{}) = 1 @@ -23,7 +23,7 @@ function _simplifydims(size::Dims{N}, strides::Dims{N}) where {N} return (tailsize..., 1), (tailstrides..., 1) elseif size[1] * strides[1] == tailstrides[1] return (size[1] * tailsize[1], tail(tailsize)..., 1), - (strides[1], tail(tailstrides)..., 1) + (strides[1], tail(tailstrides)..., 1) else return (size[1], tailsize...), (strides[1], tailstrides...) end @@ -58,7 +58,7 @@ end #------------------------------ # Compute the new dimensions of a strided view given the original size and the view slicing # indices -@inline function _computeviewsize(oldsize::NTuple{N,Int}, I::NTuple{N,SliceIndex}) where {N} +@inline function _computeviewsize(oldsize::NTuple{N, Int}, I::NTuple{N, SliceIndex}) where {N} if isa(I[1], Int) return _computeviewsize(tail(oldsize), tail(I)) elseif isa(I[1], Colon) @@ -71,23 +71,29 @@ _computeviewsize(::Tuple{}, ::Tuple{}) = () # Compute the new strides of a (strided) view given the original strides and the view # slicing indices -@inline function _computeviewstrides(oldstrides::NTuple{N,Int}, - I::NTuple{N,SliceIndex}) where {N} +@inline function _computeviewstrides( + oldstrides::NTuple{N, Int}, + I::NTuple{N, SliceIndex} + ) where {N} if isa(I[1], Integer) return _computeviewstrides(tail(oldstrides), tail(I)) elseif isa(I[1], Colon) return (oldstrides[1], _computeviewstrides(tail(oldstrides), tail(I))...) else - return (oldstrides[1] * step(I[1]), - _computeviewstrides(tail(oldstrides), tail(I))...) + return ( + oldstrides[1] * step(I[1]), + _computeviewstrides(tail(oldstrides), tail(I))..., + ) end end _computeviewstrides(::Tuple{}, ::Tuple{}) = () # Compute the additional offset of a (strided) view given the original strides and the view # slicing indices -@inline function _computeviewoffset(strides::NTuple{N,Int}, - I::NTuple{N,SliceIndex}) where {N} +@inline function _computeviewoffset( + strides::NTuple{N, Int}, + I::NTuple{N, SliceIndex} + ) where {N} if isa(I[1], Colon) return _computeviewoffset(tail(strides), tail(I)) else @@ -101,8 +107,10 @@ _computeviewoffset(::Tuple{}, ::Tuple{}) = 0 # Compute the new strides of a (strided) reshape given the original strides and new and # original sizes _computereshapestrides(newsize::Tuple{}, oldsize::Tuple{}, strides::Tuple{}) = strides -function _computereshapestrides(newsize::Tuple{}, oldsize::Dims{N}, - strides::Dims{N}) where {N} +function _computereshapestrides( + newsize::Tuple{}, oldsize::Dims{N}, + strides::Dims{N} + ) where {N} all(isequal(1), oldsize) || throw(DimensionMismatch()) return () end @@ -112,7 +120,7 @@ function _computereshapestrides(newsize::Dims, oldsize::Tuple{}, strides::Tuple{ end function _computereshapestrides(newsize::Dims, oldsize::Dims{N}, strides::Dims{N}) where {N} d, r = divrem(oldsize[1], newsize[1]) - if r == 0 + return if r == 0 s1 = strides[1] if d == 1 # not shrinking the following tuples helps type inference diff --git a/src/stridedview.jl b/src/stridedview.jl index 69eef00..b745574 100644 --- a/src/stridedview.jl +++ b/src/stridedview.jl @@ -24,32 +24,36 @@ _adjoint(::FT) = conj # StridedView type definition #----------------------------- -struct StridedView{T,N,A<:DenseArray,F<:Union{FN,FC,FA,FT}} <: AbstractArray{T,N} +struct StridedView{T, N, A <: DenseArray, F <: Union{FN, FC, FA, FT}} <: AbstractArray{T, N} parent::A - size::NTuple{N,Int} - strides::NTuple{N,Int} + size::NTuple{N, Int} + strides::NTuple{N, Int} offset::Int op::F end # Constructors #-------------- -function StridedView(parent::DenseArray, - size::NTuple{N,Int}=size(parent), - strides::NTuple{N,Int}=strides(parent), - offset::Int=0, - op::F=identity) where {N,F} +function StridedView( + parent::DenseArray, + size::NTuple{N, Int} = size(parent), + strides::NTuple{N, Int} = strides(parent), + offset::Int = 0, + op::F = identity + ) where {N, F} T = Base.promote_op(op, eltype(parent)) return StridedView{T}(parent, size, strides, offset, op) end -function StridedView{T}(parent::DenseArray, - size::NTuple{N,Int}=size(parent), - strides::NTuple{N,Int}=strides(parent), - offset::Int=0, - op::F=identity) where {T,N,F} +function StridedView{T}( + parent::DenseArray, + size::NTuple{N, Int} = size(parent), + strides::NTuple{N, Int} = strides(parent), + offset::Int = 0, + op::F = identity + ) where {T, N, F} parent′ = _normalizeparent(parent) strides′ = _normalizestrides(size, strides) - return StridedView{T,N,typeof(parent′),F}(parent′, size, strides′, offset, op) + return StridedView{T, N, typeof(parent′), F}(parent′, size, strides′, offset, op) end StridedView(a::StridedView) = a @@ -57,10 +61,10 @@ StridedView(a::Adjoint) = StridedView(a')' StridedView(a::Transpose) = transpose(StridedView(transpose(a))) StridedView(a::Base.SubArray) = sview(StridedView(a.parent), a.indices...) StridedView(a::Base.ReshapedArray) = sreshape(StridedView(a.parent), a.dims) -function StridedView(a::Base.PermutedDimsArray{T,N,P}) where {T,N,P} +function StridedView(a::Base.PermutedDimsArray{T, N, P}) where {T, N, P} return permutedims(StridedView(a.parent), P) end -function StridedView(a::Base.ReinterpretArray{T,N}) where {T,N} +function StridedView(a::Base.ReinterpretArray{T, N}) where {T, N} b = StridedView(a.parent) S = eltype(b) isbitstype(T) && isbitstype(S) && sizeof(T) == sizeof(S) || @@ -83,8 +87,10 @@ function isstrided(a::Base.ReshapedArray) newsize = a.dims oldsize = size(a.parent) any(isequal(0), newsize) && return true - newstrides = _computereshapestrides(newsize, - _simplifydims(oldsize, _strides(a.parent))...) + newstrides = _computereshapestrides( + newsize, + _simplifydims(oldsize, _strides(a.parent))... + ) return !isnothing(newstrides) end isstrided(a::Base.PermutedDimsArray) = isstrided(a.parent) @@ -93,11 +99,11 @@ isstrided(a::AbstractArray) = false # work around annoying Base behavior: it doesn't define strides for complex adjoints # because of the recursiveness of the definitions, we need to redefine all of them _strides(a::DenseArray) = strides(a) -_strides(a::Adjoint{<:Any,<:AbstractVector}) = (stride(a.parent, 2), stride(a.parent, 1)) -_strides(a::Adjoint{<:Any,<:AbstractMatrix}) = reverse(strides(a.parent)) -_strides(a::Transpose{<:Any,<:AbstractVector}) = (stride(a.parent, 2), stride(a.parent, 1)) -_strides(a::Transpose{<:Any,<:AbstractMatrix}) = reverse(strides(a.parent)) -function _strides(a::PermutedDimsArray{T,N,perm}) where {T,N,perm} +_strides(a::Adjoint{<:Any, <:AbstractVector}) = (stride(a.parent, 2), stride(a.parent, 1)) +_strides(a::Adjoint{<:Any, <:AbstractMatrix}) = reverse(strides(a.parent)) +_strides(a::Transpose{<:Any, <:AbstractVector}) = (stride(a.parent, 2), stride(a.parent, 1)) +_strides(a::Transpose{<:Any, <:AbstractMatrix}) = reverse(strides(a.parent)) +function _strides(a::PermutedDimsArray{T, N, perm}) where {T, N, perm} s = _strides(parent(a)) return ntuple(d -> s[perm[d]], Val(N)) end @@ -107,8 +113,8 @@ _strides(a::SubArray) = Base.substrides(_strides(a.parent), a.indices) #----------------------- Base.size(a::StridedView) = a.size Base.strides(a::StridedView) = a.strides -Base.stride(a::StridedView{<:Any,0}, n::Int) = 1 -function Base.stride(a::StridedView{<:Any,N}, n::Int) where {N} +Base.stride(a::StridedView{<:Any, 0}, n::Int) = 1 +function Base.stride(a::StridedView{<:Any, N}, n::Int) where {N} return (n <= N) ? a.strides[n] : a.strides[N] * a.size[N] end offset(a::StridedView) = a.offset @@ -119,13 +125,13 @@ Base.parent(a::StridedView) = a.parent Base.IndexStyle(::Type{<:StridedView}) = Base.IndexCartesian() # Indexing with N integer arguments -@inline function Base.getindex(a::StridedView{<:Any,N}, I::Vararg{Int,N}) where {N} +@inline function Base.getindex(a::StridedView{<:Any, N}, I::Vararg{Int, N}) where {N} @boundscheck checkbounds(a, I...) i = ParentIndex(a.offset + _computeind(I, a.strides)) @inbounds r = getindex(a, i) return r end -@inline function Base.setindex!(a::StridedView{<:Any,N}, v, I::Vararg{Int,N}) where {N} +@inline function Base.setindex!(a::StridedView{<:Any, N}, v, I::Vararg{Int, N}) where {N} @boundscheck checkbounds(a, I...) i = ParentIndex(a.offset + _computeind(I, a.strides)) @inbounds setindex!(a, v, i) @@ -133,12 +139,14 @@ end end # Indexing with slice indices to create a new view -@inline function Base.getindex(a::StridedView{T,N}, I::Vararg{SliceIndex,N}) where {T,N} - return StridedView{T}(a.parent, - _computeviewsize(a.size, I), - _computeviewstrides(a.strides, I), - a.offset + _computeviewoffset(a.strides, I), - a.op) +@inline function Base.getindex(a::StridedView{T, N}, I::Vararg{SliceIndex, N}) where {T, N} + return StridedView{T}( + a.parent, + _computeviewsize(a.size, I), + _computeviewstrides(a.strides, I), + a.offset + _computeviewoffset(a.strides, I), + a.op + ) end # Indexing directly into parent array @@ -161,7 +169,7 @@ end # Specific Base methods that are guaranteed to preserve`StridedView` objects #---------------------------------------------------------------------------- Base.conj(a::StridedView{<:Real}) = a -function Base.conj(a::StridedView{T}) where {T<:Complex} +function Base.conj(a::StridedView{T}) where {T <: Complex} return StridedView{T}(a.parent, a.size, a.strides, a.offset, _conj(a.op)) end function Base.conj(a::StridedView) @@ -171,22 +179,22 @@ function Base.conj(a::StridedView) return StridedView{T}(a.parent, a.size, a.strides, a.offset, newop) end -@inline function Base.permutedims(a::StridedView{T,N}, p) where {T,N} +@inline function Base.permutedims(a::StridedView{T, N}, p) where {T, N} _isperm(N, p) || throw(ArgumentError("Invalid permutation of length $N: $p")) newsize = ntuple(n -> size(a, p[n]), Val(N)) newstrides = ntuple(n -> stride(a, p[n]), Val(N)) return StridedView{T}(a.parent, newsize, newstrides, a.offset, a.op) end -LinearAlgebra.transpose(a::StridedView{<:Number,2}) = permutedims(a, (2, 1)) -LinearAlgebra.adjoint(a::StridedView{<:Number,2}) = permutedims(conj(a), (2, 1)) -function LinearAlgebra.adjoint(a::StridedView{<:Any,2}) # act recursively, like Base +LinearAlgebra.transpose(a::StridedView{<:Number, 2}) = permutedims(a, (2, 1)) +LinearAlgebra.adjoint(a::StridedView{<:Number, 2}) = permutedims(conj(a), (2, 1)) +function LinearAlgebra.adjoint(a::StridedView{<:Any, 2}) # act recursively, like Base S = Base.promote_op(a.op, eltype(a)) newop = _adjoint(a.op) T = Base.promote_op(newop, S) return permutedims(StridedView{T}(a.parent, a.size, a.strides, a.offset, newop), (2, 1)) end -function LinearAlgebra.transpose(a::StridedView{<:Any,2}) # act recursively, like Base +function LinearAlgebra.transpose(a::StridedView{<:Any, 2}) # act recursively, like Base S = Base.promote_op(a.op, eltype(a)) newop = _transpose(a.op) T = Base.promote_op(newop, S) @@ -212,15 +220,15 @@ end # Creating or transforming StridedView by slicing #------------------------------------------------- # we cannot use Base.view, as this also accepts indices that might not preserve stridedness -sview(a::StridedView{<:Any,N}, I::Vararg{SliceIndex,N}) where {N} = getindex(a, I...) +sview(a::StridedView{<:Any, N}, I::Vararg{SliceIndex, N}) where {N} = getindex(a, I...) sview(a::StridedView, I::SliceIndex) = getindex(sreshape(a, (length(a),)), I) # for StridedView and index arguments which preserve stridedness, we do replace Base.view # with sview -Base.view(a::StridedView{<:Any,N}, I::Vararg{SliceIndex,N}) where {N} = getindex(a, I...) +Base.view(a::StridedView{<:Any, N}, I::Vararg{SliceIndex, N}) where {N} = getindex(a, I...) # `sview` can be used as a constructor when acting on `AbstractArray` objects -@inline function sview(a::AbstractArray{<:Any,N}, I::Vararg{SliceIndex,N}) where {N} +@inline function sview(a::AbstractArray{<:Any, N}, I::Vararg{SliceIndex, N}) where {N} return getindex(StridedView(a), I...) end @inline function sview(a::AbstractArray, I::SliceIndex) @@ -230,7 +238,7 @@ end # Creating or transforming StridedView by reshaping #--------------------------------------------------- # An error struct for non-strided reshapes -struct ReshapeException{N₁,N₂} <: Exception +struct ReshapeException{N₁, N₂} <: Exception newsize::Dims{N₁} oldsize::Dims{N₂} strides::Dims{N₂} @@ -265,7 +273,7 @@ end # Other methods: `similar`, `copy` #---------------------------------- -function Base.similar(a::StridedView, ::Type{T}, dims::NTuple{N,Int}) where {N,T} +function Base.similar(a::StridedView, ::Type{T}, dims::NTuple{N, Int}) where {N, T} return StridedView(similar(a.parent, T, dims)) end Base.copy(a::StridedView) = copyto!(similar(a), a) @@ -275,7 +283,7 @@ Base.copy(a::StridedView) = copyto!(similar(a), a) function Base.unsafe_convert(::Type{Ptr{T}}, a::StridedView{T}) where {T} return convert(Ptr{T}, pointer(a.parent, a.offset + 1)) end -function Base.elsize(::Type{<:StridedView{T,N,A}}) where {T,N,A} +function Base.elsize(::Type{<:StridedView{T, N, A}}) where {T, N, A} return Base.elsize(A) end Base.dataids(a::StridedView) = Base.dataids(a.parent) diff --git a/test/runtests.jl b/test/runtests.jl index 3821f34..10889bf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -138,7 +138,7 @@ Random.seed!(1234) end end @test reshape(B8, (1, 1, 1)) == reshape(A8, (1, 1, 1)) == - StridedView(reshape(A8, (1, 1, 1))) == sreshape(A8, (1, 1, 1)) + StridedView(reshape(A8, (1, 1, 1))) == sreshape(A8, (1, 1, 1)) @test reshape(B8, ()) == reshape(A8, ()) end @@ -245,7 +245,7 @@ end @test view(B, :, 1:5, 3, 1:5) === sview(B, :, 1:5, 3, 1:5) === B[:, 1:5, 3, 1:5] @test view(B, :, 1:5, 3, 1:5) == StridedView(view(A, :, 1:5, 3, 1:5)) @test pointer(view(B, :, 1:5, 3, 1:5)) == - pointer(StridedView(view(A, :, 1:5, 3, 1:5))) + pointer(StridedView(view(A, :, 1:5, 3, 1:5))) @test StridedViews.offset(view(B, :, 1:5, 3, 1:5)) == 2 * stride(B, 3) end end @@ -291,5 +291,5 @@ Aqua.test_all(StridedViews) if isempty(VERSION.prerelease) using JET - JET.test_package(StridedViews; target_modules=(StridedViews,)) + JET.test_package(StridedViews; target_modules = (StridedViews,)) end