diff --git a/docs/src/Changelog.md b/docs/src/Changelog.md index 6fd79f4df..2f5e568f2 100644 --- a/docs/src/Changelog.md +++ b/docs/src/Changelog.md @@ -22,6 +22,7 @@ When releasing a new version, move the "Unreleased" changes to a new version sec ### Added +- A more robust promotion system for `storagetype`s to better handle working with unions and other abstract tensor map types ([#370](https://github.com/QuantumKitHub/TensorKit.jl/pull/370)). ### Changed diff --git a/src/tensors/abstracttensor.jl b/src/tensors/abstracttensor.jl index 7f068369a..2d7239460 100644 --- a/src/tensors/abstracttensor.jl +++ b/src/tensors/abstracttensor.jl @@ -38,12 +38,30 @@ function InnerProductStyle(::Type{TT}) where {TT <: AbstractTensorMap} return InnerProductStyle(spacetype(TT)) end +# storage types and promotion system +# ---------------------------------- @doc """ storagetype(t::AbstractTensorMap) -> Type{A<:AbstractVector} storagetype(T::Type{<:AbstractTensorMap}) -> Type{A<:AbstractVector} Return the type of vector that stores the data of a tensor. +If this is not overloaded for a given tensor type, the default value of `storagetype(scalartype(t))` is returned. + +See also [`similarstoragetype`](@ref). """ storagetype +storagetype(t) = storagetype(typeof(t)) +function storagetype(::Type{T}) where {T <: AbstractTensorMap} + if T isa Union + # attempt to be slightly more specific by promoting unions + Ma = storagetype(T.a) + Mb = storagetype(T.b) + return promote_storagetype(Ma, Mb) + else + # fallback definition by using scalartype + return similarstoragetype(scalartype(T)) + end +end +storagetype(T::Type) = throw(MethodError(storagetype, T)) # storage type determination and promotion - hooks for specializing # the default implementation tries to leverarge inference and `similar` @@ -69,6 +87,8 @@ appropriate storage types. Additionally this registers the default storage type used in constructor-like calls, and therefore will return the exact same type for a `DenseVector` input. The latter is used in `similar`-like calls, and therefore will return the type of calling `similar` on the given `DenseVector`, which need not coincide with the original type. + +See also [`promote_storagetype`](@ref). """ similarstoragetype # implement in type domain @@ -102,6 +122,74 @@ similarstoragetype(::Type{D}, ::Type{T}) where {D <: AbstractDict{<:Sector, <:Ab # default storage type for numbers similarstoragetype(::Type{T}) where {T <: Number} = Vector{T} +@doc """ + promote_storagetype([T], A, B, C...) + promote_storagetype([T], TA, TB, TC...) + +Determine an appropriate storage type for the combination of tensors `A` and `B`, or tensors of type `TA` and `TB`. +Optionally, a scalartype `T` for the destination can be supplied that might differ from the inputs. +""" promote_storagetype + +@inline promote_storagetype(A::AbstractTensorMap, B::AbstractTensorMap, Cs::AbstractTensorMap...) = + promote_storagetype(storagetype(A), storagetype(B), map(storagetype, Cs)...) +@inline promote_storagetype(::Type{T}, A::AbstractTensorMap, B::AbstractTensorMap, Cs::AbstractTensorMap...) where {T <: Number} = + promote_storagetype(similarstoragetype(A, T), similarstoragetype(B, T), map(Base.Fix2(similarstoragetype, T), Cs)...) + +@inline function promote_storagetype( + ::Type{A}, ::Type{B}, Cs::Type{<:AbstractTensorMap}... + ) where {A <: AbstractTensorMap, B <: AbstractTensorMap} + return promote_storagetype(storagetype(A), storagetype(B), map(storagetype, Cs)...) +end +@inline function promote_storagetype( + ::Type{T}, ::Type{A}, ::Type{B}, Cs::Type{<:AbstractTensorMap}... + ) where {T <: Number, A <: AbstractTensorMap, B <: AbstractTensorMap} + return promote_storagetype(similarstoragetype(A, T), similarstoragetype(B, T), map(Base.Fix2(similarstoragetype, T), Cs)...) +end + +# promotion system in the same spirit as base/promotion.jl +promote_storagetype(::Type{Base.Bottom}, ::Type{Base.Bottom}) = Base.Bottom +promote_storagetype(::Type{T}, ::Type{T}) where {T} = T +promote_storagetype(::Type{T}, ::Type{Base.Bottom}) where {T} = T +promote_storagetype(::Type{Base.Bottom}, ::Type{T}) where {T} = T + +function promote_storagetype(::Type{T}, ::Type{S}) where {T, S} + @inline + # Try promote_storage_rule in both orders. Typically only one is defined, + # and there is a fallback returning Bottom below, so the common case is + # promote_storagetype(T, S) => + # promote_storage_result(T, S, result, Bottom) => + # typejoin(result, Bottom) => result + return promote_storage_result(T, S, promote_storage_rule(T, S), promote_storage_rule(S, T)) +end + +@inline promote_storagetype(T, S, U) = promote_storagetype(promote_storagetype(T, S), U) +@inline promote_storagetype(T, S, U, V...) = promote_storagetype(promote_storagetype(T, S), U, V...) + +@doc """ + promote_storage_rule(type1, type2) + +Specifies what type should be used by [`promote_storagetype`](@ref) when given values of types `type1` and +`type2`. This function should not be called directly, but should have definitions added to +it for new types as appropriate. +""" promote_storage_rule + +promote_storage_rule(::Type, ::Type) = Base.Bottom +# Define some methods to avoid needing to enumerate unrelated possibilities when presented +# with Type{<:T}, and return a value in general accordance with the result given by promote_type +promote_storage_rule(::Type{Base.Bottom}, slurp...) = Base.Bottom +promote_storage_rule(::Type{Base.Bottom}, ::Type{Base.Bottom}, slurp...) = Base.Bottom # not strictly necessary, since the next method would match unambiguously anyways +promote_storage_rule(::Type{Base.Bottom}, ::Type{T}, slurp...) where {T} = T +promote_storage_rule(::Type{T}, ::Type{Base.Bottom}, slurp...) where {T} = T + +promote_storage_result(::Type, ::Type, ::Type{T}, ::Type{S}) where {T, S} = (@inline; promote_storagetype(T, S)) +# If no promote_storage_rule is defined, both directions give Bottom => error +promote_storage_result(T::Type, S::Type, ::Type{Base.Bottom}, ::Type{Base.Bottom}) = + throw(ArgumentError("No promotion rule defined for storagetype `$T` and `$S`")) + +# promotion rules for common vector types +promote_storage_rule(::Type{T}, ::Type{S}) where {T <: DenseVector, S <: DenseVector} = + T === S ? T : throw(ArgumentError("No promotion rule defined for storagetype `$T` and `$S`")) + # tensor characteristics: space and index information #----------------------------------------------------- """ @@ -224,8 +312,7 @@ end # tensor characteristics: work on instances and pass to type #------------------------------------------------------------ InnerProductStyle(t::AbstractTensorMap) = InnerProductStyle(typeof(t)) -storagetype(t) = storagetype(typeof(t)) -storagetype(T::Type) = throw(MethodError(storagetype, T)) + blocktype(t::AbstractTensorMap) = blocktype(typeof(t)) numout(t::AbstractTensorMap) = numout(typeof(t)) diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index 9c92c67a1..97fb2fbba 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -57,9 +57,17 @@ end space(b::BraidingTensor) = b.adjoint ? b.V1 ⊗ b.V2 ← b.V2 ⊗ b.V1 : b.V2 ⊗ b.V1 ← b.V1 ⊗ b.V2 -# TODO: this will probably give issues with GPUs, so we should try to avoid -# calling this method alltogether -storagetype(::Type{BraidingTensor{T, S}}) where {T, S} = Vector{T} +# specializations to ignore the storagetype of BraidingTensor +promote_storagetype(::Type{A}, ::Type{B}) where {A <: BraidingTensor, B <: AbstractTensorMap} = storagetype(B) +promote_storagetype(::Type{A}, ::Type{B}) where {A <: AbstractTensorMap, B <: BraidingTensor} = storagetype(A) +promote_storagetype(::Type{A}, ::Type{B}) where {A <: BraidingTensor, B <: BraidingTensor} = storagetype(A) + +promote_storagetype(::Type{T}, ::Type{A}, ::Type{B}) where {T <: Number, A <: BraidingTensor, B <: AbstractTensorMap} = + similarstoragetype(B, T) +promote_storagetype(::Type{T}, ::Type{A}, ::Type{B}) where {T <: Number, A <: AbstractTensorMap, B <: BraidingTensor} = + similarstoragetype(A, T) +promote_storagetype(::Type{T}, ::Type{A}, ::Type{B}) where {T <: Number, A <: BraidingTensor, B <: BraidingTensor} = + similarstoragetype(A, T) function Base.getindex(b::BraidingTensor) sectortype(b) === Trivial || throw(SectorMismatch()) diff --git a/src/tensors/diagonal.jl b/src/tensors/diagonal.jl index 8232710c5..7652a2fba 100644 --- a/src/tensors/diagonal.jl +++ b/src/tensors/diagonal.jl @@ -272,9 +272,8 @@ function TO.tensorcontract_type( ::Index2Tuple{1, 1} ) S = check_spacetype(A, B) - TC′ = promote_permute(TC, sectortype(S)) - M = promote_storagetype(similarstoragetype(A, TC′), similarstoragetype(B, TC′)) - return DiagonalTensorMap{TC, S, M} + M = promote_storagetype(promote_permute(TC, sectortype(S)), A, B) + return DiagonalTensorMap{scalartype(M), S, M} end function TO.tensoralloc( @@ -303,9 +302,8 @@ end function compose_dest(A::DiagonalTensorMap, B::DiagonalTensorMap) S = check_spacetype(A, B) - TC = TO.promote_contract(scalartype(A), scalartype(B), One) - M = promote_storagetype(similarstoragetype(A, TC), similarstoragetype(B, TC)) - TTC = DiagonalTensorMap{TC, S, M} + M = promote_storagetype(TO.promote_contract(scalartype(A), scalartype(B), One), A, B) + TTC = DiagonalTensorMap{scalartype(M), S, M} structure = codomain(A) ← domain(B) return TO.tensoralloc(TTC, structure, Val(false)) end diff --git a/src/tensors/linalg.jl b/src/tensors/linalg.jl index 5ac948802..900ee84fe 100644 --- a/src/tensors/linalg.jl +++ b/src/tensors/linalg.jl @@ -23,8 +23,7 @@ LinearAlgebra.normalize(t::AbstractTensorMap, p::Real = 2) = scale(t, inv(norm(t # permutations, which might require complex scalartypes even if the inputs are real. function compose_dest(A::AbstractTensorMap, B::AbstractTensorMap) S = check_spacetype(A, B) - TC = TO.promote_contract(scalartype(A), scalartype(B), One) - M = promote_storagetype(similarstoragetype(A, TC), similarstoragetype(B, TC)) + M = promote_storagetype(TO.promote_contract(scalartype(A), scalartype(B), One), A, B) TTC = tensormaptype(S, numout(A), numin(B), M) structure = codomain(A) ← domain(B) return TO.tensoralloc(TTC, structure, Val(false)) diff --git a/src/tensors/tensor.jl b/src/tensors/tensor.jl index 12281d758..615c67775 100644 --- a/src/tensors/tensor.jl +++ b/src/tensors/tensor.jl @@ -560,7 +560,6 @@ end function Base.promote_rule( ::Type{<:TT₁}, ::Type{<:TT₂} ) where {S, N₁, N₂, TT₁ <: TensorMap{<:Any, S, N₁, N₂}, TT₂ <: TensorMap{<:Any, S, N₁, N₂}} - T = VectorInterface.promote_add(scalartype(TT₁), scalartype(TT₂)) - A = promote_storagetype(similarstoragetype(TT₁, T), similarstoragetype(TT₂, T)) + A = promote_storagetype(VectorInterface.promote_add(scalartype(TT₁), scalartype(TT₂)), TT₁, TT₂) return tensormaptype(S, N₁, N₂, A) end diff --git a/src/tensors/tensoroperations.jl b/src/tensors/tensoroperations.jl index 81894d140..0820fe1af 100644 --- a/src/tensors/tensoroperations.jl +++ b/src/tensors/tensoroperations.jl @@ -153,16 +153,10 @@ function TO.tensorcontract_type( ::Index2Tuple{N₁, N₂} ) where {N₁, N₂} S = check_spacetype(A, B) - TC′ = promote_permute(TC, sectortype(S)) - M = promote_storagetype(similarstoragetype(A, TC′), similarstoragetype(B, TC′)) + M = promote_storagetype(promote_permute(TC, sectortype(S)), A, B) return tensormaptype(S, N₁, N₂, M) end -# TODO: handle actual promotion rule system -function promote_storagetype(::Type{M₁}, ::Type{M₂}) where {M₁, M₂} - return M₁ === M₂ ? M₁ : throw(ArgumentError("Cannot determine storage type for combining `$M₁` and `$M₂`")) -end - function TO.tensorcontract_structure( A::AbstractTensorMap, pA::Index2Tuple, conjA::Bool, B::AbstractTensorMap, pB::Index2Tuple, conjB::Bool,