@@ -80,19 +80,19 @@ def _step_forecast(self,
8080 xr .concat (ensemble_forecasts , dim = 'ensemble' ))
8181
8282 def _apply_obsop (self ,
83- X0 : ArrayLike ,
83+ Xb : ArrayLike ,
8484 H : ArrayLike | None ,
8585 h : Callable | None
8686 ) -> ArrayLike :
8787 if H is not None :
88- Yb = H @ X0
88+ Yb = H @ Xb
8989 else :
90- Yb = h (X0 )
90+ Yb = h (Xb )
9191
9292 return Yb
9393
9494 def _compute_analysis (self ,
95- X0 : ArrayLike ,
95+ Xb : ArrayLike ,
9696 Y : ArrayLike ,
9797 H : ArrayLike | None ,
9898 h : Callable | None ,
@@ -102,7 +102,7 @@ def _compute_analysis(self,
102102 """ETKF analysis algorithm
103103
104104 Args:
105- X0 : Forecast/background ensemble with shape
105+ Xb : Forecast/background ensemble with shape
106106 (system_dim, ensemble_dim).
107107 Y: Observation array with shape (obs_time_time, observation_dim)
108108 H: Linear observation operator with shape (observation_dim,
@@ -117,22 +117,22 @@ def _compute_analysis(self,
117117 Xa: Analysis ensemble [size: (system_dim, ensemble_dim)]
118118 """
119119 # Number of state variables, ensemble members and observations
120- system_dim , ensemble_dim = X0 .shape
120+ system_dim , ensemble_dim = Xb .shape
121121
122122 # Auxiliary matrices that will ease the computations
123123 U = jnp .ones ((ensemble_dim , ensemble_dim ))/ ensemble_dim
124124 I = jnp .identity (ensemble_dim )
125125
126126 # The ensemble is inflated (rho=1.0 is no inflation)
127- X0_pert = X0 @ (I - U )
128- X0 = X0_pert + X0 @ U
127+ Xb_pert = Xb @ (I - U )
128+ Xb = Xb_pert + Xb @ U
129129
130130 # Map every ensemble member into observation space
131- Yb = self ._apply_obsop (X0 , H , h )
131+ Yb = self ._apply_obsop (Xb , H , h )
132132
133133 # Get ensemble means and perturbations
134- X0_bar = jnp .mean (X0 , axis = 1 )
135- X0_pert = X0 @ (I - U )
134+ Xb_bar = jnp .mean (Xb , axis = 1 )
135+ Xb_pert = Xb @ (I - U )
136136
137137 yb_bar = jnp .mean (Yb , axis = 1 )
138138 Yb_pert = Yb @ (I - U )
@@ -153,17 +153,17 @@ def _compute_analysis(self,
153153
154154 wa = Pa_ens @ Yb_pert .T @ Rinv @ (Y .flatten ()- yb_bar )
155155
156- Xa_pert = X0_pert @ Wa
156+ Xa_pert = Xb_pert @ Wa
157157
158- Xa_bar = X0_bar + jnp .ravel (X0_pert @ wa )
158+ Xa_bar = Xb_bar + jnp .ravel (Xb_pert @ wa )
159159
160160 v = jnp .ones ((1 , ensemble_dim ))
161161 Xa = Xa_pert + Xa_bar [:, None ] @ v
162162
163163 return Xa
164164
165165 def _cycle_obsop (self ,
166- X0_ds : XarrayDatasetLike ,
166+ Xb_ds : XarrayDatasetLike ,
167167 obs_values : ArrayLike ,
168168 obs_loc_indices : ArrayLike ,
169169 obs_time_mask : ArrayLike ,
@@ -192,8 +192,8 @@ def _cycle_obsop(self,
192192 else :
193193 B = self .B
194194
195- X0 = X0_ds .to_stacked_array ('system' ,['ensemble' ]).data .T
196- n_sys , n_ens = X0 .shape
195+ Xb = Xb_ds .to_stacked_array ('system' ,['ensemble' ]).data .T
196+ n_sys , n_ens = Xb .shape
197197 assert n_ens == self .ensemble_dim , (
198198 'cycle:: model_forecast must have dimension {}x{}' ).format (
199199 self .ensemble_dim , self .system_dim )
@@ -203,11 +203,11 @@ def _cycle_obsop(self,
203203 H = jnp .where (obs_loc_mask .flatten (), H .T , 0 ).T
204204
205205 # Analysis cycles over all obs in data_obs
206- Xa = self ._compute_analysis (X0 = X0 ,
206+ Xa = self ._compute_analysis (Xb = Xb ,
207207 Y = obs_values ,
208208 H = H ,
209209 h = h ,
210210 R = R ,
211211 rho = self .multiplicative_inflation )
212212
213- return X0_ds .assign (x = (['ensemble' ,'i' ], Xa .T ))
213+ return Xb_ds .assign (x = (['ensemble' ,'i' ], Xa .T ))
0 commit comments