diff --git a/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl b/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl index f5efb98bb..1b2932f97 100644 --- a/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl +++ b/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl @@ -10,7 +10,7 @@ using TensorKit.Factorizations using TensorKit.Strided using TensorKit.Factorizations: AbstractAlgorithm using TensorKit: SectorDict, tensormaptype, scalar, similarstoragetype, AdjointTensorMap, scalartype, project_symmetric_and_check -import TensorKit: randisometry, rand, randn +import TensorKit: randisometry, rand, randn, _copyto!, _add_general_kernel_nonthreaded!, blocktype using TensorKit: MatrixAlgebraKit @@ -18,5 +18,6 @@ using Random include("cutensormap.jl") include("truncation.jl") +include("auxiliary.jl") end diff --git a/ext/TensorKitCUDAExt/auxiliary.jl b/ext/TensorKitCUDAExt/auxiliary.jl new file mode 100644 index 000000000..0b11a962f --- /dev/null +++ b/ext/TensorKitCUDAExt/auxiliary.jl @@ -0,0 +1,28 @@ +function TensorKit._copyto!(A::StridedView{TA, 1, <:CuArray{TA}}, B::StridedView{TB, 2, <:CuArray{TB}}) where {TA, TB} + length(A) == length(B) || throw(DimensionMismatch(lazy"length of A ($(length(A))) does not match length of B ($(length(B))")) + + Adata = parent(A) + Astr = stride(A, 1) + IA = A.offset + + Bdata = parent(B) + Bstr = strides(B) + + IB_1 = B.offset + # build index arrays + IAs = Int[] + IBs = Int[] + @inbounds for _ in axes(B, 2) + IB = IB_1 + for _ in axes(B, 1) + IA += Astr + append!(IAs, IA) + IB += Bstr[1] + append!(IBs, IB) + end + IB_1 += Bstr[2] + end + Adata[IAs] .= Bdata[IBs] + + return A +end diff --git a/ext/TensorKitCUDAExt/cutensormap.jl b/ext/TensorKitCUDAExt/cutensormap.jl index f065c2ec1..37b2e90cb 100644 --- a/ext/TensorKitCUDAExt/cutensormap.jl +++ b/ext/TensorKitCUDAExt/cutensormap.jl @@ -17,6 +17,10 @@ function TensorKit.project_symmetric_and_check(::Type{T}, ::Type{A}, data::Abstr return TensorKit.TensorMapWithStorage{T, A}(A(h_t.data), V) end +function TensorKit.blocktype(::Type{<:CuTensorMap{T, S}}) where {T, S} + return CuMatrix{T, CUDA.DeviceMemory} +end + for (fname, felt) in ((:zeros, :zero), (:ones, :one)) @eval begin function CUDA.$fname( @@ -102,9 +106,21 @@ function TensorKit.scalar(t::CuTensorMap{T, S, 0, 0}) where {T, S} end function Base.convert( - TT::Type{CuTensorMap{T, S, N₁, N₂}}, - t::AbstractTensorMap{<:Any, S, N₁, N₂} - ) where {T, S, N₁, N₂} + TT::Type{TensorMap{T, S, N₁, N₂, A}}, + t::TensorMap{T, S, N₁, N₂, AA} + ) where {T, S, N₁, N₂, A <: CuArray{T}, AA} + if typeof(t) === TT + return t + else + tnew = TT(undef, space(t)) + return copy!(tnew, t) + end +end + +function Base.convert( + TT::Type{TensorMap{T, S, N₁, N₂, A}}, + t::AdjointTensorMap + ) where {T, S, N₁, N₂, A <: CuArray{T}} if typeof(t) === TT return t else @@ -140,6 +156,8 @@ end TensorKit.promote_storage_rule(::Type{CuArray{T, N}}, ::Type{<:CuArray{T, N}}) where {T, N} = CuArray{T, N, CUDA.default_memory} +TensorKit.promote_storage_rule(::Type{<:CuArray{T, N}}, ::Type{CuArray{T, N}}) where {T, N} = + CuArray{T, N, CUDA.default_memory} # CuTensorMap exponentation: @@ -168,3 +186,21 @@ for f in (:sqrt, :log, :asin, :acos, :acosh, :atanh, :acoth) return tf end end + +function TensorKit._add_general_kernel_nonthreaded!( + tdst::CuTensorMap, tsrc::CuTensorMap, p, transformer::TensorKit.GenericTreeTransformer, α, β, backend... + ) + # preallocate buffers + buffers = TensorKit.allocate_buffers(tdst, tsrc, transformer) + + for subtransformer in transformer.data + # Special case without intermediate buffers whenever there is only a single block + if length(subtransformer[1]) == 1 + TensorKit._add_transform_single!(tdst, tsrc, p, subtransformer, α, β, backend...) + else + cu_subtransformer = tuple(CUDA.adapt(CuArray, subtransformer[1]), subtransformer[2:end]...) + TensorKit._add_transform_multi!(tdst, tsrc, p, cu_subtransformer, buffers, α, β, backend...) + end + end + return nothing +end diff --git a/src/auxiliary/auxiliary.jl b/src/auxiliary/auxiliary.jl index a7105cda6..797a55505 100644 --- a/src/auxiliary/auxiliary.jl +++ b/src/auxiliary/auxiliary.jl @@ -60,7 +60,7 @@ end # Low-overhead implementation of `copyto!` for specific case of `stride(B, 1) < stride(B, 2)` # used in indexmanipulations: avoids the overhead of Strided.jl function _copyto!(A::StridedView{<:Any, 1}, B::StridedView{<:Any, 2}) - length(A) == length(B) || throw(DimensionMismatch()) + length(A) == length(B) || throw(DimensionMismatch(lazy"length of A ($(length(A))) does not match length of B ($(length(B))")) Adata = parent(A) Astr = stride(A, 1) diff --git a/src/tensors/abstracttensor.jl b/src/tensors/abstracttensor.jl index 2d7239460..eb8c787e7 100644 --- a/src/tensors/abstracttensor.jl +++ b/src/tensors/abstracttensor.jl @@ -53,9 +53,11 @@ storagetype(t) = storagetype(typeof(t)) function storagetype(::Type{T}) where {T <: AbstractTensorMap} if T isa Union # attempt to be slightly more specific by promoting unions - Ma = storagetype(T.a) - Mb = storagetype(T.b) - return promote_storagetype(Ma, Mb) + return promote_storagetype(T.a, T.b) + elseif storagetype(T) isa Union + # attempt to be slightly more specific by promoting unions + TU = storagetype(T) + return promote_storagetype(TU.a, TU.b) else # fallback definition by using scalartype return similarstoragetype(scalartype(T)) @@ -103,8 +105,9 @@ similarstoragetype(X::Type, ::Type{T}) where {T <: Number} = # implement on tensors similarstoragetype(::Type{TT}) where {TT <: AbstractTensorMap} = similarstoragetype(storagetype(TT)) -similarstoragetype(::Type{TT}, ::Type{T}) where {TT <: AbstractTensorMap, T <: Number} = - similarstoragetype(storagetype(TT), T) +function similarstoragetype(::Type{TT}, ::Type{T}) where {TT <: AbstractTensorMap, T <: Number} + return similarstoragetype(storagetype(TT), T) +end # implement on arrays similarstoragetype(::Type{A}) where {A <: DenseVector{<:Number}} = A diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index 0070bc2d4..9d7a05af5 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -171,12 +171,15 @@ end has_shared_permute(t::BraidingTensor, ::Index2Tuple) = false function add_transform!( tdst::AbstractTensorMap, - tsrc::BraidingTensor, (p₁, p₂)::Index2Tuple, + tsrc::BraidingTensor{T, S}, + (p₁, p₂)::Index2Tuple, fusiontreetransform, α::Number, β::Number, backend::AbstractBackend... - ) + ) where {T, S} + tsrc_map = TensorMapWithStorage{scalartype(tdst), storagetype(tdst)}(undef, (tsrc.V2 ⊗ tsrc.V1) ← (tsrc.V1 ⊗ tsrc.V2)) + copy!(tsrc_map, tsrc) return add_transform!( - tdst, TensorMap(tsrc), (p₁, p₂), fusiontreetransform, α, β, + tdst, tsrc_map, (p₁, p₂), fusiontreetransform, α, β, backend... ) end diff --git a/src/tensors/tensoroperations.jl b/src/tensors/tensoroperations.jl index 0820fe1af..a9074ca43 100644 --- a/src/tensors/tensoroperations.jl +++ b/src/tensors/tensoroperations.jl @@ -419,8 +419,10 @@ end # Scalar implementation #----------------------- function scalar(t::AbstractTensorMap{T, S, 0, 0}) where {T, S} - Bs = collect(blocks(t)) - inds = findall(!iszero ∘ last, Bs) - isempty(inds) && return zero(scalartype(t)) - return only(last(Bs[only(inds)])) + Bs = blocks(t) + B_ends = collect.(map(last, Bs)) + nz_B_ends = [!iszero.(B) for B in B_ends] + valid_Bs = filter(any, B_ends) + isempty(valid_Bs) && return zero(scalartype(t)) + return only(last(first(valid_Bs))) end diff --git a/src/tensors/treetransformers.jl b/src/tensors/treetransformers.jl index 36cd3926d..b1d2008b5 100644 --- a/src/tensors/treetransformers.jl +++ b/src/tensors/treetransformers.jl @@ -46,7 +46,7 @@ function AbelianTreeTransformer(transform, p, Vdst, Vsrc) end const _GenericTransformerData{T, N} = Tuple{ - Matrix{T}, + DenseMatrix{T}, Tuple{NTuple{N, Int}, Vector{Tuple{NTuple{N, Int}, Int}}}, Tuple{NTuple{N, Int}, Vector{Tuple{NTuple{N, Int}, Int}}}, }