From 7069619bd7b7168a717e001b0073060480e7003d Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sat, 14 Feb 2026 16:11:46 -0500 Subject: [PATCH 1/9] Fix zero tangent guard in polar pullback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Guard `C .+= ΔP` with `!iszerotangent(ΔP)` in both `left_polar_pullback!` and `right_polar_pullback!` to handle the case where ΔP is `nothing`. Co-Authored-By: Claude Opus 4.6 --- src/pullbacks/polar.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pullbacks/polar.jl b/src/pullbacks/polar.jl index 4d498da0..a549321e 100644 --- a/src/pullbacks/polar.jl +++ b/src/pullbacks/polar.jl @@ -17,7 +17,7 @@ function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP; kwargs...) !iszerotangent(ΔW) && mul!(M, W', ΔW, 1, 1) !iszerotangent(ΔP) && mul!(M, ΔP, P, -1, 1) C = _sylvester(P, P, M' - M) - C .+= ΔP + !iszerotangent(ΔP) && (C .+= ΔP) ΔA = mul!(ΔA, W, C, 1, 1) if !iszerotangent(ΔW) ΔWP = ΔW / P @@ -47,7 +47,7 @@ function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ; kwargs... !iszerotangent(ΔWᴴ) && mul!(M, ΔWᴴ, Wᴴ', 1, 1) !iszerotangent(ΔP) && mul!(M, P, ΔP, -1, 1) C = _sylvester(P, P, M' - M) - C .+= ΔP + !iszerotangent(ΔP) && (C .+= ΔP) ΔA = mul!(ΔA, C, Wᴴ, 1, 1) if !iszerotangent(ΔWᴴ) PΔWᴴ = P \ ΔWᴴ From ab8aea1d47f394a42b21da9af258839aedd69e63 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 16 Feb 2026 17:33:29 -0500 Subject: [PATCH 2/9] Use eigendecomposition-based Sylvester solver for symmetric case Replace LAPACK trsyl!-based solver with a direct eigendecomposition approach when both arguments are the same Hermitian matrix (as in polar pullbacks). This avoids LAPACKException(1) for close eigenvalues. Co-Authored-By: Claude Opus 4.6 --- src/common/pullbacks.jl | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/src/common/pullbacks.jl b/src/common/pullbacks.jl index 4fe853cd..41834f5d 100644 --- a/src/common/pullbacks.jl +++ b/src/common/pullbacks.jl @@ -11,5 +11,22 @@ function iszerotangent end iszerotangent(::Any) = false iszerotangent(::Nothing) = true -# fallback -_sylvester(A, B, C) = LinearAlgebra.sylvester(A, B, C) +# Solve the Sylvester equation A*X + X*B + C = 0. +# When A === B (same Hermitian PD matrix, as in polar pullbacks), use an +# eigendecomposition-based solver to avoid LAPACK's trsyl! failing with +# LAPACKException(1) for close eigenvalues. +function _sylvester(A, B, C) + if A === B + return _sylvester_symm(A, C) + end + return LinearAlgebra.sylvester(A, B, C) +end + +function _sylvester_symm(P, C) + D, Q = LinearAlgebra.eigen(LinearAlgebra.Hermitian(P)) + Y = Q' * C * Q + @inbounds for j in axes(Y, 2), i in axes(Y, 1) + Y[i, j] = -Y[i, j] / (D[i] + D[j]) + end + return Q * Y * Q' +end From 5474c5ec44bf5843631e893e0c2e00ec7fb49f83 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 17 Feb 2026 08:47:48 -0500 Subject: [PATCH 3/9] Add AD rules for projection methods MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add rrules/pullbacks for `project_hermitian!`, `project_antihermitian!`, and `project_isometric!` directly in each AD backend extension (ChainRulesCore, Enzyme, Mooncake). The hermitian/antihermitian pullbacks are self-adjoint, while the isometric pullback delegates to `left_polar_pullback!` with zero ΔP. Includes test utilities and tests for all three backends. Co-Authored-By: Claude Opus 4.6 --- ext/MatrixAlgebraKitChainRulesCoreExt.jl | 42 +++++++++ .../MatrixAlgebraKitEnzymeExt.jl | 87 +++++++++++++++++++ .../MatrixAlgebraKitMooncakeExt.jl | 78 +++++++++++++++++ test/testsuite/ad_utils.jl | 24 +++++ test/testsuite/chainrules.jl | 57 ++++++++++++ test/testsuite/enzyme.jl | 39 +++++++++ test/testsuite/mooncake.jl | 35 ++++++++ 7 files changed, 362 insertions(+) diff --git a/ext/MatrixAlgebraKitChainRulesCoreExt.jl b/ext/MatrixAlgebraKitChainRulesCoreExt.jl index 400b2a79..ead8ebf7 100644 --- a/ext/MatrixAlgebraKitChainRulesCoreExt.jl +++ b/ext/MatrixAlgebraKitChainRulesCoreExt.jl @@ -274,4 +274,46 @@ function ChainRulesCore.rrule(::typeof(right_polar!), A, PWᴴ, alg) return PWᴴ, right_polar_pullback end +function ChainRulesCore.rrule(::typeof(project_hermitian!), A, Aₕ, alg) + Ac = copy_input(project_hermitian, A) + Aₕ = project_hermitian!(Ac, Aₕ, alg) + function project_hermitian_pullback(ΔAₕ) + ΔA = project_hermitian(unthunk(ΔAₕ)) + return NoTangent(), ΔA, ZeroTangent(), NoTangent() + end + function project_hermitian_pullback(::ZeroTangent) + return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() + end + return Aₕ, project_hermitian_pullback +end + +function ChainRulesCore.rrule(::typeof(project_antihermitian!), A, Aₐ, alg) + Ac = copy_input(project_antihermitian, A) + Aₐ = project_antihermitian!(Ac, Aₐ, alg) + function project_antihermitian_pullback(ΔAₐ) + ΔA = project_antihermitian(unthunk(ΔAₐ)) + return NoTangent(), ΔA, ZeroTangent(), NoTangent() + end + function project_antihermitian_pullback(::ZeroTangent) + return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() + end + return Aₐ, project_antihermitian_pullback +end + +function ChainRulesCore.rrule(::typeof(project_isometric!), A, W, alg) + Ac = copy_input(project_isometric, A) + # Compute the full polar decomposition to cache P for the pullback + WP = left_polar!(Ac, (similar(W), similar(W, size(W, 2), size(W, 2))), alg) + W_out = copy!(W, WP[1]) + function project_isometric_pullback(ΔW) + ΔA = zero(A) + MatrixAlgebraKit.left_polar_pullback!(ΔA, A, WP, (unthunk(ΔW), nothing)) + return NoTangent(), ΔA, ZeroTangent(), NoTangent() + end + function project_isometric_pullback(::ZeroTangent) + return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() + end + return W_out, project_isometric_pullback +end + end diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl index 24a1fa5e..4b5e01e5 100644 --- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -454,4 +454,91 @@ function EnzymeRules.reverse( return (nothing, nothing, nothing) end +# single-output projections: project_hermitian!, project_antihermitian! +# single-output projections: project_hermitian!, project_antihermitian! +for (f!, project_f) in ( + (project_hermitian!, project_hermitian), + (project_antihermitian!, project_antihermitian), + ) + @eval begin + function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f!)}, + ::Type{RT}, + A::Annotation, + arg::Annotation{TA}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT, TA} + ret = func.val(A.val, arg.val, alg.val) + cache_arg = (arg.val !== ret) || EnzymeRules.overwritten(config)[3] ? copy(ret) : nothing + dret = if EnzymeRules.needs_shadow(config) + (TA == Nothing || isa(arg, Const)) ? zero(ret) : arg.dval + else + nothing + end + primal = EnzymeRules.needs_primal(config) ? ret : nothing + return EnzymeRules.AugmentedReturn(primal, dret, (cache_arg, dret)) + end + function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f!)}, + ::Type{RT}, + cache, + A::Annotation, + arg::Annotation, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT} + cache_arg, darg = cache + argdval = something(darg, arg.dval) + if !isa(A, Const) + A.dval .+= $project_f(argdval) + end + !isa(arg, Const) && make_zero!(arg.dval) + return (nothing, nothing, nothing) + end + end +end + +# project_isometric! needs special handling: compute full polar decomposition +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(project_isometric!)}, + ::Type{RT}, + A::Annotation, + W::Annotation{TW}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT, TW} + # Compute the full polar decomposition for the pullback + Ac = copy(A.val) + m, n = size(A.val) + P = similar(A.val, n, n) + WP = left_polar!(Ac, (W.val, P), alg.val) + cache_WP = EnzymeRules.overwritten(config)[3] ? copy.(WP) : nothing + dret = if EnzymeRules.needs_shadow(config) + (TW == Nothing || isa(W, Const)) ? zero(WP[1]) : W.dval + else + nothing + end + primal = EnzymeRules.needs_primal(config) ? WP[1] : nothing + return EnzymeRules.AugmentedReturn(primal, dret, (cache_WP, dret)) +end +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(project_isometric!)}, + ::Type{RT}, + cache, + A::Annotation, + W::Annotation, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT} + cache_WP, dW = cache + Aval = nothing + WPval = something(cache_WP, (W.val, cache_WP[2])) + if !isa(A, Const) + left_polar_pullback!(A.dval, Aval, WPval, (dW, nothing)) + end + !isa(W, Const) && make_zero!(W.dval) + return (nothing, nothing, nothing) +end + end diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index 3a113c20..038414af 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -779,4 +779,82 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, al return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint end +# single-output projections: project_hermitian!, project_antihermitian! +# single-output projections: project_hermitian!, project_antihermitian! +for (f!, f, adj) in ( + (:project_hermitian!, :project_hermitian, :project_hermitian_adjoint), + (:project_antihermitian!, :project_antihermitian, :project_antihermitian_adjoint), + ) + @eval begin + @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.rrule!!(f_df::CoDual{typeof($f!)}, A_dA::CoDual, arg_darg::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) + A, dA = arrayify(A_dA) + Ac = copy(A) + arg, darg = arrayify(arg_darg) + argc = copy(arg) + $f!(A, arg, Mooncake.primal(alg_dalg)) + function $adj(::NoRData) + copy!(A, Ac) + dA .+= $f(darg) + copy!(arg, argc) + zero!(darg) + return NoRData(), NoRData(), NoRData(), NoRData() + end + return arg_darg, $adj + end + @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) + A, dA = arrayify(A_dA) + output = $f(A, Mooncake.primal(alg_dalg)) + output_codual = CoDual(output, Mooncake.zero_tangent(output)) + function $adj(::NoRData) + arg, darg = arrayify(output_codual) + dA .+= $f(darg) + zero!(darg) + return NoRData(), NoRData(), NoRData() + end + return output_codual, $adj + end + end +end + +# project_isometric! needs special handling: compute full polar decomposition +@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(project_isometric!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} +function Mooncake.rrule!!(f_df::CoDual{typeof(project_isometric!)}, A_dA::CoDual, W_dW::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) + A, dA = arrayify(A_dA) + W, dW = arrayify(W_dW) + Ac = copy(A) + Wc = copy(W) + # Compute the full polar decomposition for the pullback + m, n = size(A) + P = similar(A, n, n) + WP = left_polar!(copy(A), (copy(W), P), Mooncake.primal(alg_dalg)) + copy!(W, WP[1]) + function project_isometric_adjoint(::NoRData) + copy!(A, Ac) + left_polar_pullback!(dA, A, WP, (dW, nothing)) + copy!(W, Wc) + zero!(dW) + return NoRData(), NoRData(), NoRData(), NoRData() + end + return W_dW, project_isometric_adjoint +end + +@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(project_isometric), Any, MatrixAlgebraKit.AbstractAlgorithm} +function Mooncake.rrule!!(f_df::CoDual{typeof(project_isometric)}, A_dA::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) + A, dA = arrayify(A_dA) + alg = Mooncake.primal(alg_dalg) + # Compute the full polar decomposition for the pullback + WP = left_polar(A, alg) + W_out = WP[1] + output_codual = CoDual(W_out, Mooncake.zero_tangent(W_out)) + function project_isometric_adjoint(::NoRData) + W, dW = arrayify(output_codual) + left_polar_pullback!(dA, A, WP, (dW, nothing)) + zero!(dW) + return NoRData(), NoRData(), NoRData() + end + return output_codual, project_isometric_adjoint +end + end diff --git a/test/testsuite/ad_utils.jl b/test/testsuite/ad_utils.jl index 31fb8ca1..7c671ee0 100644 --- a/test/testsuite/ad_utils.jl +++ b/test/testsuite/ad_utils.jl @@ -421,3 +421,27 @@ function ad_right_null_setup(A) ΔNᴴ = randn!(similar(A, T, n - min(m, n), min(m, n))) * right_orth(A; alg = :lq)[2] return Nᴴ, ΔNᴴ end + +function ad_project_hermitian_setup(A) + m, n = size(A) + T = eltype(A) + Aₕ = project_hermitian(A) + ΔAₕ = randn!(similar(A, T, m, n)) + return Aₕ, ΔAₕ +end + +function ad_project_antihermitian_setup(A) + m, n = size(A) + T = eltype(A) + Aₐ = project_antihermitian(A) + ΔAₐ = randn!(similar(A, T, m, n)) + return Aₐ, ΔAₐ +end + +function ad_project_isometric_setup(A) + m, n = size(A) + T = eltype(A) + W = project_isometric(A) + ΔW = randn!(similar(A, T, m, n)) + return W, ΔW +end diff --git a/test/testsuite/chainrules.jl b/test/testsuite/chainrules.jl index b4126c59..1d3047ad 100644 --- a/test/testsuite/chainrules.jl +++ b/test/testsuite/chainrules.jl @@ -10,6 +10,7 @@ for f in :eig_trunc_no_error, :eigh_trunc_no_error, :svd_compact, :svd_trunc, :svd_trunc_no_error, :svd_vals, :left_polar, :right_polar, + :project_hermitian, :project_antihermitian, :project_isometric, ) copy_f = Symbol(:cr_copy_, f) f! = Symbol(f, '!') @@ -46,6 +47,7 @@ function test_chainrules(T::Type, sz; kwargs...) test_chainrules_svd(T, sz; kwargs...) test_chainrules_polar(T, sz; kwargs...) test_chainrules_orthnull(T, sz; kwargs...) + test_chainrules_projections(T, sz; kwargs...) end end @@ -610,3 +612,58 @@ function test_chainrules_orthnull( ) end end + +function test_chainrules_projections( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Projections Chainrules AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + config = Zygote.ZygoteRuleConfig() + if m == n + alg_h = MatrixAlgebraKit.default_hermitian_algorithm(A) + @testset "project_hermitian" begin + Aₕ, ΔAₕ = ad_project_hermitian_setup(A) + test_rrule( + cr_copy_project_hermitian, A, alg_h ⊢ NoTangent(); + output_tangent = ΔAₕ, atol = atol, rtol = rtol + ) + test_rrule( + config, project_hermitian, A; + output_tangent = ΔAₕ, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "project_antihermitian" begin + Aₐ, ΔAₐ = ad_project_antihermitian_setup(A) + test_rrule( + cr_copy_project_antihermitian, A, alg_h ⊢ NoTangent(); + output_tangent = ΔAₐ, atol = atol, rtol = rtol + ) + test_rrule( + config, project_antihermitian, A; + output_tangent = ΔAₐ, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + end + if m > n + @testset "project_isometric" begin + W, ΔW = ad_project_isometric_setup(A) + alg_iso = MatrixAlgebraKit.default_polar_algorithm(A) + test_rrule( + cr_copy_project_isometric, A, alg_iso ⊢ NoTangent(); + output_tangent = ΔW, atol = atol, rtol = rtol + ) + test_rrule( + config, project_isometric, A; + output_tangent = ΔW, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + end + end +end diff --git a/test/testsuite/enzyme.jl b/test/testsuite/enzyme.jl index 14723680..6ad56e38 100644 --- a/test/testsuite/enzyme.jl +++ b/test/testsuite/enzyme.jl @@ -105,6 +105,7 @@ function test_enzyme(T::Type, sz; kwargs...) test_enzyme_polar(T, sz; kwargs...) test_enzyme_orthnull(T, sz; kwargs...) end + test_enzyme_projections(T, sz; kwargs...) end end @@ -462,3 +463,41 @@ function test_enzyme_orthnull( end end end + +function test_enzyme_projections( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Projections Enzyme AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + if m == n + @testset "project_hermitian" begin + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + Aₕ, ΔAₕ = ad_project_hermitian_setup(A) + eltype(T) <: BlasFloat && test_reverse(project_hermitian, RT, (A, TA); atol, rtol, output_tangent = ΔAₕ, fdm) + is_cpu(A) && enz_test_pullbacks_match(rng, project_hermitian!, project_hermitian, A, Aₕ, ΔAₕ) + end + end + @testset "project_antihermitian" begin + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + Aₐ, ΔAₐ = ad_project_antihermitian_setup(A) + eltype(T) <: BlasFloat && test_reverse(project_antihermitian, RT, (A, TA); atol, rtol, output_tangent = ΔAₐ, fdm) + is_cpu(A) && enz_test_pullbacks_match(rng, project_antihermitian!, project_antihermitian, A, Aₐ, ΔAₐ) + end + end + end + if m > n + @testset "project_isometric" begin + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + W, ΔW = ad_project_isometric_setup(A) + eltype(T) <: BlasFloat && test_reverse(project_isometric, RT, (A, TA); atol, rtol, output_tangent = ΔW, fdm) + is_cpu(A) && enz_test_pullbacks_match(rng, project_isometric!, project_isometric, A, W, ΔW) + end + end + end + end +end diff --git a/test/testsuite/mooncake.jl b/test/testsuite/mooncake.jl index 29d65e31..6a437c9b 100644 --- a/test/testsuite/mooncake.jl +++ b/test/testsuite/mooncake.jl @@ -229,6 +229,7 @@ function test_mooncake(T::Type, sz; kwargs...) if T <: Number test_mooncake_orthnull(T, sz; kwargs...) end + test_mooncake_projections(T, sz; kwargs...) end end @@ -537,3 +538,37 @@ function test_mooncake_orthnull( test_pullbacks_match(((X, Nᴴ) -> right_null!(X, Nᴴ; alg = :lq)), right_null_lq, A, Nᴴ, ΔNᴴ) end end + +function test_mooncake_projections( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Projections Mooncake AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + if m == n + @testset "project_hermitian" begin + Aₕ, ΔAₕ = ad_project_hermitian_setup(A) + dAₕ = make_mooncake_tangent(ΔAₕ) + Mooncake.TestUtils.test_rule(rng, project_hermitian, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dAₕ, atol, rtol) + test_pullbacks_match(project_hermitian!, project_hermitian, A, Aₕ, ΔAₕ) + end + @testset "project_antihermitian" begin + Aₐ, ΔAₐ = ad_project_antihermitian_setup(A) + dAₐ = make_mooncake_tangent(ΔAₐ) + Mooncake.TestUtils.test_rule(rng, project_antihermitian, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dAₐ, atol, rtol) + test_pullbacks_match(project_antihermitian!, project_antihermitian, A, Aₐ, ΔAₐ) + end + end + if m > n + @testset "project_isometric" begin + W, ΔW = ad_project_isometric_setup(A) + dW = make_mooncake_tangent(ΔW) + Mooncake.TestUtils.test_rule(rng, project_isometric, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dW, atol, rtol) + test_pullbacks_match(project_isometric!, project_isometric, A, W, ΔW) + end + end + end +end From 62dd5caacd02a25cd1fa117109e70e18004b9c46 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 17 Feb 2026 09:48:34 -0500 Subject: [PATCH 4/9] simplify implementations --- ext/MatrixAlgebraKitChainRulesCoreExt.jl | 36 ++-------- .../MatrixAlgebraKitMooncakeExt.jl | 66 ++++--------------- 2 files changed, 20 insertions(+), 82 deletions(-) diff --git a/ext/MatrixAlgebraKitChainRulesCoreExt.jl b/ext/MatrixAlgebraKitChainRulesCoreExt.jl index ead8ebf7..acfe0d83 100644 --- a/ext/MatrixAlgebraKitChainRulesCoreExt.jl +++ b/ext/MatrixAlgebraKitChainRulesCoreExt.jl @@ -274,46 +274,22 @@ function ChainRulesCore.rrule(::typeof(right_polar!), A, PWᴴ, alg) return PWᴴ, right_polar_pullback end -function ChainRulesCore.rrule(::typeof(project_hermitian!), A, Aₕ, alg) - Ac = copy_input(project_hermitian, A) - Aₕ = project_hermitian!(Ac, Aₕ, alg) +function ChainRulesCore.rrule(::typeof(project_hermitian), A, alg) + Aₕ = project_hermitian(A, alg) function project_hermitian_pullback(ΔAₕ) ΔA = project_hermitian(unthunk(ΔAₕ)) - return NoTangent(), ΔA, ZeroTangent(), NoTangent() - end - function project_hermitian_pullback(::ZeroTangent) - return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() + return NoTangent(), ΔA, NoTangent() end return Aₕ, project_hermitian_pullback end -function ChainRulesCore.rrule(::typeof(project_antihermitian!), A, Aₐ, alg) - Ac = copy_input(project_antihermitian, A) - Aₐ = project_antihermitian!(Ac, Aₐ, alg) +function ChainRulesCore.rrule(::typeof(project_antihermitian), A, alg) + Aₐ = project_antihermitian(A, alg) function project_antihermitian_pullback(ΔAₐ) ΔA = project_antihermitian(unthunk(ΔAₐ)) - return NoTangent(), ΔA, ZeroTangent(), NoTangent() - end - function project_antihermitian_pullback(::ZeroTangent) - return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() + return NoTangent(), ΔA, NoTangent() end return Aₐ, project_antihermitian_pullback end -function ChainRulesCore.rrule(::typeof(project_isometric!), A, W, alg) - Ac = copy_input(project_isometric, A) - # Compute the full polar decomposition to cache P for the pullback - WP = left_polar!(Ac, (similar(W), similar(W, size(W, 2), size(W, 2))), alg) - W_out = copy!(W, WP[1]) - function project_isometric_pullback(ΔW) - ΔA = zero(A) - MatrixAlgebraKit.left_polar_pullback!(ΔA, A, WP, (unthunk(ΔW), nothing)) - return NoTangent(), ΔA, ZeroTangent(), NoTangent() - end - function project_isometric_pullback(::ZeroTangent) - return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() - end - return W_out, project_isometric_pullback -end - end diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index 038414af..0f7dcfe2 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -779,82 +779,44 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, al return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint end -# single-output projections: project_hermitian!, project_antihermitian! # single-output projections: project_hermitian!, project_antihermitian! for (f!, f, adj) in ( (:project_hermitian!, :project_hermitian, :project_hermitian_adjoint), (:project_antihermitian!, :project_antihermitian, :project_antihermitian_adjoint), ) @eval begin - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} + @is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(f_df::CoDual{typeof($f!)}, A_dA::CoDual, arg_darg::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) A, dA = arrayify(A_dA) - Ac = copy(A) - arg, darg = arrayify(arg_darg) + arg, darg = A_dA === arg_darg ? (A, dA) : arrayify(arg_darg) argc = copy(arg) - $f!(A, arg, Mooncake.primal(alg_dalg)) + arg = $f!(A, arg, Mooncake.primal(alg_dalg)) + function $adj(::NoRData) - copy!(A, Ac) dA .+= $f(darg) + dA === darg || zero!(darg) copy!(arg, argc) - zero!(darg) return NoRData(), NoRData(), NoRData(), NoRData() end return arg_darg, $adj end - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} + + @is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) A, dA = arrayify(A_dA) output = $f(A, Mooncake.primal(alg_dalg)) - output_codual = CoDual(output, Mooncake.zero_tangent(output)) + output_doutput = Mooncake.zero_fcodual(output) + + doutput = last(arrayify(output_doutput)) function $adj(::NoRData) - arg, darg = arrayify(output_codual) - dA .+= $f(darg) - zero!(darg) - return NoRData(), NoRData(), NoRData() + # TODO: need accumulating projection to avoid intermediate here + dA .+= $f(doutput) + return ntuple(Returns(NoRData(), 3)) end + return output_codual, $adj end end end -# project_isometric! needs special handling: compute full polar decomposition -@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(project_isometric!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} -function Mooncake.rrule!!(f_df::CoDual{typeof(project_isometric!)}, A_dA::CoDual, W_dW::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) - A, dA = arrayify(A_dA) - W, dW = arrayify(W_dW) - Ac = copy(A) - Wc = copy(W) - # Compute the full polar decomposition for the pullback - m, n = size(A) - P = similar(A, n, n) - WP = left_polar!(copy(A), (copy(W), P), Mooncake.primal(alg_dalg)) - copy!(W, WP[1]) - function project_isometric_adjoint(::NoRData) - copy!(A, Ac) - left_polar_pullback!(dA, A, WP, (dW, nothing)) - copy!(W, Wc) - zero!(dW) - return NoRData(), NoRData(), NoRData(), NoRData() - end - return W_dW, project_isometric_adjoint -end - -@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(project_isometric), Any, MatrixAlgebraKit.AbstractAlgorithm} -function Mooncake.rrule!!(f_df::CoDual{typeof(project_isometric)}, A_dA::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) - A, dA = arrayify(A_dA) - alg = Mooncake.primal(alg_dalg) - # Compute the full polar decomposition for the pullback - WP = left_polar(A, alg) - W_out = WP[1] - output_codual = CoDual(W_out, Mooncake.zero_tangent(W_out)) - function project_isometric_adjoint(::NoRData) - W, dW = arrayify(output_codual) - left_polar_pullback!(dA, A, WP, (dW, nothing)) - zero!(dW) - return NoRData(), NoRData(), NoRData() - end - return output_codual, project_isometric_adjoint -end - end From fbfcbccd2ab02158c3089adba30915ee03316c0b Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 17 Feb 2026 09:53:44 -0500 Subject: [PATCH 5/9] remove enzyme --- .../MatrixAlgebraKitEnzymeExt.jl | 87 ------------------- 1 file changed, 87 deletions(-) diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl index 4b5e01e5..24a1fa5e 100644 --- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -454,91 +454,4 @@ function EnzymeRules.reverse( return (nothing, nothing, nothing) end -# single-output projections: project_hermitian!, project_antihermitian! -# single-output projections: project_hermitian!, project_antihermitian! -for (f!, project_f) in ( - (project_hermitian!, project_hermitian), - (project_antihermitian!, project_antihermitian), - ) - @eval begin - function EnzymeRules.augmented_primal( - config::EnzymeRules.RevConfigWidth{1}, - func::Const{typeof($f!)}, - ::Type{RT}, - A::Annotation, - arg::Annotation{TA}, - alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, - ) where {RT, TA} - ret = func.val(A.val, arg.val, alg.val) - cache_arg = (arg.val !== ret) || EnzymeRules.overwritten(config)[3] ? copy(ret) : nothing - dret = if EnzymeRules.needs_shadow(config) - (TA == Nothing || isa(arg, Const)) ? zero(ret) : arg.dval - else - nothing - end - primal = EnzymeRules.needs_primal(config) ? ret : nothing - return EnzymeRules.AugmentedReturn(primal, dret, (cache_arg, dret)) - end - function EnzymeRules.reverse( - config::EnzymeRules.RevConfigWidth{1}, - func::Const{typeof($f!)}, - ::Type{RT}, - cache, - A::Annotation, - arg::Annotation, - alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, - ) where {RT} - cache_arg, darg = cache - argdval = something(darg, arg.dval) - if !isa(A, Const) - A.dval .+= $project_f(argdval) - end - !isa(arg, Const) && make_zero!(arg.dval) - return (nothing, nothing, nothing) - end - end -end - -# project_isometric! needs special handling: compute full polar decomposition -function EnzymeRules.augmented_primal( - config::EnzymeRules.RevConfigWidth{1}, - func::Const{typeof(project_isometric!)}, - ::Type{RT}, - A::Annotation, - W::Annotation{TW}, - alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, - ) where {RT, TW} - # Compute the full polar decomposition for the pullback - Ac = copy(A.val) - m, n = size(A.val) - P = similar(A.val, n, n) - WP = left_polar!(Ac, (W.val, P), alg.val) - cache_WP = EnzymeRules.overwritten(config)[3] ? copy.(WP) : nothing - dret = if EnzymeRules.needs_shadow(config) - (TW == Nothing || isa(W, Const)) ? zero(WP[1]) : W.dval - else - nothing - end - primal = EnzymeRules.needs_primal(config) ? WP[1] : nothing - return EnzymeRules.AugmentedReturn(primal, dret, (cache_WP, dret)) -end -function EnzymeRules.reverse( - config::EnzymeRules.RevConfigWidth{1}, - func::Const{typeof(project_isometric!)}, - ::Type{RT}, - cache, - A::Annotation, - W::Annotation, - alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, - ) where {RT} - cache_WP, dW = cache - Aval = nothing - WPval = something(cache_WP, (W.val, cache_WP[2])) - if !isa(A, Const) - left_polar_pullback!(A.dval, Aval, WPval, (dW, nothing)) - end - !isa(W, Const) && make_zero!(W.dval) - return (nothing, nothing, nothing) -end - end From e0c709b35fb553b1589a9567e33c543c5b110336 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 17 Feb 2026 09:58:26 -0500 Subject: [PATCH 6/9] simplify chainrules tests --- test/testsuite/ad_utils.jl | 24 --------------------- test/testsuite/chainrules.jl | 42 ++++-------------------------------- 2 files changed, 4 insertions(+), 62 deletions(-) diff --git a/test/testsuite/ad_utils.jl b/test/testsuite/ad_utils.jl index 7c671ee0..31fb8ca1 100644 --- a/test/testsuite/ad_utils.jl +++ b/test/testsuite/ad_utils.jl @@ -421,27 +421,3 @@ function ad_right_null_setup(A) ΔNᴴ = randn!(similar(A, T, n - min(m, n), min(m, n))) * right_orth(A; alg = :lq)[2] return Nᴴ, ΔNᴴ end - -function ad_project_hermitian_setup(A) - m, n = size(A) - T = eltype(A) - Aₕ = project_hermitian(A) - ΔAₕ = randn!(similar(A, T, m, n)) - return Aₕ, ΔAₕ -end - -function ad_project_antihermitian_setup(A) - m, n = size(A) - T = eltype(A) - Aₐ = project_antihermitian(A) - ΔAₐ = randn!(similar(A, T, m, n)) - return Aₐ, ΔAₐ -end - -function ad_project_isometric_setup(A) - m, n = size(A) - T = eltype(A) - W = project_isometric(A) - ΔW = randn!(similar(A, T, m, n)) - return W, ΔW -end diff --git a/test/testsuite/chainrules.jl b/test/testsuite/chainrules.jl index 1d3047ad..f937e0b0 100644 --- a/test/testsuite/chainrules.jl +++ b/test/testsuite/chainrules.jl @@ -10,7 +10,6 @@ for f in :eig_trunc_no_error, :eigh_trunc_no_error, :svd_compact, :svd_trunc, :svd_trunc_no_error, :svd_vals, :left_polar, :right_polar, - :project_hermitian, :project_antihermitian, :project_isometric, ) copy_f = Symbol(:cr_copy_, f) f! = Symbol(f, '!') @@ -622,47 +621,14 @@ function test_chainrules_projections( return @testset "Projections Chainrules AD rules $summary_str" begin A = instantiate_matrix(T, sz) m, n = size(A) - config = Zygote.ZygoteRuleConfig() if m == n - alg_h = MatrixAlgebraKit.default_hermitian_algorithm(A) @testset "project_hermitian" begin - Aₕ, ΔAₕ = ad_project_hermitian_setup(A) - test_rrule( - cr_copy_project_hermitian, A, alg_h ⊢ NoTangent(); - output_tangent = ΔAₕ, atol = atol, rtol = rtol - ) - test_rrule( - config, project_hermitian, A; - output_tangent = ΔAₕ, - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) + alg = MatrixAlgebraKit.default_hermitian_algorithm(A) + test_rrule(project_hermitian, A, alg; atol, rtol) end @testset "project_antihermitian" begin - Aₐ, ΔAₐ = ad_project_antihermitian_setup(A) - test_rrule( - cr_copy_project_antihermitian, A, alg_h ⊢ NoTangent(); - output_tangent = ΔAₐ, atol = atol, rtol = rtol - ) - test_rrule( - config, project_antihermitian, A; - output_tangent = ΔAₐ, - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - end - end - if m > n - @testset "project_isometric" begin - W, ΔW = ad_project_isometric_setup(A) - alg_iso = MatrixAlgebraKit.default_polar_algorithm(A) - test_rrule( - cr_copy_project_isometric, A, alg_iso ⊢ NoTangent(); - output_tangent = ΔW, atol = atol, rtol = rtol - ) - test_rrule( - config, project_isometric, A; - output_tangent = ΔW, - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) + alg = MatrixAlgebraKit.default_hermitian_algorithm(A) + test_rrule(project_antihermitian, A, alg; atol, rtol) end end end From 187d5f2dfe625e38e26e438ae8fa94de8e17c81d Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 17 Feb 2026 10:15:49 -0500 Subject: [PATCH 7/9] simplify mooncake tests --- .../MatrixAlgebraKitMooncakeExt.jl | 7 ++++--- test/testsuite/mooncake.jl | 20 ++++++------------- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index 0f7dcfe2..37b35b32 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -796,7 +796,7 @@ for (f!, f, adj) in ( dA .+= $f(darg) dA === darg || zero!(darg) copy!(arg, argc) - return NoRData(), NoRData(), NoRData(), NoRData() + return ntuple(Returns(NoRData()), 4) end return arg_darg, $adj end @@ -811,10 +811,11 @@ for (f!, f, adj) in ( function $adj(::NoRData) # TODO: need accumulating projection to avoid intermediate here dA .+= $f(doutput) - return ntuple(Returns(NoRData(), 3)) + zero!(doutput) + return ntuple(Returns(NoRData()), 3) end - return output_codual, $adj + return output_doutput, $adj end end end diff --git a/test/testsuite/mooncake.jl b/test/testsuite/mooncake.jl index 6a437c9b..9ca869d2 100644 --- a/test/testsuite/mooncake.jl +++ b/test/testsuite/mooncake.jl @@ -550,25 +550,17 @@ function test_mooncake_projections( m, n = size(A) if m == n @testset "project_hermitian" begin - Aₕ, ΔAₕ = ad_project_hermitian_setup(A) - dAₕ = make_mooncake_tangent(ΔAₕ) - Mooncake.TestUtils.test_rule(rng, project_hermitian, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dAₕ, atol, rtol) + Aₕ = project_hermitian(A) + ΔAₕ = make_mooncake_tangent(Aₕ) + Mooncake.TestUtils.test_rule(rng, project_hermitian, A; is_primitive = false, mode = Mooncake.ReverseMode, atol, rtol) test_pullbacks_match(project_hermitian!, project_hermitian, A, Aₕ, ΔAₕ) end @testset "project_antihermitian" begin - Aₐ, ΔAₐ = ad_project_antihermitian_setup(A) - dAₐ = make_mooncake_tangent(ΔAₐ) - Mooncake.TestUtils.test_rule(rng, project_antihermitian, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dAₐ, atol, rtol) + Aₐ = project_antihermitian(A) + ΔAₐ = make_mooncake_tangent(Aₐ) + Mooncake.TestUtils.test_rule(rng, project_antihermitian, A; is_primitive = false, mode = Mooncake.ReverseMode, atol, rtol) test_pullbacks_match(project_antihermitian!, project_antihermitian, A, Aₐ, ΔAₐ) end end - if m > n - @testset "project_isometric" begin - W, ΔW = ad_project_isometric_setup(A) - dW = make_mooncake_tangent(ΔW) - Mooncake.TestUtils.test_rule(rng, project_isometric, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dW, atol, rtol) - test_pullbacks_match(project_isometric!, project_isometric, A, W, ΔW) - end - end end end From 3e73f99b0c93498d8be6dacbc126ccea180cff03 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 17 Feb 2026 10:17:31 -0500 Subject: [PATCH 8/9] revert changes --- src/common/pullbacks.jl | 21 ++------------------- test/testsuite/enzyme.jl | 39 --------------------------------------- 2 files changed, 2 insertions(+), 58 deletions(-) diff --git a/src/common/pullbacks.jl b/src/common/pullbacks.jl index 41834f5d..4fe853cd 100644 --- a/src/common/pullbacks.jl +++ b/src/common/pullbacks.jl @@ -11,22 +11,5 @@ function iszerotangent end iszerotangent(::Any) = false iszerotangent(::Nothing) = true -# Solve the Sylvester equation A*X + X*B + C = 0. -# When A === B (same Hermitian PD matrix, as in polar pullbacks), use an -# eigendecomposition-based solver to avoid LAPACK's trsyl! failing with -# LAPACKException(1) for close eigenvalues. -function _sylvester(A, B, C) - if A === B - return _sylvester_symm(A, C) - end - return LinearAlgebra.sylvester(A, B, C) -end - -function _sylvester_symm(P, C) - D, Q = LinearAlgebra.eigen(LinearAlgebra.Hermitian(P)) - Y = Q' * C * Q - @inbounds for j in axes(Y, 2), i in axes(Y, 1) - Y[i, j] = -Y[i, j] / (D[i] + D[j]) - end - return Q * Y * Q' -end +# fallback +_sylvester(A, B, C) = LinearAlgebra.sylvester(A, B, C) diff --git a/test/testsuite/enzyme.jl b/test/testsuite/enzyme.jl index 6ad56e38..14723680 100644 --- a/test/testsuite/enzyme.jl +++ b/test/testsuite/enzyme.jl @@ -105,7 +105,6 @@ function test_enzyme(T::Type, sz; kwargs...) test_enzyme_polar(T, sz; kwargs...) test_enzyme_orthnull(T, sz; kwargs...) end - test_enzyme_projections(T, sz; kwargs...) end end @@ -463,41 +462,3 @@ function test_enzyme_orthnull( end end end - -function test_enzyme_projections( - T::Type, sz; - atol::Real = 0, rtol::Real = precision(T), - kwargs... - ) - summary_str = testargs_summary(T, sz) - return @testset "Projections Enzyme AD rules $summary_str" begin - A = instantiate_matrix(T, sz) - m, n = size(A) - fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) - if m == n - @testset "project_hermitian" begin - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - Aₕ, ΔAₕ = ad_project_hermitian_setup(A) - eltype(T) <: BlasFloat && test_reverse(project_hermitian, RT, (A, TA); atol, rtol, output_tangent = ΔAₕ, fdm) - is_cpu(A) && enz_test_pullbacks_match(rng, project_hermitian!, project_hermitian, A, Aₕ, ΔAₕ) - end - end - @testset "project_antihermitian" begin - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - Aₐ, ΔAₐ = ad_project_antihermitian_setup(A) - eltype(T) <: BlasFloat && test_reverse(project_antihermitian, RT, (A, TA); atol, rtol, output_tangent = ΔAₐ, fdm) - is_cpu(A) && enz_test_pullbacks_match(rng, project_antihermitian!, project_antihermitian, A, Aₐ, ΔAₐ) - end - end - end - if m > n - @testset "project_isometric" begin - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - W, ΔW = ad_project_isometric_setup(A) - eltype(T) <: BlasFloat && test_reverse(project_isometric, RT, (A, TA); atol, rtol, output_tangent = ΔW, fdm) - is_cpu(A) && enz_test_pullbacks_match(rng, project_isometric!, project_isometric, A, W, ΔW) - end - end - end - end -end From eb2228527314756b40d896ff6fa30dcab8f9ff23 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 18 Feb 2026 07:33:46 -0500 Subject: [PATCH 9/9] possibly fix implementation --- .../MatrixAlgebraKitMooncakeExt.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index 37b35b32..c51184ae 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -793,11 +793,15 @@ for (f!, f, adj) in ( arg = $f!(A, arg, Mooncake.primal(alg_dalg)) function $adj(::NoRData) - dA .+= $f(darg) - dA === darg || zero!(darg) + $f!(darg) + if dA !== darg + dA .+= darg + zero!(darg) + end copy!(arg, argc) return ntuple(Returns(NoRData()), 4) end + return arg_darg, $adj end