Skip to content

Commit c6a3410

Browse files
committed
Pyqg jax with xarray output
1 parent c941404 commit c6a3410

1 file changed

Lines changed: 29 additions & 6 deletions

File tree

dabench/data/_pyqg_jax.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ def __init__(self,
113113
store_as_jax=store_as_jax, x0=x0,
114114
**kwargs)
115115

116+
self.coord_names = ['level','x','y']
117+
self.var_names=['q']
118+
116119
@functools.partial(jax.jit, static_argnames=["self", "num_steps"])
117120
def _roll_out_state(self, state, num_steps):
118121
"""Helper method taken from pyqg-jax docs:
@@ -122,7 +125,7 @@ def _roll_out_state(self, state, num_steps):
122125
def loop_fn(carry, _x):
123126
current_state = carry
124127
next_state = self.m.step_model(current_state)
125-
return next_state, next_state
128+
return next_state, current_state
126129

127130
_final_carry, traj_steps = jax.lax.scan(
128131
loop_fn, state, None, length=num_steps
@@ -213,18 +216,38 @@ def generate(self,
213216
)
214217
)
215218

216-
self.x0 = x0.flatten()
217-
218219
# Store step times
219-
self.times = jnp.arange(0, t_final, self.delta_t)
220+
times = np.arange(0, t_final, self.delta_t)
220221

221222
# Run simulation
222223
traj = self._roll_out_state(init_state, num_steps=n_steps)
223224
qs = traj.state.q
224225

226+
# Build Xarray object for output
227+
coord_dict = dict(zip(
228+
['time'] + self.coord_names,
229+
[times] + [np.arange(dim) for dim in self.original_dim]
230+
))
231+
time_dim = times.shape[0]
232+
out_dim = (time_dim,) + self.original_dim
233+
234+
# Convert to JAX if necessary
235+
y = qs
236+
if self.store_as_jax or isinstance(y, jax.core.Tracer):
237+
y_out = jnp.array(y[:, :self.system_dim].reshape(out_dim))
238+
else:
239+
y_out = np.array(y[:, :self.system_dim].reshape(out_dim))
240+
out_vec = xr.Dataset(
241+
{self.var_names[0]: (coord_dict.keys(), y_out)},
242+
coords=coord_dict,
243+
attrs={'store_as_jax': self.store_as_jax,
244+
'system_dim': self.system_dim,
245+
'delta_t': self.delta_t
246+
}
247+
)
248+
225249
# Save values
226-
self.time_dim = qs.shape[0]
227-
self.values = qs.reshape((self.time_dim, -1))
250+
return out_vec
228251

229252
# TODO: Remove? Believe this is deprecated
230253
def forecast(self, n_steps=None, t_final=None, x0=None):

0 commit comments

Comments
 (0)