@@ -167,9 +167,6 @@ def __init__(self,
167167 self .r = jnp .array (r , dtype ) # Ekman damping (at z=0)
168168 self .tdiab = jnp .array (tdiab , dtype ) # thermal relaxation damping.
169169
170- # Initialize time counter
171- self .t = tstart
172-
173170 # Setup basic state pv (for thermal relaxation)
174171 self .symmetric = symmetric
175172 y = jnp .arange (0 , self .L , self .L / self .N , dtype = dtype )
@@ -199,6 +196,7 @@ def __init__(self,
199196 pvbar = pvbar * jnp .ones ((2 , N , N ), dtype )
200197 self .pvbar = pvbar
201198 # state to relax to with timescale tdiab
199+ # NOTE: Is this an error? It is never updated.
202200 self .pvspec_eq = rfft2 (pvbar )
203201 # initial pv field (spectral)
204202 self .pvspec = rfft2 (pv )
@@ -461,9 +459,6 @@ def integrate(self,
461459 ) -> tuple [jax .Array , jax .Array ]:
462460 """Advances pv forward number of timesteps given by self.n_steps.
463461
464- Note:
465- If pv not specified, use pvspec instance variable.
466-
467462 Args:
468463 function (function): right hand side (rhs) of the ODE. Not used, but
469464 needed to function with generate() from _data.Data().
@@ -480,15 +475,14 @@ def integrate(self,
480475 delta_t = self .delta_t
481476
482477 # Run integration
483- pvspec , values = jax .lax .scan (self ._rk4 , pvspec , xs = times )
478+ pvspec_updated , values = jax .lax .scan (self ._rk4 , pvspec , xs = times [:- 1 ])
479+
480+ # Prepend input to values so x0 is include
481+ values = jnp .insert (values , 0 , pvspec , axis = 0 )
484482
485483 # Apply reverse fft to
486484 values = self .ifft2 (values )
487485
488- # Update internal states
489- self .pvspec = pvspec
490- self .t = times [- 1 ]
491-
492486 return values , times
493487
494488 def rhs (self ,
0 commit comments