From c69b57b20e8d598373851612a84cc80526d0651b Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 18 Feb 2026 11:17:50 -0500 Subject: [PATCH 01/15] small refactor QR pullback add QR gauge projection --- src/pullbacks/qr.jl | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/pullbacks/qr.jl b/src/pullbacks/qr.jl index d92878bd..de9f660a 100644 --- a/src/pullbacks/qr.jl +++ b/src/pullbacks/qr.jl @@ -1,4 +1,11 @@ -function check_qr_cotangents(Q, R, ΔQ, ΔR, minmn::Int, p::Int; gauge_atol::Real = default_pullback_gauge_atol(ΔQ)) +qr_rank(R; rank_atol = default_pullback_rank_atol(R)) = + @something findlast(>=(rank_atol) ∘ abs, diagview(R)) 0 + +function check_qr_cotangents( + Q, R, ΔQ, ΔR, p::Int; + gauge_atol::Real = default_pullback_gauge_atol(ΔQ) + ) + minmn = min(size(Q, 1), size(R, 2)) if minmn > p # case where A is rank-deficient Δgauge = abs(zero(eltype(Q))) if !iszerotangent(ΔQ) @@ -7,11 +14,13 @@ function check_qr_cotangents(Q, R, ΔQ, ΔR, minmn::Int, p::Int; gauge_atol::Rea # columns of ΔQ should be zero for a gauge-invariant # cost function ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2)) - Δgauge = max(Δgauge, norm(ΔQ2, Inf)) + Δgauge_Q = norm(ΔQ2, Inf) + Δgauge = max(Δgauge, Δgauge_Q) end if !iszerotangent(ΔR) ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):size(R, 2)) - Δgauge = max(Δgauge, norm(ΔR22, Inf)) + Δgauge_R = norm(ΔR22, Inf) + Δgauge = max(Δgauge, Δgauge_R) end Δgauge ≤ gauge_atol || @warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" @@ -29,7 +38,7 @@ function check_qr_full_cotangents(Q1, ΔQ2, Q1dΔQ2; gauge_atol::Real = default_ # Q2' * ΔQ2 as a gauge dependent quantity. Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf) Δgauge ≤ gauge_atol || - @warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + @warn "`qr` full cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" return end @@ -60,9 +69,8 @@ function qr_pullback!( Q, R = QR m = size(Q, 1) n = size(R, 2) - minmn = min(m, n) Rd = diagview(R) - p = @something findlast(>=(rank_atol) ∘ abs, Rd) 0 + p = qr_rank(R) ΔQ, ΔR = ΔQR @@ -72,7 +80,7 @@ function qr_pullback!( ΔA1 = view(ΔA, :, 1:p) ΔA2 = view(ΔA, :, (p + 1):n) - check_qr_cotangents(Q, R, ΔQ, ΔR, minmn, p; gauge_atol) + check_qr_cotangents(Q, R, ΔQ, ΔR, p; gauge_atol) ΔQ̃ = zero!(similar(Q, (m, p))) if !iszerotangent(ΔQ) From 734f3901480b302b6722ca8f3cca550399843f5f Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 18 Feb 2026 11:23:16 -0500 Subject: [PATCH 02/15] testsuite reorganisation --- test/testsuite/TestSuite.jl | 29 ++++++++++++------- test/testsuite/{ => decompositions}/eig.jl | 0 test/testsuite/{ => decompositions}/eigh.jl | 0 test/testsuite/{ => decompositions}/lq.jl | 0 .../{ => decompositions}/orthnull.jl | 2 +- test/testsuite/{ => decompositions}/polar.jl | 0 test/testsuite/{ => decompositions}/qr.jl | 0 test/testsuite/{ => decompositions}/schur.jl | 0 test/testsuite/{ => decompositions}/svd.jl | 0 test/testsuite/{ => mooncake}/mooncake.jl | 0 10 files changed, 19 insertions(+), 12 deletions(-) rename test/testsuite/{ => decompositions}/eig.jl (100%) rename test/testsuite/{ => decompositions}/eigh.jl (100%) rename test/testsuite/{ => decompositions}/lq.jl (100%) rename test/testsuite/{ => decompositions}/orthnull.jl (99%) rename test/testsuite/{ => decompositions}/polar.jl (100%) rename test/testsuite/{ => decompositions}/qr.jl (100%) rename test/testsuite/{ => decompositions}/schur.jl (100%) rename test/testsuite/{ => decompositions}/svd.jl (100%) rename test/testsuite/{ => mooncake}/mooncake.jl (100%) diff --git a/test/testsuite/TestSuite.jl b/test/testsuite/TestSuite.jl index 12653096..ab8d8c59 100644 --- a/test/testsuite/TestSuite.jl +++ b/test/testsuite/TestSuite.jl @@ -86,17 +86,24 @@ instantiate_unitary(::Type{<:Diagonal}, A, sz) = Diagonal(fill!(similar(parent(A include("ad_utils.jl") -include("qr.jl") -include("lq.jl") -include("polar.jl") include("projections.jl") -include("schur.jl") -include("eig.jl") -include("eigh.jl") -include("orthnull.jl") -include("svd.jl") -include("mooncake.jl") -include("enzyme.jl") -include("chainrules.jl") + +# Decompositions +# -------------- +include("decompositions/qr.jl") +include("decompositions/lq.jl") +include("decompositions/polar.jl") +include("decompositions/schur.jl") +include("decompositions/eig.jl") +include("decompositions/eigh.jl") +include("decompositions/orthnull.jl") +include("decompositions/svd.jl") + +# Mooncake +# -------- +include("mooncake/mooncake.jl") +include("mooncake/qr.jl") +# include("enzyme.jl") +# include("chainrules.jl") end diff --git a/test/testsuite/eig.jl b/test/testsuite/decompositions/eig.jl similarity index 100% rename from test/testsuite/eig.jl rename to test/testsuite/decompositions/eig.jl diff --git a/test/testsuite/eigh.jl b/test/testsuite/decompositions/eigh.jl similarity index 100% rename from test/testsuite/eigh.jl rename to test/testsuite/decompositions/eigh.jl diff --git a/test/testsuite/lq.jl b/test/testsuite/decompositions/lq.jl similarity index 100% rename from test/testsuite/lq.jl rename to test/testsuite/decompositions/lq.jl diff --git a/test/testsuite/orthnull.jl b/test/testsuite/decompositions/orthnull.jl similarity index 99% rename from test/testsuite/orthnull.jl rename to test/testsuite/decompositions/orthnull.jl index 79349d2a..91a40ae2 100644 --- a/test/testsuite/orthnull.jl +++ b/test/testsuite/decompositions/orthnull.jl @@ -1,7 +1,7 @@ using TestExtras using LinearAlgebra -include("../linearmap.jl") +include("../../linearmap.jl") _left_orth_svd(x; kwargs...) = left_orth(x; alg = :svd, kwargs...) _left_orth_svd!(x, VC; kwargs...) = left_orth!(x, VC; alg = :svd, kwargs...) diff --git a/test/testsuite/polar.jl b/test/testsuite/decompositions/polar.jl similarity index 100% rename from test/testsuite/polar.jl rename to test/testsuite/decompositions/polar.jl diff --git a/test/testsuite/qr.jl b/test/testsuite/decompositions/qr.jl similarity index 100% rename from test/testsuite/qr.jl rename to test/testsuite/decompositions/qr.jl diff --git a/test/testsuite/schur.jl b/test/testsuite/decompositions/schur.jl similarity index 100% rename from test/testsuite/schur.jl rename to test/testsuite/decompositions/schur.jl diff --git a/test/testsuite/svd.jl b/test/testsuite/decompositions/svd.jl similarity index 100% rename from test/testsuite/svd.jl rename to test/testsuite/decompositions/svd.jl diff --git a/test/testsuite/mooncake.jl b/test/testsuite/mooncake/mooncake.jl similarity index 100% rename from test/testsuite/mooncake.jl rename to test/testsuite/mooncake/mooncake.jl From 023cfed154add8e2a5768f0f57632fe574419bb7 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 19 Feb 2026 09:15:44 -0500 Subject: [PATCH 03/15] add QR mooncake tests --- test/testsuite/mooncake/mooncake.jl | 45 +++----------- test/testsuite/mooncake/qr.jl | 94 +++++++++++++++++++++++++++++ 2 files changed, 101 insertions(+), 38 deletions(-) create mode 100644 test/testsuite/mooncake/qr.jl diff --git a/test/testsuite/mooncake/mooncake.jl b/test/testsuite/mooncake/mooncake.jl index 29d65e31..368a7f87 100644 --- a/test/testsuite/mooncake/mooncake.jl +++ b/test/testsuite/mooncake/mooncake.jl @@ -214,6 +214,13 @@ function test_pullbacks_match(f!, f, A, args, Δargs, alg = nothing; rdata = Moo return end +function make_input_scratch!(f!, A, F, alg) + F′ = f!(A, F, alg) + MatrixAlgebraKit.zero!(A) + F === F′ || MatrixAlgebraKit.zero!.(F) + return F′ +end + function test_mooncake(T::Type, sz; kwargs...) summary_str = testargs_summary(T, sz) return @testset "Mooncake AD $summary_str" begin @@ -232,44 +239,6 @@ function test_mooncake(T::Type, sz; kwargs...) end end -function test_mooncake_qr( - T::Type, sz; - atol::Real = 0, rtol::Real = precision(T), - kwargs... - ) - summary_str = testargs_summary(T, sz) - return @testset "QR Mooncake AD rules $summary_str" begin - A = instantiate_matrix(T, sz) - @testset "qr_compact" begin - QR, ΔQR = ad_qr_compact_setup(A) - dQR = make_mooncake_tangent(ΔQR) - Mooncake.TestUtils.test_rule(rng, qr_compact, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dQR, atol, rtol) - test_pullbacks_match(qr_compact!, qr_compact, A, QR, ΔQR) - end - @testset "qr_null" begin - N, ΔN = ad_qr_null_setup(A) - dN = make_mooncake_tangent(copy(ΔN)) - Mooncake.TestUtils.test_rule(rng, qr_null, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dN, atol, rtol) - test_pullbacks_match(qr_null!, qr_null, A, N, ΔN) - end - @testset "qr_full" begin - QR, ΔQR = ad_qr_full_setup(A) - dQR = make_mooncake_tangent(ΔQR) - Mooncake.TestUtils.test_rule(rng, qr_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dQR, atol, rtol) - test_pullbacks_match(qr_full!, qr_full, A, QR, ΔQR) - end - @testset "qr_compact - rank-deficient A" begin - m, n = size(A) - r = min(m, n) - 5 - Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) - QR, ΔQR = ad_qr_rank_deficient_compact_setup(Ard) - dQR = make_mooncake_tangent(ΔQR) - Mooncake.TestUtils.test_rule(rng, qr_compact, Ard; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dQR, atol, rtol) - test_pullbacks_match(qr_compact!, qr_compact, Ard, QR, ΔQR) - end - end -end - function test_mooncake_lq( T::Type, sz; atol::Real = 0, rtol::Real = precision(T), diff --git a/test/testsuite/mooncake/qr.jl b/test/testsuite/mooncake/qr.jl new file mode 100644 index 00000000..96932c6c --- /dev/null +++ b/test/testsuite/mooncake/qr.jl @@ -0,0 +1,94 @@ +function test_mooncake_qr( + T::Type, sz; + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Mooncake qr $summary_str" begin + test_mooncake_qr_compact(T, sz; kwargs...) + test_mooncake_qr_full(T, sz; kwargs...) + test_mooncake_qr_null(T, sz; kwargs...) + end +end + +function remove_qr_gauge_dependence!(ΔQ, A, Q, R) + m, n = size(A) + minmn = min(m, n) + Q₁ = @view Q[:, 1:minmn] + ΔQ₂ = @view ΔQ[:, (minmn + 1):end] + Q₁ᴴΔQ₂ = Q₁' * ΔQ₂ + mul!(ΔQ₂, Q₁, Q₁ᴴΔQ₂) + MatrixAlgebraKit.check_qr_full_cotangents(Q₁, ΔQ₂, Q₁ᴴΔQ₂) + return ΔQ +end + +function remove_qr_null_gauge_dependence!(ΔN, A, N) + Q, _ = qr_compact(A) + return mul!(ΔN, Q, Q' * ΔN) +end + +function test_mooncake_qr_compact( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) + ) + return @testset "qr_compact" begin + A = instantiate_matrix(T, sz) + alg = MatrixAlgebraKit.select_algorithm(qr_compact, A; positive = true) + QR = qr_compact(A, alg) + ΔQR = Mooncake.randn_tangent(rng, QR) + remove_qr_gauge_dependence!(ΔQR[1], A, QR...) + + Mooncake.TestUtils.test_rule( + rng, qr_compact, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔQR, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, make_input_scratch!, qr_compact!, A, QR, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔQR, atol, rtol, is_primitive = false + ) + end +end + +function test_mooncake_qr_full( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) + ) + return @testset "qr_full" begin + A = instantiate_matrix(T, sz) + alg = MatrixAlgebraKit.select_algorithm(qr_full, A; positive = true) + QR = qr_full(A, alg) + ΔQR = Mooncake.randn_tangent(rng, QR) + remove_qr_gauge_dependence!(ΔQR[1], A, QR...) + + Mooncake.TestUtils.test_rule( + rng, qr_full, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔQR, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, make_input_scratch!, qr_full!, A, QR, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔQR, atol, rtol, is_primitive = false + ) + end +end + +function test_mooncake_qr_null( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) + ) + return @testset "qr_null" begin + A = instantiate_matrix(T, sz) + alg = MatrixAlgebraKit.select_algorithm(qr_null, A; positive = true) + N = qr_null(A, alg) + ΔN = Mooncake.randn_tangent(rng, N) + remove_qr_null_gauge_dependence!(ΔN, A, N) + + Mooncake.TestUtils.test_rule( + rng, qr_null, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔN, atol, rtol + ) + N, ΔN = ad_qr_null_setup(A) + Mooncake.TestUtils.test_rule( + rng, make_input_scratch!, qr_null!, A, N, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔN, atol, rtol, is_primitive = false + ) + end +end From 159da4d21b5a5f4a36a0d354d771bcc1b6788708 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 19 Feb 2026 09:25:45 -0500 Subject: [PATCH 04/15] Genius suggestion by @Jutho fixes everything --- test/testsuite/mooncake/mooncake.jl | 5 ++--- test/testsuite/mooncake/qr.jl | 6 +++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/test/testsuite/mooncake/mooncake.jl b/test/testsuite/mooncake/mooncake.jl index 368a7f87..386cae35 100644 --- a/test/testsuite/mooncake/mooncake.jl +++ b/test/testsuite/mooncake/mooncake.jl @@ -214,10 +214,9 @@ function test_pullbacks_match(f!, f, A, args, Δargs, alg = nothing; rdata = Moo return end -function make_input_scratch!(f!, A, F, alg) - F′ = f!(A, F, alg) +function make_input_scratch!(f!, A, alg) + F′ = f!(A, alg) MatrixAlgebraKit.zero!(A) - F === F′ || MatrixAlgebraKit.zero!.(F) return F′ end diff --git a/test/testsuite/mooncake/qr.jl b/test/testsuite/mooncake/qr.jl index 96932c6c..eb4f3712 100644 --- a/test/testsuite/mooncake/qr.jl +++ b/test/testsuite/mooncake/qr.jl @@ -42,7 +42,7 @@ function test_mooncake_qr_compact( mode = Mooncake.ReverseMode, output_tangent = ΔQR, atol, rtol ) Mooncake.TestUtils.test_rule( - rng, make_input_scratch!, qr_compact!, A, QR, alg; + rng, make_input_scratch!, qr_compact!, A, alg; mode = Mooncake.ReverseMode, output_tangent = ΔQR, atol, rtol, is_primitive = false ) end @@ -64,7 +64,7 @@ function test_mooncake_qr_full( mode = Mooncake.ReverseMode, output_tangent = ΔQR, atol, rtol ) Mooncake.TestUtils.test_rule( - rng, make_input_scratch!, qr_full!, A, QR, alg; + rng, make_input_scratch!, qr_full!, A, alg; mode = Mooncake.ReverseMode, output_tangent = ΔQR, atol, rtol, is_primitive = false ) end @@ -87,7 +87,7 @@ function test_mooncake_qr_null( ) N, ΔN = ad_qr_null_setup(A) Mooncake.TestUtils.test_rule( - rng, make_input_scratch!, qr_null!, A, N, alg; + rng, make_input_scratch!, qr_null!, A, alg; mode = Mooncake.ReverseMode, output_tangent = ΔN, atol, rtol, is_primitive = false ) end From deda6053d91881ba9d0250104ee46020024fccd2 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 19 Feb 2026 09:31:55 -0500 Subject: [PATCH 05/15] Refactor Mooncake LQ tests --- test/testsuite/TestSuite.jl | 1 + test/testsuite/mooncake/lq.jl | 95 +++++++++++++++++++++++++++++ test/testsuite/mooncake/mooncake.jl | 36 ----------- 3 files changed, 96 insertions(+), 36 deletions(-) create mode 100644 test/testsuite/mooncake/lq.jl diff --git a/test/testsuite/TestSuite.jl b/test/testsuite/TestSuite.jl index ab8d8c59..9593bb7a 100644 --- a/test/testsuite/TestSuite.jl +++ b/test/testsuite/TestSuite.jl @@ -103,6 +103,7 @@ include("decompositions/svd.jl") # -------- include("mooncake/mooncake.jl") include("mooncake/qr.jl") +include("mooncake/lq.jl") # include("enzyme.jl") # include("chainrules.jl") diff --git a/test/testsuite/mooncake/lq.jl b/test/testsuite/mooncake/lq.jl new file mode 100644 index 00000000..542ce0cd --- /dev/null +++ b/test/testsuite/mooncake/lq.jl @@ -0,0 +1,95 @@ +function test_mooncake_lq( + T::Type, sz; + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Mooncake lq $summary_str" begin + test_mooncake_lq_compact(T, sz; kwargs...) + test_mooncake_lq_full(T, sz; kwargs...) + test_mooncake_lq_null(T, sz; kwargs...) + end +end + +function remove_lq_gauge_dependence!(ΔQ, A, L, Q) + m, n = size(A) + minmn = min(m, n) + Q₁ = @view Q[1:minmn, :] + ΔQ₂ = @view ΔQ[(minmn + 1):end, :] + ΔQ₂Q₁ᴴ = ΔQ₂ * Q₁' + mul!(ΔQ₂, ΔQ₂Q₁ᴴ, Q₁) + MatrixAlgebraKit.check_lq_full_cotangents(Q₁, ΔQ₂, ΔQ₂Q₁ᴴ) + return ΔQ +end + +function remove_lq_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) + _, Q = lq_compact(A) + ΔNᴴQᴴ = ΔNᴴ * Q' + return mul!(ΔNᴴ, ΔNᴴQᴴ, Q) +end + +function test_mooncake_lq_compact( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) + ) + return @testset "lq_compact" begin + A = instantiate_matrix(T, sz) + alg = MatrixAlgebraKit.select_algorithm(lq_compact, A; positive = true) + LQ = lq_compact(A, alg) + ΔLQ = Mooncake.randn_tangent(rng, LQ) + remove_lq_gauge_dependence!(ΔLQ[2], A, LQ...) + + Mooncake.TestUtils.test_rule( + rng, lq_compact, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔLQ, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, make_input_scratch!, lq_compact!, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔLQ, atol, rtol, is_primitive = false + ) + end +end + +function test_mooncake_lq_full( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) + ) + return @testset "lq_full" begin + A = instantiate_matrix(T, sz) + alg = MatrixAlgebraKit.select_algorithm(lq_full, A; positive = true) + LQ = lq_full(A, alg) + ΔLQ = Mooncake.randn_tangent(rng, LQ) + remove_lq_gauge_dependence!(ΔLQ[2], A, LQ...) + + Mooncake.TestUtils.test_rule( + rng, lq_full, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔLQ, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, make_input_scratch!, lq_full!, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔLQ, atol, rtol, is_primitive = false + ) + end +end + +function test_mooncake_lq_null( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) + ) + return @testset "lq_null" begin + A = instantiate_matrix(T, sz) + alg = MatrixAlgebraKit.select_algorithm(lq_null, A; positive = true) + Nᴴ = lq_null(A, alg) + ΔNᴴ = Mooncake.randn_tangent(rng, Nᴴ) + remove_lq_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) + + Mooncake.TestUtils.test_rule( + rng, lq_null, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔNᴴ, atol, rtol + ) + Nᴴ, ΔNᴴ = ad_lq_null_setup(A) + Mooncake.TestUtils.test_rule( + rng, make_input_scratch!, lq_null!, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔNᴴ, atol, rtol, is_primitive = false + ) + end +end diff --git a/test/testsuite/mooncake/mooncake.jl b/test/testsuite/mooncake/mooncake.jl index 386cae35..df769016 100644 --- a/test/testsuite/mooncake/mooncake.jl +++ b/test/testsuite/mooncake/mooncake.jl @@ -238,42 +238,6 @@ function test_mooncake(T::Type, sz; kwargs...) end end -function test_mooncake_lq( - T::Type, sz; - atol::Real = 0, rtol::Real = precision(T), - kwargs... - ) - summary_str = testargs_summary(T, sz) - return @testset "LQ Mooncake AD rules $summary_str" begin - A = instantiate_matrix(T, sz) - @testset "lq_compact" begin - LQ, ΔLQ = ad_lq_compact_setup(A) - Mooncake.TestUtils.test_rule(rng, lq_compact, A; is_primitive = false, mode = Mooncake.ReverseMode, atol, rtol) - test_pullbacks_match(lq_compact!, lq_compact, A, LQ, ΔLQ) - end - @testset "lq_null" begin - Nᴴ, ΔNᴴ = ad_lq_null_setup(A) - dNᴴ = make_mooncake_tangent(ΔNᴴ) - Mooncake.TestUtils.test_rule(rng, lq_null, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dNᴴ, atol, rtol) - test_pullbacks_match(lq_null!, lq_null, A, Nᴴ, ΔNᴴ) - end - @testset "lq_full" begin - LQ, ΔLQ = ad_lq_full_setup(A) - dLQ = make_mooncake_tangent(ΔLQ) - Mooncake.TestUtils.test_rule(rng, lq_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dLQ, atol, rtol) - test_pullbacks_match(lq_full!, lq_full, A, LQ, ΔLQ) - end - @testset "lq_compact - rank-deficient A" begin - m, n = size(A) - r = min(m, n) - 5 - Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) - LQ, ΔLQ = ad_lq_rank_deficient_compact_setup(Ard) - dLQ = make_mooncake_tangent(ΔLQ) - Mooncake.TestUtils.test_rule(rng, lq_compact, Ard; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dLQ, atol, rtol) - test_pullbacks_match(lq_compact!, lq_compact, Ard, LQ, ΔLQ) - end - end -end function test_mooncake_eig( T::Type, sz; From 98c9fc0f22d0cc25b8835c8a8acebbe0817745f4 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 19 Feb 2026 09:57:54 -0500 Subject: [PATCH 06/15] Refactor Mooncake Eig tests --- test/testsuite/TestSuite.jl | 2 + test/testsuite/ad_utils.jl | 13 +-- test/testsuite/mooncake/eig.jl | 102 ++++++++++++++++++ test/testsuite/mooncake/eigh.jl | 161 ++++++++++++++++++++++++++++ test/testsuite/mooncake/mooncake.jl | 137 ----------------------- 5 files changed, 267 insertions(+), 148 deletions(-) create mode 100644 test/testsuite/mooncake/eig.jl create mode 100644 test/testsuite/mooncake/eigh.jl diff --git a/test/testsuite/TestSuite.jl b/test/testsuite/TestSuite.jl index 9593bb7a..ba50f7a2 100644 --- a/test/testsuite/TestSuite.jl +++ b/test/testsuite/TestSuite.jl @@ -104,6 +104,8 @@ include("decompositions/svd.jl") include("mooncake/mooncake.jl") include("mooncake/qr.jl") include("mooncake/lq.jl") +include("mooncake/eig.jl") +include("mooncake/eigh.jl") # include("enzyme.jl") # include("chainrules.jl") diff --git a/test/testsuite/ad_utils.jl b/test/testsuite/ad_utils.jl index 31fb8ca1..3b0d3d61 100644 --- a/test/testsuite/ad_utils.jl +++ b/test/testsuite/ad_utils.jl @@ -8,15 +8,6 @@ function remove_svdgauge_dependence!( mul!(ΔU, U, gaugepart, -1, 1) return ΔU, ΔVᴴ end -function remove_eiggauge_dependence!( - ΔV, D, V; - degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D) - ) - gaugepart = V' * ΔV - gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 - mul!(ΔV, V / (V' * V), gaugepart, -1, 1) - return ΔV -end function remove_eighgauge_dependence!( ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D) @@ -204,7 +195,7 @@ function ad_eig_full_setup(A) D, V = DV Ddiag = diagview(D) ΔV = randn!(similar(A, complex(T), m, m)) - ΔV = remove_eiggauge_dependence!(ΔV, D, V) + ΔV = remove_eig_gauge_dependence!(ΔV, D, V) ΔD = randn!(similar(A, complex(T), m, m)) ΔD2 = Diagonal(randn!(similar(A, complex(T), m))) return DV, (ΔD, ΔV), (ΔD2, ΔV) @@ -216,7 +207,7 @@ function ad_eig_full_setup(A::Diagonal) DV = eig_full(A) D, V = DV ΔV = randn!(similar(A.diag, T, m, m)) - ΔV = remove_eiggauge_dependence!(ΔV, D, V) + ΔV = remove_eig_gauge_dependence!(ΔV, D, V) ΔD = Diagonal(randn!(similar(A.diag, T, m))) ΔD2 = Diagonal(randn!(similar(A.diag, T, m))) return DV, (ΔD, ΔV), (ΔD2, ΔV) diff --git a/test/testsuite/mooncake/eig.jl b/test/testsuite/mooncake/eig.jl new file mode 100644 index 00000000..50d08499 --- /dev/null +++ b/test/testsuite/mooncake/eig.jl @@ -0,0 +1,102 @@ +function test_mooncake_eig( + T::Type, sz; + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Mooncake eig $summary_str" begin + test_mooncake_eig_full(T, sz; kwargs...) + test_mooncake_eig_vals(T, sz; kwargs...) + test_mooncake_eig_trunc(T, sz; kwargs...) + end +end + +function test_mooncake_eig_full( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) + ) + return @testset "eig_full" begin + A = make_eig_matrix(T, sz) + alg = MatrixAlgebraKit.select_algorithm(eig_full, A) + DV = eig_full(A, alg) + ΔDV = Mooncake.randn_tangent(rng, DV) + remove_eiggauge_dependence!(ΔDV[2], DV...) + + Mooncake.TestUtils.test_rule( + rng, eig_full, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔDV, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, make_input_scratch!, eig_full!, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔDV, atol, rtol, is_primitive = false + ) + end +end + +function test_mooncake_eig_vals( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) + ) + return @testset "eig_vals" begin + A = make_eig_matrix(T, sz) + alg = MatrixAlgebraKit.select_algorithm(eig_vals, A) + D = eig_vals(A, alg) + ΔD = Mooncake.randn_tangent(rng, D) + + Mooncake.TestUtils.test_rule( + rng, eig_vals, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔD, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, make_input_scratch!, eig_vals!, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔD, atol, rtol, is_primitive = false + ) + end +end + +function test_mooncake_eig_trunc( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) + ) + return @testset "eig_trunc" begin + A = make_eig_matrix(T, sz) + m = size(A, 1) + + alg = MatrixAlgebraKit.select_algorithm(eig_full, A) + DV = eig_full(A, alg) + ΔDV = Mooncake.randn_tangent(rng, DV) + remove_eig_gauge_dependence!(ΔDV[2], DV...) + + @testset "truncrank($r)" for r in round.(Int, range(1, m + 4, 4)) + trunc = truncrank(r; by = abs) + alg_trunc = TruncatedAlgorithm(alg, trunc) + + # truncate the gauge-corrected tangents + DVtrunc, ind = MatrixAlgebraKit.truncate(eig_trunc!, DV, trunc) + ΔDV_primal = Mooncake.tangent_to_primal!!(copy.(DV), ΔDV) + ΔDVtrunc_primal = (Diagonal(diagview(ΔDV_primal[1])[ind]), ΔDV_primal[2][:, ind]) + ΔDVtrunc = Mooncake.primal_to_tangent!!(Mooncake.zero_tangent(DVtrunc), ΔDVtrunc_primal) + + Mooncake.TestUtils.test_rule( + rng, eig_trunc_no_error, A, alg_trunc; + mode = Mooncake.ReverseMode, output_tangent = ΔDVtrunc, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, make_input_scratch!, eig_trunc_no_error!, A, alg_trunc; + mode = Mooncake.ReverseMode, output_tangent = ΔDVtrunc, atol, rtol, is_primitive = false + ) + + DVϵ = eig_trunc(A, alg_trunc) + Δϵ = Mooncake.zero_tangent(DVϵ[end]) + ΔDVϵtrunc = (ΔDVtrunc..., Δϵ) + + Mooncake.TestUtils.test_rule( + rng, eig_trunc, A, alg_trunc; + mode = Mooncake.ReverseMode, output_tangent = ΔDVϵtrunc, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, make_input_scratch!, eig_trunc!, A, alg_trunc; + mode = Mooncake.ReverseMode, output_tangent = ΔDVϵtrunc, atol, rtol, is_primitive = false + ) + end + end +end diff --git a/test/testsuite/mooncake/eigh.jl b/test/testsuite/mooncake/eigh.jl new file mode 100644 index 00000000..ea7a87f5 --- /dev/null +++ b/test/testsuite/mooncake/eigh.jl @@ -0,0 +1,161 @@ +function mc_copy_eigh_full(A, alg) + A = (A + A') / 2 + return eigh_full(A, alg) +end + +function mc_copy_eigh_full!(A, DV, alg) + A = (A + A') / 2 + return eigh_full!(A, DV, alg) +end + +function mc_copy_eigh_vals(A, alg) + A = (A + A') / 2 + return eigh_vals(A, alg) +end + +function mc_copy_eigh_vals!(A, D, alg) + A = (A + A') / 2 + return eigh_vals!(A, D, alg) +end + +function mc_copy_eigh_trunc(A, alg) + A = (A + A') / 2 + return eigh_trunc(A, alg) +end + +function mc_copy_eigh_trunc!(A, DV, alg) + A = (A + A') / 2 + return eigh_trunc!(A, DV, alg) +end + +function mc_copy_eigh_trunc_no_error(A, alg) + A = (A + A') / 2 + return eigh_trunc_no_error(A, alg) +end + +function mc_copy_eigh_trunc_no_error!(A, DV, alg) + A = (A + A') / 2 + return eigh_trunc_no_error!(A, DV, alg) +end + +MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_full), A) = MatrixAlgebraKit.copy_input(eigh_full, A) +MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_vals), A) = MatrixAlgebraKit.copy_input(eigh_vals, A) +MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_trunc), A) = MatrixAlgebraKit.copy_input(eigh_trunc, A) +MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_trunc_no_error), A) = MatrixAlgebraKit.copy_input(eigh_trunc, A) + +function remove_eigh_gauge_dependence!( + ΔV, D, V; + degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D) + ) + gaugepart = V' * ΔV + gaugepart = project_antihermitian!(gaugepart) + gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 + mul!(ΔV, V, gaugepart, -1, 1) + return ΔV +end + +eigh_wrapper(f, A, alg) = f(project_hermitian(A), alg) +eigh!_wrapper(f!, A, alg) = (F = f!(project_hermitian!(A), alg); MatrixAlgebraKit.zero!(A); F) + +function test_mooncake_eigh( + T::Type, sz; + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Mooncake eigh $summary_str" begin + test_mooncake_eigh_full(T, sz; kwargs...) + test_mooncake_eigh_vals(T, sz; kwargs...) + test_mooncake_eigh_trunc(T, sz; kwargs...) + end +end + +function test_mooncake_eigh_full( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) + ) + return @testset "eigh_full" begin + A = make_eigh_matrix(T, sz) + alg = MatrixAlgebraKit.select_algorithm(eigh_full, A) + DV = eigh_full(A, alg) + ΔDV = Mooncake.randn_tangent(rng, DV) + remove_eigh_gauge_dependence!(ΔDV[2], DV...) + + Mooncake.TestUtils.test_rule( + rng, eigh_wrapper, eigh_full, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔDV, is_primitive = false, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, eigh!_wrapper, eigh_full!, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔDV, atol, rtol, is_primitive = false + ) + end +end + +function test_mooncake_eigh_vals( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) + ) + return @testset "eigh_vals" begin + A = make_eigh_matrix(T, sz) + alg = MatrixAlgebraKit.select_algorithm(eigh_vals, A) + D = eigh_vals(A, alg) + ΔD = Mooncake.randn_tangent(rng, D) + + Mooncake.TestUtils.test_rule( + rng, eigh_wrapper, eigh_vals, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔD, is_primitive = false, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, eigh!_wrapper, eigh_vals!, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔD, atol, rtol, is_primitive = false + ) + end +end + +function test_mooncake_eigh_trunc( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) + ) + return @testset "eigh_trunc" begin + A = make_eigh_matrix(T, sz) + m = size(A, 1) + + alg = MatrixAlgebraKit.select_algorithm(eigh_full, A) + DV = eigh_full(A, alg) + ΔDV = Mooncake.randn_tangent(rng, DV) + remove_eigh_gauge_dependence!(ΔDV[2], DV...) + + @testset "truncrank($r)" for r in round.(Int, range(1, m + 4, 4)) + trunc = truncrank(r; by = abs) + alg_trunc = TruncatedAlgorithm(alg, trunc) + + # truncate the gauge-corrected tangents + DVtrunc, ind = MatrixAlgebraKit.truncate(eigh_trunc!, DV, trunc) + ΔDV_primal = Mooncake.tangent_to_primal!!(copy.(DV), ΔDV) + ΔDVtrunc_primal = (Diagonal(diagview(ΔDV_primal[1])[ind]), ΔDV_primal[2][:, ind]) + ΔDVtrunc = Mooncake.primal_to_tangent!!(Mooncake.zero_tangent(DVtrunc), ΔDVtrunc_primal) + + Mooncake.TestUtils.test_rule( + rng, eigh_wrapper, eigh_trunc_no_error, A, alg_trunc; + mode = Mooncake.ReverseMode, output_tangent = ΔDVtrunc, atol, rtol, is_primitive = false + ) + Mooncake.TestUtils.test_rule( + rng, eigh!_wrapper, eigh_trunc_no_error!, A, alg_trunc; + mode = Mooncake.ReverseMode, output_tangent = ΔDVtrunc, atol, rtol, is_primitive = false + ) + + DVϵ = eigh_trunc(A, alg_trunc) + Δϵ = Mooncake.zero_tangent(DVϵ[end]) + ΔDVϵtrunc = (ΔDVtrunc..., Δϵ) + + Mooncake.TestUtils.test_rule( + rng, eigh_wrapper, eigh_trunc, A, alg_trunc; + mode = Mooncake.ReverseMode, output_tangent = ΔDVϵtrunc, atol, rtol, is_primitive = false + ) + Mooncake.TestUtils.test_rule( + rng, eigh!_wrapper, eigh_trunc!, A, alg_trunc; + mode = Mooncake.ReverseMode, output_tangent = ΔDVϵtrunc, atol, rtol, is_primitive = false + ) + end + end +end diff --git a/test/testsuite/mooncake/mooncake.jl b/test/testsuite/mooncake/mooncake.jl index df769016..b2538191 100644 --- a/test/testsuite/mooncake/mooncake.jl +++ b/test/testsuite/mooncake/mooncake.jl @@ -6,50 +6,6 @@ using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD, eigh_trunc using LinearAlgebra: BlasFloat using GenericLinearAlgebra -function mc_copy_eigh_full(A; kwargs...) - A = (A + A') / 2 - return eigh_full(A; kwargs...) -end - -function mc_copy_eigh_full!(A, DV; kwargs...) - A = (A + A') / 2 - return eigh_full!(A, DV; kwargs...) -end - -function mc_copy_eigh_vals(A; kwargs...) - A = (A + A') / 2 - return eigh_vals(A; kwargs...) -end - -function mc_copy_eigh_vals!(A, D; kwargs...) - A = (A + A') / 2 - return eigh_vals!(A, D; kwargs...) -end - -function mc_copy_eigh_trunc(A, alg; kwargs...) - A = (A + A') / 2 - return eigh_trunc(A, alg; kwargs...) -end - -function mc_copy_eigh_trunc!(A, DV, alg; kwargs...) - A = (A + A') / 2 - return eigh_trunc!(A, DV, alg; kwargs...) -end - -function mc_copy_eigh_trunc_no_error(A, alg; kwargs...) - A = (A + A') / 2 - return eigh_trunc_no_error(A, alg; kwargs...) -end - -function mc_copy_eigh_trunc_no_error!(A, DV, alg; kwargs...) - A = (A + A') / 2 - return eigh_trunc_no_error!(A, DV, alg; kwargs...) -end - -MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_full), A) = MatrixAlgebraKit.copy_input(eigh_full, A) -MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_vals), A) = MatrixAlgebraKit.copy_input(eigh_vals, A) -MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_trunc), A) = MatrixAlgebraKit.copy_input(eigh_trunc, A) -MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_trunc_no_error), A) = MatrixAlgebraKit.copy_input(eigh_trunc, A) make_mooncake_tangent(ΔAelem::T) where {T <: Number} = ΔAelem make_mooncake_tangent(ΔA::AbstractMatrix) = ΔA @@ -239,99 +195,6 @@ function test_mooncake(T::Type, sz; kwargs...) end -function test_mooncake_eig( - T::Type, sz; - atol::Real = 0, rtol::Real = precision(T), - kwargs... - ) - summary_str = testargs_summary(T, sz) - return @testset "EIG Mooncake AD rules $summary_str" begin - A = make_eig_matrix(T, sz) - m = size(A, 1) - @testset "eig_full" begin - DV, ΔDV, ΔD2V = ad_eig_full_setup(A) - dDV = make_mooncake_tangent(ΔD2V) - Mooncake.TestUtils.test_rule(rng, eig_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dDV, atol, rtol) - test_pullbacks_match(eig_full!, eig_full, A, DV, ΔD2V) - end - @testset "eig_vals" begin - D, ΔD = ad_eig_vals_setup(A) - dD = make_mooncake_tangent(ΔD) - Mooncake.TestUtils.test_rule(rng, eig_vals, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dD, atol, rtol) - test_pullbacks_match(eig_vals!, eig_vals, A, D, ΔD) - end - @testset "eig_trunc" begin - for r in 1:4:m - truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eig_algorithm(A), truncrank(r; by = abs)) - DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) - ϵ = zero(real(eltype(T))) - dDVerr = make_mooncake_tangent((copy.(ΔDVtrunc)..., ϵ)) - Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol, rtol) - test_pullbacks_match(eig_trunc!, eig_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T)))), ȳ = ΔDVtrunc) - dDVtrunc = make_mooncake_tangent(ΔDVtrunc) - Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol, rtol) - test_pullbacks_match(eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg; ȳ = ΔDVtrunc) - end - truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eig_algorithm(A), truncrank(5; by = real)) - DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) - ϵ = zero(real(eltype(T))) - dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) - Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol, rtol) - test_pullbacks_match(eig_trunc!, eig_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T)))), ȳ = ΔDVtrunc) - dDVtrunc = make_mooncake_tangent(ΔDVtrunc) - Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol, rtol) - test_pullbacks_match(eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg; ȳ = ΔDVtrunc) - end - end -end - -function test_mooncake_eigh( - T::Type, sz; - atol::Real = 0, rtol::Real = precision(T), - kwargs... - ) - summary_str = testargs_summary(T, sz) - return @testset "EIGH Mooncake AD rules $summary_str" begin - A = make_eigh_matrix(T, sz) - m = size(A, 1) - @testset "eigh_full" begin - DV, ΔDV, ΔD2V = ad_eigh_full_setup(A) - dDV = make_mooncake_tangent(ΔD2V) - Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_full, A; mode = Mooncake.ReverseMode, output_tangent = dDV, is_primitive = false, atol, rtol) - test_pullbacks_match(mc_copy_eigh_full!, mc_copy_eigh_full, A, DV, ΔD2V) - end - @testset "eigh_vals" begin - D, ΔD = ad_eigh_vals_setup(A) - dD = make_mooncake_tangent(ΔD) - Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_vals, A; mode = Mooncake.ReverseMode, output_tangent = dD, is_primitive = false, atol, rtol) - test_pullbacks_match(mc_copy_eigh_vals!, mc_copy_eigh_vals, A, D, ΔD) - end - @testset "eigh_trunc" begin - for r in 1:4:m - truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eigh_algorithm(A), truncrank(r; by = abs)) - DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) - ϵ = zero(real(eltype(T))) - dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) - Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol, rtol, is_primitive = false) - test_pullbacks_match(mc_copy_eigh_trunc!, mc_copy_eigh_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T)))), ȳ = ΔDVtrunc) - dDVtrunc = make_mooncake_tangent(ΔDVtrunc) - Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol, rtol, is_primitive = false) - test_pullbacks_match(mc_copy_eigh_trunc_no_error!, mc_copy_eigh_trunc_no_error, A, DV, ΔDV, truncalg; ȳ = ΔDVtrunc) - end - D = eigh_vals(A / 2) - truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eigh_algorithm(A), trunctol(; atol = maximum(abs, D) / 2)) - DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) - ϵ = zero(real(eltype(T))) - dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) - Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol, rtol, is_primitive = false) - test_pullbacks_match(mc_copy_eigh_trunc!, mc_copy_eigh_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T)))), ȳ = ΔDVtrunc) - dDVtrunc = make_mooncake_tangent(ΔDVtrunc) - Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol, rtol, is_primitive = false) - test_pullbacks_match(mc_copy_eigh_trunc_no_error!, mc_copy_eigh_trunc_no_error, A, DV, ΔDV, truncalg; ȳ = ΔDVtrunc) - end - end -end - function test_mooncake_svd( T::Type, sz; atol::Real = 0, rtol::Real = precision(T), From 64a4246c4a59719b7e10bc1977e41ade5e4a4bb3 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 19 Feb 2026 13:33:41 -0500 Subject: [PATCH 07/15] fix pullback implementations! --- .../MatrixAlgebraKitMooncakeExt.jl | 13 ++++++------- test/testsuite/mooncake/eig.jl | 12 +++++++++++- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index 3a113c20..f32e4258 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -239,7 +239,7 @@ for f in (:eig, :eigh) _warn_pullback_truncerror(dϵ) # compute pullbacks - $f_pullback!(dA, Ac, DVc, dDVtrunc, ind) + $f_pullback!(dA, Ac, DV, dDVtrunc, ind) zero!.(dDVtrunc) # since this is allocated in this function this is probably not required # restore state @@ -351,8 +351,8 @@ for f in (:eig, :eigh) dDVtrunc = last.(arrayify.(DVtrunc, Mooncake.tangent(DVtrunc_dDVtrunc))) function $f_adjoint!(::NoRData) # compute pullbacks - $f_pullback!(dA, Ac, DVc, dDVtrunc, ind) - zero!.(dDVtrunc) # since this is allocated in this function this is probably not required + $f_pullback!(dA, Ac, DV, dDVtrunc, ind) + zero!.(dDV) # restore state copy!(A, Ac) @@ -425,7 +425,7 @@ for (f!, f) in ( S, dS = arrayify(USVᴴ[2], dUSVᴴ[2]) Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3]) USVᴴc = copy.(USVᴴ) - output = $f!(A, Mooncake.primal(alg_dalg)) + output = $f!(A, USVᴴ, Mooncake.primal(alg_dalg)) function svd_adjoint(::NoRData) copy!(A, Ac) if $(f! == svd_compact!) @@ -590,7 +590,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS _warn_pullback_truncerror(dϵ) # compute pullbacks - svd_pullback!(dA, Ac, USVᴴc, dUSVᴴtrunc, ind) + svd_pullback!(dA, Ac, USVᴴ, dUSVᴴtrunc, ind) zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required zero!.(dUSVᴴ) @@ -717,8 +717,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, U dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc))) function svd_trunc_adjoint(::NoRData) # compute pullbacks - svd_pullback!(dA, Ac, USVᴴc, dUSVᴴtrunc, ind) - zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required + svd_pullback!(dA, Ac, USVᴴ, dUSVᴴtrunc, ind) zero!.(dUSVᴴ) # restore state diff --git a/test/testsuite/mooncake/eig.jl b/test/testsuite/mooncake/eig.jl index 50d08499..6ee5cd91 100644 --- a/test/testsuite/mooncake/eig.jl +++ b/test/testsuite/mooncake/eig.jl @@ -10,6 +10,16 @@ function test_mooncake_eig( end end +function remove_eig_gauge_dependence!( + ΔV, D, V; + degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D) + ) + gaugepart = V' * ΔV + gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 + mul!(ΔV, V / (V' * V), gaugepart, -1, 1) + return ΔV +end + function test_mooncake_eig_full( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) @@ -19,7 +29,7 @@ function test_mooncake_eig_full( alg = MatrixAlgebraKit.select_algorithm(eig_full, A) DV = eig_full(A, alg) ΔDV = Mooncake.randn_tangent(rng, DV) - remove_eiggauge_dependence!(ΔDV[2], DV...) + remove_eig_gauge_dependence!(ΔDV[2], DV...) Mooncake.TestUtils.test_rule( rng, eig_full, A, alg; From 8b209c21a5b9e1991e5fd4f374755b9507c85505 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 19 Feb 2026 14:42:59 -0500 Subject: [PATCH 08/15] Refactor Mooncake SVD tests --- test/testsuite/TestSuite.jl | 1 + test/testsuite/mooncake/mooncake.jl | 55 --------- test/testsuite/mooncake/svd.jl | 177 ++++++++++++++++++++++++++++ 3 files changed, 178 insertions(+), 55 deletions(-) create mode 100644 test/testsuite/mooncake/svd.jl diff --git a/test/testsuite/TestSuite.jl b/test/testsuite/TestSuite.jl index ba50f7a2..e9eac578 100644 --- a/test/testsuite/TestSuite.jl +++ b/test/testsuite/TestSuite.jl @@ -106,6 +106,7 @@ include("mooncake/qr.jl") include("mooncake/lq.jl") include("mooncake/eig.jl") include("mooncake/eigh.jl") +include("mooncake/svd.jl") # include("enzyme.jl") # include("chainrules.jl") diff --git a/test/testsuite/mooncake/mooncake.jl b/test/testsuite/mooncake/mooncake.jl index b2538191..43cf0b5d 100644 --- a/test/testsuite/mooncake/mooncake.jl +++ b/test/testsuite/mooncake/mooncake.jl @@ -195,61 +195,6 @@ function test_mooncake(T::Type, sz; kwargs...) end -function test_mooncake_svd( - T::Type, sz; - atol::Real = 0, rtol::Real = precision(T), - kwargs... - ) - summary_str = testargs_summary(T, sz) - return @testset "SVD Mooncake AD rules $summary_str" begin - A = instantiate_matrix(T, sz) - minmn = min(size(A)...) - @testset "svd_compact" begin - USVᴴ, _, ΔUSVᴴ = ad_svd_compact_setup(A) - dUSVᴴ = make_mooncake_tangent(ΔUSVᴴ) - Mooncake.TestUtils.test_rule(rng, svd_compact, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol, rtol) - test_pullbacks_match(svd_compact!, svd_compact, A, USVᴴ, ΔUSVᴴ) - end - @testset "svd_full" begin - USVᴴ, ΔUSVᴴ = ad_svd_full_setup(A) - dUSVᴴ = make_mooncake_tangent(ΔUSVᴴ) - Mooncake.TestUtils.test_rule(rng, svd_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol, rtol) - test_pullbacks_match(svd_full!, svd_full, A, USVᴴ, ΔUSVᴴ) - end - @testset "svd_vals" begin - S, ΔS = ad_svd_vals_setup(A) - Mooncake.TestUtils.test_rule(rng, svd_vals, A; is_primitive = false, mode = Mooncake.ReverseMode, atol, rtol) - test_pullbacks_match(svd_vals!, svd_vals, A, S, ΔS) - end - @testset "svd_trunc" begin - @testset for r in 1:4:minmn - truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_svd_algorithm(A), truncrank(r)) - USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) - ϵ = zero(real(eltype(T))) - dUSVᴴerr = make_mooncake_tangent((ΔUSVᴴtrunc..., ϵ)) - Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol, rtol) - test_pullbacks_match(svd_trunc!, svd_trunc, A, USVᴴ, ΔUSVᴴ, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T)))), ȳ = ΔUSVᴴtrunc) - dUSVᴴ = make_mooncake_tangent(ΔUSVᴴtrunc) - Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol, rtol) - test_pullbacks_match(svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg; ȳ = ΔUSVᴴtrunc) - end - @testset "trunctol" begin - A = instantiate_matrix(T, sz) - S, ΔS = ad_svd_vals_setup(A) - truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_svd_algorithm(A), trunctol(atol = S[1, 1] / 2)) - USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) - ϵ = zero(real(eltype(T))) - dUSVᴴerr = make_mooncake_tangent((ΔUSVᴴtrunc..., ϵ)) - Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol, rtol) - test_pullbacks_match(svd_trunc!, svd_trunc, A, USVᴴ, ΔUSVᴴ, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T)))), ȳ = ΔUSVᴴtrunc) - dUSVᴴ = make_mooncake_tangent(ΔUSVᴴtrunc) - Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol, rtol) - test_pullbacks_match(svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg; ȳ = ΔUSVᴴtrunc) - end - end - end -end - function test_mooncake_polar( T::Type, sz; atol::Real = 0, rtol::Real = precision(T), diff --git a/test/testsuite/mooncake/svd.jl b/test/testsuite/mooncake/svd.jl new file mode 100644 index 00000000..d31477ae --- /dev/null +++ b/test/testsuite/mooncake/svd.jl @@ -0,0 +1,177 @@ +function remove_svd_gauge_dependence!( + ΔU, ΔVᴴ, U, S, Vᴴ; + degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(S) + ) + minmn = length(diagview(S)) + U₁ = view(U, :, 1:minmn) + Vᴴ₁ = view(Vᴴ, 1:minmn, :) + ΔU₁ = view(ΔU, :, 1:minmn) + ΔVᴴ₁ = view(ΔVᴴ, 1:minmn, :) + Sdiag = diagview(S) + gaugepart = mul!(U₁' * ΔU₁, Vᴴ₁, ΔVᴴ₁', true, true) + gaugepart = project_antihermitian!(gaugepart) + gaugepart[abs.(transpose(Sdiag) .- Sdiag) .>= degeneracy_atol] .= 0 + mul!(ΔU₁, U₁, gaugepart, -1, 1) + ΔU[:, (minmn + 1):end] .= 0 + ΔVᴴ[(minmn + 1):end, :] .= 0 + return ΔU, ΔVᴴ +end + +function test_mooncake_svd( + T::Type, sz; + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Mooncake svd $summary_str" begin + test_mooncake_svd_compact(T, sz; kwargs...) + test_mooncake_svd_full(T, sz; kwargs...) + test_mooncake_svd_vals(T, sz; kwargs...) + test_mooncake_svd_trunc(T, sz; kwargs...) + end +end + +function test_mooncake_svd_compact( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) + ) + return @testset "svd_compact" begin + A = instantiate_matrix(T, sz) + alg = MatrixAlgebraKit.select_algorithm(svd_compact, A) + USVᴴ = svd_compact(A, alg) + ΔUSVᴴ = Mooncake.randn_tangent(rng, USVᴴ) + remove_svd_gauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...) + + Mooncake.TestUtils.test_rule( + rng, svd_compact, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔUSVᴴ, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, make_input_scratch!, svd_compact!, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔUSVᴴ, atol, rtol, is_primitive = false + ) + end +end + +function test_mooncake_svd_full( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) + ) + return @testset "svd_full" begin + A = instantiate_matrix(T, sz) + alg = MatrixAlgebraKit.select_algorithm(svd_full, A) + USVᴴ = svd_full(A, alg) + ΔUSVᴴ = Mooncake.randn_tangent(rng, USVᴴ) + remove_svd_gauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...) + + Mooncake.TestUtils.test_rule( + rng, svd_full, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔUSVᴴ, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, make_input_scratch!, svd_full!, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔUSVᴴ, atol, rtol, is_primitive = false + ) + end +end + +function test_mooncake_svd_vals( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) + ) + return @testset "svd_vals" begin + A = instantiate_matrix(T, sz) + alg = MatrixAlgebraKit.select_algorithm(svd_vals, A) + S = svd_vals(A, alg) + ΔS = Mooncake.randn_tangent(rng, S) + + Mooncake.TestUtils.test_rule( + rng, svd_vals, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔS, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, make_input_scratch!, svd_vals!, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔS, atol, rtol, is_primitive = false + ) + end +end + +function test_mooncake_svd_trunc( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) + ) + return @testset "svd_trunc" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + minmn = min(m, n) + + alg = MatrixAlgebraKit.select_algorithm(svd_compact, A) + USVᴴ = svd_compact(A, alg) + ΔUSVᴴ = Mooncake.randn_tangent(rng, USVᴴ) + remove_svd_gauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...) + + @testset "truncrank($r)" for r in round.(Int, range(1, minmn + 4, 4)) + trunc = truncrank(r) + alg_trunc = TruncatedAlgorithm(alg, trunc) + + # truncate the gauge-corrected tangents + USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, trunc) + ΔUSVᴴ_primal = Mooncake.tangent_to_primal!!(copy.(USVᴴ), ΔUSVᴴ) + ΔUSVᴴtrunc_primal = (ΔUSVᴴ_primal[1][:, ind], Diagonal(diagview(ΔUSVᴴ_primal[2])[ind]), ΔUSVᴴ_primal[3][ind, :]) + ΔUSVᴴtrunc = Mooncake.primal_to_tangent!!(Mooncake.zero_tangent(USVᴴtrunc), ΔUSVᴴtrunc_primal) + + Mooncake.TestUtils.test_rule( + rng, svd_trunc_no_error, A, alg_trunc; + mode = Mooncake.ReverseMode, output_tangent = ΔUSVᴴtrunc, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, make_input_scratch!, svd_trunc_no_error!, A, alg_trunc; + mode = Mooncake.ReverseMode, output_tangent = ΔUSVᴴtrunc, atol, rtol, is_primitive = false + ) + + USVᴴϵ = svd_trunc(A, alg_trunc) + Δϵ = Mooncake.zero_tangent(USVᴴϵ[end]) + ΔUSVᴴϵtrunc = (ΔUSVᴴtrunc..., Δϵ) + + Mooncake.TestUtils.test_rule( + rng, svd_trunc, A, alg_trunc; + mode = Mooncake.ReverseMode, output_tangent = ΔUSVᴴϵtrunc, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, make_input_scratch!, svd_trunc!, A, alg_trunc; + mode = Mooncake.ReverseMode, output_tangent = ΔUSVᴴϵtrunc, atol, rtol, is_primitive = false + ) + end + + @testset "trunctol" begin + trunc = trunctol(atol = diagview(USVᴴ[2])[1] / 2) + alg_trunc = TruncatedAlgorithm(alg, trunc) + + USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, trunc) + ΔUSVᴴ_primal = Mooncake.tangent_to_primal!!(copy.(USVᴴ), ΔUSVᴴ) + ΔUSVᴴtrunc_primal = (ΔUSVᴴ_primal[1][:, ind], Diagonal(diagview(ΔUSVᴴ_primal[2])[ind]), ΔUSVᴴ_primal[3][ind, :]) + ΔUSVᴴtrunc = Mooncake.primal_to_tangent!!(Mooncake.zero_tangent(USVᴴtrunc), ΔUSVᴴtrunc_primal) + + Mooncake.TestUtils.test_rule( + rng, svd_trunc_no_error, A, alg_trunc; + mode = Mooncake.ReverseMode, output_tangent = ΔUSVᴴtrunc, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, make_input_scratch!, svd_trunc_no_error!, A, alg_trunc; + mode = Mooncake.ReverseMode, output_tangent = ΔUSVᴴtrunc, atol, rtol, is_primitive = false + ) + + USVᴴϵ = svd_trunc(A, alg_trunc) + Δϵ = Mooncake.zero_tangent(USVᴴϵ[end]) + ΔUSVᴴϵtrunc = (ΔUSVᴴtrunc..., Δϵ) + + Mooncake.TestUtils.test_rule( + rng, svd_trunc, A, alg_trunc; + mode = Mooncake.ReverseMode, output_tangent = ΔUSVᴴϵtrunc, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, make_input_scratch!, svd_trunc!, A, alg_trunc; + mode = Mooncake.ReverseMode, output_tangent = ΔUSVᴴϵtrunc, atol, rtol, is_primitive = false + ) + end + end +end From f8d7cc689506a84a239dbbf8c5b2b5ecc6b2f703 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 19 Feb 2026 14:52:22 -0500 Subject: [PATCH 09/15] Refactor Mooncake Polar tests --- test/testsuite/TestSuite.jl | 1 + test/testsuite/mooncake/mooncake.jl | 26 ------------- test/testsuite/mooncake/polar.jl | 58 +++++++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 26 deletions(-) create mode 100644 test/testsuite/mooncake/polar.jl diff --git a/test/testsuite/TestSuite.jl b/test/testsuite/TestSuite.jl index e9eac578..09098482 100644 --- a/test/testsuite/TestSuite.jl +++ b/test/testsuite/TestSuite.jl @@ -107,6 +107,7 @@ include("mooncake/lq.jl") include("mooncake/eig.jl") include("mooncake/eigh.jl") include("mooncake/svd.jl") +include("mooncake/polar.jl") # include("enzyme.jl") # include("chainrules.jl") diff --git a/test/testsuite/mooncake/mooncake.jl b/test/testsuite/mooncake/mooncake.jl index 43cf0b5d..737c27ed 100644 --- a/test/testsuite/mooncake/mooncake.jl +++ b/test/testsuite/mooncake/mooncake.jl @@ -195,32 +195,6 @@ function test_mooncake(T::Type, sz; kwargs...) end -function test_mooncake_polar( - T::Type, sz; - atol::Real = 0, rtol::Real = precision(T), - kwargs... - ) - summary_str = testargs_summary(T, sz) - return @testset "Polar Mooncake AD rules $summary_str" begin - A = instantiate_matrix(T, sz) - m, n = size(A) - @testset "left_polar" begin - if m >= n - WP, ΔWP = ad_left_polar_setup(A) - Mooncake.TestUtils.test_rule(rng, left_polar, A; is_primitive = false, mode = Mooncake.ReverseMode, atol, rtol) - test_pullbacks_match(left_polar!, left_polar, A, WP, ΔWP) - end - end - @testset "right_polar" begin - if m <= n - PWᴴ, ΔPWᴴ = ad_right_polar_setup(A) - Mooncake.TestUtils.test_rule(rng, right_polar, A; is_primitive = false, mode = Mooncake.ReverseMode, atol, rtol) - test_pullbacks_match(right_polar!, right_polar, A, PWᴴ, ΔPWᴴ) - end - end - end -end - left_orth_qr(X) = left_orth(X; alg = :qr) left_orth_polar(X) = left_orth(X; alg = :polar) left_null_qr(X) = left_null(X; alg = :qr) diff --git a/test/testsuite/mooncake/polar.jl b/test/testsuite/mooncake/polar.jl new file mode 100644 index 00000000..469c6af0 --- /dev/null +++ b/test/testsuite/mooncake/polar.jl @@ -0,0 +1,58 @@ +function test_mooncake_polar( + T::Type, sz; + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Mooncake polar $summary_str" begin + test_mooncake_left_polar(T, sz; kwargs...) + test_mooncake_right_polar(T, sz; kwargs...) + end +end + +function test_mooncake_left_polar( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) + ) + return @testset "left_polar" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + if m >= n + alg = MatrixAlgebraKit.select_algorithm(left_polar, A) + WP = left_polar(A, alg) + ΔWP = Mooncake.randn_tangent(rng, WP) + + Mooncake.TestUtils.test_rule( + rng, left_polar, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔWP, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, make_input_scratch!, left_polar!, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔWP, atol, rtol, is_primitive = false + ) + end + end +end + +function test_mooncake_right_polar( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) + ) + return @testset "right_polar" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + if m <= n + alg = MatrixAlgebraKit.select_algorithm(right_polar, A) + PWᴴ = right_polar(A, alg) + ΔPWᴴ = Mooncake.randn_tangent(rng, PWᴴ) + + Mooncake.TestUtils.test_rule( + rng, right_polar, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔPWᴴ, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, make_input_scratch!, right_polar!, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔPWᴴ, atol, rtol, is_primitive = false + ) + end + end +end From 552559f05225fa38b72691ce2098a09e2753ce91 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 19 Feb 2026 14:55:27 -0500 Subject: [PATCH 10/15] make testsets verbose --- test/testsuite/mooncake/mooncake.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/testsuite/mooncake/mooncake.jl b/test/testsuite/mooncake/mooncake.jl index 737c27ed..35b86e6e 100644 --- a/test/testsuite/mooncake/mooncake.jl +++ b/test/testsuite/mooncake/mooncake.jl @@ -178,7 +178,7 @@ end function test_mooncake(T::Type, sz; kwargs...) summary_str = testargs_summary(T, sz) - return @testset "Mooncake AD $summary_str" begin + return @testset "Mooncake AD $summary_str" verbose = true begin test_mooncake_qr(T, sz; kwargs...) test_mooncake_lq(T, sz; kwargs...) if length(sz) == 1 || sz[1] == sz[2] From 9161bde94ee5fb036721a27ae02d4d2b1749c4ce Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 19 Feb 2026 15:07:27 -0500 Subject: [PATCH 11/15] Refactor Mooncake OrthNull tests --- test/testsuite/TestSuite.jl | 1 + test/testsuite/mooncake/mooncake.jl | 58 ---------- test/testsuite/mooncake/orthnull.jl | 158 ++++++++++++++++++++++++++++ 3 files changed, 159 insertions(+), 58 deletions(-) create mode 100644 test/testsuite/mooncake/orthnull.jl diff --git a/test/testsuite/TestSuite.jl b/test/testsuite/TestSuite.jl index 09098482..e03c50ae 100644 --- a/test/testsuite/TestSuite.jl +++ b/test/testsuite/TestSuite.jl @@ -108,6 +108,7 @@ include("mooncake/eig.jl") include("mooncake/eigh.jl") include("mooncake/svd.jl") include("mooncake/polar.jl") +include("mooncake/orthnull.jl") # include("enzyme.jl") # include("chainrules.jl") diff --git a/test/testsuite/mooncake/mooncake.jl b/test/testsuite/mooncake/mooncake.jl index 35b86e6e..42e92289 100644 --- a/test/testsuite/mooncake/mooncake.jl +++ b/test/testsuite/mooncake/mooncake.jl @@ -193,61 +193,3 @@ function test_mooncake(T::Type, sz; kwargs...) end end end - - -left_orth_qr(X) = left_orth(X; alg = :qr) -left_orth_polar(X) = left_orth(X; alg = :polar) -left_null_qr(X) = left_null(X; alg = :qr) -right_orth_lq(X) = right_orth(X; alg = :lq) -right_orth_polar(X) = right_orth(X; alg = :polar) -right_null_lq(X) = right_null(X; alg = :lq) - -MatrixAlgebraKit.copy_input(::typeof(left_orth_qr), A) = MatrixAlgebraKit.copy_input(left_orth, A) -MatrixAlgebraKit.copy_input(::typeof(left_orth_polar), A) = MatrixAlgebraKit.copy_input(left_orth, A) -MatrixAlgebraKit.copy_input(::typeof(left_null_qr), A) = MatrixAlgebraKit.copy_input(left_null, A) -MatrixAlgebraKit.copy_input(::typeof(right_orth_lq), A) = MatrixAlgebraKit.copy_input(right_orth, A) -MatrixAlgebraKit.copy_input(::typeof(right_orth_polar), A) = MatrixAlgebraKit.copy_input(right_orth, A) -MatrixAlgebraKit.copy_input(::typeof(right_null_lq), A) = MatrixAlgebraKit.copy_input(right_null, A) - -function test_mooncake_orthnull( - T::Type, sz; - atol::Real = 0, rtol::Real = precision(T), - kwargs... - ) - summary_str = testargs_summary(T, sz) - return @testset "Orthnull Mooncake AD rules $summary_str" begin - A = instantiate_matrix(T, sz) - m, n = size(A) - VC, ΔVC = ad_left_orth_setup(A) - CVᴴ, ΔCVᴴ = ad_right_orth_setup(A) - Mooncake.TestUtils.test_rule(rng, left_orth, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) - test_pullbacks_match(left_orth!, left_orth, A, VC, ΔVC) - Mooncake.TestUtils.test_rule(rng, right_orth, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) - test_pullbacks_match(right_orth!, right_orth, A, CVᴴ, ΔCVᴴ) - - Mooncake.TestUtils.test_rule(rng, left_orth_qr, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) - test_pullbacks_match(((X, VC) -> left_orth!(X, VC; alg = :qr)), left_orth_qr, A, VC, ΔVC) - if m >= n - Mooncake.TestUtils.test_rule(rng, left_orth_polar, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) - test_pullbacks_match(((X, VC) -> left_orth!(X, VC; alg = :polar)), left_orth_polar, A, VC, ΔVC) - end - - N, ΔN = ad_left_null_setup(A) - dN = make_mooncake_tangent(ΔN) - Mooncake.TestUtils.test_rule(rng, left_null_qr, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false, output_tangent = dN) - test_pullbacks_match(((X, N) -> left_null!(X, N; alg = :qr)), left_null_qr, A, N, ΔN) - - Mooncake.TestUtils.test_rule(rng, right_orth_lq, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) - test_pullbacks_match(((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :lq)), right_orth_lq, A, CVᴴ, ΔCVᴴ) - - if m <= n - Mooncake.TestUtils.test_rule(rng, right_orth_polar, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) - test_pullbacks_match(((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :polar)), right_orth_polar, A, CVᴴ, ΔCVᴴ) - end - - Nᴴ, ΔNᴴ = ad_right_null_setup(A) - dNᴴ = make_mooncake_tangent(ΔNᴴ) - Mooncake.TestUtils.test_rule(rng, right_null_lq, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false, output_tangent = dNᴴ) - test_pullbacks_match(((X, Nᴴ) -> right_null!(X, Nᴴ; alg = :lq)), right_null_lq, A, Nᴴ, ΔNᴴ) - end -end diff --git a/test/testsuite/mooncake/orthnull.jl b/test/testsuite/mooncake/orthnull.jl new file mode 100644 index 00000000..946f252b --- /dev/null +++ b/test/testsuite/mooncake/orthnull.jl @@ -0,0 +1,158 @@ +function remove_left_null_gauge_dependence!(ΔN, A, N) + Q, _ = qr_compact(A) + mul!(ΔN, Q, Q' * ΔN) + return ΔN +end + +function remove_right_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) + _, Q = lq_compact(A) + mul!(ΔNᴴ, ΔNᴴ * Q', Q) + return ΔNᴴ +end + +function test_mooncake_orthnull( + T::Type, sz; + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Mooncake orthnull $summary_str" begin + test_mooncake_left_orth(T, sz; kwargs...) + test_mooncake_right_orth(T, sz; kwargs...) + test_mooncake_left_null(T, sz; kwargs...) + test_mooncake_right_null(T, sz; kwargs...) + end +end + +function test_mooncake_left_orth( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) + ) + return @testset "left_orth" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + + @testset "qr" begin + alg = MatrixAlgebraKit.select_algorithm(left_orth!, A, :qr) + VC = left_orth(A, alg) + ΔVC = Mooncake.randn_tangent(rng, VC) + + Mooncake.TestUtils.test_rule( + rng, left_orth, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔVC, is_primitive = false, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, make_input_scratch!, left_orth!, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔVC, is_primitive = false, atol, rtol + ) + end + + if m >= n + @testset "polar" begin + alg = MatrixAlgebraKit.select_algorithm(left_orth!, A, :polar) + VC = left_orth(A, alg) + ΔVC = Mooncake.randn_tangent(rng, VC) + + Mooncake.TestUtils.test_rule( + rng, left_orth, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔVC, is_primitive = false, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, make_input_scratch!, left_orth!, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔVC, is_primitive = false, atol, rtol + ) + end + end + end +end + +function test_mooncake_right_orth( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) + ) + return @testset "right_orth" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + + @testset "lq" begin + alg = MatrixAlgebraKit.select_algorithm(right_orth!, A, :lq) + CVᴴ = right_orth(A, alg) + ΔCVᴴ = Mooncake.randn_tangent(rng, CVᴴ) + + Mooncake.TestUtils.test_rule( + rng, right_orth, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔCVᴴ, is_primitive = false, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, make_input_scratch!, right_orth!, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔCVᴴ, is_primitive = false, atol, rtol + ) + end + + if m <= n + @testset "polar" begin + alg = MatrixAlgebraKit.select_algorithm(right_orth!, A, :polar) + CVᴴ = right_orth(A, alg) + ΔCVᴴ = Mooncake.randn_tangent(rng, CVᴴ) + + Mooncake.TestUtils.test_rule( + rng, right_orth, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔCVᴴ, is_primitive = false, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, make_input_scratch!, right_orth!, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔCVᴴ, is_primitive = false, atol, rtol + ) + end + end + end +end + +function test_mooncake_left_null( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) + ) + return @testset "left_null" begin + A = instantiate_matrix(T, sz) + + @testset "qr" begin + alg = MatrixAlgebraKit.select_algorithm(left_null!, A, :qr) + N = left_null(A, alg) + ΔN = Mooncake.randn_tangent(rng, N) + remove_left_null_gauge_dependence!(ΔN, A, N) + + Mooncake.TestUtils.test_rule( + rng, left_null, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔN, is_primitive = false, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, make_input_scratch!, left_null!, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔN, is_primitive = false, atol, rtol + ) + end + end +end + +function test_mooncake_right_null( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) + ) + return @testset "right_null" begin + A = instantiate_matrix(T, sz) + + @testset "lq" begin + alg = MatrixAlgebraKit.select_algorithm(right_null!, A, :lq) + Nᴴ = right_null(A, alg) + ΔNᴴ = Mooncake.randn_tangent(rng, Nᴴ) + remove_right_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) + + Mooncake.TestUtils.test_rule( + rng, right_null, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔNᴴ, is_primitive = false, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, make_input_scratch!, right_null!, A, alg; + mode = Mooncake.ReverseMode, output_tangent = ΔNᴴ, is_primitive = false, atol, rtol + ) + end + end +end From 6550d8beb1ff7f16f3f56a981f23fdcf45a7e8f8 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 19 Feb 2026 15:43:22 -0500 Subject: [PATCH 12/15] clean up --- test/testsuite/TestSuite.jl | 6 +- test/testsuite/mooncake/eig.jl | 47 ++++++-- test/testsuite/mooncake/eigh.jl | 104 ++++++++-------- test/testsuite/mooncake/lq.jl | 54 +++++++-- test/testsuite/mooncake/mooncake.jl | 180 ++-------------------------- test/testsuite/mooncake/orthnull.jl | 49 +++++++- test/testsuite/mooncake/polar.jl | 22 +++- test/testsuite/mooncake/qr.jl | 54 +++++++-- test/testsuite/mooncake/svd.jl | 41 ++++++- 9 files changed, 284 insertions(+), 273 deletions(-) diff --git a/test/testsuite/TestSuite.jl b/test/testsuite/TestSuite.jl index e03c50ae..7b6d3a66 100644 --- a/test/testsuite/TestSuite.jl +++ b/test/testsuite/TestSuite.jl @@ -13,6 +13,7 @@ using MatrixAlgebraKit using MatrixAlgebraKit: diagview using LinearAlgebra: Diagonal, norm, istriu, istril, I using Random, StableRNGs +using Mooncake using AMDGPU, CUDA const tests = Dict() @@ -109,7 +110,8 @@ include("mooncake/eigh.jl") include("mooncake/svd.jl") include("mooncake/polar.jl") include("mooncake/orthnull.jl") -# include("enzyme.jl") -# include("chainrules.jl") + +include("enzyme.jl") +include("chainrules.jl") end diff --git a/test/testsuite/mooncake/eig.jl b/test/testsuite/mooncake/eig.jl index 6ee5cd91..41d7d7b8 100644 --- a/test/testsuite/mooncake/eig.jl +++ b/test/testsuite/mooncake/eig.jl @@ -1,15 +1,10 @@ -function test_mooncake_eig( - T::Type, sz; - kwargs... - ) - summary_str = testargs_summary(T, sz) - return @testset "Mooncake eig $summary_str" begin - test_mooncake_eig_full(T, sz; kwargs...) - test_mooncake_eig_vals(T, sz; kwargs...) - test_mooncake_eig_trunc(T, sz; kwargs...) - end -end +""" + remove_eig_gauge_dependence!(ΔV, D, V) +Remove the gauge-dependent part from the cotangent `ΔV` of the eigenvector matrix `V`. The +eigenvectors are only determined up to complex phase (and unitary mixing for degenerate +eigenvalues), so the corresponding components of `ΔV` are projected out. +""" function remove_eig_gauge_dependence!( ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D) @@ -20,6 +15,25 @@ function remove_eig_gauge_dependence!( return ΔV end +""" + test_mooncake_eig(T, sz; kwargs...) + +Run all Mooncake AD tests for eigendecompositions of element type `T` and size `sz`. +""" +function test_mooncake_eig(T::Type, sz; kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "Mooncake eig $summary_str" begin + test_mooncake_eig_full(T, sz; kwargs...) + test_mooncake_eig_vals(T, sz; kwargs...) + test_mooncake_eig_trunc(T, sz; kwargs...) + end +end + +""" + test_mooncake_eig_full(T, sz; rng, atol, rtol) + +Test the Mooncake reverse-mode AD rule for `eig_full` and its in-place variant. +""" function test_mooncake_eig_full( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) @@ -42,6 +56,11 @@ function test_mooncake_eig_full( end end +""" + test_mooncake_eig_vals(T, sz; rng, atol, rtol) + +Test the Mooncake reverse-mode AD rule for `eig_vals` and its in-place variant. +""" function test_mooncake_eig_vals( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) @@ -63,6 +82,12 @@ function test_mooncake_eig_vals( end end +""" + test_mooncake_eig_trunc(T, sz; rng, atol, rtol) + +Test the Mooncake reverse-mode AD rules for `eig_trunc`, `eig_trunc_no_error`, and their +in-place variants, over a range of truncation ranks and a tolerance-based truncation. +""" function test_mooncake_eig_trunc( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) diff --git a/test/testsuite/mooncake/eigh.jl b/test/testsuite/mooncake/eigh.jl index ea7a87f5..e4487363 100644 --- a/test/testsuite/mooncake/eigh.jl +++ b/test/testsuite/mooncake/eigh.jl @@ -1,48 +1,11 @@ -function mc_copy_eigh_full(A, alg) - A = (A + A') / 2 - return eigh_full(A, alg) -end - -function mc_copy_eigh_full!(A, DV, alg) - A = (A + A') / 2 - return eigh_full!(A, DV, alg) -end - -function mc_copy_eigh_vals(A, alg) - A = (A + A') / 2 - return eigh_vals(A, alg) -end - -function mc_copy_eigh_vals!(A, D, alg) - A = (A + A') / 2 - return eigh_vals!(A, D, alg) -end - -function mc_copy_eigh_trunc(A, alg) - A = (A + A') / 2 - return eigh_trunc(A, alg) -end - -function mc_copy_eigh_trunc!(A, DV, alg) - A = (A + A') / 2 - return eigh_trunc!(A, DV, alg) -end - -function mc_copy_eigh_trunc_no_error(A, alg) - A = (A + A') / 2 - return eigh_trunc_no_error(A, alg) -end - -function mc_copy_eigh_trunc_no_error!(A, DV, alg) - A = (A + A') / 2 - return eigh_trunc_no_error!(A, DV, alg) -end - -MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_full), A) = MatrixAlgebraKit.copy_input(eigh_full, A) -MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_vals), A) = MatrixAlgebraKit.copy_input(eigh_vals, A) -MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_trunc), A) = MatrixAlgebraKit.copy_input(eigh_trunc, A) -MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_trunc_no_error), A) = MatrixAlgebraKit.copy_input(eigh_trunc, A) - +""" + remove_eigh_gauge_dependence!(ΔV, D, V) + +Remove the gauge-dependent part from the cotangent `ΔV` of the Hermitian eigenvector matrix +`V`. The eigenvectors are only determined up to complex phase (and unitary mixing for +degenerate eigenvalues), so the corresponding anti-Hermitian components of `V' * ΔV` are +projected out. +""" function remove_eigh_gauge_dependence!( ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D) @@ -54,13 +17,28 @@ function remove_eigh_gauge_dependence!( return ΔV end +""" + eigh_wrapper(f, A, alg) + +Wrapper that symmetrizes `A` before calling `f(A, alg)`. Used to test Hermitian +eigendecomposition rules on a general matrix by first projecting onto the Hermitian subspace. +""" eigh_wrapper(f, A, alg) = f(project_hermitian(A), alg) + +""" + eigh!_wrapper(f!, A, alg) + +Wrapper that symmetrizes `A` in-place before calling `f!(A, alg)`, then zeros `A`. Used to +test in-place Hermitian eigendecomposition rules via Mooncake's non-primitive AD path. +""" eigh!_wrapper(f!, A, alg) = (F = f!(project_hermitian!(A), alg); MatrixAlgebraKit.zero!(A); F) -function test_mooncake_eigh( - T::Type, sz; - kwargs... - ) +""" + test_mooncake_eigh(T, sz; kwargs...) + +Run all Mooncake AD tests for Hermitian eigendecompositions of element type `T` and size `sz`. +""" +function test_mooncake_eigh(T::Type, sz; kwargs...) summary_str = testargs_summary(T, sz) return @testset "Mooncake eigh $summary_str" begin test_mooncake_eigh_full(T, sz; kwargs...) @@ -69,6 +47,11 @@ function test_mooncake_eigh( end end +""" + test_mooncake_eigh_full(T, sz; rng, atol, rtol) + +Test the Mooncake reverse-mode AD rule for `eigh_full` and its in-place variant. +""" function test_mooncake_eigh_full( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) @@ -86,11 +69,16 @@ function test_mooncake_eigh_full( ) Mooncake.TestUtils.test_rule( rng, eigh!_wrapper, eigh_full!, A, alg; - mode = Mooncake.ReverseMode, output_tangent = ΔDV, atol, rtol, is_primitive = false + mode = Mooncake.ReverseMode, output_tangent = ΔDV, is_primitive = false, atol, rtol ) end end +""" + test_mooncake_eigh_vals(T, sz; rng, atol, rtol) + +Test the Mooncake reverse-mode AD rule for `eigh_vals` and its in-place variant. +""" function test_mooncake_eigh_vals( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) @@ -107,11 +95,17 @@ function test_mooncake_eigh_vals( ) Mooncake.TestUtils.test_rule( rng, eigh!_wrapper, eigh_vals!, A, alg; - mode = Mooncake.ReverseMode, output_tangent = ΔD, atol, rtol, is_primitive = false + mode = Mooncake.ReverseMode, output_tangent = ΔD, is_primitive = false, atol, rtol ) end end +""" + test_mooncake_eigh_trunc(T, sz; rng, atol, rtol) + +Test the Mooncake reverse-mode AD rules for `eigh_trunc`, `eigh_trunc_no_error`, and their +in-place variants, over a range of truncation ranks. +""" function test_mooncake_eigh_trunc( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) @@ -137,11 +131,11 @@ function test_mooncake_eigh_trunc( Mooncake.TestUtils.test_rule( rng, eigh_wrapper, eigh_trunc_no_error, A, alg_trunc; - mode = Mooncake.ReverseMode, output_tangent = ΔDVtrunc, atol, rtol, is_primitive = false + mode = Mooncake.ReverseMode, output_tangent = ΔDVtrunc, is_primitive = false, atol, rtol ) Mooncake.TestUtils.test_rule( rng, eigh!_wrapper, eigh_trunc_no_error!, A, alg_trunc; - mode = Mooncake.ReverseMode, output_tangent = ΔDVtrunc, atol, rtol, is_primitive = false + mode = Mooncake.ReverseMode, output_tangent = ΔDVtrunc, is_primitive = false, atol, rtol ) DVϵ = eigh_trunc(A, alg_trunc) @@ -150,11 +144,11 @@ function test_mooncake_eigh_trunc( Mooncake.TestUtils.test_rule( rng, eigh_wrapper, eigh_trunc, A, alg_trunc; - mode = Mooncake.ReverseMode, output_tangent = ΔDVϵtrunc, atol, rtol, is_primitive = false + mode = Mooncake.ReverseMode, output_tangent = ΔDVϵtrunc, is_primitive = false, atol, rtol ) Mooncake.TestUtils.test_rule( rng, eigh!_wrapper, eigh_trunc!, A, alg_trunc; - mode = Mooncake.ReverseMode, output_tangent = ΔDVϵtrunc, atol, rtol, is_primitive = false + mode = Mooncake.ReverseMode, output_tangent = ΔDVϵtrunc, is_primitive = false, atol, rtol ) end end diff --git a/test/testsuite/mooncake/lq.jl b/test/testsuite/mooncake/lq.jl index 542ce0cd..3ec0cc1d 100644 --- a/test/testsuite/mooncake/lq.jl +++ b/test/testsuite/mooncake/lq.jl @@ -1,15 +1,10 @@ -function test_mooncake_lq( - T::Type, sz; - kwargs... - ) - summary_str = testargs_summary(T, sz) - return @testset "Mooncake lq $summary_str" begin - test_mooncake_lq_compact(T, sz; kwargs...) - test_mooncake_lq_full(T, sz; kwargs...) - test_mooncake_lq_null(T, sz; kwargs...) - end -end +""" + remove_lq_gauge_dependence!(ΔQ, A, L, Q) +Remove the gauge-dependent part from the cotangent `ΔQ` of the full-LQ orthogonal factor `Q`. +For the full LQ decomposition, the extra rows of `Q` beyond `min(m, n)` are not uniquely +determined by `A`, so the corresponding part of `ΔQ` is projected to remove this ambiguity. +""" function remove_lq_gauge_dependence!(ΔQ, A, L, Q) m, n = size(A) minmn = min(m, n) @@ -21,12 +16,38 @@ function remove_lq_gauge_dependence!(ΔQ, A, L, Q) return ΔQ end +""" + remove_lq_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) + +Remove the gauge-dependent part from the cotangent `ΔNᴴ` of the LQ null space `Nᴴ`. The null +space is only determined up to a unitary rotation, so `ΔNᴴ` is projected onto the row span of +the compact LQ factor `Q₁`. +""" function remove_lq_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) _, Q = lq_compact(A) ΔNᴴQᴴ = ΔNᴴ * Q' return mul!(ΔNᴴ, ΔNᴴQᴴ, Q) end +""" + test_mooncake_lq(T, sz; kwargs...) + +Run all Mooncake AD tests for LQ decompositions of element type `T` and size `sz`. +""" +function test_mooncake_lq(T::Type, sz; kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "Mooncake lq $summary_str" begin + test_mooncake_lq_compact(T, sz; kwargs...) + test_mooncake_lq_full(T, sz; kwargs...) + test_mooncake_lq_null(T, sz; kwargs...) + end +end + +""" + test_mooncake_lq_compact(T, sz; rng, atol, rtol) + +Test the Mooncake reverse-mode AD rule for `lq_compact` and its in-place variant. +""" function test_mooncake_lq_compact( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) @@ -49,6 +70,11 @@ function test_mooncake_lq_compact( end end +""" + test_mooncake_lq_full(T, sz; rng, atol, rtol) + +Test the Mooncake reverse-mode AD rule for `lq_full` and its in-place variant. +""" function test_mooncake_lq_full( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) @@ -71,6 +97,11 @@ function test_mooncake_lq_full( end end +""" + test_mooncake_lq_null(T, sz; rng, atol, rtol) + +Test the Mooncake reverse-mode AD rule for `lq_null` and its in-place variant. +""" function test_mooncake_lq_null( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) @@ -86,7 +117,6 @@ function test_mooncake_lq_null( rng, lq_null, A, alg; mode = Mooncake.ReverseMode, output_tangent = ΔNᴴ, atol, rtol ) - Nᴴ, ΔNᴴ = ad_lq_null_setup(A) Mooncake.TestUtils.test_rule( rng, make_input_scratch!, lq_null!, A, alg; mode = Mooncake.ReverseMode, output_tangent = ΔNᴴ, atol, rtol, is_primitive = false diff --git a/test/testsuite/mooncake/mooncake.jl b/test/testsuite/mooncake/mooncake.jl index 42e92289..6488cc4e 100644 --- a/test/testsuite/mooncake/mooncake.jl +++ b/test/testsuite/mooncake/mooncake.jl @@ -1,181 +1,23 @@ -using TestExtras -using MatrixAlgebraKit -using Mooncake, Mooncake.TestUtils -using Mooncake: rrule!! -using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD, eigh_trunc -using LinearAlgebra: BlasFloat -using GenericLinearAlgebra - - -make_mooncake_tangent(ΔAelem::T) where {T <: Number} = ΔAelem -make_mooncake_tangent(ΔA::AbstractMatrix) = ΔA -make_mooncake_tangent(ΔA::AbstractVector) = ΔA -make_mooncake_tangent(ΔD::Diagonal) = Mooncake.build_tangent(typeof(ΔD), diagview(ΔD)) - -make_mooncake_tangent(T::Tuple) = Mooncake.build_tangent(typeof(T), make_mooncake_tangent.(T)...) - -make_mooncake_fdata(x) = make_mooncake_tangent(x) -make_mooncake_fdata(x::Diagonal) = Mooncake.FData((diag = make_mooncake_tangent(x.diag),)) -make_mooncake_fdata(x::Tuple) = map(make_mooncake_fdata, x) - -# copies a preset tangent into a Mooncake CoDual -# for use in the pullback. -function copy_tangent(var::Mooncake.CoDual, Δargs) - dargs = make_mooncake_fdata(deepcopy(Δargs)) - copyto!(Mooncake.tangent(var), dargs) - return -end - -function copy_tangent(var::Mooncake.CoDual, Δargs::Tuple) - dargs = make_mooncake_fdata.(deepcopy(Δargs)) - for (var_tangent, darg) in zip(Mooncake.tangent(var), dargs) - if var_tangent isa Mooncake.FData - for (var_f, darg_f) in zip(Mooncake._fields(var_tangent), Mooncake._fields(darg)) - copyto!(var_f, darg_f) - end - else - copyto!(var_tangent, darg) - end - end - return -end - -# no `alg` argument -function _get_copying_derivative(f, rrule, A, ΔA, args, Δargs, ::Nothing, rdata) - dA_copy = make_mooncake_fdata(copy(ΔA)) - A_copy = copy(A) - A_dA = Mooncake.CoDual(A_copy, dA_copy) - copy_out, copy_pb!! = rrule(Mooncake.CoDual(f, Mooncake.NoFData()), A_dA) - # copy Δargs into tangent of the output variable for the pullback check - copy_tangent(copy_out, Δargs) - copy_pb!!(rdata) - @test Mooncake.primal(A_dA) == A - return dA_copy, Mooncake.tangent(copy_out) -end - -# `alg` argument -function _get_copying_derivative(f, rrule, A, ΔA, args, Δargs, alg, rdata) - dA_copy = make_mooncake_fdata(copy(ΔA)) - A_copy = copy(A) - A_dA = Mooncake.CoDual(A_copy, dA_copy) - copy_out, copy_pb!! = rrule(Mooncake.CoDual(f, Mooncake.NoFData()), A_dA, Mooncake.CoDual(alg, Mooncake.NoFData())) - # copy Δargs into tangent of the output variable for the pullback check - copy_tangent(copy_out, Δargs) - copy_pb!!(rdata) - @test Mooncake.primal(A_dA) == A - return dA_copy, Mooncake.tangent(copy_out) -end - -function _get_inplace_derivative(f!, A, ΔA, args, Δargs, ::Nothing, rdata; ȳ = Δargs) - dA_inplace = make_mooncake_fdata(copy(ΔA)) - A_inplace = copy(A) - args_copy = deepcopy(args) - dargs_inplace = make_mooncake_fdata(deepcopy(Δargs)) - # not every f! has a handwritten rrule!! - inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args)} - has_handwritten_rule = hasmethod(Mooncake.rrule!!, inplace_sig) - A_dA = Mooncake.CoDual(A_inplace, dA_inplace) - args_dargs = Mooncake.CoDual(args_copy, dargs_inplace) - if has_handwritten_rule - inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), A_dA, args_dargs) - else - inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args)} - rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) - inplace_rrule = Mooncake.build_rrule(rvs_interp, inplace_sig) - inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), A_dA, args_dargs) - end - # copy reference derivative of output ȳ into inplace_out - # needed for inplace methods like svd_trunc! that generate - # new output variables - copy_tangent(inplace_out, ȳ) - inplace_pb!!(rdata) - @test Mooncake.primal(A_dA) == A - @test Mooncake.primal(args_dargs) == args_copy - return dA_inplace, Mooncake.tangent(inplace_out) -end - -function _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata; ȳ = Δargs) - dA_inplace = make_mooncake_fdata(copy(ΔA)) - A_inplace = copy(A) - args_copy = deepcopy(args) - dargs_inplace = make_mooncake_fdata(deepcopy(Δargs)) - # not every f! has a handwritten rrule!! - inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)} - has_handwritten_rule = hasmethod(Mooncake.rrule!!, inplace_sig) - A_dA = Mooncake.CoDual(A_inplace, dA_inplace) - args_dargs = Mooncake.CoDual(args_copy, dargs_inplace) - if has_handwritten_rule - inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), A_dA, args_dargs, Mooncake.CoDual(alg, Mooncake.NoFData())) - else - inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)} - rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) - inplace_rrule = Mooncake.build_rrule(rvs_interp, inplace_sig) - inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), A_dA, args_dargs, Mooncake.CoDual(alg, Mooncake.NoFData())) - end - # copy reference derivative of output ȳ into inplace_out - # needed for inplace methods like svd_trunc! that generate - # new output variables - copy_tangent(inplace_out, ȳ) - inplace_pb!!(rdata) - @test Mooncake.primal(A_dA) == A - @test Mooncake.primal(args_dargs) == args_copy - return dA_inplace, Mooncake.tangent(inplace_out) -end - """ - test_pullbacks_match(f!, f, A, args, Δargs, alg = nothing; rdata = Mooncake.NoRData()) - -Compare the result of running the *in-place, mutating* function `f!`'s reverse rule -with the result of running its *non-mutating* partner function `f`'s reverse rule. -We must compare directly because many of the mutating functions modify `A` as a -scratch workspace, making testing `f!` against finite differences infeasible. + make_input_scratch!(f!, A, alg) -The arguments to this function are: - - `f!` the mutating, in-place version of the function (accepts `args` for the function result) - - `f` the non-mutating version of the function (does not accept `args` for the function result) - - `A` the input matrix to factorize - - `args` preallocated output for `f!` (e.g. `Q` and `R` matrices for `qr_compact!`) - - `Δargs` precomputed derivatives of `args` for pullbacks of `f` and `f!`, to ensure they receive the same input - - `alg` optional algorithm keyword argument - - `rdata` Mooncake reverse data to supply to the pullback, in case `f` and `f!` return scalar results (as truncating functions do) +Helper for testing in-place Mooncake rules. Calls `f!(A, alg)`, zeros out `A` (since `f!` +uses it as scratch space), and returns the output. This allows `Mooncake.TestUtils.test_rule` +to verify the reverse rule of `f!` without pre-allocating the output structure. """ -function test_pullbacks_match(f!, f, A, args, Δargs, alg = nothing; rdata = Mooncake.NoRData(), ȳ = deepcopy(Δargs)) - sig = isnothing(alg) ? Tuple{typeof(f), typeof(A)} : Tuple{typeof(f), typeof(A), typeof(alg)} - rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) - rrule = Mooncake.build_rrule(rvs_interp, sig) - ΔA = randn(rng, eltype(A), size(A)) - - copy_args = isa(args, Tuple) ? copy.(args) : copy(args) - inplace_args = isa(args, Tuple) ? copy.(args) : copy(args) - dA_copy, dargs_copy = _get_copying_derivative(f, rrule, A, ΔA, copy_args, ȳ, alg, rdata) - dA_inplace, dargs_inplace = _get_inplace_derivative(f!, A, ΔA, inplace_args, Δargs, alg, rdata; ȳ) - - dA_inplace_ = Mooncake.arrayify(A, dA_inplace)[2] - dA_copy_ = Mooncake.arrayify(A, dA_copy)[2] - @test dA_inplace_ ≈ dA_copy_ - @test copy_args == inplace_args - if dargs_copy isa Tuple - for (darg_copy_, darg_inplace_) in zip(dargs_copy, dargs_inplace) - if darg_copy_ isa Mooncake.FData - for (c_f, i_f) in zip(Mooncake._fields(darg_copy_), Mooncake._fields(darg_inplace_)) - @test c_f == i_f - end - else - @test darg_copy_ == darg_inplace_ - end - end - else - @test dargs_copy == dargs_inplace - end - return -end - function make_input_scratch!(f!, A, alg) F′ = f!(A, alg) MatrixAlgebraKit.zero!(A) return F′ end +""" + test_mooncake(T, sz; kwargs...) + +Run all Mooncake AD tests for element type `T` and size `sz`. Dispatches to per-decomposition +sub-suites. Square or vector sizes enable the eigendecomposition tests; element types that are +plain number types enable the orthnull tests. +""" function test_mooncake(T::Type, sz; kwargs...) summary_str = testargs_summary(T, sz) return @testset "Mooncake AD $summary_str" verbose = true begin diff --git a/test/testsuite/mooncake/orthnull.jl b/test/testsuite/mooncake/orthnull.jl index 946f252b..efaeb2b2 100644 --- a/test/testsuite/mooncake/orthnull.jl +++ b/test/testsuite/mooncake/orthnull.jl @@ -1,19 +1,36 @@ +""" + remove_left_null_gauge_dependence!(ΔN, A, N) + +Remove the gauge-dependent part from the cotangent `ΔN` of the left null space `N`. The null +space basis is only determined up to a unitary rotation, so `ΔN` is projected onto the column +span of the compact QR factor `Q₁` of `A`. +""" function remove_left_null_gauge_dependence!(ΔN, A, N) Q, _ = qr_compact(A) mul!(ΔN, Q, Q' * ΔN) return ΔN end +""" + remove_right_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) + +Remove the gauge-dependent part from the cotangent `ΔNᴴ` of the right null space `Nᴴ`. The +null space basis is only determined up to a unitary rotation, so `ΔNᴴ` is projected onto the +row span of the compact LQ factor `Q₁` of `A`. +""" function remove_right_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) _, Q = lq_compact(A) mul!(ΔNᴴ, ΔNᴴ * Q', Q) return ΔNᴴ end -function test_mooncake_orthnull( - T::Type, sz; - kwargs... - ) +""" + test_mooncake_orthnull(T, sz; kwargs...) + +Run all Mooncake AD tests for orthogonal basis and null space computations of element type `T` +and size `sz`. +""" +function test_mooncake_orthnull(T::Type, sz; kwargs...) summary_str = testargs_summary(T, sz) return @testset "Mooncake orthnull $summary_str" begin test_mooncake_left_orth(T, sz; kwargs...) @@ -23,6 +40,12 @@ function test_mooncake_orthnull( end end +""" + test_mooncake_left_orth(T, sz; rng, atol, rtol) + +Test the Mooncake reverse-mode AD rules for `left_orth` with QR and polar (when `m >= n`) +algorithms, and their in-place variants. +""" function test_mooncake_left_orth( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) @@ -65,6 +88,12 @@ function test_mooncake_left_orth( end end +""" + test_mooncake_right_orth(T, sz; rng, atol, rtol) + +Test the Mooncake reverse-mode AD rules for `right_orth` with LQ and polar (when `m <= n`) +algorithms, and their in-place variants. +""" function test_mooncake_right_orth( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) @@ -107,6 +136,12 @@ function test_mooncake_right_orth( end end +""" + test_mooncake_left_null(T, sz; rng, atol, rtol) + +Test the Mooncake reverse-mode AD rule for `left_null` with the QR algorithm and its +in-place variant. +""" function test_mooncake_left_null( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) @@ -132,6 +167,12 @@ function test_mooncake_left_null( end end +""" + test_mooncake_right_null(T, sz; rng, atol, rtol) + +Test the Mooncake reverse-mode AD rule for `right_null` with the LQ algorithm and its +in-place variant. +""" function test_mooncake_right_null( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) diff --git a/test/testsuite/mooncake/polar.jl b/test/testsuite/mooncake/polar.jl index 469c6af0..4f04b06b 100644 --- a/test/testsuite/mooncake/polar.jl +++ b/test/testsuite/mooncake/polar.jl @@ -1,7 +1,9 @@ -function test_mooncake_polar( - T::Type, sz; - kwargs... - ) +""" + test_mooncake_polar(T, sz; kwargs...) + +Run all Mooncake AD tests for polar decompositions of element type `T` and size `sz`. +""" +function test_mooncake_polar(T::Type, sz; kwargs...) summary_str = testargs_summary(T, sz) return @testset "Mooncake polar $summary_str" begin test_mooncake_left_polar(T, sz; kwargs...) @@ -9,6 +11,12 @@ function test_mooncake_polar( end end +""" + test_mooncake_left_polar(T, sz; rng, atol, rtol) + +Test the Mooncake reverse-mode AD rule for `left_polar` and its in-place variant. Only runs +for tall or square matrices (`m >= n`). +""" function test_mooncake_left_polar( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) @@ -33,6 +41,12 @@ function test_mooncake_left_polar( end end +""" + test_mooncake_right_polar(T, sz; rng, atol, rtol) + +Test the Mooncake reverse-mode AD rule for `right_polar` and its in-place variant. Only runs +for wide or square matrices (`m <= n`). +""" function test_mooncake_right_polar( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) diff --git a/test/testsuite/mooncake/qr.jl b/test/testsuite/mooncake/qr.jl index eb4f3712..da513d80 100644 --- a/test/testsuite/mooncake/qr.jl +++ b/test/testsuite/mooncake/qr.jl @@ -1,15 +1,10 @@ -function test_mooncake_qr( - T::Type, sz; - kwargs... - ) - summary_str = testargs_summary(T, sz) - return @testset "Mooncake qr $summary_str" begin - test_mooncake_qr_compact(T, sz; kwargs...) - test_mooncake_qr_full(T, sz; kwargs...) - test_mooncake_qr_null(T, sz; kwargs...) - end -end +""" + remove_qr_gauge_dependence!(ΔQ, A, Q, R) +Remove the gauge-dependent part from the cotangent `ΔQ` of the full-QR orthogonal factor `Q`. +For the full QR decomposition, the extra columns of `Q` beyond `min(m, n)` are not uniquely +determined by `A`, so the corresponding part of `ΔQ` is projected to remove this ambiguity. +""" function remove_qr_gauge_dependence!(ΔQ, A, Q, R) m, n = size(A) minmn = min(m, n) @@ -21,11 +16,37 @@ function remove_qr_gauge_dependence!(ΔQ, A, Q, R) return ΔQ end +""" + remove_qr_null_gauge_dependence!(ΔN, A, N) + +Remove the gauge-dependent part from the cotangent `ΔN` of the QR null space `N`. The null +space is only determined up to a unitary rotation, so `ΔN` is projected onto the column span +of the compact QR factor `Q₁`. +""" function remove_qr_null_gauge_dependence!(ΔN, A, N) Q, _ = qr_compact(A) return mul!(ΔN, Q, Q' * ΔN) end +""" + test_mooncake_qr(T, sz; kwargs...) + +Run all Mooncake AD tests for QR decompositions of element type `T` and size `sz`. +""" +function test_mooncake_qr(T::Type, sz; kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "Mooncake qr $summary_str" begin + test_mooncake_qr_compact(T, sz; kwargs...) + test_mooncake_qr_full(T, sz; kwargs...) + test_mooncake_qr_null(T, sz; kwargs...) + end +end + +""" + test_mooncake_qr_compact(T, sz; rng, atol, rtol) + +Test the Mooncake reverse-mode AD rule for `qr_compact` and its in-place variant. +""" function test_mooncake_qr_compact( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) @@ -48,6 +69,11 @@ function test_mooncake_qr_compact( end end +""" + test_mooncake_qr_full(T, sz; rng, atol, rtol) + +Test the Mooncake reverse-mode AD rule for `qr_full` and its in-place variant. +""" function test_mooncake_qr_full( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) @@ -70,6 +96,11 @@ function test_mooncake_qr_full( end end +""" + test_mooncake_qr_null(T, sz; rng, atol, rtol) + +Test the Mooncake reverse-mode AD rule for `qr_null` and its in-place variant. +""" function test_mooncake_qr_null( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) @@ -85,7 +116,6 @@ function test_mooncake_qr_null( rng, qr_null, A, alg; mode = Mooncake.ReverseMode, output_tangent = ΔN, atol, rtol ) - N, ΔN = ad_qr_null_setup(A) Mooncake.TestUtils.test_rule( rng, make_input_scratch!, qr_null!, A, alg; mode = Mooncake.ReverseMode, output_tangent = ΔN, atol, rtol, is_primitive = false diff --git a/test/testsuite/mooncake/svd.jl b/test/testsuite/mooncake/svd.jl index d31477ae..0c7cbfad 100644 --- a/test/testsuite/mooncake/svd.jl +++ b/test/testsuite/mooncake/svd.jl @@ -1,3 +1,12 @@ +""" + remove_svd_gauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) + +Remove the gauge-dependent part from the cotangents `ΔU` and `ΔVᴴ` of the SVD factors. The +singular vectors are only determined up to a common complex phase per singular value (and +unitary mixing for degenerate singular values), so the corresponding anti-Hermitian components +of `U₁' * ΔU₁ + Vᴴ₁ * ΔVᴴ₁'` are projected out. For the full SVD, the extra columns of `U` +and rows of `Vᴴ` beyond `min(m, n)` are additionally zeroed out. +""" function remove_svd_gauge_dependence!( ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(S) @@ -17,10 +26,12 @@ function remove_svd_gauge_dependence!( return ΔU, ΔVᴴ end -function test_mooncake_svd( - T::Type, sz; - kwargs... - ) +""" + test_mooncake_svd(T, sz; kwargs...) + +Run all Mooncake AD tests for SVD decompositions of element type `T` and size `sz`. +""" +function test_mooncake_svd(T::Type, sz; kwargs...) summary_str = testargs_summary(T, sz) return @testset "Mooncake svd $summary_str" begin test_mooncake_svd_compact(T, sz; kwargs...) @@ -30,6 +41,11 @@ function test_mooncake_svd( end end +""" + test_mooncake_svd_compact(T, sz; rng, atol, rtol) + +Test the Mooncake reverse-mode AD rule for `svd_compact` and its in-place variant. +""" function test_mooncake_svd_compact( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) @@ -52,6 +68,12 @@ function test_mooncake_svd_compact( end end +""" + test_mooncake_svd_full(T, sz; rng, atol, rtol) + +Test the Mooncake reverse-mode AD rule for `svd_full` and its in-place variant. The +gauge-dependent extra columns of `U` and rows of `Vᴴ` are zeroed out in the cotangent. +""" function test_mooncake_svd_full( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) @@ -74,6 +96,11 @@ function test_mooncake_svd_full( end end +""" + test_mooncake_svd_vals(T, sz; rng, atol, rtol) + +Test the Mooncake reverse-mode AD rule for `svd_vals` and its in-place variant. +""" function test_mooncake_svd_vals( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) @@ -95,6 +122,12 @@ function test_mooncake_svd_vals( end end +""" + test_mooncake_svd_trunc(T, sz; rng, atol, rtol) + +Test the Mooncake reverse-mode AD rules for `svd_trunc`, `svd_trunc_no_error`, and their +in-place variants, over a range of truncation ranks and a tolerance-based truncation. +""" function test_mooncake_svd_trunc( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) From b3329de51a9aa8dc4d7403e13366edeca545c38d Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 19 Feb 2026 15:57:45 -0500 Subject: [PATCH 13/15] rename `call_and_zero!` --- test/testsuite/mooncake/eig.jl | 8 ++++---- test/testsuite/mooncake/lq.jl | 6 +++--- test/testsuite/mooncake/mooncake.jl | 11 ++++++----- test/testsuite/mooncake/orthnull.jl | 12 ++++++------ test/testsuite/mooncake/polar.jl | 4 ++-- test/testsuite/mooncake/qr.jl | 6 +++--- test/testsuite/mooncake/svd.jl | 14 +++++++------- 7 files changed, 31 insertions(+), 30 deletions(-) diff --git a/test/testsuite/mooncake/eig.jl b/test/testsuite/mooncake/eig.jl index 41d7d7b8..2fec4b5d 100644 --- a/test/testsuite/mooncake/eig.jl +++ b/test/testsuite/mooncake/eig.jl @@ -50,7 +50,7 @@ function test_mooncake_eig_full( mode = Mooncake.ReverseMode, output_tangent = ΔDV, atol, rtol ) Mooncake.TestUtils.test_rule( - rng, make_input_scratch!, eig_full!, A, alg; + rng, call_and_zero!, eig_full!, A, alg; mode = Mooncake.ReverseMode, output_tangent = ΔDV, atol, rtol, is_primitive = false ) end @@ -76,7 +76,7 @@ function test_mooncake_eig_vals( mode = Mooncake.ReverseMode, output_tangent = ΔD, atol, rtol ) Mooncake.TestUtils.test_rule( - rng, make_input_scratch!, eig_vals!, A, alg; + rng, call_and_zero!, eig_vals!, A, alg; mode = Mooncake.ReverseMode, output_tangent = ΔD, atol, rtol, is_primitive = false ) end @@ -116,7 +116,7 @@ function test_mooncake_eig_trunc( mode = Mooncake.ReverseMode, output_tangent = ΔDVtrunc, atol, rtol ) Mooncake.TestUtils.test_rule( - rng, make_input_scratch!, eig_trunc_no_error!, A, alg_trunc; + rng, call_and_zero!, eig_trunc_no_error!, A, alg_trunc; mode = Mooncake.ReverseMode, output_tangent = ΔDVtrunc, atol, rtol, is_primitive = false ) @@ -129,7 +129,7 @@ function test_mooncake_eig_trunc( mode = Mooncake.ReverseMode, output_tangent = ΔDVϵtrunc, atol, rtol ) Mooncake.TestUtils.test_rule( - rng, make_input_scratch!, eig_trunc!, A, alg_trunc; + rng, call_and_zero!, eig_trunc!, A, alg_trunc; mode = Mooncake.ReverseMode, output_tangent = ΔDVϵtrunc, atol, rtol, is_primitive = false ) end diff --git a/test/testsuite/mooncake/lq.jl b/test/testsuite/mooncake/lq.jl index 3ec0cc1d..66917a76 100644 --- a/test/testsuite/mooncake/lq.jl +++ b/test/testsuite/mooncake/lq.jl @@ -64,7 +64,7 @@ function test_mooncake_lq_compact( mode = Mooncake.ReverseMode, output_tangent = ΔLQ, atol, rtol ) Mooncake.TestUtils.test_rule( - rng, make_input_scratch!, lq_compact!, A, alg; + rng, call_and_zero!, lq_compact!, A, alg; mode = Mooncake.ReverseMode, output_tangent = ΔLQ, atol, rtol, is_primitive = false ) end @@ -91,7 +91,7 @@ function test_mooncake_lq_full( mode = Mooncake.ReverseMode, output_tangent = ΔLQ, atol, rtol ) Mooncake.TestUtils.test_rule( - rng, make_input_scratch!, lq_full!, A, alg; + rng, call_and_zero!, lq_full!, A, alg; mode = Mooncake.ReverseMode, output_tangent = ΔLQ, atol, rtol, is_primitive = false ) end @@ -118,7 +118,7 @@ function test_mooncake_lq_null( mode = Mooncake.ReverseMode, output_tangent = ΔNᴴ, atol, rtol ) Mooncake.TestUtils.test_rule( - rng, make_input_scratch!, lq_null!, A, alg; + rng, call_and_zero!, lq_null!, A, alg; mode = Mooncake.ReverseMode, output_tangent = ΔNᴴ, atol, rtol, is_primitive = false ) end diff --git a/test/testsuite/mooncake/mooncake.jl b/test/testsuite/mooncake/mooncake.jl index 6488cc4e..22f0c729 100644 --- a/test/testsuite/mooncake/mooncake.jl +++ b/test/testsuite/mooncake/mooncake.jl @@ -1,11 +1,12 @@ """ - make_input_scratch!(f!, A, alg) + call_and_zero!(f!, A, alg) -Helper for testing in-place Mooncake rules. Calls `f!(A, alg)`, zeros out `A` (since `f!` -uses it as scratch space), and returns the output. This allows `Mooncake.TestUtils.test_rule` -to verify the reverse rule of `f!` without pre-allocating the output structure. +Helper for testing in-place Mooncake rules. +Calls `f!(A, alg)`, followed by zeroing out `A` and returns the output of `f!`. +This allows `Mooncake.TestUtils.test_rule` to verify the reverse rule of `f!` through finite differences, +without counting the contributions of `A`, as this is used solely as scratch space. """ -function make_input_scratch!(f!, A, alg) +function call_and_zero!(f!, A, alg) F′ = f!(A, alg) MatrixAlgebraKit.zero!(A) return F′ diff --git a/test/testsuite/mooncake/orthnull.jl b/test/testsuite/mooncake/orthnull.jl index efaeb2b2..cbe55a69 100644 --- a/test/testsuite/mooncake/orthnull.jl +++ b/test/testsuite/mooncake/orthnull.jl @@ -64,7 +64,7 @@ function test_mooncake_left_orth( mode = Mooncake.ReverseMode, output_tangent = ΔVC, is_primitive = false, atol, rtol ) Mooncake.TestUtils.test_rule( - rng, make_input_scratch!, left_orth!, A, alg; + rng, call_and_zero!, left_orth!, A, alg; mode = Mooncake.ReverseMode, output_tangent = ΔVC, is_primitive = false, atol, rtol ) end @@ -80,7 +80,7 @@ function test_mooncake_left_orth( mode = Mooncake.ReverseMode, output_tangent = ΔVC, is_primitive = false, atol, rtol ) Mooncake.TestUtils.test_rule( - rng, make_input_scratch!, left_orth!, A, alg; + rng, call_and_zero!, left_orth!, A, alg; mode = Mooncake.ReverseMode, output_tangent = ΔVC, is_primitive = false, atol, rtol ) end @@ -112,7 +112,7 @@ function test_mooncake_right_orth( mode = Mooncake.ReverseMode, output_tangent = ΔCVᴴ, is_primitive = false, atol, rtol ) Mooncake.TestUtils.test_rule( - rng, make_input_scratch!, right_orth!, A, alg; + rng, call_and_zero!, right_orth!, A, alg; mode = Mooncake.ReverseMode, output_tangent = ΔCVᴴ, is_primitive = false, atol, rtol ) end @@ -128,7 +128,7 @@ function test_mooncake_right_orth( mode = Mooncake.ReverseMode, output_tangent = ΔCVᴴ, is_primitive = false, atol, rtol ) Mooncake.TestUtils.test_rule( - rng, make_input_scratch!, right_orth!, A, alg; + rng, call_and_zero!, right_orth!, A, alg; mode = Mooncake.ReverseMode, output_tangent = ΔCVᴴ, is_primitive = false, atol, rtol ) end @@ -160,7 +160,7 @@ function test_mooncake_left_null( mode = Mooncake.ReverseMode, output_tangent = ΔN, is_primitive = false, atol, rtol ) Mooncake.TestUtils.test_rule( - rng, make_input_scratch!, left_null!, A, alg; + rng, call_and_zero!, left_null!, A, alg; mode = Mooncake.ReverseMode, output_tangent = ΔN, is_primitive = false, atol, rtol ) end @@ -191,7 +191,7 @@ function test_mooncake_right_null( mode = Mooncake.ReverseMode, output_tangent = ΔNᴴ, is_primitive = false, atol, rtol ) Mooncake.TestUtils.test_rule( - rng, make_input_scratch!, right_null!, A, alg; + rng, call_and_zero!, right_null!, A, alg; mode = Mooncake.ReverseMode, output_tangent = ΔNᴴ, is_primitive = false, atol, rtol ) end diff --git a/test/testsuite/mooncake/polar.jl b/test/testsuite/mooncake/polar.jl index 4f04b06b..6c1f2758 100644 --- a/test/testsuite/mooncake/polar.jl +++ b/test/testsuite/mooncake/polar.jl @@ -34,7 +34,7 @@ function test_mooncake_left_polar( mode = Mooncake.ReverseMode, output_tangent = ΔWP, atol, rtol ) Mooncake.TestUtils.test_rule( - rng, make_input_scratch!, left_polar!, A, alg; + rng, call_and_zero!, left_polar!, A, alg; mode = Mooncake.ReverseMode, output_tangent = ΔWP, atol, rtol, is_primitive = false ) end @@ -64,7 +64,7 @@ function test_mooncake_right_polar( mode = Mooncake.ReverseMode, output_tangent = ΔPWᴴ, atol, rtol ) Mooncake.TestUtils.test_rule( - rng, make_input_scratch!, right_polar!, A, alg; + rng, call_and_zero!, right_polar!, A, alg; mode = Mooncake.ReverseMode, output_tangent = ΔPWᴴ, atol, rtol, is_primitive = false ) end diff --git a/test/testsuite/mooncake/qr.jl b/test/testsuite/mooncake/qr.jl index da513d80..2e732f9c 100644 --- a/test/testsuite/mooncake/qr.jl +++ b/test/testsuite/mooncake/qr.jl @@ -63,7 +63,7 @@ function test_mooncake_qr_compact( mode = Mooncake.ReverseMode, output_tangent = ΔQR, atol, rtol ) Mooncake.TestUtils.test_rule( - rng, make_input_scratch!, qr_compact!, A, alg; + rng, call_and_zero!, qr_compact!, A, alg; mode = Mooncake.ReverseMode, output_tangent = ΔQR, atol, rtol, is_primitive = false ) end @@ -90,7 +90,7 @@ function test_mooncake_qr_full( mode = Mooncake.ReverseMode, output_tangent = ΔQR, atol, rtol ) Mooncake.TestUtils.test_rule( - rng, make_input_scratch!, qr_full!, A, alg; + rng, call_and_zero!, qr_full!, A, alg; mode = Mooncake.ReverseMode, output_tangent = ΔQR, atol, rtol, is_primitive = false ) end @@ -117,7 +117,7 @@ function test_mooncake_qr_null( mode = Mooncake.ReverseMode, output_tangent = ΔN, atol, rtol ) Mooncake.TestUtils.test_rule( - rng, make_input_scratch!, qr_null!, A, alg; + rng, call_and_zero!, qr_null!, A, alg; mode = Mooncake.ReverseMode, output_tangent = ΔN, atol, rtol, is_primitive = false ) end diff --git a/test/testsuite/mooncake/svd.jl b/test/testsuite/mooncake/svd.jl index 0c7cbfad..9b5029aa 100644 --- a/test/testsuite/mooncake/svd.jl +++ b/test/testsuite/mooncake/svd.jl @@ -62,7 +62,7 @@ function test_mooncake_svd_compact( mode = Mooncake.ReverseMode, output_tangent = ΔUSVᴴ, atol, rtol ) Mooncake.TestUtils.test_rule( - rng, make_input_scratch!, svd_compact!, A, alg; + rng, call_and_zero!, svd_compact!, A, alg; mode = Mooncake.ReverseMode, output_tangent = ΔUSVᴴ, atol, rtol, is_primitive = false ) end @@ -90,7 +90,7 @@ function test_mooncake_svd_full( mode = Mooncake.ReverseMode, output_tangent = ΔUSVᴴ, atol, rtol ) Mooncake.TestUtils.test_rule( - rng, make_input_scratch!, svd_full!, A, alg; + rng, call_and_zero!, svd_full!, A, alg; mode = Mooncake.ReverseMode, output_tangent = ΔUSVᴴ, atol, rtol, is_primitive = false ) end @@ -116,7 +116,7 @@ function test_mooncake_svd_vals( mode = Mooncake.ReverseMode, output_tangent = ΔS, atol, rtol ) Mooncake.TestUtils.test_rule( - rng, make_input_scratch!, svd_vals!, A, alg; + rng, call_and_zero!, svd_vals!, A, alg; mode = Mooncake.ReverseMode, output_tangent = ΔS, atol, rtol, is_primitive = false ) end @@ -157,7 +157,7 @@ function test_mooncake_svd_trunc( mode = Mooncake.ReverseMode, output_tangent = ΔUSVᴴtrunc, atol, rtol ) Mooncake.TestUtils.test_rule( - rng, make_input_scratch!, svd_trunc_no_error!, A, alg_trunc; + rng, call_and_zero!, svd_trunc_no_error!, A, alg_trunc; mode = Mooncake.ReverseMode, output_tangent = ΔUSVᴴtrunc, atol, rtol, is_primitive = false ) @@ -170,7 +170,7 @@ function test_mooncake_svd_trunc( mode = Mooncake.ReverseMode, output_tangent = ΔUSVᴴϵtrunc, atol, rtol ) Mooncake.TestUtils.test_rule( - rng, make_input_scratch!, svd_trunc!, A, alg_trunc; + rng, call_and_zero!, svd_trunc!, A, alg_trunc; mode = Mooncake.ReverseMode, output_tangent = ΔUSVᴴϵtrunc, atol, rtol, is_primitive = false ) end @@ -189,7 +189,7 @@ function test_mooncake_svd_trunc( mode = Mooncake.ReverseMode, output_tangent = ΔUSVᴴtrunc, atol, rtol ) Mooncake.TestUtils.test_rule( - rng, make_input_scratch!, svd_trunc_no_error!, A, alg_trunc; + rng, call_and_zero!, svd_trunc_no_error!, A, alg_trunc; mode = Mooncake.ReverseMode, output_tangent = ΔUSVᴴtrunc, atol, rtol, is_primitive = false ) @@ -202,7 +202,7 @@ function test_mooncake_svd_trunc( mode = Mooncake.ReverseMode, output_tangent = ΔUSVᴴϵtrunc, atol, rtol ) Mooncake.TestUtils.test_rule( - rng, make_input_scratch!, svd_trunc!, A, alg_trunc; + rng, call_and_zero!, svd_trunc!, A, alg_trunc; mode = Mooncake.ReverseMode, output_tangent = ΔUSVᴴϵtrunc, atol, rtol, is_primitive = false ) end From 61e806165de3a77ac25b27cb7253b7214e587127 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 20 Feb 2026 10:09:26 -0500 Subject: [PATCH 14/15] move gauge dependence removal to ad_utils again --- test/testsuite/ad_utils.jl | 156 +++++++++++++++++++++++++--- test/testsuite/mooncake/eig.jl | 17 --- test/testsuite/mooncake/eigh.jl | 19 ---- test/testsuite/mooncake/lq.jl | 31 ------ test/testsuite/mooncake/orthnull.jl | 26 ----- test/testsuite/mooncake/qr.jl | 30 ------ test/testsuite/mooncake/svd.jl | 28 ----- 7 files changed, 143 insertions(+), 164 deletions(-) diff --git a/test/testsuite/ad_utils.jl b/test/testsuite/ad_utils.jl index 3b0d3d61..d85b1ffc 100644 --- a/test/testsuite/ad_utils.jl +++ b/test/testsuite/ad_utils.jl @@ -1,14 +1,29 @@ -function remove_svdgauge_dependence!( - ΔU, ΔVᴴ, U, S, Vᴴ; - degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(S) +""" + remove_eig_gauge_dependence!(ΔV, D, V) + +Remove the gauge-dependent part from the cotangent `ΔV` of the eigenvector matrix `V`. The +eigenvectors are only determined up to complex phase (and unitary mixing for degenerate +eigenvalues), so the corresponding components of `ΔV` are projected out. +""" +function remove_eig_gauge_dependence!( + ΔV, D, V; + degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D) ) - gaugepart = mul!(U' * ΔU, Vᴴ, ΔVᴴ', true, true) - gaugepart = project_antihermitian!(gaugepart) - gaugepart[abs.(transpose(diagview(S)) .- diagview(S)) .>= degeneracy_atol] .= 0 - mul!(ΔU, U, gaugepart, -1, 1) - return ΔU, ΔVᴴ + gaugepart = V' * ΔV + gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 + mul!(ΔV, V / (V' * V), gaugepart, -1, 1) + return ΔV end -function remove_eighgauge_dependence!( + +""" + remove_eigh_gauge_dependence!(ΔV, D, V) + +Remove the gauge-dependent part from the cotangent `ΔV` of the Hermitian eigenvector matrix +`V`. The eigenvectors are only determined up to complex phase (and unitary mixing for +degenerate eigenvalues), so the corresponding anti-Hermitian components of `V' * ΔV` are +projected out. +""" +function remove_eigh_gauge_dependence!( ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D) ) @@ -19,6 +34,121 @@ function remove_eighgauge_dependence!( return ΔV end +""" + remove_svd_gauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) + +Remove the gauge-dependent part from the cotangents `ΔU` and `ΔVᴴ` of the SVD factors. The +singular vectors are only determined up to a common complex phase per singular value (and +unitary mixing for degenerate singular values), so the corresponding anti-Hermitian components +of `U₁' * ΔU₁ + Vᴴ₁ * ΔVᴴ₁'` are projected out. For the full SVD, the extra columns of `U` +and rows of `Vᴴ` beyond `min(m, n)` are additionally zeroed out. +""" +function remove_svd_gauge_dependence!( + ΔU, ΔVᴴ, U, S, Vᴴ; + degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(S) + ) + minmn = length(diagview(S)) + U₁ = view(U, :, 1:minmn) + Vᴴ₁ = view(Vᴴ, 1:minmn, :) + ΔU₁ = view(ΔU, :, 1:minmn) + ΔVᴴ₁ = view(ΔVᴴ, 1:minmn, :) + Sdiag = diagview(S) + gaugepart = mul!(U₁' * ΔU₁, Vᴴ₁, ΔVᴴ₁', true, true) + gaugepart = project_antihermitian!(gaugepart) + gaugepart[abs.(transpose(Sdiag) .- Sdiag) .>= degeneracy_atol] .= 0 + mul!(ΔU₁, U₁, gaugepart, -1, 1) + ΔU[:, (minmn + 1):end] .= 0 + ΔVᴴ[(minmn + 1):end, :] .= 0 + return ΔU, ΔVᴴ +end + +""" + remove_qr_gauge_dependence!(ΔQ, A, Q, R) + +Remove the gauge-dependent part from the cotangent `ΔQ` of the full-QR orthogonal factor `Q`. +For the full QR decomposition, the extra columns of `Q` beyond `min(m, n)` are not uniquely +determined by `A`, so the corresponding part of `ΔQ` is projected to remove this ambiguity. +""" +function remove_qr_gauge_dependence!(ΔQ, A, Q, R) + m, n = size(A) + minmn = min(m, n) + Q₁ = @view Q[:, 1:minmn] + ΔQ₂ = @view ΔQ[:, (minmn + 1):end] + Q₁ᴴΔQ₂ = Q₁' * ΔQ₂ + mul!(ΔQ₂, Q₁, Q₁ᴴΔQ₂) + MatrixAlgebraKit.check_qr_full_cotangents(Q₁, ΔQ₂, Q₁ᴴΔQ₂) + return ΔQ +end + +""" + remove_qr_null_gauge_dependence!(ΔN, A, N) + +Remove the gauge-dependent part from the cotangent `ΔN` of the QR null space `N`. The null +space is only determined up to a unitary rotation, so `ΔN` is projected onto the column span +of the compact QR factor `Q₁`. +""" +function remove_qr_null_gauge_dependence!(ΔN, A, N) + Q, _ = qr_compact(A) + return mul!(ΔN, Q, Q' * ΔN) +end + +""" + remove_lq_gauge_dependence!(ΔQ, A, L, Q) + +Remove the gauge-dependent part from the cotangent `ΔQ` of the full-LQ orthogonal factor `Q`. +For the full LQ decomposition, the extra rows of `Q` beyond `min(m, n)` are not uniquely +determined by `A`, so the corresponding part of `ΔQ` is projected to remove this ambiguity. +""" +function remove_lq_gauge_dependence!(ΔQ, A, L, Q) + m, n = size(A) + minmn = min(m, n) + Q₁ = @view Q[1:minmn, :] + ΔQ₂ = @view ΔQ[(minmn + 1):end, :] + ΔQ₂Q₁ᴴ = ΔQ₂ * Q₁' + mul!(ΔQ₂, ΔQ₂Q₁ᴴ, Q₁) + MatrixAlgebraKit.check_lq_full_cotangents(Q₁, ΔQ₂, ΔQ₂Q₁ᴴ) + return ΔQ +end + +""" + remove_lq_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) + +Remove the gauge-dependent part from the cotangent `ΔNᴴ` of the LQ null space `Nᴴ`. The null +space is only determined up to a unitary rotation, so `ΔNᴴ` is projected onto the row span of +the compact LQ factor `Q₁`. +""" +function remove_lq_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) + _, Q = lq_compact(A) + ΔNᴴQᴴ = ΔNᴴ * Q' + return mul!(ΔNᴴ, ΔNᴴQᴴ, Q) +end + +""" + remove_left_null_gauge_dependence!(ΔN, A, N) + +Remove the gauge-dependent part from the cotangent `ΔN` of the left null space `N`. The null +space basis is only determined up to a unitary rotation, so `ΔN` is projected onto the column +span of the compact QR factor `Q₁` of `A`. +""" +function remove_left_null_gauge_dependence!(ΔN, A, N) + Q, _ = qr_compact(A) + mul!(ΔN, Q, Q' * ΔN) + return ΔN +end + +""" + remove_right_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) + +Remove the gauge-dependent part from the cotangent `ΔNᴴ` of the right null space `Nᴴ`. The +null space basis is only determined up to a unitary rotation, so `ΔNᴴ` is projected onto the +row span of the compact LQ factor `Q₁` of `A`. +""" +function remove_right_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) + _, Q = lq_compact(A) + mul!(ΔNᴴ, ΔNᴴ * Q', Q) + return ΔNᴴ +end + function stabilize_eigvals!(D::AbstractVector) absD = collect(abs.(D)) p = invperm(sortperm(collect(absD))) # rank of abs(D) @@ -246,7 +376,7 @@ function ad_eigh_full_setup(A) D, V = DV Ddiag = diagview(D) ΔV = randn!(similar(A, T, m, m)) - ΔV = remove_eighgauge_dependence!(ΔV, D, V) + ΔV = remove_eigh_gauge_dependence!(ΔV, D, V) ΔD = randn!(similar(A, real(T), m, m)) ΔD2 = Diagonal(randn!(similar(A, real(T), m))) return DV, (ΔD, ΔV), (ΔD2, ΔV) @@ -279,7 +409,7 @@ function ad_svd_compact_setup(A) ΔS2 = Diagonal(randn!(similar(A, real(T), minmn))) ΔVᴴ = randn!(similar(A, T, minmn, n)) U, S, Vᴴ = svd_compact(A) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) + ΔU, ΔVᴴ = remove_svd_gauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) return (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ), (ΔU, ΔS2, ΔVᴴ) end @@ -292,7 +422,7 @@ function ad_svd_compact_setup(A::Diagonal) ΔS2 = Diagonal(randn!(similar(A.diag, real(T), minmn))) ΔVᴴ = randn!(similar(A.diag, T, m, n)) U, S, Vᴴ = svd_compact(A) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) + ΔU, ΔVᴴ = remove_svd_gauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) return (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ), (ΔU, ΔS2, ΔVᴴ) end @@ -305,7 +435,7 @@ function ad_svd_full_setup(A) ΔS2 = Diagonal(randn!(similar(A, real(T), minmn))) ΔVᴴ = randn!(similar(A, T, minmn, n)) U, S, Vᴴ = svd_compact(A) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) + ΔU, ΔVᴴ = remove_svd_gauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) ΔUfull = similar(A, T, m, m) ΔUfull .= zero(T) ΔSfull = similar(A, real(T), m, n) diff --git a/test/testsuite/mooncake/eig.jl b/test/testsuite/mooncake/eig.jl index 2fec4b5d..059b4416 100644 --- a/test/testsuite/mooncake/eig.jl +++ b/test/testsuite/mooncake/eig.jl @@ -1,20 +1,3 @@ -""" - remove_eig_gauge_dependence!(ΔV, D, V) - -Remove the gauge-dependent part from the cotangent `ΔV` of the eigenvector matrix `V`. The -eigenvectors are only determined up to complex phase (and unitary mixing for degenerate -eigenvalues), so the corresponding components of `ΔV` are projected out. -""" -function remove_eig_gauge_dependence!( - ΔV, D, V; - degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D) - ) - gaugepart = V' * ΔV - gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 - mul!(ΔV, V / (V' * V), gaugepart, -1, 1) - return ΔV -end - """ test_mooncake_eig(T, sz; kwargs...) diff --git a/test/testsuite/mooncake/eigh.jl b/test/testsuite/mooncake/eigh.jl index e4487363..7b91e445 100644 --- a/test/testsuite/mooncake/eigh.jl +++ b/test/testsuite/mooncake/eigh.jl @@ -1,22 +1,3 @@ -""" - remove_eigh_gauge_dependence!(ΔV, D, V) - -Remove the gauge-dependent part from the cotangent `ΔV` of the Hermitian eigenvector matrix -`V`. The eigenvectors are only determined up to complex phase (and unitary mixing for -degenerate eigenvalues), so the corresponding anti-Hermitian components of `V' * ΔV` are -projected out. -""" -function remove_eigh_gauge_dependence!( - ΔV, D, V; - degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D) - ) - gaugepart = V' * ΔV - gaugepart = project_antihermitian!(gaugepart) - gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 - mul!(ΔV, V, gaugepart, -1, 1) - return ΔV -end - """ eigh_wrapper(f, A, alg) diff --git a/test/testsuite/mooncake/lq.jl b/test/testsuite/mooncake/lq.jl index 66917a76..6b75cfed 100644 --- a/test/testsuite/mooncake/lq.jl +++ b/test/testsuite/mooncake/lq.jl @@ -1,34 +1,3 @@ -""" - remove_lq_gauge_dependence!(ΔQ, A, L, Q) - -Remove the gauge-dependent part from the cotangent `ΔQ` of the full-LQ orthogonal factor `Q`. -For the full LQ decomposition, the extra rows of `Q` beyond `min(m, n)` are not uniquely -determined by `A`, so the corresponding part of `ΔQ` is projected to remove this ambiguity. -""" -function remove_lq_gauge_dependence!(ΔQ, A, L, Q) - m, n = size(A) - minmn = min(m, n) - Q₁ = @view Q[1:minmn, :] - ΔQ₂ = @view ΔQ[(minmn + 1):end, :] - ΔQ₂Q₁ᴴ = ΔQ₂ * Q₁' - mul!(ΔQ₂, ΔQ₂Q₁ᴴ, Q₁) - MatrixAlgebraKit.check_lq_full_cotangents(Q₁, ΔQ₂, ΔQ₂Q₁ᴴ) - return ΔQ -end - -""" - remove_lq_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) - -Remove the gauge-dependent part from the cotangent `ΔNᴴ` of the LQ null space `Nᴴ`. The null -space is only determined up to a unitary rotation, so `ΔNᴴ` is projected onto the row span of -the compact LQ factor `Q₁`. -""" -function remove_lq_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) - _, Q = lq_compact(A) - ΔNᴴQᴴ = ΔNᴴ * Q' - return mul!(ΔNᴴ, ΔNᴴQᴴ, Q) -end - """ test_mooncake_lq(T, sz; kwargs...) diff --git a/test/testsuite/mooncake/orthnull.jl b/test/testsuite/mooncake/orthnull.jl index cbe55a69..b8a952ce 100644 --- a/test/testsuite/mooncake/orthnull.jl +++ b/test/testsuite/mooncake/orthnull.jl @@ -1,29 +1,3 @@ -""" - remove_left_null_gauge_dependence!(ΔN, A, N) - -Remove the gauge-dependent part from the cotangent `ΔN` of the left null space `N`. The null -space basis is only determined up to a unitary rotation, so `ΔN` is projected onto the column -span of the compact QR factor `Q₁` of `A`. -""" -function remove_left_null_gauge_dependence!(ΔN, A, N) - Q, _ = qr_compact(A) - mul!(ΔN, Q, Q' * ΔN) - return ΔN -end - -""" - remove_right_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) - -Remove the gauge-dependent part from the cotangent `ΔNᴴ` of the right null space `Nᴴ`. The -null space basis is only determined up to a unitary rotation, so `ΔNᴴ` is projected onto the -row span of the compact LQ factor `Q₁` of `A`. -""" -function remove_right_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) - _, Q = lq_compact(A) - mul!(ΔNᴴ, ΔNᴴ * Q', Q) - return ΔNᴴ -end - """ test_mooncake_orthnull(T, sz; kwargs...) diff --git a/test/testsuite/mooncake/qr.jl b/test/testsuite/mooncake/qr.jl index 2e732f9c..e7892a64 100644 --- a/test/testsuite/mooncake/qr.jl +++ b/test/testsuite/mooncake/qr.jl @@ -1,33 +1,3 @@ -""" - remove_qr_gauge_dependence!(ΔQ, A, Q, R) - -Remove the gauge-dependent part from the cotangent `ΔQ` of the full-QR orthogonal factor `Q`. -For the full QR decomposition, the extra columns of `Q` beyond `min(m, n)` are not uniquely -determined by `A`, so the corresponding part of `ΔQ` is projected to remove this ambiguity. -""" -function remove_qr_gauge_dependence!(ΔQ, A, Q, R) - m, n = size(A) - minmn = min(m, n) - Q₁ = @view Q[:, 1:minmn] - ΔQ₂ = @view ΔQ[:, (minmn + 1):end] - Q₁ᴴΔQ₂ = Q₁' * ΔQ₂ - mul!(ΔQ₂, Q₁, Q₁ᴴΔQ₂) - MatrixAlgebraKit.check_qr_full_cotangents(Q₁, ΔQ₂, Q₁ᴴΔQ₂) - return ΔQ -end - -""" - remove_qr_null_gauge_dependence!(ΔN, A, N) - -Remove the gauge-dependent part from the cotangent `ΔN` of the QR null space `N`. The null -space is only determined up to a unitary rotation, so `ΔN` is projected onto the column span -of the compact QR factor `Q₁`. -""" -function remove_qr_null_gauge_dependence!(ΔN, A, N) - Q, _ = qr_compact(A) - return mul!(ΔN, Q, Q' * ΔN) -end - """ test_mooncake_qr(T, sz; kwargs...) diff --git a/test/testsuite/mooncake/svd.jl b/test/testsuite/mooncake/svd.jl index 9b5029aa..51a56549 100644 --- a/test/testsuite/mooncake/svd.jl +++ b/test/testsuite/mooncake/svd.jl @@ -1,31 +1,3 @@ -""" - remove_svd_gauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) - -Remove the gauge-dependent part from the cotangents `ΔU` and `ΔVᴴ` of the SVD factors. The -singular vectors are only determined up to a common complex phase per singular value (and -unitary mixing for degenerate singular values), so the corresponding anti-Hermitian components -of `U₁' * ΔU₁ + Vᴴ₁ * ΔVᴴ₁'` are projected out. For the full SVD, the extra columns of `U` -and rows of `Vᴴ` beyond `min(m, n)` are additionally zeroed out. -""" -function remove_svd_gauge_dependence!( - ΔU, ΔVᴴ, U, S, Vᴴ; - degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(S) - ) - minmn = length(diagview(S)) - U₁ = view(U, :, 1:minmn) - Vᴴ₁ = view(Vᴴ, 1:minmn, :) - ΔU₁ = view(ΔU, :, 1:minmn) - ΔVᴴ₁ = view(ΔVᴴ, 1:minmn, :) - Sdiag = diagview(S) - gaugepart = mul!(U₁' * ΔU₁, Vᴴ₁, ΔVᴴ₁', true, true) - gaugepart = project_antihermitian!(gaugepart) - gaugepart[abs.(transpose(Sdiag) .- Sdiag) .>= degeneracy_atol] .= 0 - mul!(ΔU₁, U₁, gaugepart, -1, 1) - ΔU[:, (minmn + 1):end] .= 0 - ΔVᴴ[(minmn + 1):end, :] .= 0 - return ΔU, ΔVᴴ -end - """ test_mooncake_svd(T, sz; kwargs...) From 0de37f0ad7c13e4ca018aaba72496daba5107958 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 20 Feb 2026 10:11:59 -0500 Subject: [PATCH 15/15] separate out mooncake tests --- test/mooncake.jl | 29 ----------------------------- test/mooncake/eig.jl | 19 +++++++++++++++++++ test/mooncake/eigh.jl | 19 +++++++++++++++++++ test/mooncake/lq.jl | 19 +++++++++++++++++++ test/mooncake/orthnull.jl | 19 +++++++++++++++++++ test/mooncake/polar.jl | 19 +++++++++++++++++++ test/mooncake/qr.jl | 19 +++++++++++++++++++ test/mooncake/svd.jl | 19 +++++++++++++++++++ test/runtests.jl | 2 +- 9 files changed, 134 insertions(+), 30 deletions(-) delete mode 100644 test/mooncake.jl create mode 100644 test/mooncake/eig.jl create mode 100644 test/mooncake/eigh.jl create mode 100644 test/mooncake/lq.jl create mode 100644 test/mooncake/orthnull.jl create mode 100644 test/mooncake/polar.jl create mode 100644 test/mooncake/qr.jl create mode 100644 test/mooncake/svd.jl diff --git a/test/mooncake.jl b/test/mooncake.jl deleted file mode 100644 index d2f54ece..00000000 --- a/test/mooncake.jl +++ /dev/null @@ -1,29 +0,0 @@ -using MatrixAlgebraKit -using Test -using LinearAlgebra: Diagonal -using CUDA, AMDGPU - -#BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) -BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI -GenericFloats = () -@isdefined(TestSuite) || include("testsuite/TestSuite.jl") -using .TestSuite - -is_buildkite = get(ENV, "BUILDKITE", "false") == "true" - -m = 19 -for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) - TestSuite.seed_rng!(123) - #=if CUDA.functional() - TestSuite.test_mooncake(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) - #n == m && TestSuite.test_mooncake(Diagonal{T, CuVector{T}}, m; atol = m * TestSuite.precision(T), rtol = m * TestSuite.precision(T)) - end - if AMDGPU.functional() - TestSuite.test_mooncake(ROCMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) - TestSuite.test_mooncake(Diagonal{T, ROCVector{T}}, m; atol = m * TestSuite.precision(T), rtol = m * TestSuite.precision(T)) - end=# # not yet supported - if !is_buildkite - TestSuite.test_mooncake(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) - #n == m && TestSuite.test_mooncake(Diagonal{T, Vector{T}}, m; atol = m * TestSuite.precision(T), rtol = m * TestSuite.precision(T)) - end -end diff --git a/test/mooncake/eig.jl b/test/mooncake/eig.jl new file mode 100644 index 00000000..2e8a8606 --- /dev/null +++ b/test/mooncake/eig.jl @@ -0,0 +1,19 @@ +using MatrixAlgebraKit +using Test +using LinearAlgebra: Diagonal +using CUDA, AMDGPU + +BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI +GenericFloats = () +@isdefined(TestSuite) || include("../testsuite/TestSuite.jl") +using .TestSuite + +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" + +m = 19 +for T in (BLASFloats..., GenericFloats...) + TestSuite.seed_rng!(123) + if !is_buildkite + TestSuite.test_mooncake_eig(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) + end +end diff --git a/test/mooncake/eigh.jl b/test/mooncake/eigh.jl new file mode 100644 index 00000000..5528af0f --- /dev/null +++ b/test/mooncake/eigh.jl @@ -0,0 +1,19 @@ +using MatrixAlgebraKit +using Test +using LinearAlgebra: Diagonal +using CUDA, AMDGPU + +BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI +GenericFloats = () +@isdefined(TestSuite) || include("../testsuite/TestSuite.jl") +using .TestSuite + +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" + +m = 19 +for T in (BLASFloats..., GenericFloats...) + TestSuite.seed_rng!(123) + if !is_buildkite + TestSuite.test_mooncake_eigh(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) + end +end diff --git a/test/mooncake/lq.jl b/test/mooncake/lq.jl new file mode 100644 index 00000000..6c9f8fd4 --- /dev/null +++ b/test/mooncake/lq.jl @@ -0,0 +1,19 @@ +using MatrixAlgebraKit +using Test +using LinearAlgebra: Diagonal +using CUDA, AMDGPU + +BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI +GenericFloats = () +@isdefined(TestSuite) || include("../testsuite/TestSuite.jl") +using .TestSuite + +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" + +m = 19 +for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) + TestSuite.seed_rng!(123) + if !is_buildkite + TestSuite.test_mooncake_lq(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end +end diff --git a/test/mooncake/orthnull.jl b/test/mooncake/orthnull.jl new file mode 100644 index 00000000..6f8dac9a --- /dev/null +++ b/test/mooncake/orthnull.jl @@ -0,0 +1,19 @@ +using MatrixAlgebraKit +using Test +using LinearAlgebra: Diagonal +using CUDA, AMDGPU + +BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI +GenericFloats = () +@isdefined(TestSuite) || include("../testsuite/TestSuite.jl") +using .TestSuite + +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" + +m = 19 +for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) + TestSuite.seed_rng!(123) + if !is_buildkite + TestSuite.test_mooncake_orthnull(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end +end diff --git a/test/mooncake/polar.jl b/test/mooncake/polar.jl new file mode 100644 index 00000000..74c828b6 --- /dev/null +++ b/test/mooncake/polar.jl @@ -0,0 +1,19 @@ +using MatrixAlgebraKit +using Test +using LinearAlgebra: Diagonal +using CUDA, AMDGPU + +BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI +GenericFloats = () +@isdefined(TestSuite) || include("../testsuite/TestSuite.jl") +using .TestSuite + +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" + +m = 19 +for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) + TestSuite.seed_rng!(123) + if !is_buildkite + TestSuite.test_mooncake_polar(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end +end diff --git a/test/mooncake/qr.jl b/test/mooncake/qr.jl new file mode 100644 index 00000000..9ffc4798 --- /dev/null +++ b/test/mooncake/qr.jl @@ -0,0 +1,19 @@ +using MatrixAlgebraKit +using Test +using LinearAlgebra: Diagonal +using CUDA, AMDGPU + +BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI +GenericFloats = () +@isdefined(TestSuite) || include("../testsuite/TestSuite.jl") +using .TestSuite + +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" + +m = 19 +for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) + TestSuite.seed_rng!(123) + if !is_buildkite + TestSuite.test_mooncake_qr(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end +end diff --git a/test/mooncake/svd.jl b/test/mooncake/svd.jl new file mode 100644 index 00000000..982ec040 --- /dev/null +++ b/test/mooncake/svd.jl @@ -0,0 +1,19 @@ +using MatrixAlgebraKit +using Test +using LinearAlgebra: Diagonal +using CUDA, AMDGPU + +BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI +GenericFloats = () +@isdefined(TestSuite) || include("../testsuite/TestSuite.jl") +using .TestSuite + +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" + +m = 19 +for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) + TestSuite.seed_rng!(123) + if !is_buildkite + TestSuite.test_mooncake_svd(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 9b180a90..69c18501 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -27,7 +27,7 @@ if filter_tests!(testsuite, args) is_apple_ci = Sys.isapple() && get(ENV, "CI", "false") == "true" if is_apple_ci delete!(testsuite, "enzyme") - delete!(testsuite, "mooncake") + filter!(p -> !startswith(first(p), "mooncake/"), testsuite) delete!(testsuite, "chainrules") end Sys.iswindows() && delete!(testsuite, "enzyme")