Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ When releasing a new version, move the "Unreleased" changes to a new version sec

### Added

- A more robust promotion system for `storagetype`s to better handle working with unions and other abstract tensor map types ([#370](https://github.com/QuantumKitHub/TensorKit.jl/pull/370)).

### Changed

Expand Down
91 changes: 89 additions & 2 deletions src/tensors/abstracttensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,30 @@ function InnerProductStyle(::Type{TT}) where {TT <: AbstractTensorMap}
return InnerProductStyle(spacetype(TT))
end

# storage types and promotion system
# ----------------------------------
@doc """
storagetype(t::AbstractTensorMap) -> Type{A<:AbstractVector}
storagetype(T::Type{<:AbstractTensorMap}) -> Type{A<:AbstractVector}

Return the type of vector that stores the data of a tensor.
If this is not overloaded for a given tensor type, the default value of `storagetype(scalartype(t))` is returned.

See also [`similarstoragetype`](@ref).
""" storagetype
storagetype(t) = storagetype(typeof(t))
function storagetype(::Type{T}) where {T <: AbstractTensorMap}
if T isa Union
# attempt to be slightly more specific by promoting unions
Ma = storagetype(T.a)
Mb = storagetype(T.b)
return promote_storagetype(Ma, Mb)
else
# fallback definition by using scalartype
return similarstoragetype(scalartype(T))
end
end
storagetype(T::Type) = throw(MethodError(storagetype, T))

# storage type determination and promotion - hooks for specializing
# the default implementation tries to leverarge inference and `similar`
Expand All @@ -69,6 +87,8 @@ appropriate storage types. Additionally this registers the default storage type
used in constructor-like calls, and therefore will return the exact same type for a `DenseVector`
input. The latter is used in `similar`-like calls, and therefore will return the type of calling
`similar` on the given `DenseVector`, which need not coincide with the original type.

See also [`promote_storagetype`](@ref).
""" similarstoragetype

# implement in type domain
Expand Down Expand Up @@ -102,6 +122,74 @@ similarstoragetype(::Type{D}, ::Type{T}) where {D <: AbstractDict{<:Sector, <:Ab
# default storage type for numbers
similarstoragetype(::Type{T}) where {T <: Number} = Vector{T}

@doc """
promote_storagetype([T], A, B, C...)
promote_storagetype([T], TA, TB, TC...)

Determine an appropriate storage type for the combination of tensors `A` and `B`, or tensors of type `TA` and `TB`.
Optionally, a scalartype `T` for the destination can be supplied that might differ from the inputs.
""" promote_storagetype

@inline promote_storagetype(A::AbstractTensorMap, B::AbstractTensorMap, Cs::AbstractTensorMap...) =
promote_storagetype(storagetype(A), storagetype(B), map(storagetype, Cs)...)
@inline promote_storagetype(::Type{T}, A::AbstractTensorMap, B::AbstractTensorMap, Cs::AbstractTensorMap...) where {T <: Number} =
promote_storagetype(similarstoragetype(A, T), similarstoragetype(B, T), map(Base.Fix2(similarstoragetype, T), Cs)...)

@inline function promote_storagetype(
::Type{A}, ::Type{B}, Cs::Type{<:AbstractTensorMap}...
) where {A <: AbstractTensorMap, B <: AbstractTensorMap}
return promote_storagetype(storagetype(A), storagetype(B), map(storagetype, Cs)...)
end
@inline function promote_storagetype(
::Type{T}, ::Type{A}, ::Type{B}, Cs::Type{<:AbstractTensorMap}...
) where {T <: Number, A <: AbstractTensorMap, B <: AbstractTensorMap}
return promote_storagetype(similarstoragetype(A, T), similarstoragetype(B, T), map(Base.Fix2(similarstoragetype, T), Cs)...)
end

# promotion system in the same spirit as base/promotion.jl
promote_storagetype(::Type{Base.Bottom}, ::Type{Base.Bottom}) = Base.Bottom
promote_storagetype(::Type{T}, ::Type{T}) where {T} = T
promote_storagetype(::Type{T}, ::Type{Base.Bottom}) where {T} = T
promote_storagetype(::Type{Base.Bottom}, ::Type{T}) where {T} = T

function promote_storagetype(::Type{T}, ::Type{S}) where {T, S}
@inline
# Try promote_storage_rule in both orders. Typically only one is defined,
# and there is a fallback returning Bottom below, so the common case is
# promote_storagetype(T, S) =>
# promote_storage_result(T, S, result, Bottom) =>
# typejoin(result, Bottom) => result
return promote_storage_result(T, S, promote_storage_rule(T, S), promote_storage_rule(S, T))
end

@inline promote_storagetype(T, S, U) = promote_storagetype(promote_storagetype(T, S), U)
@inline promote_storagetype(T, S, U, V...) = promote_storagetype(promote_storagetype(T, S), U, V...)

@doc """
promote_storage_rule(type1, type2)

Specifies what type should be used by [`promote_storagetype`](@ref) when given values of types `type1` and
`type2`. This function should not be called directly, but should have definitions added to
it for new types as appropriate.
""" promote_storage_rule

promote_storage_rule(::Type, ::Type) = Base.Bottom
# Define some methods to avoid needing to enumerate unrelated possibilities when presented
# with Type{<:T}, and return a value in general accordance with the result given by promote_type
promote_storage_rule(::Type{Base.Bottom}, slurp...) = Base.Bottom
promote_storage_rule(::Type{Base.Bottom}, ::Type{Base.Bottom}, slurp...) = Base.Bottom # not strictly necessary, since the next method would match unambiguously anyways
promote_storage_rule(::Type{Base.Bottom}, ::Type{T}, slurp...) where {T} = T
promote_storage_rule(::Type{T}, ::Type{Base.Bottom}, slurp...) where {T} = T

promote_storage_result(::Type, ::Type, ::Type{T}, ::Type{S}) where {T, S} = (@inline; promote_storagetype(T, S))
# If no promote_storage_rule is defined, both directions give Bottom => error
promote_storage_result(T::Type, S::Type, ::Type{Base.Bottom}, ::Type{Base.Bottom}) =
throw(ArgumentError("No promotion rule defined for storagetype `$T` and `$S`"))

# promotion rules for common vector types
promote_storage_rule(::Type{T}, ::Type{S}) where {T <: DenseVector, S <: DenseVector} =
T === S ? T : throw(ArgumentError("No promotion rule defined for storagetype `$T` and `$S`"))

# tensor characteristics: space and index information
#-----------------------------------------------------
"""
Expand Down Expand Up @@ -224,8 +312,7 @@ end
# tensor characteristics: work on instances and pass to type
#------------------------------------------------------------
InnerProductStyle(t::AbstractTensorMap) = InnerProductStyle(typeof(t))
storagetype(t) = storagetype(typeof(t))
storagetype(T::Type) = throw(MethodError(storagetype, T))

blocktype(t::AbstractTensorMap) = blocktype(typeof(t))

numout(t::AbstractTensorMap) = numout(typeof(t))
Expand Down
14 changes: 11 additions & 3 deletions src/tensors/braidingtensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,17 @@ end

space(b::BraidingTensor) = b.adjoint ? b.V1 ⊗ b.V2 ← b.V2 ⊗ b.V1 : b.V2 ⊗ b.V1 ← b.V1 ⊗ b.V2

# TODO: this will probably give issues with GPUs, so we should try to avoid
# calling this method alltogether
storagetype(::Type{BraidingTensor{T, S}}) where {T, S} = Vector{T}
# specializations to ignore the storagetype of BraidingTensor
promote_storagetype(::Type{A}, ::Type{B}) where {A <: BraidingTensor, B <: AbstractTensorMap} = storagetype(B)
promote_storagetype(::Type{A}, ::Type{B}) where {A <: AbstractTensorMap, B <: BraidingTensor} = storagetype(A)
promote_storagetype(::Type{A}, ::Type{B}) where {A <: BraidingTensor, B <: BraidingTensor} = storagetype(A)

promote_storagetype(::Type{T}, ::Type{A}, ::Type{B}) where {T <: Number, A <: BraidingTensor, B <: AbstractTensorMap} =
similarstoragetype(B, T)
promote_storagetype(::Type{T}, ::Type{A}, ::Type{B}) where {T <: Number, A <: AbstractTensorMap, B <: BraidingTensor} =
similarstoragetype(A, T)
promote_storagetype(::Type{T}, ::Type{A}, ::Type{B}) where {T <: Number, A <: BraidingTensor, B <: BraidingTensor} =
similarstoragetype(A, T)

function Base.getindex(b::BraidingTensor)
sectortype(b) === Trivial || throw(SectorMismatch())
Expand Down
10 changes: 4 additions & 6 deletions src/tensors/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,8 @@ function TO.tensorcontract_type(
::Index2Tuple{1, 1}
)
S = check_spacetype(A, B)
TC′ = promote_permute(TC, sectortype(S))
M = promote_storagetype(similarstoragetype(A, TC′), similarstoragetype(B, TC′))
return DiagonalTensorMap{TC, S, M}
M = promote_storagetype(promote_permute(TC, sectortype(S)), A, B)
return DiagonalTensorMap{scalartype(M), S, M}
end

function TO.tensoralloc(
Expand Down Expand Up @@ -303,9 +302,8 @@ end

function compose_dest(A::DiagonalTensorMap, B::DiagonalTensorMap)
S = check_spacetype(A, B)
TC = TO.promote_contract(scalartype(A), scalartype(B), One)
M = promote_storagetype(similarstoragetype(A, TC), similarstoragetype(B, TC))
TTC = DiagonalTensorMap{TC, S, M}
M = promote_storagetype(TO.promote_contract(scalartype(A), scalartype(B), One), A, B)
TTC = DiagonalTensorMap{scalartype(M), S, M}
structure = codomain(A) ← domain(B)
return TO.tensoralloc(TTC, structure, Val(false))
end
Expand Down
3 changes: 1 addition & 2 deletions src/tensors/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ LinearAlgebra.normalize(t::AbstractTensorMap, p::Real = 2) = scale(t, inv(norm(t
# permutations, which might require complex scalartypes even if the inputs are real.
function compose_dest(A::AbstractTensorMap, B::AbstractTensorMap)
S = check_spacetype(A, B)
TC = TO.promote_contract(scalartype(A), scalartype(B), One)
M = promote_storagetype(similarstoragetype(A, TC), similarstoragetype(B, TC))
M = promote_storagetype(TO.promote_contract(scalartype(A), scalartype(B), One), A, B)
TTC = tensormaptype(S, numout(A), numin(B), M)
structure = codomain(A) ← domain(B)
return TO.tensoralloc(TTC, structure, Val(false))
Expand Down
3 changes: 1 addition & 2 deletions src/tensors/tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,6 @@ end
function Base.promote_rule(
::Type{<:TT₁}, ::Type{<:TT₂}
) where {S, N₁, N₂, TT₁ <: TensorMap{<:Any, S, N₁, N₂}, TT₂ <: TensorMap{<:Any, S, N₁, N₂}}
T = VectorInterface.promote_add(scalartype(TT₁), scalartype(TT₂))
A = promote_storagetype(similarstoragetype(TT₁, T), similarstoragetype(TT₂, T))
A = promote_storagetype(VectorInterface.promote_add(scalartype(TT₁), scalartype(TT₂)), TT₁, TT₂)
return tensormaptype(S, N₁, N₂, A)
end
8 changes: 1 addition & 7 deletions src/tensors/tensoroperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,16 +153,10 @@ function TO.tensorcontract_type(
::Index2Tuple{N₁, N₂}
) where {N₁, N₂}
S = check_spacetype(A, B)
TC′ = promote_permute(TC, sectortype(S))
M = promote_storagetype(similarstoragetype(A, TC′), similarstoragetype(B, TC′))
M = promote_storagetype(promote_permute(TC, sectortype(S)), A, B)
return tensormaptype(S, N₁, N₂, M)
end

# TODO: handle actual promotion rule system
function promote_storagetype(::Type{M₁}, ::Type{M₂}) where {M₁, M₂}
return M₁ === M₂ ? M₁ : throw(ArgumentError("Cannot determine storage type for combining `$M₁` and `$M₂`"))
end

function TO.tensorcontract_structure(
A::AbstractTensorMap, pA::Index2Tuple, conjA::Bool,
B::AbstractTensorMap, pB::Index2Tuple, conjB::Bool,
Expand Down