@@ -113,6 +113,9 @@ def __init__(self,
113113 store_as_jax = store_as_jax , x0 = x0 ,
114114 ** kwargs )
115115
116+ self .coord_names = ['level' ,'x' ,'y' ]
117+ self .var_names = ['q' ]
118+
116119 @functools .partial (jax .jit , static_argnames = ["self" , "num_steps" ])
117120 def _roll_out_state (self , state , num_steps ):
118121 """Helper method taken from pyqg-jax docs:
@@ -122,7 +125,7 @@ def _roll_out_state(self, state, num_steps):
122125 def loop_fn (carry , _x ):
123126 current_state = carry
124127 next_state = self .m .step_model (current_state )
125- return next_state , next_state
128+ return next_state , current_state
126129
127130 _final_carry , traj_steps = jax .lax .scan (
128131 loop_fn , state , None , length = num_steps
@@ -213,18 +216,38 @@ def generate(self,
213216 )
214217 )
215218
216- self .x0 = x0 .flatten ()
217-
218219 # Store step times
219- self . times = jnp .arange (0 , t_final , self .delta_t )
220+ times = np .arange (0 , t_final , self .delta_t )
220221
221222 # Run simulation
222223 traj = self ._roll_out_state (init_state , num_steps = n_steps )
223224 qs = traj .state .q
224225
226+ # Build Xarray object for output
227+ coord_dict = dict (zip (
228+ ['time' ] + self .coord_names ,
229+ [times ] + [np .arange (dim ) for dim in self .original_dim ]
230+ ))
231+ time_dim = times .shape [0 ]
232+ out_dim = (time_dim ,) + self .original_dim
233+
234+ # Convert to JAX if necessary
235+ y = qs
236+ if self .store_as_jax or isinstance (y , jax .core .Tracer ):
237+ y_out = jnp .array (y [:, :self .system_dim ].reshape (out_dim ))
238+ else :
239+ y_out = np .array (y [:, :self .system_dim ].reshape (out_dim ))
240+ out_vec = xr .Dataset (
241+ {self .var_names [0 ]: (coord_dict .keys (), y_out )},
242+ coords = coord_dict ,
243+ attrs = {'store_as_jax' : self .store_as_jax ,
244+ 'system_dim' : self .system_dim ,
245+ 'delta_t' : self .delta_t
246+ }
247+ )
248+
225249 # Save values
226- self .time_dim = qs .shape [0 ]
227- self .values = qs .reshape ((self .time_dim , - 1 ))
250+ return out_vec
228251
229252 # TODO: Remove? Believe this is deprecated
230253 def forecast (self , n_steps = None , t_final = None , x0 = None ):
0 commit comments