Skip to content

Commit 7ae0a78

Browse files
saveat implementation
1 parent eb53701 commit 7ae0a78

1 file changed

Lines changed: 20 additions & 14 deletions

File tree

src/simple_regular_solve.jl

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ function simple_implicit_tau_leaping_loop!(
468468
prob, alg, u_current, t_current, t_end, p, rng,
469469
rate, nu, hor, max_hor, max_stoich, numjumps, epsilon,
470470
dtmin, saveat_times, usave, tsave, du, counts, rate_cache,
471-
maj, solver)
471+
maj, solver, save_end)
472472
save_idx = 1
473473

474474
while t_current < t_end
@@ -519,17 +519,23 @@ function simple_implicit_tau_leaping_loop!(
519519
u_current = u_new
520520
t_current = t_new
521521
end
522+
523+
# Save endpoint if requested and not already saved
524+
if save_end && (isempty(tsave) || tsave[end] != t_end)
525+
push!(usave, copy(u_current))
526+
push!(tsave, t_end)
527+
end
522528
end
523529

524530
function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
525531
seed = nothing,
526532
dtmin = nothing,
527-
saveat = nothing)
533+
saveat = nothing, save_start = nothing, save_end = nothing)
528534
validate_pure_leaping_inputs(jump_prob, alg) ||
529535
error("SimpleImplicitTauLeaping can only be used with PureLeaping JumpProblem with a MassActionJump.")
530536

531537
(; prob, rng) = jump_prob
532-
(seed !== nothing) && Random.seed!(rng, seed)
538+
(seed !== nothing) && seed!(rng, seed)
533539

534540
maj = jump_prob.massaction_jump
535541
numjumps = get_num_majumps(maj)
@@ -545,10 +551,18 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
545551
dtmin = 1e-10 * one(typeof(tspan[2]))
546552
end
547553

554+
saveat_times, save_start, save_end = _process_saveat(saveat, tspan, save_start, save_end)
555+
556+
# Initialize current state and saved history
548557
u_current = copy(u0)
549558
t_current = tspan[1]
550-
usave = [copy(u0)]
551-
tsave = [tspan[1]]
559+
if save_start
560+
usave = [copy(u0)]
561+
tsave = [tspan[1]]
562+
else
563+
usave = typeof(u0)[]
564+
tsave = typeof(tspan[1])[]
565+
end
552566
rate_cache = zeros(float(eltype(u0)), numjumps)
553567
counts = zero(rate_cache)
554568
du = similar(u0)
@@ -566,19 +580,11 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
566580
hor = compute_hor(reactant_stoch, numjumps)
567581
max_hor, max_stoich = precompute_reaction_conditions(reactant_stoch, hor, length(u0), numjumps)
568582

569-
# Set up saveat_times
570-
if isnothing(saveat)
571-
saveat_times = Vector{typeof(tspan[1])}()
572-
elseif saveat isa Number
573-
saveat_times = collect(range(tspan[1], tspan[2], step=saveat))
574-
else
575-
saveat_times = collect(saveat)
576-
end
577583
simple_implicit_tau_leaping_loop!(
578584
prob, alg, u_current, t_current, t_end, p, rng,
579585
rate, nu, hor, max_hor, max_stoich, numjumps, epsilon,
580586
dtmin, saveat_times, usave, tsave, du, counts, rate_cache,
581-
maj, solver)
587+
maj, solver, save_end)
582588

583589
sol = DiffEqBase.build_solution(prob, alg, tsave, usave,
584590
calculate_error=false,

0 commit comments

Comments
 (0)