Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ Printf = "1"
Random = "1"
SafeTestsets = "0.1"
ScopedValues = "1.3.0"
Strided = "2"
Strided = "2.3.3"
TensorKitSectors = "0.3.5"
TensorOperations = "5.1"
Test = "1"
Expand Down
2 changes: 1 addition & 1 deletion ext/TensorKitCUDAExt/TensorKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using TensorKit.Factorizations
using TensorKit.Strided
using TensorKit.Factorizations: AbstractAlgorithm
using TensorKit: SectorDict, tensormaptype, scalar, similarstoragetype, AdjointTensorMap, scalartype, project_symmetric_and_check
import TensorKit: randisometry, rand, randn
import TensorKit: randisometry, rand, randn, _copyto!, _add_general_kernel_nonthreaded!, blocktype

using TensorKit: MatrixAlgebraKit

Expand Down
37 changes: 23 additions & 14 deletions ext/TensorKitCUDAExt/cutensormap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ function TensorKit.project_symmetric_and_check(::Type{T}, ::Type{A}, data::Abstr
return TensorKit.TensorMapWithStorage{T, A}(A(h_t.data), V)
end

function TensorKit.blocktype(::Type{<:CuTensorMap{T, S}}) where {T, S}
return CuMatrix{T, CUDA.DeviceMemory}
end

for (fname, felt) in ((:zeros, :zero), (:ones, :one))
@eval begin
function CUDA.$fname(
Expand Down Expand Up @@ -101,18 +105,6 @@ function TensorKit.scalar(t::CuTensorMap{T, S, 0, 0}) where {T, S}
return isempty(inds) ? zero(scalartype(t)) : @allowscalar @inbounds t.data[only(inds)]
end

function Base.convert(
TT::Type{CuTensorMap{T, S, N₁, N₂}},
t::AbstractTensorMap{<:Any, S, N₁, N₂}
) where {T, S, N₁, N₂}
if typeof(t) === TT
return t
else
tnew = TT(undef, space(t))
return copy!(tnew, t)
end
end

function LinearAlgebra.isposdef(t::CuTensorMap)
domain(t) == codomain(t) ||
throw(SpaceMismatch("`isposdef` requires domain and codomain to be the same"))
Expand All @@ -138,10 +130,9 @@ function Base.promote_rule(
return CuTensorMap{T, S, N₁, N₂}
end

TensorKit.promote_storage_rule(::Type{CuArray{T, N}}, ::Type{<:CuArray{T, N}}) where {T, N} =
TensorKit.promote_storage_rule(::Type{<:CuArray{T, N}}, ::Type{<:CuArray{T, N}}) where {T, N} =
CuArray{T, N, CUDA.default_memory}


# CuTensorMap exponentation:
function TensorKit.exp!(t::CuTensorMap)
domain(t) == codomain(t) ||
Expand All @@ -168,3 +159,21 @@ for f in (:sqrt, :log, :asin, :acos, :acosh, :atanh, :acoth)
return tf
end
end

function TensorKit._add_general_kernel_nonthreaded!(
tdst::CuTensorMap, tsrc::CuTensorMap, p, transformer::TensorKit.GenericTreeTransformer, α, β, backend...
)
# preallocate buffers
buffers = TensorKit.allocate_buffers(tdst, tsrc, transformer)

for subtransformer in transformer.data
# Special case without intermediate buffers whenever there is only a single block
if length(subtransformer[1]) == 1
TensorKit._add_transform_single!(tdst, tsrc, p, subtransformer, α, β, backend...)
else
cu_subtransformer = tuple(CUDA.adapt(CuArray, subtransformer[1]), subtransformer[2:end]...)
TensorKit._add_transform_multi!(tdst, tsrc, p, cu_subtransformer, buffers, α, β, backend...)
end
end
return nothing
end
8 changes: 5 additions & 3 deletions src/auxiliary/auxiliary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,12 @@ function _interleave(a::NTuple{N}, b::NTuple{N}) where {N}
return (a[1], b[1], _interleave(tail(a), tail(b))...)
end

_copyto!(A, B) = copyto!(A, B)

# Low-overhead implementation of `copyto!` for specific case of `stride(B, 1) < stride(B, 2)`
# used in indexmanipulations: avoids the overhead of Strided.jl
function _copyto!(A::StridedView{<:Any, 1}, B::StridedView{<:Any, 2})
length(A) == length(B) || throw(DimensionMismatch())
# for CPU-hosted Arrays # used in indexmanipulations: avoids the overhead of Strided.jl
function _copyto!(A::StridedView{TA, 1, AA}, B::StridedView{TB, 2, BB}) where {TA <: Number, TB <: Number, AA <: DenseArray{TA}, BB <: DenseArray{TB}}
length(A) == length(B) || throw(DimensionMismatch(lazy"length of A ($(length(A))) does not match length of B ($(length(B))"))

Adata = parent(A)
Astr = stride(A, 1)
Expand Down
13 changes: 8 additions & 5 deletions src/tensors/abstracttensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,11 @@ storagetype(t) = storagetype(typeof(t))
function storagetype(::Type{T}) where {T <: AbstractTensorMap}
if T isa Union
# attempt to be slightly more specific by promoting unions
Ma = storagetype(T.a)
Mb = storagetype(T.b)
return promote_storagetype(Ma, Mb)
return promote_storagetype(T.a, T.b)
elseif eltype(T) isa Union
# attempt to be slightly more specific by promoting unions
TU = eltype(T)
return promote_storagetype(TU.a, TU.b)
else
# fallback definition by using scalartype
return similarstoragetype(scalartype(T))
Expand Down Expand Up @@ -103,8 +105,9 @@ similarstoragetype(X::Type, ::Type{T}) where {T <: Number} =

# implement on tensors
similarstoragetype(::Type{TT}) where {TT <: AbstractTensorMap} = similarstoragetype(storagetype(TT))
similarstoragetype(::Type{TT}, ::Type{T}) where {TT <: AbstractTensorMap, T <: Number} =
similarstoragetype(storagetype(TT), T)
function similarstoragetype(::Type{TT}, ::Type{T}) where {TT <: AbstractTensorMap, T <: Number}
return similarstoragetype(storagetype(TT), T)
end

# implement on arrays
similarstoragetype(::Type{A}) where {A <: DenseVector{<:Number}} = A
Expand Down
9 changes: 6 additions & 3 deletions src/tensors/braidingtensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,15 @@ end
has_shared_permute(t::BraidingTensor, ::Index2Tuple) = false
function add_transform!(
tdst::AbstractTensorMap,
tsrc::BraidingTensor, (p₁, p₂)::Index2Tuple,
tsrc::BraidingTensor{T, S},
(p₁, p₂)::Index2Tuple,
fusiontreetransform,
α::Number, β::Number, backend::AbstractBackend...
)
) where {T, S}
tsrc_map = similar(tdst, storagetype(tdst), space(tsrc))
copy!(tsrc_map, tsrc)
return add_transform!(
tdst, TensorMap(tsrc), (p₁, p₂), fusiontreetransform, α, β,
tdst, tsrc_map, (p₁, p₂), fusiontreetransform, α, β,
backend...
)
end
Expand Down
2 changes: 1 addition & 1 deletion src/tensors/treetransformers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ function AbelianTreeTransformer(transform, p, Vdst, Vsrc)
end

const _GenericTransformerData{T, N} = Tuple{
Matrix{T},
DenseMatrix{T},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this change makes the types below abstractly typed, do we need this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, in order to allow device-side matrices to get passed in. Otherwise you get attempts to multiply CuMatrix * Matrix outside of constructors

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, but in that case we would really have to make that an additional type parameter in the GenericTreeTransformer struct -- these were introduced to hyper specialize and get maximal efficiency, so I don't think we can eat a type-instability here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, it would have been helpful to have had a comment or anything that this was why they were there

Tuple{NTuple{N, Int}, Vector{Tuple{NTuple{N, Int}, Int}}},
Tuple{NTuple{N, Int}, Vector{Tuple{NTuple{N, Int}, Int}}},
}
Expand Down
8 changes: 4 additions & 4 deletions test/cuda/tensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ for V in spacelist
next = @constinferred Nothing iterate(bs, state)
b2 = @constinferred block(t, first(blocksectors(t)))
@test b1 == b2
@test_broken eltype(bs) === Pair{typeof(c), typeof(b1)}
@test_broken typeof(b1) === TensorKit.blocktype(t)
@test eltype(bs) === Pair{typeof(c), typeof(b1)}
@test typeof(b1) === TensorKit.blocktype(t)
@test typeof(c) === sectortype(t)
end
end
Expand Down Expand Up @@ -162,8 +162,8 @@ for V in spacelist
next = @constinferred Nothing iterate(bs, state)
b2 = @constinferred block(t', first(blocksectors(t')))
@test b1 == b2
@test_broken eltype(bs) === Pair{typeof(c), typeof(b1)}
@test_broken typeof(b1) === TensorKit.blocktype(t')
@test eltype(bs) === Pair{typeof(c), typeof(b1)}
@test typeof(b1) === TensorKit.blocktype(t')
@test typeof(c) === sectortype(t)
# linear algebra
@test isa(@constinferred(norm(t)), real(T))
Expand Down
Loading