Skip to content

Commit 08e4688

Browse files
saveat implementation
1 parent f33a041 commit 08e4688

1 file changed

Lines changed: 20 additions & 14 deletions

File tree

src/simple_regular_solve.jl

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ function simple_implicit_tau_leaping_loop!(
591591
prob, alg, u_current, t_current, t_end, p, rng,
592592
rate, nu, hor, max_hor, max_stoich, numjumps, epsilon,
593593
dtmin, saveat_times, usave, tsave, du, counts, rate_cache,
594-
maj, solver)
594+
maj, solver, save_end)
595595
save_idx = 1
596596

597597
while t_current < t_end
@@ -642,17 +642,23 @@ function simple_implicit_tau_leaping_loop!(
642642
u_current = u_new
643643
t_current = t_new
644644
end
645+
646+
# Save endpoint if requested and not already saved
647+
if save_end && (isempty(tsave) || tsave[end] != t_end)
648+
push!(usave, copy(u_current))
649+
push!(tsave, t_end)
650+
end
645651
end
646652

647653
function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
648654
seed = nothing,
649655
dtmin = nothing,
650-
saveat = nothing)
656+
saveat = nothing, save_start = nothing, save_end = nothing)
651657
validate_pure_leaping_inputs(jump_prob, alg) ||
652658
error("SimpleImplicitTauLeaping can only be used with PureLeaping JumpProblem with a MassActionJump.")
653659

654660
(; prob, rng) = jump_prob
655-
(seed !== nothing) && Random.seed!(rng, seed)
661+
(seed !== nothing) && seed!(rng, seed)
656662

657663
maj = jump_prob.massaction_jump
658664
numjumps = get_num_majumps(maj)
@@ -668,10 +674,18 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
668674
dtmin = 1e-10 * one(typeof(tspan[2]))
669675
end
670676

677+
saveat_times, save_start, save_end = _process_saveat(saveat, tspan, save_start, save_end)
678+
679+
# Initialize current state and saved history
671680
u_current = copy(u0)
672681
t_current = tspan[1]
673-
usave = [copy(u0)]
674-
tsave = [tspan[1]]
682+
if save_start
683+
usave = [copy(u0)]
684+
tsave = [tspan[1]]
685+
else
686+
usave = typeof(u0)[]
687+
tsave = typeof(tspan[1])[]
688+
end
675689
rate_cache = zeros(float(eltype(u0)), numjumps)
676690
counts = zero(rate_cache)
677691
du = similar(u0)
@@ -689,19 +703,11 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
689703
hor = compute_hor(reactant_stoch, numjumps)
690704
max_hor, max_stoich = precompute_reaction_conditions(reactant_stoch, hor, length(u0), numjumps)
691705

692-
# Set up saveat_times
693-
if isnothing(saveat)
694-
saveat_times = Vector{typeof(tspan[1])}()
695-
elseif saveat isa Number
696-
saveat_times = collect(range(tspan[1], tspan[2], step=saveat))
697-
else
698-
saveat_times = collect(saveat)
699-
end
700706
simple_implicit_tau_leaping_loop!(
701707
prob, alg, u_current, t_current, t_end, p, rng,
702708
rate, nu, hor, max_hor, max_stoich, numjumps, epsilon,
703709
dtmin, saveat_times, usave, tsave, du, counts, rate_cache,
704-
maj, solver)
710+
maj, solver, save_end)
705711

706712
sol = DiffEqBase.build_solution(prob, alg, tsave, usave,
707713
calculate_error=false,

0 commit comments

Comments
 (0)