55import jax
66import xarray as xr
77import xarray_jax as xj
8+ from typing import Callable
89
910import dabench .dacycler ._utils as dac_utils
11+ from dabench .model import Model
12+
13+
14+ # For typing
15+ ArrayLike = np .ndarray | jax .Array
16+ XarrayDatasetLike = xr .Dataset | xj .XjDataset
1017
1118class DACycler ():
1219 """Base class for DACycler object
1320
1421 Attributes:
15- system_dim (int) : System dimension
16- delta_t (float) : The timestep of the model (assumed uniform)
17- model_obj (dabench.Model) : Forecast model object.
18- in_4d (bool) : True for 4D data assimilation techniques (e.g. 4DVar).
22+ system_dim: System dimension
23+ delta_t: The timestep of the model (assumed uniform)
24+ model_obj: Forecast model object.
25+ in_4d: True for 4D data assimilation techniques (e.g. 4DVar).
1926 Default is False.
20- ensemble (bool) : True for ensemble-based data assimilation techniques
27+ ensemble: True for ensemble-based data assimilation techniques
2128 (ETKF). Default is False
22- B (ndarray) : Initial / static background error covariance. Shape:
29+ B: Initial / static background error covariance. Shape:
2330 (system_dim, system_dim). If not provided, will be calculated
2431 automatically.
25- R (ndarray) : Observation error covariance matrix. Shape
32+ R: Observation error covariance matrix. Shape
2633 (obs_dim, obs_dim). If not provided, will be calculated
2734 automatically.
28- H (ndarray) : Observation operator with shape: (obs_dim, system_dim).
35+ H: Observation operator with shape: (obs_dim, system_dim).
2936 If not provided will be calculated automatically.
30- h (function) : Optional observation operator as function. More flexible
37+ h: Optional observation operator as function. More flexible
3138 (allows for more complex observation operator). Default is None.
3239 """
3340
3441 def __init__ (self ,
35- system_dim = None ,
36- delta_t = None ,
37- model_obj = None ,
38- in_4d = False ,
39- ensemble = False ,
40- B = None ,
41- R = None ,
42- H = None ,
43- h = None ,
44- analysis_time_in_window = None
42+ system_dim : int ,
43+ delta_t : float ,
44+ model_obj : Model ,
45+ in_4d : bool = False ,
46+ ensemble : bool = False ,
47+ B : ArrayLike | None = None ,
48+ R : ArrayLike | None = None ,
49+ H : ArrayLike | None = None ,
50+ h : Callable | None = None ,
4551 ):
4652
4753 self .h = h
@@ -53,43 +59,64 @@ def __init__(self,
5359 self .system_dim = system_dim
5460 self .delta_t = delta_t
5561 self .model_obj = model_obj
56- self .analysis_time_in_window = analysis_time_in_window
5762
5863
59- def _calc_default_H (self , obs_values , obs_loc_indices ):
64+ def _calc_default_H (self ,
65+ obs_values : ArrayLike ,
66+ obs_loc_indices : ArrayLike
67+ ) -> jax .Array :
6068 H = jnp .zeros ((obs_values .flatten ().shape [0 ], self .system_dim ))
6169 H = H .at [jnp .arange (H .shape [0 ]),
6270 obs_loc_indices .flatten (),
6371 ].set (1 )
6472 return H
6573
66- def _calc_default_R (self , obs_values , obs_error_sd ):
74+ def _calc_default_R (self ,
75+ obs_values : ArrayLike ,
76+ obs_error_sd : float
77+ ) -> jax .Array :
6778 return jnp .identity (obs_values .flatten ().shape [0 ])* (obs_error_sd ** 2 )
6879
69- def _calc_default_B (self ):
80+ def _calc_default_B (self ) -> jax . Array :
7081 """If B is not provided, identity matrix with shape (system_dim, system_dim."""
7182 return jnp .identity (self .system_dim )
7283
73- def _step_forecast (self , xa , n_steps = 1 ):
84+ def _step_forecast (self ,
85+ xa : XarrayDatasetLike ,
86+ n_steps : int = 1
87+ ) -> XarrayDatasetLike :
7488 """Perform forecast using model object"""
7589 return self .model_obj .forecast (xa , n_steps = n_steps )
7690
77- def _step_cycle (self , xb , obs_vals , obs_locs , obs_time_mask , obs_loc_mask ,
78- H = None , h = None , R = None , B = None , ** kwargs ):
91+ def _step_cycle (self ,
92+ cur_state : XarrayDatasetLike ,
93+ obs_vals : ArrayLike ,
94+ obs_locs : ArrayLike ,
95+ obs_time_mask : ArrayLike ,
96+ obs_loc_mask : ArrayLike ,
97+ H : ArrayLike | None = None ,
98+ h : Callable | None = None ,
99+ R : ArrayLike | None = None ,
100+ B :ArrayLike | None = None ,
101+ ** kwargs
102+ ) -> XarrayDatasetLike :
79103 if H is not None or h is None :
80104 vals = self ._cycle_obsop (
81- xb , obs_vals , obs_locs , obs_time_mask ,
105+ cur_state , obs_vals , obs_locs , obs_time_mask ,
82106 obs_loc_mask , H , R , B , ** kwargs )
83107 return vals
84108 else :
85109 raise ValueError (
86110 'Only linear obs operators (H) are supported right now.' )
87111 vals = self ._cycle_general_obsop (
88- xb , obs_vals , obs_locs , obs_time_mask ,
112+ cur_state , obs_vals , obs_locs , obs_time_mask ,
89113 obs_loc_mask , h , R , B , ** kwargs )
90114 return vals
91115
92- def _cycle_and_forecast (self , cur_state , filtered_idx ):
116+ def _cycle_and_forecast (self ,
117+ cur_state : xj .XjDataset ,
118+ filtered_idx : ArrayLike
119+ ) -> tuple [xj .XjDataset , XarrayDatasetLike ]:
93120 # 1. Get data
94121 # 1-b. Calculate obs_time_mask and restore filtered_idx to original values
95122 cur_state = cur_state .to_xarray ()
@@ -119,7 +146,10 @@ def _cycle_and_forecast(self, cur_state, filtered_idx):
119146
120147 return xj .from_xarray (next_state ), forecast_states
121148
122- def _cycle_and_forecast_4d (self , cur_state , filtered_idx ):
149+ def _cycle_and_forecast_4d (self ,
150+ cur_state : xj .XjDataset ,
151+ filtered_idx : ArrayLike
152+ ) -> tuple [xj .XjDataset , XarrayDatasetLike ]:
123153 # 1. Get data
124154 # 1-b. Calculate obs_time_mask and restore filtered_idx to original values
125155 cur_state = cur_state .to_xarray ()
@@ -160,35 +190,32 @@ def _cycle_and_forecast_4d(self, cur_state, filtered_idx):
160190 return xj .from_xarray (next_state ), forecast_states
161191
162192 def cycle (self ,
163- input_state ,
164- start_time ,
165- obs_vector ,
166- n_cycles ,
167- obs_error_sd = None ,
168- analysis_window = 0.2 ,
169- analysis_time_in_window = None ,
170- return_forecast = False
171- ):
193+ input_state : XarrayDatasetLike ,
194+ start_time : float | np . datetime64 ,
195+ obs_vector : XarrayDatasetLike ,
196+ n_cycles : int ,
197+ obs_error_sd : float | ArrayLike | None = None ,
198+ analysis_window : float = 0.2 ,
199+ analysis_time_in_window : float | None = None ,
200+ return_forecast : bool = False
201+ ) -> XarrayDatasetLike :
172202 """Perform DA cycle repeatedly, including analysis and forecast
173203
174204 Args:
175- input_state (vector.StateVector) : Input state.
176- start_time (float or datetime-like) : Starting time.
177- obs_vector (vector.ObsVector) : Observations vector.
178- n_cycles (int) : Number of analysis cycles to run, each of length
205+ input_state: Input state as a Xarray Dataset
206+ start_time: Starting time.
207+ obs_vector: Observations vector.
208+ n_cycles: Number of analysis cycles to run, each of length
179209 analysis_window.
180- analysis_window (float) : Time window from which to gather
210+ analysis_window: Time window from which to gather
181211 observations for DA Cycle.
182- analysis_time_in_window (float) : Where within analysis_window
212+ analysis_time_in_window: Where within analysis_window
183213 to perform analysis. For example, 0.0 is the start of the
184214 window. Default is None, which selects the middle of the
185215 window.
186- return_forecast (bool) : If True, returns forecast at each model
216+ return_forecast: If True, returns forecast at each model
187217 timestep. If False, returns only analyses, one per analysis
188- cycle. Default is False.
189-
190- Returns:
191- vector.StateVector of analyses and times.
218+ cycle.
192219 """
193220
194221 # These could be different if observer doesn't observe all variables
@@ -202,10 +229,11 @@ def cycle(self,
202229 self .analysis_window = analysis_window
203230
204231 # If don't specify analysis_time_in_window, is assumed to be middle
205- if self .analysis_time_in_window is None and analysis_time_in_window is None :
206- analysis_time_in_window = self .analysis_window / 2
207- else :
208- analysis_time_in_window = self .analysis_time_in_window
232+ if analysis_time_in_window is None :
233+ if self .in_4d :
234+ analysis_time_in_window = 0
235+ else :
236+ analysis_time_in_window = self .analysis_window / 2
209237
210238 # Steps per window + 1 to include start
211239 self .steps_per_window = round (analysis_window / self .delta_t ) + 1
@@ -256,13 +284,13 @@ def cycle(self,
256284 xj .from_xarray (input_state ),
257285 all_filtered_padded )
258286
259- all_vals_xr = xr .Dataset (
287+ all_vals_ds = xr .Dataset (
260288 {var : (('cycle' ,) + tuple (all_values [var ].dims ),
261289 all_values [var ].data )
262290 for var in all_values .data_vars }
263291 ).rename_dims ({'time' : 'cycle_timestep' })
264292
265293 if return_forecast :
266- return all_vals_xr .drop_isel (cycle_timestep = - 1 )
294+ return all_vals_ds .drop_isel (cycle_timestep = - 1 )
267295 else :
268- return all_vals_xr .isel (cycle_timestep = 0 )
296+ return all_vals_ds .isel (cycle_timestep = 0 )
0 commit comments