@@ -279,6 +279,20 @@ numdiff = map(numdiffn) do n
279279end
280280
281281
282+ let n = first (csan)
283+ bfun, b_u0, b_p, brusselator_jac, brusselator_comp = makebrusselator! (PROBS, n)
284+ solver = Rodas5 (autodiff = false )
285+ for alg in ADJOINT_METHODS_IQ
286+ f = SciMLSensitivity. alg_autodiff (alg) ? bfun :
287+ ODEFunction (bfun, jac = brusselator_jac)
288+ diffeq_sen_l2 (f, b_u0, tspan, b_p, bt, solver; sensalg = alg, tols... )
289+ end
290+ for alg in ADJOINT_METHODS_G
291+ diffeq_sen_l2 (bfun, b_u0, tspan, b_p, bt, solver; sensalg = alg, tols... )
292+ end
293+ end
294+
295+
282296csa_iq = map (csan) do n
283297 bfun, b_u0, b_p, brusselator_jac, brusselator_comp = makebrusselator! (PROBS, n)
284298 @time ts = map (ADJOINT_METHODS_IQ) do alg
@@ -366,6 +380,18 @@ _adjoint_methods = ntuple(4) do ii
366380end |> NamedTuple{(:interp , :quad , :gauss , :gausskronrod )}
367381adjoint_methods = mapreduce (collect, vcat, _adjoint_methods)
368382
383+
384+ let n = first (csan)
385+ bfun, b_u0, b_p, brusselator_jac, brusselator_comp = makebrusselator! (PROBS, n)
386+ solver = Rodas5 (autodiff = false )
387+ for alg in adjoint_methods
388+ f = SciMLSensitivity. alg_autodiff (alg) ? bfun :
389+ ODEFunction (bfun, jac = brusselator_jac)
390+ diffeq_sen_l2 (f, b_u0, tspan, b_p, bt, solver; sensalg = alg, tols... )
391+ end
392+ end
393+
394+
369395csavjp = map (csan) do n
370396 bfun, b_u0, b_p, brusselator_jac, brusselator_comp = makebrusselator! (PROBS, n)
371397 @time ts = map (adjoint_methods) do alg
@@ -448,6 +474,275 @@ yaxis!(plt_gk, "Runtime (s)", :log10);
448474plot! (plt_gk, legend = :outertopleft , size = (1200 , 600 ))
449475
450476
477+ const CHILD_PREAMBLE = raw """
478+ using OrdinaryDiffEq, ReverseDiff, ForwardDiff, FiniteDiff, SciMLSensitivity
479+ using LinearAlgebra, Mooncake
480+
481+ function get_rss_mib()
482+ statm = read("/proc/self/statm", String)
483+ resident_pages = parse(Int, split(statm)[2])
484+ return resident_pages * 4096 / (1024^2)
485+ end
486+
487+ function makebrusselator(N = 8)
488+ xyd_brusselator = range(0, stop = 1, length = N)
489+ function limit(a, N)
490+ if a == N+1
491+ return 1
492+ elseif a == 0
493+ return N
494+ else
495+ return a
496+ end
497+ end
498+ brusselator_f(x, y, t) = ifelse(
499+ (((x-0.3)^2 + (y-0.6)^2) <= 0.1^2) &&
500+ (t >= 1.1), 5.0, 0.0)
501+ brusselator_2d_loop = let N=N, xyd=xyd_brusselator, dx=step(xyd_brusselator)
502+ function brusselator_2d_loop(du, u, p, t)
503+ @inbounds begin
504+ ii1 = N^2
505+ ii2 = ii1+N^2
506+ ii3 = ii2+2(N^2)
507+ A = @view p[1:ii1]
508+ B = @view p[(ii1 + 1):ii2]
509+ α = @view p[(ii2 + 1):ii3]
510+ II = LinearIndices((N, N, 2))
511+ for I in CartesianIndices((N, N))
512+ x = xyd[I[1]]
513+ y = xyd[I[2]]
514+ i = I[1]
515+ j = I[2]
516+ ip1 = limit(i+1, N);
517+ im1 = limit(i-1, N)
518+ jp1 = limit(j+1, N);
519+ jm1 = limit(j-1, N)
520+ du[II[i, j, 1]] = α[II[
521+ i, j, 1]]*(u[II[im1, j, 1]] + u[II[ip1, j, 1]] +
522+ u[II[i, jp1, 1]] + u[II[i, jm1, 1]] -
523+ 4u[II[i, j, 1]])/dx^2 +
524+ B[II[i, j, 1]] + u[II[i, j, 1]]^2*u[II[i, j, 2]] -
525+ (A[II[i, j, 1]] + 1)*u[II[i, j, 1]] +
526+ brusselator_f(x, y, t)
527+ end
528+ for I in CartesianIndices((N, N))
529+ i = I[1]
530+ j = I[2]
531+ ip1 = limit(i+1, N)
532+ im1 = limit(i-1, N)
533+ jp1 = limit(j+1, N)
534+ jm1 = limit(j-1, N)
535+ du[II[i, j, 2]] = α[II[
536+ i, j, 2]]*(u[II[im1, j, 2]] + u[II[ip1, j, 2]] + u[II[i, jp1, 2]] +
537+ u[II[i, jm1, 2]] - 4u[II[i, j, 2]])/dx^2 +
538+ A[II[i, j, 1]]*u[II[i, j, 1]] -
539+ u[II[i, j, 1]]^2*u[II[i, j, 2]]
540+ end
541+ return nothing
542+ end
543+ end
544+ end
545+ function init_brusselator_2d(xyd)
546+ N = length(xyd)
547+ u = zeros(N, N, 2)
548+ for I in CartesianIndices((N, N))
549+ x = xyd[I[1]]
550+ y = xyd[I[2]]
551+ u[I, 1] = 22*(y*(1-y))^(3/2)
552+ u[I, 2] = 27*(x*(1-x))^(3/2)
553+ end
554+ vec(u)
555+ end
556+ dx = step(xyd_brusselator)
557+ e1 = ones(N-1)
558+ off = N-1
559+ e4 = ones(N-off)
560+ T = diagm(0=>-2ones(N), -1=>e1, 1=>e1, off=>e4, -off=>e4) ./ dx^2
561+ Ie = Matrix{Float64}(I, N, N)
562+ Op = kron(Ie, T) + kron(T, Ie)
563+ brusselator_jac = let N=N
564+ (J, a, p, t) -> begin
565+ ii1 = N^2
566+ ii2 = ii1+N^2
567+ ii3 = ii2+2(N^2)
568+ A = @view p[1:ii1]
569+ B = @view p[(ii1 + 1):ii2]
570+ α = @view p[(ii2 + 1):ii3]
571+ u = @view a[1:(end ÷ 2)]
572+ v = @view a[(end ÷ 2 + 1):end]
573+ N2 = length(a)÷2
574+ α1 = @view α[1:(end ÷ 2)]
575+ α2 = @view α[(end ÷ 2 + 1):end]
576+ fill!(J, 0)
577+ J[1:N2, 1:N2] .= α1 .* Op
578+ J[(N2 + 1):end, (N2 + 1):end] .= α2 .* Op
579+ J1 = @view J[1:N2, 1:N2]
580+ J2 = @view J[(N2 + 1):end, 1:N2]
581+ J3 = @view J[1:N2, (N2 + 1):end]
582+ J4 = @view J[(N2 + 1):end, (N2 + 1):end]
583+ J1[diagind(J1)] .+= @. 2u*v-(A+1)
584+ J2[diagind(J2)] .= @. A-2u*v
585+ J3[diagind(J3)] .= @. u^2
586+ J4[diagind(J4)] .+= @. -u^2
587+ nothing
588+ end
589+ end
590+ u0 = init_brusselator_2d(xyd_brusselator)
591+ p = [fill(3.4, N^2); fill(1.0, N^2); fill(10.0, 2*N^2)]
592+ brusselator_2d_loop, u0, p, brusselator_jac
593+ end
594+
595+ Base.vec(v::Adjoint{<:Real, <:AbstractVector}) = vec(v')
596+
597+ bt = 0:0.1:1
598+ tspan = (0.0, 1.0)
599+ tols = (abstol = 1e-5, reltol = 1e-7)
600+
601+ function auto_sen_l2(
602+ f, u0, tspan, p, t, alg = Tsit5(); diffalg = ReverseDiff.gradient, kwargs...)
603+ test_f(p) = begin
604+ prob = ODEProblem{true, SciMLBase.FullSpecialize}(f, convert.(eltype(p), u0), tspan, p)
605+ sol = solve(prob, alg, saveat = t; kwargs...)
606+ sum(sol.u) do x
607+ sum(z->(1-z)^2/2, x)
608+ end
609+ end
610+ diffalg(test_f, p)
611+ end
612+
613+ @inline function diffeq_sen_l2(df, u0, tspan, p, t, alg = Tsit5();
614+ abstol = 1e-5, reltol = 1e-7, iabstol = abstol, ireltol = reltol,
615+ sensalg = SensitivityAlg(), kwargs...)
616+ prob = ODEProblem{true, SciMLBase.FullSpecialize}(df, u0, tspan, p)
617+ saveat = tspan[1] != t[1] && tspan[end] != t[end] ? vcat(tspan[1], t, tspan[end]) : t
618+ sol = solve(prob, alg, abstol = abstol, reltol = reltol, saveat = saveat; kwargs...)
619+ dg(out, u, p, t, i) = (out.=u .- 1.0)
620+ adjoint_sensitivities(sol, alg; t, abstol = abstol, dgdu_discrete = dg,
621+ reltol = reltol, sensealg = sensalg)
622+ end
623+ """
624+
625+ const PROJECT_DIR = @__DIR__
626+
627+ function run_memory_benchmark (n:: Int , method_setup:: String )
628+ child_script = CHILD_PREAMBLE * """
629+
630+ n = $(n)
631+ bfun, b_u0, b_p, brusselator_jac = makebrusselator(n)
632+
633+ GC.gc(); GC.gc()
634+ rss_before = get_rss_mib()
635+
636+ """ * method_setup * """
637+
638+ GC.gc(); GC.gc()
639+ rss_after = get_rss_mib()
640+
641+ println("BRUSSMEM_TIMING:", t)
642+ println("BRUSSMEM_RSS_BEFORE:", rss_before)
643+ println("BRUSSMEM_RSS_AFTER:", rss_after)
644+ """
645+
646+ try
647+ output = read (
648+ ` $(Base. julia_cmd ()) --project=$(PROJECT_DIR) -e $(child_script) ` , String)
649+ time_m = match (r" BRUSSMEM_TIMING:([\d .eE+-]+)" , output)
650+ before_m = match (r" BRUSSMEM_RSS_BEFORE:([\d .eE+-]+)" , output)
651+ after_m = match (r" BRUSSMEM_RSS_AFTER:([\d .eE+-]+)" , output)
652+ if time_m === nothing || before_m === nothing || after_m === nothing
653+ @warn " Failed to parse subprocess output" n output
654+ return (; rss_before = NaN , rss_after = NaN , delta_mib = NaN , timing = NaN )
655+ end
656+ timing = parse (Float64, time_m. captures[1 ])
657+ rss_before = parse (Float64, before_m. captures[1 ])
658+ rss_after = parse (Float64, after_m. captures[1 ])
659+ delta_mib = rss_after - rss_before
660+ return (; rss_before, rss_after, delta_mib, timing)
661+ catch e
662+ @warn " Subprocess failed" n exception = (e, catch_backtrace ())
663+ return (; rss_before = NaN , rss_after = NaN , delta_mib = NaN , timing = NaN )
664+ end
665+ end
666+
667+ mem_sizes = [2 , 4 , 6 , 8 , 10 , 12 ]
668+
669+
670+ forwarddiff_mem = map (mem_sizes) do n
671+ result = run_memory_benchmark (n, """
672+ auto_sen_l2(bfun, b_u0, tspan, b_p, bt, Rodas5();
673+ diffalg = ForwardDiff.gradient, tols...)
674+ t = @elapsed auto_sen_l2(bfun, b_u0, tspan, b_p, bt, Rodas5();
675+ diffalg = ForwardDiff.gradient, tols...)
676+ """ )
677+ @show n, result
678+ result
679+ end
680+
681+
682+ numdiff_mem = map (mem_sizes) do n
683+ result = run_memory_benchmark (n, """
684+ auto_sen_l2(bfun, b_u0, tspan, b_p, bt, Rodas5();
685+ diffalg = FiniteDiff.finite_difference_gradient, tols...)
686+ t = @elapsed auto_sen_l2(bfun, b_u0, tspan, b_p, bt, Rodas5();
687+ diffalg = FiniteDiff.finite_difference_gradient, tols...)
688+ """ )
689+ @show n, result
690+ result
691+ end
692+
693+
694+ adjoint_ad_configs = [
695+ (" Interp user-Jacobian" ,
696+ " InterpolatingAdjoint(autodiff = false, autojacvec = false)" , true ),
697+ (" Interp AD-Jacobian" ,
698+ " InterpolatingAdjoint(autodiff = true, autojacvec = false)" , false ),
699+ (" Quad user-Jacobian" ,
700+ " QuadratureAdjoint(autodiff = false, autojacvec = false)" , true ),
701+ (" Quad AD-Jacobian" ,
702+ " QuadratureAdjoint(autodiff = true, autojacvec = false)" , false ),
703+ (" Gauss AD-Jacobian" ,
704+ " GaussAdjoint(autodiff = true, autojacvec = false)" , false ),
705+ (" GaussKronrod AD-Jacobian" ,
706+ " GaussKronrodAdjoint(autodiff = true, autojacvec = false)" , false ),
707+ ]
708+
709+ adjoint_ad_mem = map (adjoint_ad_configs) do (name, sensalg_str, needs_jac)
710+ results = map (mem_sizes) do n
711+ f_expr = needs_jac ? " ODEFunction(bfun, jac = brusselator_jac)" : " bfun"
712+ result = run_memory_benchmark (n, """
713+ sensalg = $(sensalg_str)
714+ f = $(f_expr)
715+ solver = Rodas5(autodiff = false)
716+ diffeq_sen_l2(f, b_u0, tspan, b_p, bt, solver; sensalg = sensalg, tols...)
717+ t = @elapsed diffeq_sen_l2(f, b_u0, tspan, b_p, bt, solver;
718+ sensalg = sensalg, tols...)
719+ """ )
720+ @show name, n, result
721+ result
722+ end
723+ (name = name, results = results)
724+ end
725+
726+
727+ mem_params = n_to_param .(mem_sizes)
728+
729+ plt_mem1 = plot (title = " Brusselator Sensitivity Memory Scaling" );
730+ plot! (plt_mem1, mem_params, [r. delta_mib for r in forwarddiff_mem],
731+ lab = " Forward-Mode DSAAD" , lw = lw, marksize = ms,
732+ linestyle = :auto , marker = :auto );
733+ plot! (plt_mem1, mem_params, [r. delta_mib for r in numdiff_mem],
734+ lab = " Numerical Differentiation" , lw = lw, marksize = ms,
735+ linestyle = :auto , marker = :auto );
736+ for entry in adjoint_ad_mem
737+ plot! (plt_mem1, mem_params, [r. delta_mib for r in entry. results],
738+ lab = entry. name, lw = lw, marksize = ms,
739+ linestyle = :auto , marker = :auto )
740+ end
741+ xaxis! (plt_mem1, " Number of Parameters" , :log10 );
742+ yaxis! (plt_mem1, " Memory (MiB)" );
743+ plot! (plt_mem1, legend = :outertopleft , size = (1200 , 600 ))
744+
745+
451746using SciMLBenchmarks
452747SciMLBenchmarks. bench_footer (WEAVE_ARGS[:folder ], WEAVE_ARGS[:file ])
453748
0 commit comments