Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions ext/StridedViewsCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ using StridedViews
using CUDA
using CUDA: Adapt, CuPtr

const CuStridedView{T,N,A<:CuArray{T}} = StridedView{T,N,A}
const CuStridedView{T, N, A <: CuArray{T}} = StridedView{T, N, A}

function Adapt.adapt_structure(::Type{T}, A::StridedView) where {T}
return StridedView(Adapt.adapt_structure(T, parent(A)),
A.size, A.strides, A.offset, A.op)
return StridedView(
Adapt.adapt_structure(T, parent(A)),
A.size, A.strides, A.offset, A.op
)
end

function Base.pointer(x::CuStridedView{T}) where {T}
Expand Down
2 changes: 1 addition & 1 deletion src/StridedViews.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module StridedViews

import Base: parent, size, strides, tail, setindex
using Base: @propagate_inbounds, RangeIndex, Dims
const SliceIndex = Union{RangeIndex,Colon}
const SliceIndex = Union{RangeIndex, Colon}

using LinearAlgebra
export StridedView, sreshape, sview, isstrided
Expand Down
34 changes: 21 additions & 13 deletions src/auxiliary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
#---------------------
# Check whether p is a valid permutation of length N
_isperm(N::Integer, p::AbstractVector) = (length(p) == N && isperm(p))
_isperm(N::Integer, p::NTuple{M,Integer}) where {M} = (M == N && isperm(p))
_isperm(N::Integer, p::NTuple{M, Integer}) where {M} = (M == N && isperm(p))
_isperm(N::Integer, p) = false

# Compute the memory index given a list of cartesian indices and corresponding strides
@inline function _computeind(indices::NTuple{N,Int}, strides::NTuple{N,Int}) where {N}
@inline function _computeind(indices::NTuple{N, Int}, strides::NTuple{N, Int}) where {N}
return (indices[1] - 1) * strides[1] + _computeind(tail(indices), tail(strides))
end
_computeind(indices::Tuple{}, strides::Tuple{}) = 1
Expand All @@ -23,7 +23,7 @@ function _simplifydims(size::Dims{N}, strides::Dims{N}) where {N}
return (tailsize..., 1), (tailstrides..., 1)
elseif size[1] * strides[1] == tailstrides[1]
return (size[1] * tailsize[1], tail(tailsize)..., 1),
(strides[1], tail(tailstrides)..., 1)
(strides[1], tail(tailstrides)..., 1)
else
return (size[1], tailsize...), (strides[1], tailstrides...)
end
Expand Down Expand Up @@ -58,7 +58,7 @@ end
#------------------------------
# Compute the new dimensions of a strided view given the original size and the view slicing
# indices
@inline function _computeviewsize(oldsize::NTuple{N,Int}, I::NTuple{N,SliceIndex}) where {N}
@inline function _computeviewsize(oldsize::NTuple{N, Int}, I::NTuple{N, SliceIndex}) where {N}
if isa(I[1], Int)
return _computeviewsize(tail(oldsize), tail(I))
elseif isa(I[1], Colon)
Expand All @@ -71,23 +71,29 @@ _computeviewsize(::Tuple{}, ::Tuple{}) = ()

# Compute the new strides of a (strided) view given the original strides and the view
# slicing indices
@inline function _computeviewstrides(oldstrides::NTuple{N,Int},
I::NTuple{N,SliceIndex}) where {N}
@inline function _computeviewstrides(
oldstrides::NTuple{N, Int},
I::NTuple{N, SliceIndex}
) where {N}
if isa(I[1], Integer)
return _computeviewstrides(tail(oldstrides), tail(I))
elseif isa(I[1], Colon)
return (oldstrides[1], _computeviewstrides(tail(oldstrides), tail(I))...)
else
return (oldstrides[1] * step(I[1]),
_computeviewstrides(tail(oldstrides), tail(I))...)
return (
oldstrides[1] * step(I[1]),
_computeviewstrides(tail(oldstrides), tail(I))...,
)
end
end
_computeviewstrides(::Tuple{}, ::Tuple{}) = ()

# Compute the additional offset of a (strided) view given the original strides and the view
# slicing indices
@inline function _computeviewoffset(strides::NTuple{N,Int},
I::NTuple{N,SliceIndex}) where {N}
@inline function _computeviewoffset(
strides::NTuple{N, Int},
I::NTuple{N, SliceIndex}
) where {N}
if isa(I[1], Colon)
return _computeviewoffset(tail(strides), tail(I))
else
Expand All @@ -101,8 +107,10 @@ _computeviewoffset(::Tuple{}, ::Tuple{}) = 0
# Compute the new strides of a (strided) reshape given the original strides and new and
# original sizes
_computereshapestrides(newsize::Tuple{}, oldsize::Tuple{}, strides::Tuple{}) = strides
function _computereshapestrides(newsize::Tuple{}, oldsize::Dims{N},
strides::Dims{N}) where {N}
function _computereshapestrides(
newsize::Tuple{}, oldsize::Dims{N},
strides::Dims{N}
) where {N}
all(isequal(1), oldsize) || throw(DimensionMismatch())
return ()
end
Expand All @@ -112,7 +120,7 @@ function _computereshapestrides(newsize::Dims, oldsize::Tuple{}, strides::Tuple{
end
function _computereshapestrides(newsize::Dims, oldsize::Dims{N}, strides::Dims{N}) where {N}
d, r = divrem(oldsize[1], newsize[1])
if r == 0
return if r == 0
s1 = strides[1]
if d == 1
# not shrinking the following tuples helps type inference
Expand Down
98 changes: 53 additions & 45 deletions src/stridedview.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,43 +24,47 @@ _adjoint(::FT) = conj

# StridedView type definition
#-----------------------------
struct StridedView{T,N,A<:DenseArray,F<:Union{FN,FC,FA,FT}} <: AbstractArray{T,N}
struct StridedView{T, N, A <: DenseArray, F <: Union{FN, FC, FA, FT}} <: AbstractArray{T, N}
parent::A
size::NTuple{N,Int}
strides::NTuple{N,Int}
size::NTuple{N, Int}
strides::NTuple{N, Int}
offset::Int
op::F
end

# Constructors
#--------------
function StridedView(parent::DenseArray,
size::NTuple{N,Int}=size(parent),
strides::NTuple{N,Int}=strides(parent),
offset::Int=0,
op::F=identity) where {N,F}
function StridedView(
parent::DenseArray,
size::NTuple{N, Int} = size(parent),
strides::NTuple{N, Int} = strides(parent),
offset::Int = 0,
op::F = identity
) where {N, F}
T = Base.promote_op(op, eltype(parent))
return StridedView{T}(parent, size, strides, offset, op)
end
function StridedView{T}(parent::DenseArray,
size::NTuple{N,Int}=size(parent),
strides::NTuple{N,Int}=strides(parent),
offset::Int=0,
op::F=identity) where {T,N,F}
function StridedView{T}(
parent::DenseArray,
size::NTuple{N, Int} = size(parent),
strides::NTuple{N, Int} = strides(parent),
offset::Int = 0,
op::F = identity
) where {T, N, F}
parent′ = _normalizeparent(parent)
strides′ = _normalizestrides(size, strides)
return StridedView{T,N,typeof(parent′),F}(parent′, size, strides′, offset, op)
return StridedView{T, N, typeof(parent′), F}(parent′, size, strides′, offset, op)
end

StridedView(a::StridedView) = a
StridedView(a::Adjoint) = StridedView(a')'
StridedView(a::Transpose) = transpose(StridedView(transpose(a)))
StridedView(a::Base.SubArray) = sview(StridedView(a.parent), a.indices...)
StridedView(a::Base.ReshapedArray) = sreshape(StridedView(a.parent), a.dims)
function StridedView(a::Base.PermutedDimsArray{T,N,P}) where {T,N,P}
function StridedView(a::Base.PermutedDimsArray{T, N, P}) where {T, N, P}
return permutedims(StridedView(a.parent), P)
end
function StridedView(a::Base.ReinterpretArray{T,N}) where {T,N}
function StridedView(a::Base.ReinterpretArray{T, N}) where {T, N}
b = StridedView(a.parent)
S = eltype(b)
isbitstype(T) && isbitstype(S) && sizeof(T) == sizeof(S) ||
Expand All @@ -83,8 +87,10 @@ function isstrided(a::Base.ReshapedArray)
newsize = a.dims
oldsize = size(a.parent)
any(isequal(0), newsize) && return true
newstrides = _computereshapestrides(newsize,
_simplifydims(oldsize, _strides(a.parent))...)
newstrides = _computereshapestrides(
newsize,
_simplifydims(oldsize, _strides(a.parent))...
)
return !isnothing(newstrides)
end
isstrided(a::Base.PermutedDimsArray) = isstrided(a.parent)
Expand All @@ -93,11 +99,11 @@ isstrided(a::AbstractArray) = false
# work around annoying Base behavior: it doesn't define strides for complex adjoints
# because of the recursiveness of the definitions, we need to redefine all of them
_strides(a::DenseArray) = strides(a)
_strides(a::Adjoint{<:Any,<:AbstractVector}) = (stride(a.parent, 2), stride(a.parent, 1))
_strides(a::Adjoint{<:Any,<:AbstractMatrix}) = reverse(strides(a.parent))
_strides(a::Transpose{<:Any,<:AbstractVector}) = (stride(a.parent, 2), stride(a.parent, 1))
_strides(a::Transpose{<:Any,<:AbstractMatrix}) = reverse(strides(a.parent))
function _strides(a::PermutedDimsArray{T,N,perm}) where {T,N,perm}
_strides(a::Adjoint{<:Any, <:AbstractVector}) = (stride(a.parent, 2), stride(a.parent, 1))
_strides(a::Adjoint{<:Any, <:AbstractMatrix}) = reverse(strides(a.parent))
_strides(a::Transpose{<:Any, <:AbstractVector}) = (stride(a.parent, 2), stride(a.parent, 1))
_strides(a::Transpose{<:Any, <:AbstractMatrix}) = reverse(strides(a.parent))
function _strides(a::PermutedDimsArray{T, N, perm}) where {T, N, perm}
s = _strides(parent(a))
return ntuple(d -> s[perm[d]], Val(N))
end
Expand All @@ -107,8 +113,8 @@ _strides(a::SubArray) = Base.substrides(_strides(a.parent), a.indices)
#-----------------------
Base.size(a::StridedView) = a.size
Base.strides(a::StridedView) = a.strides
Base.stride(a::StridedView{<:Any,0}, n::Int) = 1
function Base.stride(a::StridedView{<:Any,N}, n::Int) where {N}
Base.stride(a::StridedView{<:Any, 0}, n::Int) = 1
function Base.stride(a::StridedView{<:Any, N}, n::Int) where {N}
return (n <= N) ? a.strides[n] : a.strides[N] * a.size[N]
end
offset(a::StridedView) = a.offset
Expand All @@ -119,26 +125,28 @@ Base.parent(a::StridedView) = a.parent
Base.IndexStyle(::Type{<:StridedView}) = Base.IndexCartesian()

# Indexing with N integer arguments
@inline function Base.getindex(a::StridedView{<:Any,N}, I::Vararg{Int,N}) where {N}
@inline function Base.getindex(a::StridedView{<:Any, N}, I::Vararg{Int, N}) where {N}
@boundscheck checkbounds(a, I...)
i = ParentIndex(a.offset + _computeind(I, a.strides))
@inbounds r = getindex(a, i)
return r
end
@inline function Base.setindex!(a::StridedView{<:Any,N}, v, I::Vararg{Int,N}) where {N}
@inline function Base.setindex!(a::StridedView{<:Any, N}, v, I::Vararg{Int, N}) where {N}
@boundscheck checkbounds(a, I...)
i = ParentIndex(a.offset + _computeind(I, a.strides))
@inbounds setindex!(a, v, i)
return a
end

# Indexing with slice indices to create a new view
@inline function Base.getindex(a::StridedView{T,N}, I::Vararg{SliceIndex,N}) where {T,N}
return StridedView{T}(a.parent,
_computeviewsize(a.size, I),
_computeviewstrides(a.strides, I),
a.offset + _computeviewoffset(a.strides, I),
a.op)
@inline function Base.getindex(a::StridedView{T, N}, I::Vararg{SliceIndex, N}) where {T, N}
return StridedView{T}(
a.parent,
_computeviewsize(a.size, I),
_computeviewstrides(a.strides, I),
a.offset + _computeviewoffset(a.strides, I),
a.op
)
end

# Indexing directly into parent array
Expand All @@ -161,7 +169,7 @@ end
# Specific Base methods that are guaranteed to preserve`StridedView` objects
#----------------------------------------------------------------------------
Base.conj(a::StridedView{<:Real}) = a
function Base.conj(a::StridedView{T}) where {T<:Complex}
function Base.conj(a::StridedView{T}) where {T <: Complex}
return StridedView{T}(a.parent, a.size, a.strides, a.offset, _conj(a.op))
end
function Base.conj(a::StridedView)
Expand All @@ -171,22 +179,22 @@ function Base.conj(a::StridedView)
return StridedView{T}(a.parent, a.size, a.strides, a.offset, newop)
end

@inline function Base.permutedims(a::StridedView{T,N}, p) where {T,N}
@inline function Base.permutedims(a::StridedView{T, N}, p) where {T, N}
_isperm(N, p) || throw(ArgumentError("Invalid permutation of length $N: $p"))
newsize = ntuple(n -> size(a, p[n]), Val(N))
newstrides = ntuple(n -> stride(a, p[n]), Val(N))
return StridedView{T}(a.parent, newsize, newstrides, a.offset, a.op)
end

LinearAlgebra.transpose(a::StridedView{<:Number,2}) = permutedims(a, (2, 1))
LinearAlgebra.adjoint(a::StridedView{<:Number,2}) = permutedims(conj(a), (2, 1))
function LinearAlgebra.adjoint(a::StridedView{<:Any,2}) # act recursively, like Base
LinearAlgebra.transpose(a::StridedView{<:Number, 2}) = permutedims(a, (2, 1))
LinearAlgebra.adjoint(a::StridedView{<:Number, 2}) = permutedims(conj(a), (2, 1))
function LinearAlgebra.adjoint(a::StridedView{<:Any, 2}) # act recursively, like Base
S = Base.promote_op(a.op, eltype(a))
newop = _adjoint(a.op)
T = Base.promote_op(newop, S)
return permutedims(StridedView{T}(a.parent, a.size, a.strides, a.offset, newop), (2, 1))
end
function LinearAlgebra.transpose(a::StridedView{<:Any,2}) # act recursively, like Base
function LinearAlgebra.transpose(a::StridedView{<:Any, 2}) # act recursively, like Base
S = Base.promote_op(a.op, eltype(a))
newop = _transpose(a.op)
T = Base.promote_op(newop, S)
Expand All @@ -212,15 +220,15 @@ end
# Creating or transforming StridedView by slicing
#-------------------------------------------------
# we cannot use Base.view, as this also accepts indices that might not preserve stridedness
sview(a::StridedView{<:Any,N}, I::Vararg{SliceIndex,N}) where {N} = getindex(a, I...)
sview(a::StridedView{<:Any, N}, I::Vararg{SliceIndex, N}) where {N} = getindex(a, I...)
sview(a::StridedView, I::SliceIndex) = getindex(sreshape(a, (length(a),)), I)

# for StridedView and index arguments which preserve stridedness, we do replace Base.view
# with sview
Base.view(a::StridedView{<:Any,N}, I::Vararg{SliceIndex,N}) where {N} = getindex(a, I...)
Base.view(a::StridedView{<:Any, N}, I::Vararg{SliceIndex, N}) where {N} = getindex(a, I...)

# `sview` can be used as a constructor when acting on `AbstractArray` objects
@inline function sview(a::AbstractArray{<:Any,N}, I::Vararg{SliceIndex,N}) where {N}
@inline function sview(a::AbstractArray{<:Any, N}, I::Vararg{SliceIndex, N}) where {N}
return getindex(StridedView(a), I...)
end
@inline function sview(a::AbstractArray, I::SliceIndex)
Expand All @@ -230,7 +238,7 @@ end
# Creating or transforming StridedView by reshaping
#---------------------------------------------------
# An error struct for non-strided reshapes
struct ReshapeException{N₁,N₂} <: Exception
struct ReshapeException{N₁, N₂} <: Exception
newsize::Dims{N₁}
oldsize::Dims{N₂}
strides::Dims{N₂}
Expand Down Expand Up @@ -265,7 +273,7 @@ end

# Other methods: `similar`, `copy`
#----------------------------------
function Base.similar(a::StridedView, ::Type{T}, dims::NTuple{N,Int}) where {N,T}
function Base.similar(a::StridedView, ::Type{T}, dims::NTuple{N, Int}) where {N, T}
return StridedView(similar(a.parent, T, dims))
end
Base.copy(a::StridedView) = copyto!(similar(a), a)
Expand All @@ -275,7 +283,7 @@ Base.copy(a::StridedView) = copyto!(similar(a), a)
function Base.unsafe_convert(::Type{Ptr{T}}, a::StridedView{T}) where {T}
return convert(Ptr{T}, pointer(a.parent, a.offset + 1))
end
function Base.elsize(::Type{<:StridedView{T,N,A}}) where {T,N,A}
function Base.elsize(::Type{<:StridedView{T, N, A}}) where {T, N, A}
return Base.elsize(A)
end
Base.dataids(a::StridedView) = Base.dataids(a.parent)
6 changes: 3 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ Random.seed!(1234)
end
end
@test reshape(B8, (1, 1, 1)) == reshape(A8, (1, 1, 1)) ==
StridedView(reshape(A8, (1, 1, 1))) == sreshape(A8, (1, 1, 1))
StridedView(reshape(A8, (1, 1, 1))) == sreshape(A8, (1, 1, 1))
@test reshape(B8, ()) == reshape(A8, ())
end

Expand Down Expand Up @@ -245,7 +245,7 @@ end
@test view(B, :, 1:5, 3, 1:5) === sview(B, :, 1:5, 3, 1:5) === B[:, 1:5, 3, 1:5]
@test view(B, :, 1:5, 3, 1:5) == StridedView(view(A, :, 1:5, 3, 1:5))
@test pointer(view(B, :, 1:5, 3, 1:5)) ==
pointer(StridedView(view(A, :, 1:5, 3, 1:5)))
pointer(StridedView(view(A, :, 1:5, 3, 1:5)))
@test StridedViews.offset(view(B, :, 1:5, 3, 1:5)) == 2 * stride(B, 3)
end
end
Expand Down Expand Up @@ -291,5 +291,5 @@ Aqua.test_all(StridedViews)

if isempty(VERSION.prerelease)
using JET
JET.test_package(StridedViews; target_modules=(StridedViews,))
JET.test_package(StridedViews; target_modules = (StridedViews,))
end
Loading