Skip to content

Commit 57e5e23

Browse files
committed
Working truncs
1 parent 0db2345 commit 57e5e23

File tree

3 files changed

+74
-165
lines changed

3 files changed

+74
-165
lines changed

ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl

Lines changed: 49 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ module MatrixAlgebraKitEnzymeExt
22

33
using MatrixAlgebraKit
44
using MatrixAlgebraKit: copy_input
5-
using MatrixAlgebraKit: diagview, inv_safe, eig_trunc!, eigh_trunc!
5+
using MatrixAlgebraKit: diagview, inv_safe, eig_trunc!, eigh_trunc!, truncate
66
using MatrixAlgebraKit: qr_pullback!, lq_pullback!
77
using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
8-
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_trunc_pullback!, eigh_trunc_pullback!
8+
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!
99
using MatrixAlgebraKit: svd_pullback!
1010
using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!
1111
using Enzyme
@@ -187,7 +187,7 @@ end
187187

188188
function EnzymeRules.augmented_primal(
189189
config::EnzymeRules.RevConfigWidth{1},
190-
func::Const{typeof(svd_trunc!)},
190+
func::Const{typeof(svd_trunc_no_error!)},
191191
::Type{RT},
192192
A::Annotation,
193193
USVᴴ::Annotation,
@@ -218,7 +218,7 @@ function EnzymeRules.augmented_primal(
218218
end
219219
function EnzymeRules.reverse(
220220
config::EnzymeRules.RevConfigWidth{1},
221-
func::Const{typeof(svd_trunc!)},
221+
func::Const{typeof(svd_trunc_no_error!)},
222222
dret::Type{RT},
223223
cache,
224224
A::Annotation,
@@ -235,137 +235,54 @@ function EnzymeRules.reverse(
235235
!isa(USVᴴ, Const) && make_zero!(USVᴴ.dval)
236236
return (nothing, nothing, nothing)
237237
end
238-
#=
239-
function EnzymeRules.augmented_primal(
240-
config::EnzymeRules.RevConfigWidth{1},
241-
func::Const{typeof(svd_trunc)},
242-
::Type{MixedDuplicated},
243-
A::Annotation,
244-
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
245-
)
246-
# form cache if needed
247-
cache_A = copy(A.val)
248-
U, S, Vᴴ, ϵ = svd_trunc(A.val, USVᴴ.val, alg.val.alg)
249-
primal = EnzymeRules.needs_primal(config) ? (U, S, Vᴴ, ϵ) : nothing
250-
dU = zero(U)
251-
dS = zero(S)
252-
dVᴴ = zero(Vᴴ)
253-
dϵ = zero(ϵ)
254-
shadow = EnzymeRules.needs_shadow(config) ? (dU, dS, dVᴴ, dϵ) : nothing
255-
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, (U, S, Vᴴ), (dU, dS, dVᴴ)))
256-
end
257-
function EnzymeRules.reverse(
258-
config::EnzymeRules.RevConfigWidth{1},
259-
func::Const{typeof(svd_trunc)},
260-
dret::Type{MixedDuplicated},
261-
cache,
262-
A::Annotation,
263-
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
264-
)
265-
cache_A, cache_USVᴴ, shadow_USVᴴ = cache
266-
U, S, Vᴴ = cache_USVᴴ
267-
dU, dS, dVᴴ = shadow_USVᴴ
268-
Aval = isnothing(cache_A) ? A.val : cache_A
269-
if !isa(A, Const) && !isa(USVᴴ, Const)
270-
svd_trunc_pullback!(A.dval, Aval, (U, S, Vᴴ), shadow_USVᴴ, ind)
271-
end
272-
return (nothing, nothing, nothing)
273-
end
274-
=#
275-
function EnzymeRules.augmented_primal(
276-
config::EnzymeRules.RevConfigWidth{1},
277-
func::Const{typeof(eigh_trunc!)},
278-
::Type{RT},
279-
A::Annotation,
280-
DV::Annotation{Tuple{TD, TV}},
281-
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
282-
) where {RT, TD, TV}
283-
# form cache if needed
284-
cache_A = copy(A.val)
285-
MatrixAlgebraKit.eigh_full!(A.val, DV.val, alg.val.alg)
286-
cache_DV = copy.(DV.val)
287-
DV′, ind = MatrixAlgebraKit.truncate(eigh_trunc!, DV.val, alg.val.trunc)
288-
ϵ.val = MatrixAlgebraKit.truncation_error!(diagview(DV.val[1]), ind)
289-
primal = EnzymeRules.needs_primal(config) ? (DV′..., ϵ.val) : nothing
290-
shadow_DV = if !isa(A, Const) && !isa(DV, Const)
291-
dD, dV = DV.dval
292-
dDtrunc = Diagonal(diagview(dD)[ind])
293-
dVtrunc = dV[:, ind]
294-
(dDtrunc, dVtrunc)
295-
else
296-
(nothing, nothing)
297-
end
298-
!isa(ϵ, Const) && make_zero.dval)
299-
shadow_ϵ = !isa(ϵ, Const) ? ϵ.dval : zero(T)
300-
shadow = EnzymeRules.needs_shadow(config) ? (shadow_DV..., shadow_ϵ) : nothing
301-
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_DV, shadow_DV, ind))
302-
end
303-
function EnzymeRules.reverse(
304-
config::EnzymeRules.RevConfigWidth{1},
305-
func::Const{typeof(eigh_trunc!)},
306-
::Type{RT},
307-
cache,
308-
A::Annotation,
309-
DV::Annotation{Tuple{TD, TV}},
310-
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
311-
) where {RT, TD, TV}
312-
cache_A, cache_DV, cache_dDVtrunc, ind = cache
313-
Aval = cache_A
314-
D, V = cache_DV
315-
dD, dV = cache_dDVtrunc
316-
if !isa(A, Const) && !isa(DV, Const)
317-
MatrixAlgebraKit.eigh_pullback!(A.dval, Aval, (D, V), (dD, dV), ind)
318-
end
319-
!isa(DV, Const) && make_zero!(DV.dval)
320-
!isa(ϵ, Const) && make_zero!.dval)
321-
return (nothing, nothing, nothing, nothing)
322-
end
323238

324-
function EnzymeRules.augmented_primal(
325-
config::EnzymeRules.RevConfigWidth{1},
326-
func::Const{typeof(eig_trunc!)},
327-
::Type{RT},
328-
A::Annotation,
329-
DV::Annotation{Tuple{TD, TV}},
330-
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
331-
) where {RT, TD, TV}
332-
# form cache if needed
333-
cache_A = copy(A.val)
334-
eig_full!(A.val, DV.val, alg.val.alg)
335-
cache_DV = copy.(DV.val)
336-
DV′, ind = MatrixAlgebraKit.truncate(eig_trunc!, DV.val, alg.val.trunc)
337-
ϵ.val = MatrixAlgebraKit.truncation_error!(diagview(DV.val[1]), ind)
338-
primal = EnzymeRules.needs_primal(config) ? (DV′..., ϵ.val) : nothing
339-
shadow_DV = if !isa(A, Const) && !isa(DV, Const)
340-
dD, dV = DV.dval
341-
dDtrunc = Diagonal(diagview(dD)[ind])
342-
dVtrunc = dV[:, ind]
343-
(dDtrunc, dVtrunc)
344-
else
345-
(nothing, nothing)
239+
for (f, trunc_f, full_f, pb) in ((:eigh_trunc_no_error!, :eigh_trunc!, :eigh_full!, :eigh_pullback!),
240+
(:eig_trunc_no_error!, :eig_trunc!, :eig_full!, :eig_pullback!),
241+
)
242+
@eval function EnzymeRules.augmented_primal(
243+
config::EnzymeRules.RevConfigWidth{1},
244+
func::Const{typeof($f)},
245+
::Type{RT},
246+
A::Annotation,
247+
DV::Annotation{Tuple{TD, TV}},
248+
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
249+
) where {RT, TD, TV}
250+
# form cache if needed
251+
cache_A = copy(A.val)
252+
$full_f(A.val, DV.val, alg.val.alg)
253+
cache_DV = copy.(DV.val)
254+
DV′, ind = truncate($trunc_f, DV.val, alg.val.trunc)
255+
primal = EnzymeRules.needs_primal(config) ? DV′ : nothing
256+
shadow_DV = if !isa(A, Const) && !isa(DV, Const)
257+
dD, dV = DV.dval
258+
dDtrunc = Diagonal(diagview(dD)[ind])
259+
dVtrunc = dV[:, ind]
260+
(dDtrunc, dVtrunc)
261+
else
262+
(nothing, nothing)
263+
end
264+
shadow = EnzymeRules.needs_shadow(config) ? shadow_DV : nothing
265+
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_DV, shadow_DV, ind))
346266
end
347-
shadow = EnzymeRules.needs_shadow(config) ? (shadow_DV..., zero(T)) : nothing
348-
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_DV, shadow_DV))
349-
end
350-
function EnzymeRules.reverse(
351-
config::EnzymeRules.RevConfigWidth{1},
352-
func::Const{typeof(eig_trunc!)},
353-
::Type{RT},
354-
cache,
355-
A::Annotation,
356-
DV::Annotation{Tuple{TD, TV}},
357-
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
358-
) where {RT, TD, TV}
359-
cache_A, cache_DV, cache_dDVtrunc = cache
360-
D, V = cache_DV
361-
Aval = cache_A
362-
dD, dV = cache_dDVtrunc
363-
if !isa(A, Const) && !isa(DV, Const)
364-
eig_trunc_pullback!(A.dval, Aval, (D, V), (dD, dV))
267+
@eval function EnzymeRules.reverse(
268+
config::EnzymeRules.RevConfigWidth{1},
269+
func::Const{typeof($f)},
270+
::Type{RT},
271+
cache,
272+
A::Annotation,
273+
DV::Annotation{Tuple{TD, TV}},
274+
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
275+
) where {RT, TD, TV}
276+
cache_A, cache_DV, cache_dDVtrunc, ind = cache
277+
Aval = cache_A
278+
D, V = cache_DV
279+
dD, dV = cache_dDVtrunc
280+
if !isa(A, Const) && !isa(DV, Const)
281+
$pb(A.dval, Aval, (D, V), (dD, dV), ind)
282+
end
283+
!isa(DV, Const) && make_zero!(DV.dval)
284+
return (nothing, nothing, nothing)
365285
end
366-
!isa(DV, Const) && make_zero!(DV.dval)
367-
!isa(ϵ, Const) && make_zero!.dval)
368-
return (nothing, nothing, nothing, nothing)
369286
end
370287

371288
for (f!, f_full!, pb!) in (

test/enzyme.jl

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@ using Enzyme, EnzymeTestUtils
77
using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD
88
using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul!
99

10-
is_ci = get(ENV, "CI", "false") == "true"
11-
12-
ETs = is_ci ? (Float64, Float32) : (Float64, Float32, ComplexF32, ComplexF64) # Enzyme/#2631
10+
ETs = (Float32, ComplexF64)
1311
include("ad_utils.jl")
1412
function test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; ȳ = copy.(Δargs), return_act = Duplicated)
1513
ΔA = randn(rng, eltype(A), size(A)...)
@@ -188,10 +186,8 @@ end
188186
Vtrunc = V[:, ind]
189187
ΔDtrunc = Diagonal(diagview(ΔD2)[ind])
190188
ΔVtrunc = ΔV[:, ind]
191-
# broken due to Enzyme
192-
#test_reverse(eig_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))))
193-
# broken due to Enzyme
194-
#test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg, ȳ=(ΔDtrunc, ΔVtrunc, zero(real(T))), return_act=RT)
189+
test_reverse(eig_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
190+
test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ=(ΔDtrunc, ΔVtrunc), return_act=RT)
195191
dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
196192
dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc))
197193
@test isapprox(dA1, dA2; atol = atol, rtol = rtol)
@@ -202,10 +198,8 @@ end
202198
Vtrunc = V[:, ind]
203199
ΔDtrunc = Diagonal(diagview(ΔD2)[ind])
204200
ΔVtrunc = ΔV[:, ind]
205-
# broken due to Enzyme
206-
#test_reverse(eig_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))))
207-
# broken due to Enzyme
208-
#test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg; ȳ=(ΔDtrunc, ΔVtrunc, zero(real(T))), return_act=RT)
201+
test_reverse(eig_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
202+
test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg; ȳ=(ΔDtrunc, ΔVtrunc), return_act=RT)
209203
dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
210204
dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc))
211205
@test isapprox(dA1, dA2; atol = atol, rtol = rtol)
@@ -253,24 +247,24 @@ function copy_eigh_vals!(A, D, alg; kwargs...)
253247
return eigh_vals!(A, D, alg; kwargs...)
254248
end
255249

256-
function copy_eigh_trunc(A; kwargs...)
250+
function copy_eigh_trunc_no_error(A; kwargs...)
257251
A = (A + A') / 2
258-
return eigh_trunc(A; kwargs...)
252+
return eigh_trunc_no_error(A; kwargs...)
259253
end
260254

261-
function copy_eigh_trunc!(A, DV; kwargs...)
255+
function copy_eigh_trunc_no_error!(A, DV; kwargs...)
262256
A = (A + A') / 2
263-
return eigh_trunc!(A, DV; kwargs...)
257+
return eigh_trunc_no_error!(A, DV; kwargs...)
264258
end
265259

266-
function copy_eigh_trunc(A, alg; kwargs...)
260+
function copy_eigh_trunc_no_error(A, alg; kwargs...)
267261
A = (A + A') / 2
268-
return eigh_trunc(A; kwargs...)
262+
return eigh_trunc_no_error(A, alg; kwargs...)
269263
end
270264

271-
function copy_eigh_trunc!(A, DV, alg; kwargs...)
265+
function copy_eigh_trunc_no_error!(A, DV, alg; kwargs...)
272266
A = (A + A') / 2
273-
return eigh_trunc!(A, DV; kwargs...)
267+
return eigh_trunc_no_error!(A, DV, alg; kwargs...)
274268
end
275269

276270
@timedtestset "EIGH AD Rules with eltype $T" for T in ETs
@@ -307,9 +301,8 @@ end
307301
Vtrunc = V[:, ind]
308302
ΔDtrunc = Diagonal(diagview(ΔD2)[ind])
309303
ΔVtrunc = ΔV[:, ind]
310-
# broken due to Enzyme
311-
#test_reverse(copy_eigh_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))))
312-
#test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg, ȳ=(ΔDtrunc, ΔVtrunc, zero(real(T))), return_act=RT)
304+
test_reverse(copy_eigh_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
305+
test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ=(ΔDtrunc, ΔVtrunc), return_act=RT)
313306
end
314307
Ddiag = diagview(D)
315308
truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, Ddiag) / 2))
@@ -318,9 +311,8 @@ end
318311
Vtrunc = V[:, ind]
319312
ΔDtrunc = Diagonal(diagview(ΔD2)[ind])
320313
ΔVtrunc = ΔV[:, ind]
321-
# broken due to Enzyme
322-
#test_reverse(copy_eigh_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))))
323-
#test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg, ȳ=(ΔDtrunc, ΔVtrunc, zero(real(T))), return_act=RT)
314+
test_reverse(copy_eigh_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
315+
test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ=(ΔDtrunc, ΔVtrunc), return_act=RT)
324316
end
325317
end
326318
end
@@ -386,8 +378,8 @@ end
386378
ΔStrunc = Diagonal(diagview(ΔS2)[ind])
387379
ΔUtrunc = ΔU[:, ind]
388380
ΔVᴴtrunc = ΔVᴴ[ind, :]
389-
test_reverse(svd_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), fdm = fdm)
390-
test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ=(ΔUtrunc, ΔStrunc, ΔVᴴtrunc), return_act=RT)
381+
test_reverse(svd_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), fdm = fdm)
382+
test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ=(ΔUtrunc, ΔStrunc, ΔVᴴtrunc), return_act=RT)
391383
end
392384
U, S, Vᴴ = svd_compact(A)
393385
ΔU = randn(rng, T, m, minmn)
@@ -403,8 +395,8 @@ end
403395
ΔStrunc = Diagonal(diagview(ΔS2)[ind])
404396
ΔUtrunc = ΔU[:, ind]
405397
ΔVᴴtrunc = ΔVᴴ[ind, :]
406-
test_reverse(svd_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), fdm = fdm)
407-
test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ=(ΔUtrunc, ΔStrunc, ΔVᴴtrunc), return_act=RT)
398+
test_reverse(svd_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), fdm = fdm)
399+
test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ=(ΔUtrunc, ΔStrunc, ΔVᴴtrunc), return_act=RT)
408400
end
409401
end
410402
end

test/runtests.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using SafeTestsets
44
# specific ones
55
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
66
if !is_buildkite
7-
#=@safetestset "Algorithms" begin
7+
@safetestset "Algorithms" begin
88
include("algorithms.jl")
99
end
1010
@safetestset "Projections" begin
@@ -37,11 +37,11 @@ if !is_buildkite
3737
end
3838
@safetestset "Image and Null Space" begin
3939
include("orthnull.jl")
40-
end=#
40+
end
4141
@safetestset "Enzyme" begin
4242
include("enzyme.jl")
4343
end
44-
#=@safetestset "Mooncake" begin
44+
@safetestset "Mooncake" begin
4545
include("mooncake.jl")
4646
end
4747
@safetestset "ChainRules" begin
@@ -75,7 +75,7 @@ if !is_buildkite
7575
using GenericSchur
7676
@safetestset "General Eigenvalue Decomposition" begin
7777
include("genericschur/eig.jl")
78-
end=#
78+
end
7979
end
8080

8181
using CUDA

0 commit comments

Comments
 (0)