@@ -7,9 +7,7 @@ using Enzyme, EnzymeTestUtils
77using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD
88using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul!
99
10- is_ci = get (ENV , " CI" , " false" ) == " true"
11-
12- ETs = is_ci ? (Float64, Float32) : (Float64, Float32, ComplexF32, ComplexF64) # Enzyme/#2631
10+ ETs = (Float32, ComplexF64)
1311include (" ad_utils.jl" )
1412function test_pullbacks_match (rng, f!, f, A, args, Δargs, alg = nothing ; ȳ = copy .(Δargs), return_act = Duplicated)
1513 ΔA = randn (rng, eltype (A), size (A)... )
188186 Vtrunc = V[:, ind]
189187 ΔDtrunc = Diagonal (diagview (ΔD2)[ind])
190188 ΔVtrunc = ΔV[:, ind]
191- # broken due to Enzyme
192- # test_reverse(eig_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))))
193- # broken due to Enzyme
194- # test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg, ȳ=(ΔDtrunc, ΔVtrunc, zero(real(T))), return_act=RT)
189+ test_reverse (eig_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
190+ test_pullbacks_match (rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ= (ΔDtrunc, ΔVtrunc), return_act= RT)
195191 dA1 = MatrixAlgebraKit. eig_pullback! (zero (A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
196192 dA2 = MatrixAlgebraKit. eig_trunc_pullback! (zero (A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc))
197193 @test isapprox (dA1, dA2; atol = atol, rtol = rtol)
202198 Vtrunc = V[:, ind]
203199 ΔDtrunc = Diagonal (diagview (ΔD2)[ind])
204200 ΔVtrunc = ΔV[:, ind]
205- # broken due to Enzyme
206- # test_reverse(eig_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))))
207- # broken due to Enzyme
208- # test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg; ȳ=(ΔDtrunc, ΔVtrunc, zero(real(T))), return_act=RT)
201+ test_reverse (eig_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
202+ test_pullbacks_match (rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg; ȳ= (ΔDtrunc, ΔVtrunc), return_act= RT)
209203 dA1 = MatrixAlgebraKit. eig_pullback! (zero (A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
210204 dA2 = MatrixAlgebraKit. eig_trunc_pullback! (zero (A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc))
211205 @test isapprox (dA1, dA2; atol = atol, rtol = rtol)
@@ -253,24 +247,24 @@ function copy_eigh_vals!(A, D, alg; kwargs...)
253247 return eigh_vals! (A, D, alg; kwargs... )
254248end
255249
256- function copy_eigh_trunc (A; kwargs... )
250+ function copy_eigh_trunc_no_error (A; kwargs... )
257251 A = (A + A' ) / 2
258- return eigh_trunc (A; kwargs... )
252+ return eigh_trunc_no_error (A; kwargs... )
259253end
260254
261- function copy_eigh_trunc ! (A, DV; kwargs... )
255+ function copy_eigh_trunc_no_error ! (A, DV; kwargs... )
262256 A = (A + A' ) / 2
263- return eigh_trunc ! (A, DV; kwargs... )
257+ return eigh_trunc_no_error ! (A, DV; kwargs... )
264258end
265259
266- function copy_eigh_trunc (A, alg; kwargs... )
260+ function copy_eigh_trunc_no_error (A, alg; kwargs... )
267261 A = (A + A' ) / 2
268- return eigh_trunc (A ; kwargs... )
262+ return eigh_trunc_no_error (A, alg ; kwargs... )
269263end
270264
271- function copy_eigh_trunc ! (A, DV, alg; kwargs... )
265+ function copy_eigh_trunc_no_error ! (A, DV, alg; kwargs... )
272266 A = (A + A' ) / 2
273- return eigh_trunc ! (A, DV; kwargs... )
267+ return eigh_trunc_no_error ! (A, DV, alg ; kwargs... )
274268end
275269
276270@timedtestset " EIGH AD Rules with eltype $T " for T in ETs
307301 Vtrunc = V[:, ind]
308302 ΔDtrunc = Diagonal (diagview (ΔD2)[ind])
309303 ΔVtrunc = ΔV[:, ind]
310- # broken due to Enzyme
311- # test_reverse(copy_eigh_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))))
312- # test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg, ȳ=(ΔDtrunc, ΔVtrunc, zero(real(T))), return_act=RT)
304+ test_reverse (copy_eigh_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
305+ test_pullbacks_match (rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ= (ΔDtrunc, ΔVtrunc), return_act= RT)
313306 end
314307 Ddiag = diagview (D)
315308 truncalg = TruncatedAlgorithm (alg, trunctol (; atol = maximum (abs, Ddiag) / 2 ))
318311 Vtrunc = V[:, ind]
319312 ΔDtrunc = Diagonal (diagview (ΔD2)[ind])
320313 ΔVtrunc = ΔV[:, ind]
321- # broken due to Enzyme
322- # test_reverse(copy_eigh_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))))
323- # test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg, ȳ=(ΔDtrunc, ΔVtrunc, zero(real(T))), return_act=RT)
314+ test_reverse (copy_eigh_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
315+ test_pullbacks_match (rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ= (ΔDtrunc, ΔVtrunc), return_act= RT)
324316 end
325317 end
326318end
386378 ΔStrunc = Diagonal (diagview (ΔS2)[ind])
387379 ΔUtrunc = ΔU[:, ind]
388380 ΔVᴴtrunc = ΔVᴴ[ind, :]
389- test_reverse (svd_trunc , RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), fdm = fdm)
390- test_pullbacks_match (rng, svd_trunc !, svd_trunc , A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ= (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), return_act= RT)
381+ test_reverse (svd_trunc_no_error , RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), fdm = fdm)
382+ test_pullbacks_match (rng, svd_trunc_no_error !, svd_trunc_no_error , A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ= (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), return_act= RT)
391383 end
392384 U, S, Vᴴ = svd_compact (A)
393385 ΔU = randn (rng, T, m, minmn)
403395 ΔStrunc = Diagonal (diagview (ΔS2)[ind])
404396 ΔUtrunc = ΔU[:, ind]
405397 ΔVᴴtrunc = ΔVᴴ[ind, :]
406- test_reverse (svd_trunc , RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), fdm = fdm)
407- test_pullbacks_match (rng, svd_trunc !, svd_trunc , A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ= (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), return_act= RT)
398+ test_reverse (svd_trunc_no_error , RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), fdm = fdm)
399+ test_pullbacks_match (rng, svd_trunc_no_error !, svd_trunc_no_error , A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ= (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), return_act= RT)
408400 end
409401 end
410402 end
0 commit comments