Skip to content

Commit cc89d5d

Browse files
committed
Update x0 -> xb in etkf and 3dvar
1 parent b5b4ccc commit cc89d5d

2 files changed

Lines changed: 24 additions & 24 deletions

File tree

dabench/dacycler/_etkf.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -80,19 +80,19 @@ def _step_forecast(self,
8080
xr.concat(ensemble_forecasts, dim='ensemble'))
8181

8282
def _apply_obsop(self,
83-
X0: ArrayLike,
83+
Xb: ArrayLike,
8484
H: ArrayLike | None,
8585
h: Callable | None
8686
) -> ArrayLike:
8787
if H is not None:
88-
Yb = H @ X0
88+
Yb = H @ Xb
8989
else:
90-
Yb = h(X0)
90+
Yb = h(Xb)
9191

9292
return Yb
9393

9494
def _compute_analysis(self,
95-
X0: ArrayLike,
95+
Xb: ArrayLike,
9696
Y: ArrayLike,
9797
H: ArrayLike | None,
9898
h: Callable | None,
@@ -102,7 +102,7 @@ def _compute_analysis(self,
102102
"""ETKF analysis algorithm
103103
104104
Args:
105-
X0: Forecast/background ensemble with shape
105+
Xb: Forecast/background ensemble with shape
106106
(system_dim, ensemble_dim).
107107
Y: Observation array with shape (obs_time_time, observation_dim)
108108
H: Linear observation operator with shape (observation_dim,
@@ -117,22 +117,22 @@ def _compute_analysis(self,
117117
Xa: Analysis ensemble [size: (system_dim, ensemble_dim)]
118118
"""
119119
# Number of state variables, ensemble members and observations
120-
system_dim, ensemble_dim = X0.shape
120+
system_dim, ensemble_dim = Xb.shape
121121

122122
# Auxiliary matrices that will ease the computations
123123
U = jnp.ones((ensemble_dim, ensemble_dim))/ensemble_dim
124124
I = jnp.identity(ensemble_dim)
125125

126126
# The ensemble is inflated (rho=1.0 is no inflation)
127-
X0_pert = X0 @ (I-U)
128-
X0 = X0_pert + X0 @ U
127+
Xb_pert = Xb @ (I-U)
128+
Xb = Xb_pert + Xb @ U
129129

130130
# Map every ensemble member into observation space
131-
Yb = self._apply_obsop(X0, H, h)
131+
Yb = self._apply_obsop(Xb, H, h)
132132

133133
# Get ensemble means and perturbations
134-
X0_bar = jnp.mean(X0, axis=1)
135-
X0_pert = X0 @ (I-U)
134+
Xb_bar = jnp.mean(Xb, axis=1)
135+
Xb_pert = Xb @ (I-U)
136136

137137
yb_bar = jnp.mean(Yb, axis=1)
138138
Yb_pert = Yb @ (I-U)
@@ -153,17 +153,17 @@ def _compute_analysis(self,
153153

154154
wa = Pa_ens @ Yb_pert.T @ Rinv @ (Y.flatten()-yb_bar)
155155

156-
Xa_pert = X0_pert @ Wa
156+
Xa_pert = Xb_pert @ Wa
157157

158-
Xa_bar = X0_bar + jnp.ravel(X0_pert @ wa)
158+
Xa_bar = Xb_bar + jnp.ravel(Xb_pert @ wa)
159159

160160
v = jnp.ones((1, ensemble_dim))
161161
Xa = Xa_pert + Xa_bar[:, None] @ v
162162

163163
return Xa
164164

165165
def _cycle_obsop(self,
166-
X0_ds: XarrayDatasetLike,
166+
Xb_ds: XarrayDatasetLike,
167167
obs_values: ArrayLike,
168168
obs_loc_indices: ArrayLike,
169169
obs_time_mask: ArrayLike,
@@ -192,8 +192,8 @@ def _cycle_obsop(self,
192192
else:
193193
B = self.B
194194

195-
X0 = X0_ds.to_stacked_array('system',['ensemble']).data.T
196-
n_sys, n_ens = X0.shape
195+
Xb = Xb_ds.to_stacked_array('system',['ensemble']).data.T
196+
n_sys, n_ens = Xb.shape
197197
assert n_ens == self.ensemble_dim, (
198198
'cycle:: model_forecast must have dimension {}x{}').format(
199199
self.ensemble_dim, self.system_dim)
@@ -203,11 +203,11 @@ def _cycle_obsop(self,
203203
H = jnp.where(obs_loc_mask.flatten(), H.T, 0).T
204204

205205
# Analysis cycles over all obs in data_obs
206-
Xa = self._compute_analysis(X0=X0,
206+
Xa = self._compute_analysis(Xb=Xb,
207207
Y=obs_values,
208208
H=H,
209209
h=h,
210210
R=R,
211211
rho=self.multiplicative_inflation)
212212

213-
return X0_ds.assign(x=(['ensemble','i'], Xa.T))
213+
return Xb_ds.assign(x=(['ensemble','i'], Xa.T))

dabench/dacycler/_var3d.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(self,
5656
B=B, R=R, H=H, h=h)
5757

5858
def _cycle_obsop(self,
59-
x0_ds: XarrayDatasetLike,
59+
xb_ds: XarrayDatasetLike,
6060
obs_values: ArrayLike,
6161
obs_loc_indices: ArrayLike,
6262
obs_time_mask: ArrayLike,
@@ -85,26 +85,26 @@ def _cycle_obsop(self,
8585
else:
8686
B = self.B
8787

88-
x0 = x0_ds.to_stacked_array('system',[]).data.flatten()
88+
xb = xb_ds.to_stacked_array('system',[]).data.flatten()
8989
y = obs_values.flatten()
9090

9191
# Apply masks to H
9292
H = jnp.where(obs_time_mask.flatten(), H.T, 0).T
9393
H = jnp.where(obs_loc_mask.flatten(), H.T, 0).T
9494

9595
# Set parameters
96-
xdim = x0.size # Size or get one of the shape params?
96+
xdim = xb.size # Size or get one of the shape params?
9797
Rinv = jnp.linalg.inv(R)
9898

9999
# 'preconditioning with B'
100100
I = jnp.identity(xdim)
101101
BHt = jnp.dot(B, H.T)
102102
BHtRinv = jnp.dot(BHt, Rinv)
103103
A = I + jnp.dot(BHtRinv, H)
104-
b1 = x0 + jnp.dot(BHtRinv, y)
104+
b1 = xb + jnp.dot(BHtRinv, y)
105105

106106
# Use minimization algorithm to minimize cost function:
107-
xa, ierr = jscipy.sparse.linalg.cg(A, b1, x0=x0, tol=1e-05,
107+
xa, ierr = jscipy.sparse.linalg.cg(A, b1, x0=xb, tol=1e-05,
108108
maxiter=1000)
109109

110-
return x0_ds.assign(x=(x0_ds.dims, xa.T))
110+
return xb_ds.assign(x=(xb_ds.dims, xa.T))

0 commit comments

Comments
 (0)