Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions ext/MatrixAlgebraKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,4 +274,22 @@ function ChainRulesCore.rrule(::typeof(right_polar!), A, PWᴴ, alg)
return PWᴴ, right_polar_pullback
end

function ChainRulesCore.rrule(::typeof(project_hermitian), A, alg)
Aₕ = project_hermitian(A, alg)
function project_hermitian_pullback(ΔAₕ)
ΔA = project_hermitian(unthunk(ΔAₕ))
return NoTangent(), ΔA, NoTangent()
end
return Aₕ, project_hermitian_pullback
end

function ChainRulesCore.rrule(::typeof(project_antihermitian), A, alg)
Aₐ = project_antihermitian(A, alg)
function project_antihermitian_pullback(ΔAₐ)
ΔA = project_antihermitian(unthunk(ΔAₐ))
return NoTangent(), ΔA, NoTangent()
end
return Aₐ, project_antihermitian_pullback
end

end
45 changes: 45 additions & 0 deletions ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -779,4 +779,49 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, al
return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
end

# single-output projections: project_hermitian!, project_antihermitian!
for (f!, f, adj) in (
(:project_hermitian!, :project_hermitian, :project_hermitian_adjoint),
(:project_antihermitian!, :project_antihermitian, :project_antihermitian_adjoint),
)
@eval begin
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(f_df::CoDual{typeof($f!)}, A_dA::CoDual, arg_darg::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm})
A, dA = arrayify(A_dA)
arg, darg = A_dA === arg_darg ? (A, dA) : arrayify(arg_darg)
argc = copy(arg)
arg = $f!(A, arg, Mooncake.primal(alg_dalg))

function $adj(::NoRData)
$f!(darg)
if dA !== darg
dA .+= darg
zero!(darg)
end
copy!(arg, argc)
return ntuple(Returns(NoRData()), 4)
end

return arg_darg, $adj
end

@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm})
A, dA = arrayify(A_dA)
output = $f(A, Mooncake.primal(alg_dalg))
output_doutput = Mooncake.zero_fcodual(output)

doutput = last(arrayify(output_doutput))
function $adj(::NoRData)
# TODO: need accumulating projection to avoid intermediate here
dA .+= $f(doutput)
zero!(doutput)
return ntuple(Returns(NoRData()), 3)
end

return output_doutput, $adj
end
end
end

end
4 changes: 2 additions & 2 deletions src/pullbacks/polar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP; kwargs...)
!iszerotangent(ΔW) && mul!(M, W', ΔW, 1, 1)
!iszerotangent(ΔP) && mul!(M, ΔP, P, -1, 1)
C = _sylvester(P, P, M' - M)
C .+= ΔP
!iszerotangent(ΔP) && (C .+= ΔP)
ΔA = mul!(ΔA, W, C, 1, 1)
if !iszerotangent(ΔW)
ΔWP = ΔW / P
Expand Down Expand Up @@ -47,7 +47,7 @@ function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ; kwargs...
!iszerotangent(ΔWᴴ) && mul!(M, ΔWᴴ, Wᴴ', 1, 1)
!iszerotangent(ΔP) && mul!(M, P, ΔP, -1, 1)
C = _sylvester(P, P, M' - M)
C .+= ΔP
!iszerotangent(ΔP) && (C .+= ΔP)
ΔA = mul!(ΔA, C, Wᴴ, 1, 1)
if !iszerotangent(ΔWᴴ)
PΔWᴴ = P \ ΔWᴴ
Expand Down
23 changes: 23 additions & 0 deletions test/testsuite/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ function test_chainrules(T::Type, sz; kwargs...)
test_chainrules_svd(T, sz; kwargs...)
test_chainrules_polar(T, sz; kwargs...)
test_chainrules_orthnull(T, sz; kwargs...)
test_chainrules_projections(T, sz; kwargs...)
end
end

Expand Down Expand Up @@ -610,3 +611,25 @@ function test_chainrules_orthnull(
)
end
end

function test_chainrules_projections(
T::Type, sz;
atol::Real = 0, rtol::Real = precision(T),
kwargs...
)
summary_str = testargs_summary(T, sz)
return @testset "Projections Chainrules AD rules $summary_str" begin
A = instantiate_matrix(T, sz)
m, n = size(A)
if m == n
@testset "project_hermitian" begin
alg = MatrixAlgebraKit.default_hermitian_algorithm(A)
test_rrule(project_hermitian, A, alg; atol, rtol)
end
@testset "project_antihermitian" begin
alg = MatrixAlgebraKit.default_hermitian_algorithm(A)
test_rrule(project_antihermitian, A, alg; atol, rtol)
end
end
end
end
27 changes: 27 additions & 0 deletions test/testsuite/mooncake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ function test_mooncake(T::Type, sz; kwargs...)
if T <: Number
test_mooncake_orthnull(T, sz; kwargs...)
end
test_mooncake_projections(T, sz; kwargs...)
end
end

Expand Down Expand Up @@ -537,3 +538,29 @@ function test_mooncake_orthnull(
test_pullbacks_match(((X, Nᴴ) -> right_null!(X, Nᴴ; alg = :lq)), right_null_lq, A, Nᴴ, ΔNᴴ)
end
end

function test_mooncake_projections(
T::Type, sz;
atol::Real = 0, rtol::Real = precision(T),
kwargs...
)
summary_str = testargs_summary(T, sz)
return @testset "Projections Mooncake AD rules $summary_str" begin
A = instantiate_matrix(T, sz)
m, n = size(A)
if m == n
@testset "project_hermitian" begin
Aₕ = project_hermitian(A)
ΔAₕ = make_mooncake_tangent(Aₕ)
Mooncake.TestUtils.test_rule(rng, project_hermitian, A; is_primitive = false, mode = Mooncake.ReverseMode, atol, rtol)
test_pullbacks_match(project_hermitian!, project_hermitian, A, Aₕ, ΔAₕ)
end
@testset "project_antihermitian" begin
Aₐ = project_antihermitian(A)
ΔAₐ = make_mooncake_tangent(Aₐ)
Mooncake.TestUtils.test_rule(rng, project_antihermitian, A; is_primitive = false, mode = Mooncake.ReverseMode, atol, rtol)
test_pullbacks_match(project_antihermitian!, project_antihermitian, A, Aₐ, ΔAₐ)
end
end
end
end