diff --git a/ext/MatrixAlgebraKitChainRulesCoreExt.jl b/ext/MatrixAlgebraKitChainRulesCoreExt.jl index 400b2a79..acfe0d83 100644 --- a/ext/MatrixAlgebraKitChainRulesCoreExt.jl +++ b/ext/MatrixAlgebraKitChainRulesCoreExt.jl @@ -274,4 +274,22 @@ function ChainRulesCore.rrule(::typeof(right_polar!), A, PWᴴ, alg) return PWᴴ, right_polar_pullback end +function ChainRulesCore.rrule(::typeof(project_hermitian), A, alg) + Aₕ = project_hermitian(A, alg) + function project_hermitian_pullback(ΔAₕ) + ΔA = project_hermitian(unthunk(ΔAₕ)) + return NoTangent(), ΔA, NoTangent() + end + return Aₕ, project_hermitian_pullback +end + +function ChainRulesCore.rrule(::typeof(project_antihermitian), A, alg) + Aₐ = project_antihermitian(A, alg) + function project_antihermitian_pullback(ΔAₐ) + ΔA = project_antihermitian(unthunk(ΔAₐ)) + return NoTangent(), ΔA, NoTangent() + end + return Aₐ, project_antihermitian_pullback +end + end diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index 3a113c20..c51184ae 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -779,4 +779,49 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, al return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint end +# single-output projections: project_hermitian!, project_antihermitian! +for (f!, f, adj) in ( + (:project_hermitian!, :project_hermitian, :project_hermitian_adjoint), + (:project_antihermitian!, :project_antihermitian, :project_antihermitian_adjoint), + ) + @eval begin + @is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.rrule!!(f_df::CoDual{typeof($f!)}, A_dA::CoDual, arg_darg::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) + A, dA = arrayify(A_dA) + arg, darg = A_dA === arg_darg ? (A, dA) : arrayify(arg_darg) + argc = copy(arg) + arg = $f!(A, arg, Mooncake.primal(alg_dalg)) + + function $adj(::NoRData) + $f!(darg) + if dA !== darg + dA .+= darg + zero!(darg) + end + copy!(arg, argc) + return ntuple(Returns(NoRData()), 4) + end + + return arg_darg, $adj + end + + @is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) + A, dA = arrayify(A_dA) + output = $f(A, Mooncake.primal(alg_dalg)) + output_doutput = Mooncake.zero_fcodual(output) + + doutput = last(arrayify(output_doutput)) + function $adj(::NoRData) + # TODO: need accumulating projection to avoid intermediate here + dA .+= $f(doutput) + zero!(doutput) + return ntuple(Returns(NoRData()), 3) + end + + return output_doutput, $adj + end + end +end + end diff --git a/src/pullbacks/polar.jl b/src/pullbacks/polar.jl index 4d498da0..a549321e 100644 --- a/src/pullbacks/polar.jl +++ b/src/pullbacks/polar.jl @@ -17,7 +17,7 @@ function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP; kwargs...) !iszerotangent(ΔW) && mul!(M, W', ΔW, 1, 1) !iszerotangent(ΔP) && mul!(M, ΔP, P, -1, 1) C = _sylvester(P, P, M' - M) - C .+= ΔP + !iszerotangent(ΔP) && (C .+= ΔP) ΔA = mul!(ΔA, W, C, 1, 1) if !iszerotangent(ΔW) ΔWP = ΔW / P @@ -47,7 +47,7 @@ function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ; kwargs... !iszerotangent(ΔWᴴ) && mul!(M, ΔWᴴ, Wᴴ', 1, 1) !iszerotangent(ΔP) && mul!(M, P, ΔP, -1, 1) C = _sylvester(P, P, M' - M) - C .+= ΔP + !iszerotangent(ΔP) && (C .+= ΔP) ΔA = mul!(ΔA, C, Wᴴ, 1, 1) if !iszerotangent(ΔWᴴ) PΔWᴴ = P \ ΔWᴴ diff --git a/test/testsuite/chainrules.jl b/test/testsuite/chainrules.jl index b4126c59..f937e0b0 100644 --- a/test/testsuite/chainrules.jl +++ b/test/testsuite/chainrules.jl @@ -46,6 +46,7 @@ function test_chainrules(T::Type, sz; kwargs...) test_chainrules_svd(T, sz; kwargs...) test_chainrules_polar(T, sz; kwargs...) test_chainrules_orthnull(T, sz; kwargs...) + test_chainrules_projections(T, sz; kwargs...) end end @@ -610,3 +611,25 @@ function test_chainrules_orthnull( ) end end + +function test_chainrules_projections( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Projections Chainrules AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + if m == n + @testset "project_hermitian" begin + alg = MatrixAlgebraKit.default_hermitian_algorithm(A) + test_rrule(project_hermitian, A, alg; atol, rtol) + end + @testset "project_antihermitian" begin + alg = MatrixAlgebraKit.default_hermitian_algorithm(A) + test_rrule(project_antihermitian, A, alg; atol, rtol) + end + end + end +end diff --git a/test/testsuite/mooncake.jl b/test/testsuite/mooncake.jl index 29d65e31..9ca869d2 100644 --- a/test/testsuite/mooncake.jl +++ b/test/testsuite/mooncake.jl @@ -229,6 +229,7 @@ function test_mooncake(T::Type, sz; kwargs...) if T <: Number test_mooncake_orthnull(T, sz; kwargs...) end + test_mooncake_projections(T, sz; kwargs...) end end @@ -537,3 +538,29 @@ function test_mooncake_orthnull( test_pullbacks_match(((X, Nᴴ) -> right_null!(X, Nᴴ; alg = :lq)), right_null_lq, A, Nᴴ, ΔNᴴ) end end + +function test_mooncake_projections( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Projections Mooncake AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + if m == n + @testset "project_hermitian" begin + Aₕ = project_hermitian(A) + ΔAₕ = make_mooncake_tangent(Aₕ) + Mooncake.TestUtils.test_rule(rng, project_hermitian, A; is_primitive = false, mode = Mooncake.ReverseMode, atol, rtol) + test_pullbacks_match(project_hermitian!, project_hermitian, A, Aₕ, ΔAₕ) + end + @testset "project_antihermitian" begin + Aₐ = project_antihermitian(A) + ΔAₐ = make_mooncake_tangent(Aₐ) + Mooncake.TestUtils.test_rule(rng, project_antihermitian, A; is_primitive = false, mode = Mooncake.ReverseMode, atol, rtol) + test_pullbacks_match(project_antihermitian!, project_antihermitian, A, Aₐ, ΔAₐ) + end + end + end +end