Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
545835d
add `arrayify` for adjoint tensor
lkdvos Jan 17, 2026
e3255e4
add vectorinterface rules
lkdvos Jan 17, 2026
5e867a5
add tensoroperations rules
lkdvos Jan 17, 2026
6c0b08e
add indexmanipulations rules
lkdvos Jan 18, 2026
4be4e8f
add mul rules
lkdvos Jan 20, 2026
822ef1c
temporarily disable Fibonacci (complex) spaces
lkdvos Jan 20, 2026
55eaffa
bump TupleTools compat
lkdvos Jan 20, 2026
2b6f2c2
add twist! rule
lkdvos Jan 21, 2026
cbbfd33
add flip rule
lkdvos Jan 21, 2026
d406f3d
vector spaces arent vector spaces!
lkdvos Jan 21, 2026
351272e
insert and remove units
lkdvos Jan 21, 2026
85d0423
mark a bunch of things as non-differentiable
lkdvos Jan 21, 2026
7e0860c
rewrite rule for `tensortrace!` in terms of `trace_permute!`
lkdvos Jan 21, 2026
e54dc46
dont need rules for `tensoradd!`
lkdvos Jan 21, 2026
50c3f49
add planaroperations
lkdvos Jan 22, 2026
7d446e7
rewrite rule `tensorcontract` in terms of `blas_contract!`
lkdvos Jan 22, 2026
a60c768
add rule `tr`
lkdvos Jan 22, 2026
40c04ca
give up on planartrace for now
lkdvos Jan 22, 2026
4975dbd
add rule `inv`
lkdvos Jan 22, 2026
e580b1b
is_primitive in namespace
lkdvos Jan 22, 2026
fdb16e4
share more code
lkdvos Jan 22, 2026
00cb504
split AD tests to reduce CI pressure
lkdvos Jan 22, 2026
6fefa70
add missing imports
lkdvos Jan 22, 2026
58d0cce
remove the use of the internal `Mooncake._rdata`
lkdvos Jan 26, 2026
defee0e
add comments about `NoRData()`
lkdvos Jan 26, 2026
4a559a2
add TODO
lkdvos Jan 26, 2026
1131568
correctly implement `_needs_tangent`
lkdvos Jan 29, 2026
3e2ea41
update to Mooncake 0.5
lkdvos Jan 29, 2026
f56e8ec
add TensorMap tangent type
lkdvos Jan 29, 2026
e9f9867
fix stupid tolerance mistake
lkdvos Jan 29, 2026
085fd43
enable complex tests
lkdvos Jan 29, 2026
51cf502
add tangent type test
lkdvos Jan 29, 2026
de106c1
correct arrayify
lkdvos Jan 29, 2026
a20ed96
fix indexmanipulations
lkdvos Jan 29, 2026
23956af
bump versions
lkdvos Jan 29, 2026
90204c1
deal with more complex sector shenanigans
lkdvos Jan 29, 2026
b610b19
properly accumulate
lkdvos Jan 29, 2026
68873bb
nicer _needs_tangent
lkdvos Jan 30, 2026
748b5fa
remove source
lkdvos Jan 30, 2026
9cbef37
fix TensorOperations
lkdvos Jan 30, 2026
eab2da3
remove duplicate method
lkdvos Jan 30, 2026
4501efc
fix arg order
lkdvos Jan 30, 2026
6ac538f
add missing ChainRules import
lkdvos Jan 30, 2026
a4dafb0
add JET compat
lkdvos Jan 31, 2026
916d39d
some cleanup
lkdvos Feb 1, 2026
415e0ca
more handling of scalartypes
lkdvos Feb 1, 2026
90f8217
more testing
lkdvos Feb 1, 2026
5eb7824
add specialization for `MAK.zero!`
lkdvos Feb 2, 2026
aa64c44
add tests on factorizations
lkdvos Feb 2, 2026
006d4c3
add DiagonalTensorMap tangent type
lkdvos Feb 2, 2026
8c2e540
specialize SVD pullback implementations
lkdvos Feb 2, 2026
aab9d1a
careful about projections
lkdvos Feb 3, 2026
85eab0c
disable mooncake tests on Apple
lkdvos Feb 3, 2026
2737794
add missing diagonal constructor
lkdvos Feb 3, 2026
98ff451
update some tests
lkdvos Feb 4, 2026
b40f36b
fix arg order
lkdvos Feb 4, 2026
0368042
qreduce test weight
lkdvos Feb 11, 2026
e118e30
fix vectorinterface implementations
lkdvos Feb 11, 2026
094b80d
clean up and comment tangent code
lkdvos Feb 13, 2026
9715de5
fix missing primal_to_tangent_internal!! implementation
lkdvos Feb 14, 2026
1b07c78
make better use of intermediates in alpha pullbacks
lkdvos Feb 14, 2026
3e0e021
clean up getfield rules
lkdvos Feb 14, 2026
9b09f71
clean up constructor rules
lkdvos Feb 14, 2026
afb69bc
unmark planar_trace as primitive
lkdvos Feb 14, 2026
9c49196
guard against non-differentiable number types
lkdvos Feb 14, 2026
8a4ebe0
correct small mistake
lkdvos Feb 18, 2026
35d44b7
remove duplicate definition
lkdvos Feb 18, 2026
b84e84f
rdatas have to be `NoRData`
lkdvos Feb 18, 2026
44c2fa3
field symbols should be better
lkdvos Feb 18, 2026
20faba4
handle `NoTangent` correctly
lkdvos Feb 18, 2026
0ad41f2
bugfix
lkdvos Feb 18, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ jobs:
- symmetries
- tensors
- other
- autodiff
- mooncake
- chainrules
os:
- ubuntu-latest
- macOS-latest
Expand All @@ -55,7 +56,8 @@ jobs:
- symmetries
- tensors
- other
- autodiff
- mooncake
- chainrules
os:
- ubuntu-latest
- macOS-latest
Expand Down
14 changes: 9 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"

[extensions]
TensorKitAdaptExt = "Adapt"
Expand All @@ -34,6 +34,7 @@ TensorKitMooncakeExt = "Mooncake"

[compat]
Adapt = "4"
AllocCheck = "0.2.3"
Aqua = "0.6, 0.7, 0.8"
ArgParse = "1.2.0"
CUDA = "5.9"
Expand All @@ -42,10 +43,11 @@ ChainRulesTestUtils = "1"
Combinatorics = "1"
FiniteDifferences = "0.12"
GPUArrays = "11.3.1"
JET = "0.9, 0.10, 0.11"
LRUCache = "1.0.2"
LinearAlgebra = "1"
MatrixAlgebraKit = "0.6.3"
Mooncake = "0.4.183"
MatrixAlgebraKit = "0.6.4"
Mooncake = "0.5"
OhMyThreads = "0.8.0"
Printf = "1"
Random = "1"
Expand All @@ -56,14 +58,15 @@ TensorKitSectors = "0.3.5"
TensorOperations = "5.1"
Test = "1"
TestExtras = "0.2,0.3"
TupleTools = "1.1"
TupleTools = "1.5"
VectorInterface = "0.4.8, 0.5"
Zygote = "0.7"
cuTENSOR = "2"
julia = "1.10"

[extras]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand All @@ -72,6 +75,7 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Expand All @@ -82,4 +86,4 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"

[targets]
test = ["ArgParse", "Adapt", "Aqua", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote", "Mooncake"]
test = ["ArgParse", "Adapt", "Aqua", "AllocCheck", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote", "Mooncake", "JET"]
14 changes: 11 additions & 3 deletions ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
module TensorKitMooncakeExt

using Mooncake
using Mooncake: @zero_derivative, DefaultCtx, ReverseMode, NoRData, CoDual, arrayify, primal
using Mooncake: @zero_derivative, @is_primitive,
DefaultCtx, MinimalCtx, ReverseMode, NoFData, NoRData, NoTangent,
CoDual, Dual, arrayify, primal, tangent, zero_fcodual
using TensorKit
import TensorKit as TK
using VectorInterface
using TensorOperations: TensorOperations, IndexTuple, Index2Tuple, linearize
import TensorOperations as TO
using VectorInterface: One, Zero
using MatrixAlgebraKit
using TupleTools

using Random: AbstractRNG

include("utility.jl")
include("tangent.jl")
include("linalg.jl")
include("indexmanipulations.jl")
include("vectorinterface.jl")
include("tensoroperations.jl")
include("planaroperations.jl")
include("factorizations.jl")

end
63 changes: 63 additions & 0 deletions ext/TensorKitMooncakeExt/factorizations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
for f in (:svd_compact, :svd_full)
f_pullback = Symbol(f, :_pullback)
@eval begin
@is_primitive DefaultCtx ReverseMode Tuple{typeof($f), AbstractTensorMap, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractTensorMap}, alg_dalg::CoDual)
A, dA = arrayify(A_dA)
alg = primal(alg_dalg)

USVᴴ = $f(A, primal(alg_dalg))
USVᴴ_dUSVᴴ = Mooncake.zero_fcodual(USVᴴ)
dUSVᴴ = last.(arrayify.(USVᴴ, tangent(USVᴴ_dUSVᴴ)))

function $f_pullback(::NoRData)
MatrixAlgebraKit.svd_pullback!(dA, A, USVᴴ, dUSVᴴ)
MatrixAlgebraKit.zero!.(dUSVᴴ)
return ntuple(Returns(NoRData()), 3)
end

return USVᴴ_dUSVᴴ, $f_pullback
end
end

# mutating version is not guaranteed to actually mutate
# so we can simply use the non-mutating version instead and avoid having to worry about
# storing copies and restoring state
f! = Symbol(f, :!)
f!_pullback = Symbol(f!, :_pullback)
@eval begin
@is_primitive DefaultCtx ReverseMode Tuple{typeof($f!), AbstractTensorMap, Any, MatrixAlgebraKit.AbstractAlgorithm}
Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual{<:AbstractTensorMap}, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual) =
Mooncake.rrule!!(Mooncake.zero_fcodual($f), A_dA, alg_dalg)
end
end

@is_primitive DefaultCtx ReverseMode Tuple{typeof(svd_trunc), AbstractTensorMap, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(
::CoDual{typeof(svd_trunc)},
A_dA::CoDual{<:AbstractTensorMap},
alg_dalg::CoDual{<:MatrixAlgebraKit.TruncatedAlgorithm}
)
A, dA = arrayify(A_dA)
alg = primal(alg_dalg)

USVᴴ = svd_compact(A, alg.alg)
USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
ϵ = MatrixAlgebraKit.truncation_error(diagview(USVᴴ[2]), ind)

USVᴴtrunc_dUSVᴴtrunc = Mooncake.zero_fcodual((USVᴴtrunc..., ϵ))
dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Base.front(tangent(USVᴴtrunc_dUSVᴴtrunc))))

function svd_trunc_pullback((_, _, _, dϵ)::Tuple{NoRData, NoRData, NoRData, Real})
abs(dϵ) ≤ MatrixAlgebraKit.defaulttol(dϵ) ||
@warn "Gradient for `svd_trunc` ignores non-zero tangents for truncation error"
MatrixAlgebraKit.svd_pullback!(dA, A, USVᴴ, dUSVᴴtrunc, ind)
return ntuple(Returns(NoRData()), 3)
end

return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_pullback
end

@is_primitive DefaultCtx ReverseMode Tuple{typeof(svd_trunc!), AbstractTensorMap, Any, MatrixAlgebraKit.AbstractAlgorithm}
Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual{<:AbstractTensorMap}, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual) =
Mooncake.rrule!!(Mooncake.zero_fcodual(svd_trunc), A_dA, alg_dalg)
Loading
Loading