Skip to content

Commit 9b1fec4

Browse files
committed
Updated var names for 4dvar. xb = bg, xa = analysis, x = temporary during loops, 0 denotes first timestep
1 parent cc89d5d commit 9b1fec4

2 files changed

Lines changed: 29 additions & 31 deletions

File tree

dabench/dacycler/_var4d.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def _calc_J_term(self,
126126
@partial(jax.jit, static_argnums=[0, 1])
127127
def _innerloop_4d(self,
128128
system_dim: int,
129-
Xb_ds: XarrayDatasetLike,
129+
X_ds: XarrayDatasetLike,
130130
xb0_ds: XarrayDatasetLike,
131131
obs_vals: ArrayLike,
132132
Hs: ArrayLike,
@@ -137,8 +137,8 @@ def _innerloop_4d(self,
137137
obs_time_mask: ArrayLike
138138
) -> XarrayDatasetLike:
139139
"""4DVar innerloop"""
140-
x0_prev_ds = Xb_ds.isel(time=0)
141-
Xb_ar = Xb_ds.to_stacked_array('system',['time'])
140+
x0_ds = X_ds.isel(time=0)
141+
X_ar = X_ds.to_stacked_array('system',['time'])
142142

143143
# Set up Variables
144144
SumMtHtRinvHM = jnp.zeros_like(B) # A input
@@ -151,22 +151,22 @@ def _innerloop_4d(self,
151151
lambda: self._calc_J_term(
152152
Hs.at[i].get(mode='clip'),
153153
M.data[j],
154-
Rinv, obs_vals[i], Xb_ar.data[j]),
154+
Rinv, obs_vals[i], X_ar.data[j]),
155155
lambda: (jnp.zeros_like(SumMtHtRinvHM),
156156
jnp.zeros_like(SumMtHtRinvD))
157157
)
158158
SumMtHtRinvHM += Jb
159159
SumMtHtRinvD += Jo
160160
# Compute initial departure
161-
db0 = (xb0_ds - x0_prev_ds).to_stacked_array('system',[]).data
161+
db0 = (xb0_ds - x0_ds).to_stacked_array('system',[]).data
162162

163163
# Solve Ax=b for the initial perturbation
164164
dx0 = self._solve(db0, SumMtHtRinvHM, SumMtHtRinvD, B)
165165

166166
# New x0 guess is the last guess plus the analyzed delta
167-
x0_new_ds = x0_prev_ds + dx0.ravel()
167+
xa0_ds = x0_ds + dx0.ravel()
168168

169-
return x0_new_ds
169+
return xa0_ds
170170

171171
def _make_outerloop_4d(self,
172172
xb0_ds: XarrayDatasetLike,
@@ -185,18 +185,18 @@ def _outerloop_4d(x0_ds: XarrayDatasetLike,
185185
# Get TLM and current forecast trajectory
186186
# Based on current best guess for x0
187187
x0_ds = x0_ds.to_xarray()
188-
xb_ds, M = self.model_obj.compute_tlm(
188+
X_ds, M = self.model_obj.compute_tlm(
189189
n_steps=n_steps,
190190
state_vec=x0_ds
191191
)
192192

193193
# 4D-Var inner loop
194-
x0_new_ds = self._innerloop_4d(
195-
self.system_dim, xb_ds, xb0_ds, obs_values,
194+
xa0_ds = self._innerloop_4d(
195+
self.system_dim, X_ds, xb0_ds, obs_values,
196196
Hs, B, Rinv, M, obs_window_indices, obs_time_mask
197197
)
198198

199-
return xj.from_xarray(x0_new_ds.assign_coords(x0_ds.coords)), x0_ds
199+
return xj.from_xarray(xa0_ds.assign_coords(x0_ds.coords)), x0_ds
200200

201201
return _outerloop_4d
202202

@@ -236,7 +236,7 @@ def _solve(self,
236236
return dx0
237237

238238
def _cycle_obsop(self,
239-
x0_ds: XarrayDatasetLike,
239+
xb0_ds: XarrayDatasetLike,
240240
obs_values: ArrayLike,
241241
obs_loc_indices: ArrayLike,
242242
obs_time_mask: ArrayLike,
@@ -284,14 +284,12 @@ def _cycle_obsop(self,
284284
# Static Variables
285285
Rinv = jscipy.linalg.inv(R)
286286

287-
# Best guess for x0 starts as background
288-
x0_new_ds = deepcopy(x0_ds)
289-
290287
outerloop_4d_func = self._make_outerloop_4d(
291-
x0_ds, Hs, B, Rinv, obs_values, obs_window_indices,
288+
xb0_ds, Hs, B, Rinv, obs_values, obs_window_indices,
292289
obs_time_mask, self.steps_per_window)
293290

294-
x0_new_ds, all_x0s = jax.lax.scan(outerloop_4d_func, init=xj.from_xarray(x0_new_ds),
291+
xa0_ds, all_x0s = jax.lax.scan(outerloop_4d_func,
292+
init=xj.from_xarray(xb0_ds),
295293
xs=None, length=self.n_outer_loops)
296294

297-
return x0_new_ds.to_xarray()
295+
return xa0_ds.to_xarray()

dabench/dacycler/_var4d_backprop.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -134,13 +134,13 @@ def _callback_raise_error(self,
134134

135135
# @partial(jax.jit, static_argnums=[0])
136136
def _calc_obs_term(self,
137-
Xb: ArrayLike,
137+
X: ArrayLike,
138138
obs_vals: ArrayLike,
139139
Ht: ArrayLike,
140140
Rinv: ArrayLike
141141
) -> jax.Array:
142-
Yb = Xb @ Ht
143-
resid = Yb.ravel() - obs_vals.ravel()
142+
Y = X @ Ht
143+
resid = Y.ravel() - obs_vals.ravel()
144144

145145
return jnp.sum(resid.T @ Rinv @ resid)
146146

@@ -163,15 +163,15 @@ def loss_4dvarcost(x0: XarrayDatasetLike) -> jax.Array:
163163

164164
# Make new prediction
165165
# NOTE: [1] selects the full forecast instead of last timestep only
166-
Xb = self._step_forecast(
166+
X = self._step_forecast(
167167
x0, n_steps)[1].to_stacked_array('system',['time']).data
168168

169169
# Calculate observation term of J_0
170170
obs_term = 0
171171
for i, j in enumerate(obs_window_indices):
172172
obs_term += jax.lax.cond(
173173
obs_time_mask.at[i].get(mode='fill', fill_value=0),
174-
lambda: self._calc_obs_term(Xb[j], obs_vals[i],
174+
lambda: self._calc_obs_term(X[j], obs_vals[i],
175175
Hs.at[i].get(mode='clip').T,
176176
Rinv),
177177
lambda: 0.0
@@ -220,15 +220,15 @@ def _backprop_epoch(
220220
updates, opt_state = optimizer.update(dx0_hess, opt_state)
221221
x0_ar.data = optax.apply_updates(
222222
x0_ar.data, updates)
223-
x0_new_ds = x0_ar.to_unstacked_dataset('system').assign_attrs(
223+
xa0_ds = x0_ar.to_unstacked_dataset('system').assign_attrs(
224224
x0_ds.attrs
225225
)
226-
return (xj.from_xarray(x0_new_ds), init_loss, opt_state), loss_val
226+
return (xj.from_xarray(xa0_ds), init_loss, opt_state), loss_val
227227

228228
return _backprop_epoch
229229

230230
def _cycle_obsop(self,
231-
x0_ds: XarrayDatasetLike,
231+
xb0_ds: XarrayDatasetLike,
232232
obs_values: ArrayLike,
233233
obs_loc_indices: ArrayLike,
234234
obs_time_mask: ArrayLike,
@@ -281,7 +281,7 @@ def _cycle_obsop(self,
281281
Binv + Hs.at[0].get().T @ Rinv @ Hs.at[0].get())
282282

283283
loss_func = self._make_loss(
284-
x0_ds,
284+
xb0_ds,
285285
obs_values,
286286
Hs,
287287
Binv,
@@ -295,15 +295,15 @@ def _cycle_obsop(self,
295295
1,
296296
self.lr_decay)
297297
optimizer = optax.sgd(lr)
298-
opt_state = optimizer.init(x0_ds.to_stacked_array('system',[]).data)
298+
opt_state = optimizer.init(xb0_ds.to_stacked_array('system',[]).data)
299299

300300
# Make initial forecast and calculate loss
301301
backprop_epoch_func = self._make_backprop_epoch(loss_func, optimizer,
302302
hessian_inv)
303303
epoch_state_tuple, loss_vals = jax.lax.scan(
304-
backprop_epoch_func, init=(xj.from_xarray(x0_ds), 0., opt_state),
304+
backprop_epoch_func, init=(xj.from_xarray(xb0_ds), 0., opt_state),
305305
xs=jnp.arange(self.num_iters))
306306

307-
x0_new_ds = epoch_state_tuple[0].to_xarray()
307+
xa0_ds = epoch_state_tuple[0].to_xarray()
308308

309-
return x0_new_ds
309+
return xa0_ds

0 commit comments

Comments
 (0)