Skip to content
13 changes: 6 additions & 7 deletions ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ for f in (:eig, :eigh)
_warn_pullback_truncerror(dϵ)

# compute pullbacks
$f_pullback!(dA, Ac, DVc, dDVtrunc, ind)
$f_pullback!(dA, Ac, DV, dDVtrunc, ind)
zero!.(dDVtrunc) # since this is allocated in this function this is probably not required

# restore state
Expand Down Expand Up @@ -351,8 +351,8 @@ for f in (:eig, :eigh)
dDVtrunc = last.(arrayify.(DVtrunc, Mooncake.tangent(DVtrunc_dDVtrunc)))
function $f_adjoint!(::NoRData)
# compute pullbacks
$f_pullback!(dA, Ac, DVc, dDVtrunc, ind)
zero!.(dDVtrunc) # since this is allocated in this function this is probably not required
$f_pullback!(dA, Ac, DV, dDVtrunc, ind)
zero!.(dDV)

# restore state
copy!(A, Ac)
Expand Down Expand Up @@ -425,7 +425,7 @@ for (f!, f) in (
S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3])
USVᴴc = copy.(USVᴴ)
output = $f!(A, Mooncake.primal(alg_dalg))
output = $f!(A, USVᴴ, Mooncake.primal(alg_dalg))
function svd_adjoint(::NoRData)
copy!(A, Ac)
if $(f! == svd_compact!)
Expand Down Expand Up @@ -590,7 +590,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS
_warn_pullback_truncerror(dϵ)

# compute pullbacks
svd_pullback!(dA, Ac, USVᴴc, dUSVᴴtrunc, ind)
svd_pullback!(dA, Ac, USVᴴ, dUSVᴴtrunc, ind)
zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
zero!.(dUSVᴴ)

Expand Down Expand Up @@ -717,8 +717,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, U
dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc)))
function svd_trunc_adjoint(::NoRData)
# compute pullbacks
svd_pullback!(dA, Ac, USVᴴc, dUSVᴴtrunc, ind)
zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
svd_pullback!(dA, Ac, USVᴴ, dUSVᴴtrunc, ind)
zero!.(dUSVᴴ)

# restore state
Expand Down
22 changes: 15 additions & 7 deletions src/pullbacks/qr.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
function check_qr_cotangents(Q, R, ΔQ, ΔR, minmn::Int, p::Int; gauge_atol::Real = default_pullback_gauge_atol(ΔQ))
qr_rank(R; rank_atol = default_pullback_rank_atol(R)) =
@something findlast(>=(rank_atol) ∘ abs, diagview(R)) 0

function check_qr_cotangents(
Q, R, ΔQ, ΔR, p::Int;
gauge_atol::Real = default_pullback_gauge_atol(ΔQ)
)
minmn = min(size(Q, 1), size(R, 2))
if minmn > p # case where A is rank-deficient
Δgauge = abs(zero(eltype(Q)))
if !iszerotangent(ΔQ)
Expand All @@ -7,11 +14,13 @@ function check_qr_cotangents(Q, R, ΔQ, ΔR, minmn::Int, p::Int; gauge_atol::Rea
# columns of ΔQ should be zero for a gauge-invariant
# cost function
ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2))
Δgauge = max(Δgauge, norm(ΔQ2, Inf))
Δgauge_Q = norm(ΔQ2, Inf)
Δgauge = max(Δgauge, Δgauge_Q)
end
if !iszerotangent(ΔR)
ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):size(R, 2))
Δgauge = max(Δgauge, norm(ΔR22, Inf))
Δgauge_R = norm(ΔR22, Inf)
Δgauge = max(Δgauge, Δgauge_R)
end
Δgauge ≤ gauge_atol ||
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
Expand All @@ -29,7 +38,7 @@ function check_qr_full_cotangents(Q1, ΔQ2, Q1dΔQ2; gauge_atol::Real = default_
# Q2' * ΔQ2 as a gauge dependent quantity.
Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf)
Δgauge ≤ gauge_atol ||
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
@warn "`qr` full cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
return
end

Expand Down Expand Up @@ -60,9 +69,8 @@ function qr_pullback!(
Q, R = QR
m = size(Q, 1)
n = size(R, 2)
minmn = min(m, n)
Rd = diagview(R)
p = @something findlast(>=(rank_atol) ∘ abs, Rd) 0
p = qr_rank(R)

ΔQ, ΔR = ΔQR

Expand All @@ -72,7 +80,7 @@ function qr_pullback!(
ΔA1 = view(ΔA, :, 1:p)
ΔA2 = view(ΔA, :, (p + 1):n)

check_qr_cotangents(Q, R, ΔQ, ΔR, minmn, p; gauge_atol)
check_qr_cotangents(Q, R, ΔQ, ΔR, p; gauge_atol)

ΔQ̃ = zero!(similar(Q, (m, p)))
if !iszerotangent(ΔQ)
Expand Down
33 changes: 24 additions & 9 deletions test/testsuite/TestSuite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ using MatrixAlgebraKit
using MatrixAlgebraKit: diagview
using LinearAlgebra: Diagonal, norm, istriu, istril, I
using Random, StableRNGs
using Mooncake
using AMDGPU, CUDA

const tests = Dict()
Expand Down Expand Up @@ -86,16 +87,30 @@ instantiate_unitary(::Type{<:Diagonal}, A, sz) = Diagonal(fill!(similar(parent(A

include("ad_utils.jl")

include("qr.jl")
include("lq.jl")
include("polar.jl")
include("projections.jl")
include("schur.jl")
include("eig.jl")
include("eigh.jl")
include("orthnull.jl")
include("svd.jl")
include("mooncake.jl")

# Decompositions
# --------------
include("decompositions/qr.jl")
include("decompositions/lq.jl")
include("decompositions/polar.jl")
include("decompositions/schur.jl")
include("decompositions/eig.jl")
include("decompositions/eigh.jl")
include("decompositions/orthnull.jl")
include("decompositions/svd.jl")

# Mooncake
# --------
include("mooncake/mooncake.jl")
include("mooncake/qr.jl")
include("mooncake/lq.jl")
include("mooncake/eig.jl")
include("mooncake/eigh.jl")
include("mooncake/svd.jl")
include("mooncake/polar.jl")
include("mooncake/orthnull.jl")

include("enzyme.jl")
include("chainrules.jl")

Expand Down
13 changes: 2 additions & 11 deletions test/testsuite/ad_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,6 @@ function remove_svdgauge_dependence!(
mul!(ΔU, U, gaugepart, -1, 1)
return ΔU, ΔVᴴ
end
function remove_eiggauge_dependence!(
ΔV, D, V;
degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D)
)
gaugepart = V' * ΔV
gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0
mul!(ΔV, V / (V' * V), gaugepart, -1, 1)
return ΔV
end
function remove_eighgauge_dependence!(
ΔV, D, V;
degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D)
Expand Down Expand Up @@ -204,7 +195,7 @@ function ad_eig_full_setup(A)
D, V = DV
Ddiag = diagview(D)
ΔV = randn!(similar(A, complex(T), m, m))
ΔV = remove_eiggauge_dependence!(ΔV, D, V)
ΔV = remove_eig_gauge_dependence!(ΔV, D, V)
ΔD = randn!(similar(A, complex(T), m, m))
ΔD2 = Diagonal(randn!(similar(A, complex(T), m)))
return DV, (ΔD, ΔV), (ΔD2, ΔV)
Expand All @@ -216,7 +207,7 @@ function ad_eig_full_setup(A::Diagonal)
DV = eig_full(A)
D, V = DV
ΔV = randn!(similar(A.diag, T, m, m))
ΔV = remove_eiggauge_dependence!(ΔV, D, V)
ΔV = remove_eig_gauge_dependence!(ΔV, D, V)
ΔD = Diagonal(randn!(similar(A.diag, T, m)))
ΔD2 = Diagonal(randn!(similar(A.diag, T, m)))
return DV, (ΔD, ΔV), (ΔD2, ΔV)
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using TestExtras
using LinearAlgebra

include("../linearmap.jl")
include("../../linearmap.jl")

_left_orth_svd(x; kwargs...) = left_orth(x; alg = :svd, kwargs...)
_left_orth_svd!(x, VC; kwargs...) = left_orth!(x, VC; alg = :svd, kwargs...)
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading