@@ -126,7 +126,7 @@ def _calc_J_term(self,
126126 @partial (jax .jit , static_argnums = [0 , 1 ])
127127 def _innerloop_4d (self ,
128128 system_dim : int ,
129- Xb_ds : XarrayDatasetLike ,
129+ X_ds : XarrayDatasetLike ,
130130 xb0_ds : XarrayDatasetLike ,
131131 obs_vals : ArrayLike ,
132132 Hs : ArrayLike ,
@@ -137,8 +137,8 @@ def _innerloop_4d(self,
137137 obs_time_mask : ArrayLike
138138 ) -> XarrayDatasetLike :
139139 """4DVar innerloop"""
140- x0_prev_ds = Xb_ds .isel (time = 0 )
141- Xb_ar = Xb_ds .to_stacked_array ('system' ,['time' ])
140+ x0_ds = X_ds .isel (time = 0 )
141+ X_ar = X_ds .to_stacked_array ('system' ,['time' ])
142142
143143 # Set up Variables
144144 SumMtHtRinvHM = jnp .zeros_like (B ) # A input
@@ -151,22 +151,22 @@ def _innerloop_4d(self,
151151 lambda : self ._calc_J_term (
152152 Hs .at [i ].get (mode = 'clip' ),
153153 M .data [j ],
154- Rinv , obs_vals [i ], Xb_ar .data [j ]),
154+ Rinv , obs_vals [i ], X_ar .data [j ]),
155155 lambda : (jnp .zeros_like (SumMtHtRinvHM ),
156156 jnp .zeros_like (SumMtHtRinvD ))
157157 )
158158 SumMtHtRinvHM += Jb
159159 SumMtHtRinvD += Jo
160160 # Compute initial departure
161- db0 = (xb0_ds - x0_prev_ds ).to_stacked_array ('system' ,[]).data
161+ db0 = (xb0_ds - x0_ds ).to_stacked_array ('system' ,[]).data
162162
163163 # Solve Ax=b for the initial perturbation
164164 dx0 = self ._solve (db0 , SumMtHtRinvHM , SumMtHtRinvD , B )
165165
166166 # New x0 guess is the last guess plus the analyzed delta
167- x0_new_ds = x0_prev_ds + dx0 .ravel ()
167+ xa0_ds = x0_ds + dx0 .ravel ()
168168
169- return x0_new_ds
169+ return xa0_ds
170170
171171 def _make_outerloop_4d (self ,
172172 xb0_ds : XarrayDatasetLike ,
@@ -185,18 +185,18 @@ def _outerloop_4d(x0_ds: XarrayDatasetLike,
185185 # Get TLM and current forecast trajectory
186186 # Based on current best guess for x0
187187 x0_ds = x0_ds .to_xarray ()
188- xb_ds , M = self .model_obj .compute_tlm (
188+ X_ds , M = self .model_obj .compute_tlm (
189189 n_steps = n_steps ,
190190 state_vec = x0_ds
191191 )
192192
193193 # 4D-Var inner loop
194- x0_new_ds = self ._innerloop_4d (
195- self .system_dim , xb_ds , xb0_ds , obs_values ,
194+ xa0_ds = self ._innerloop_4d (
195+ self .system_dim , X_ds , xb0_ds , obs_values ,
196196 Hs , B , Rinv , M , obs_window_indices , obs_time_mask
197197 )
198198
199- return xj .from_xarray (x0_new_ds .assign_coords (x0_ds .coords )), x0_ds
199+ return xj .from_xarray (xa0_ds .assign_coords (x0_ds .coords )), x0_ds
200200
201201 return _outerloop_4d
202202
@@ -236,7 +236,7 @@ def _solve(self,
236236 return dx0
237237
238238 def _cycle_obsop (self ,
239- x0_ds : XarrayDatasetLike ,
239+ xb0_ds : XarrayDatasetLike ,
240240 obs_values : ArrayLike ,
241241 obs_loc_indices : ArrayLike ,
242242 obs_time_mask : ArrayLike ,
@@ -284,14 +284,12 @@ def _cycle_obsop(self,
284284 # Static Variables
285285 Rinv = jscipy .linalg .inv (R )
286286
287- # Best guess for x0 starts as background
288- x0_new_ds = deepcopy (x0_ds )
289-
290287 outerloop_4d_func = self ._make_outerloop_4d (
291- x0_ds , Hs , B , Rinv , obs_values , obs_window_indices ,
288+ xb0_ds , Hs , B , Rinv , obs_values , obs_window_indices ,
292289 obs_time_mask , self .steps_per_window )
293290
294- x0_new_ds , all_x0s = jax .lax .scan (outerloop_4d_func , init = xj .from_xarray (x0_new_ds ),
291+ xa0_ds , all_x0s = jax .lax .scan (outerloop_4d_func ,
292+ init = xj .from_xarray (xb0_ds ),
295293 xs = None , length = self .n_outer_loops )
296294
297- return x0_new_ds .to_xarray ()
295+ return xa0_ds .to_xarray ()
0 commit comments