Skip to content

Commit e7ca104

Browse files
committed
sqturb with proper xarray outputs
1 parent c6a3410 commit e7ca104

1 file changed

Lines changed: 5 additions & 11 deletions

File tree

dabench/data/_sqgturb.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,6 @@ def __init__(self,
167167
self.r = jnp.array(r, dtype) # Ekman damping (at z=0)
168168
self.tdiab = jnp.array(tdiab, dtype) # thermal relaxation damping.
169169

170-
# Initialize time counter
171-
self.t = tstart
172-
173170
# Setup basic state pv (for thermal relaxation)
174171
self.symmetric = symmetric
175172
y = jnp.arange(0, self.L, self.L / self.N, dtype=dtype)
@@ -199,6 +196,7 @@ def __init__(self,
199196
pvbar = pvbar * jnp.ones((2, N, N), dtype)
200197
self.pvbar = pvbar
201198
# state to relax to with timescale tdiab
199+
# NOTE: Is this an error? It is never updated.
202200
self.pvspec_eq = rfft2(pvbar)
203201
# initial pv field (spectral)
204202
self.pvspec = rfft2(pv)
@@ -461,9 +459,6 @@ def integrate(self,
461459
) -> tuple[jax.Array, jax.Array]:
462460
"""Advances pv forward number of timesteps given by self.n_steps.
463461
464-
Note:
465-
If pv not specified, use pvspec instance variable.
466-
467462
Args:
468463
function (function): right hand side (rhs) of the ODE. Not used, but
469464
needed to function with generate() from _data.Data().
@@ -480,15 +475,14 @@ def integrate(self,
480475
delta_t = self.delta_t
481476

482477
# Run integration
483-
pvspec, values = jax.lax.scan(self._rk4, pvspec, xs=times)
478+
pvspec_updated, values = jax.lax.scan(self._rk4, pvspec, xs=times[:-1])
479+
480+
# Prepend input to values so x0 is include
481+
values = jnp.insert(values, 0, pvspec, axis=0)
484482

485483
# Apply reverse fft to
486484
values = self.ifft2(values)
487485

488-
# Update internal states
489-
self.pvspec = pvspec
490-
self.t = times[-1]
491-
492486
return values, times
493487

494488
def rhs(self,

0 commit comments

Comments
 (0)