@@ -591,7 +591,7 @@ function simple_implicit_tau_leaping_loop!(
591591 prob, alg, u_current, t_current, t_end, p, rng,
592592 rate, nu, hor, max_hor, max_stoich, numjumps, epsilon,
593593 dtmin, saveat_times, usave, tsave, du, counts, rate_cache,
594- maj, solver)
594+ maj, solver, save_end )
595595 save_idx = 1
596596
597597 while t_current < t_end
@@ -642,17 +642,23 @@ function simple_implicit_tau_leaping_loop!(
642642 u_current = u_new
643643 t_current = t_new
644644 end
645+
646+ # Save endpoint if requested and not already saved
647+ if save_end && (isempty (tsave) || tsave[end ] != t_end)
648+ push! (usave, copy (u_current))
649+ push! (tsave, t_end)
650+ end
645651end
646652
647653function DiffEqBase. solve (jump_prob:: JumpProblem , alg:: SimpleImplicitTauLeaping ;
648654 seed = nothing ,
649655 dtmin = nothing ,
650- saveat = nothing )
656+ saveat = nothing , save_start = nothing , save_end = nothing )
651657 validate_pure_leaping_inputs (jump_prob, alg) ||
652658 error (" SimpleImplicitTauLeaping can only be used with PureLeaping JumpProblem with a MassActionJump." )
653659
654660 (; prob, rng) = jump_prob
655- (seed != = nothing ) && Random . seed! (rng, seed)
661+ (seed != = nothing ) && seed! (rng, seed)
656662
657663 maj = jump_prob. massaction_jump
658664 numjumps = get_num_majumps (maj)
@@ -668,10 +674,18 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
668674 dtmin = 1e-10 * one (typeof (tspan[2 ]))
669675 end
670676
677+ saveat_times, save_start, save_end = _process_saveat (saveat, tspan, save_start, save_end)
678+
679+ # Initialize current state and saved history
671680 u_current = copy (u0)
672681 t_current = tspan[1 ]
673- usave = [copy (u0)]
674- tsave = [tspan[1 ]]
682+ if save_start
683+ usave = [copy (u0)]
684+ tsave = [tspan[1 ]]
685+ else
686+ usave = typeof (u0)[]
687+ tsave = typeof (tspan[1 ])[]
688+ end
675689 rate_cache = zeros (float (eltype (u0)), numjumps)
676690 counts = zero (rate_cache)
677691 du = similar (u0)
@@ -689,19 +703,11 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
689703 hor = compute_hor (reactant_stoch, numjumps)
690704 max_hor, max_stoich = precompute_reaction_conditions (reactant_stoch, hor, length (u0), numjumps)
691705
692- # Set up saveat_times
693- if isnothing (saveat)
694- saveat_times = Vector {typeof(tspan[1])} ()
695- elseif saveat isa Number
696- saveat_times = collect (range (tspan[1 ], tspan[2 ], step= saveat))
697- else
698- saveat_times = collect (saveat)
699- end
700706 simple_implicit_tau_leaping_loop! (
701707 prob, alg, u_current, t_current, t_end, p, rng,
702708 rate, nu, hor, max_hor, max_stoich, numjumps, epsilon,
703709 dtmin, saveat_times, usave, tsave, du, counts, rate_cache,
704- maj, solver)
710+ maj, solver, save_end )
705711
706712 sol = DiffEqBase. build_solution (prob, alg, tsave, usave,
707713 calculate_error= false ,
0 commit comments