Skip to content

Commit f5f4473

Browse files
authored
Merge pull request #67 from StevePny/docs/type-hints
Type hints and renaming some variables
2 parents a48a390 + 1c03929 commit f5f4473

23 files changed

Lines changed: 1074 additions & 726 deletions

dabench/dacycler/_dacycler.py

Lines changed: 85 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -5,43 +5,49 @@
55
import jax
66
import xarray as xr
77
import xarray_jax as xj
8+
from typing import Callable
89

910
import dabench.dacycler._utils as dac_utils
11+
from dabench.model import Model
12+
13+
14+
# For typing
15+
ArrayLike = np.ndarray | jax.Array
16+
XarrayDatasetLike = xr.Dataset | xj.XjDataset
1017

1118
class DACycler():
1219
"""Base class for DACycler object
1320
1421
Attributes:
15-
system_dim (int): System dimension
16-
delta_t (float): The timestep of the model (assumed uniform)
17-
model_obj (dabench.Model): Forecast model object.
18-
in_4d (bool): True for 4D data assimilation techniques (e.g. 4DVar).
22+
system_dim: System dimension
23+
delta_t: The timestep of the model (assumed uniform)
24+
model_obj: Forecast model object.
25+
in_4d: True for 4D data assimilation techniques (e.g. 4DVar).
1926
Default is False.
20-
ensemble (bool): True for ensemble-based data assimilation techniques
27+
ensemble: True for ensemble-based data assimilation techniques
2128
(ETKF). Default is False
22-
B (ndarray): Initial / static background error covariance. Shape:
29+
B: Initial / static background error covariance. Shape:
2330
(system_dim, system_dim). If not provided, will be calculated
2431
automatically.
25-
R (ndarray): Observation error covariance matrix. Shape
32+
R: Observation error covariance matrix. Shape
2633
(obs_dim, obs_dim). If not provided, will be calculated
2734
automatically.
28-
H (ndarray): Observation operator with shape: (obs_dim, system_dim).
35+
H: Observation operator with shape: (obs_dim, system_dim).
2936
If not provided will be calculated automatically.
30-
h (function): Optional observation operator as function. More flexible
37+
h: Optional observation operator as function. More flexible
3138
(allows for more complex observation operator). Default is None.
3239
"""
3340

3441
def __init__(self,
35-
system_dim=None,
36-
delta_t=None,
37-
model_obj=None,
38-
in_4d=False,
39-
ensemble=False,
40-
B=None,
41-
R=None,
42-
H=None,
43-
h=None,
44-
analysis_time_in_window=None
42+
system_dim: int,
43+
delta_t: float,
44+
model_obj: Model,
45+
in_4d: bool = False,
46+
ensemble: bool = False,
47+
B: ArrayLike | None = None,
48+
R: ArrayLike | None = None,
49+
H: ArrayLike | None = None,
50+
h: Callable | None = None,
4551
):
4652

4753
self.h = h
@@ -53,43 +59,64 @@ def __init__(self,
5359
self.system_dim = system_dim
5460
self.delta_t = delta_t
5561
self.model_obj = model_obj
56-
self.analysis_time_in_window = analysis_time_in_window
5762

5863

59-
def _calc_default_H(self, obs_values, obs_loc_indices):
64+
def _calc_default_H(self,
65+
obs_values: ArrayLike,
66+
obs_loc_indices: ArrayLike
67+
) -> jax.Array:
6068
H = jnp.zeros((obs_values.flatten().shape[0], self.system_dim))
6169
H = H.at[jnp.arange(H.shape[0]),
6270
obs_loc_indices.flatten(),
6371
].set(1)
6472
return H
6573

66-
def _calc_default_R(self, obs_values, obs_error_sd):
74+
def _calc_default_R(self,
75+
obs_values: ArrayLike,
76+
obs_error_sd: float
77+
) -> jax.Array:
6778
return jnp.identity(obs_values.flatten().shape[0])*(obs_error_sd**2)
6879

69-
def _calc_default_B(self):
80+
def _calc_default_B(self) -> jax.Array:
7081
"""If B is not provided, identity matrix with shape (system_dim, system_dim."""
7182
return jnp.identity(self.system_dim)
7283

73-
def _step_forecast(self, xa, n_steps=1):
84+
def _step_forecast(self,
85+
xa: XarrayDatasetLike,
86+
n_steps: int = 1
87+
) -> XarrayDatasetLike:
7488
"""Perform forecast using model object"""
7589
return self.model_obj.forecast(xa, n_steps=n_steps)
7690

77-
def _step_cycle(self, xb, obs_vals, obs_locs, obs_time_mask, obs_loc_mask,
78-
H=None, h=None, R=None, B=None, **kwargs):
91+
def _step_cycle(self,
92+
cur_state: XarrayDatasetLike,
93+
obs_vals: ArrayLike,
94+
obs_locs: ArrayLike,
95+
obs_time_mask: ArrayLike,
96+
obs_loc_mask: ArrayLike,
97+
H: ArrayLike | None = None,
98+
h: Callable | None =None,
99+
R: ArrayLike | None = None,
100+
B:ArrayLike | None = None,
101+
**kwargs
102+
) -> XarrayDatasetLike:
79103
if H is not None or h is None:
80104
vals = self._cycle_obsop(
81-
xb, obs_vals, obs_locs, obs_time_mask,
105+
cur_state, obs_vals, obs_locs, obs_time_mask,
82106
obs_loc_mask, H, R, B, **kwargs)
83107
return vals
84108
else:
85109
raise ValueError(
86110
'Only linear obs operators (H) are supported right now.')
87111
vals = self._cycle_general_obsop(
88-
xb, obs_vals, obs_locs, obs_time_mask,
112+
cur_state, obs_vals, obs_locs, obs_time_mask,
89113
obs_loc_mask, h, R, B, **kwargs)
90114
return vals
91115

92-
def _cycle_and_forecast(self, cur_state, filtered_idx):
116+
def _cycle_and_forecast(self,
117+
cur_state: xj.XjDataset,
118+
filtered_idx: ArrayLike
119+
) -> tuple[xj.XjDataset, XarrayDatasetLike]:
93120
# 1. Get data
94121
# 1-b. Calculate obs_time_mask and restore filtered_idx to original values
95122
cur_state = cur_state.to_xarray()
@@ -119,7 +146,10 @@ def _cycle_and_forecast(self, cur_state, filtered_idx):
119146

120147
return xj.from_xarray(next_state), forecast_states
121148

122-
def _cycle_and_forecast_4d(self, cur_state, filtered_idx):
149+
def _cycle_and_forecast_4d(self,
150+
cur_state: xj.XjDataset,
151+
filtered_idx: ArrayLike
152+
) -> tuple[xj.XjDataset, XarrayDatasetLike]:
123153
# 1. Get data
124154
# 1-b. Calculate obs_time_mask and restore filtered_idx to original values
125155
cur_state = cur_state.to_xarray()
@@ -160,35 +190,32 @@ def _cycle_and_forecast_4d(self, cur_state, filtered_idx):
160190
return xj.from_xarray(next_state), forecast_states
161191

162192
def cycle(self,
163-
input_state,
164-
start_time,
165-
obs_vector,
166-
n_cycles,
167-
obs_error_sd=None,
168-
analysis_window=0.2,
169-
analysis_time_in_window=None,
170-
return_forecast=False
171-
):
193+
input_state: XarrayDatasetLike,
194+
start_time: float | np.datetime64,
195+
obs_vector: XarrayDatasetLike,
196+
n_cycles: int,
197+
obs_error_sd: float | ArrayLike | None = None,
198+
analysis_window: float = 0.2,
199+
analysis_time_in_window: float | None = None,
200+
return_forecast: bool = False
201+
) -> XarrayDatasetLike:
172202
"""Perform DA cycle repeatedly, including analysis and forecast
173203
174204
Args:
175-
input_state (vector.StateVector): Input state.
176-
start_time (float or datetime-like): Starting time.
177-
obs_vector (vector.ObsVector): Observations vector.
178-
n_cycles (int): Number of analysis cycles to run, each of length
205+
input_state: Input state as a Xarray Dataset
206+
start_time: Starting time.
207+
obs_vector: Observations vector.
208+
n_cycles: Number of analysis cycles to run, each of length
179209
analysis_window.
180-
analysis_window (float): Time window from which to gather
210+
analysis_window: Time window from which to gather
181211
observations for DA Cycle.
182-
analysis_time_in_window (float): Where within analysis_window
212+
analysis_time_in_window: Where within analysis_window
183213
to perform analysis. For example, 0.0 is the start of the
184214
window. Default is None, which selects the middle of the
185215
window.
186-
return_forecast (bool): If True, returns forecast at each model
216+
return_forecast: If True, returns forecast at each model
187217
timestep. If False, returns only analyses, one per analysis
188-
cycle. Default is False.
189-
190-
Returns:
191-
vector.StateVector of analyses and times.
218+
cycle.
192219
"""
193220

194221
# These could be different if observer doesn't observe all variables
@@ -202,10 +229,11 @@ def cycle(self,
202229
self.analysis_window = analysis_window
203230

204231
# If don't specify analysis_time_in_window, is assumed to be middle
205-
if self.analysis_time_in_window is None and analysis_time_in_window is None:
206-
analysis_time_in_window = self.analysis_window/2
207-
else:
208-
analysis_time_in_window = self.analysis_time_in_window
232+
if analysis_time_in_window is None:
233+
if self.in_4d:
234+
analysis_time_in_window = 0
235+
else:
236+
analysis_time_in_window = self.analysis_window/2
209237

210238
# Steps per window + 1 to include start
211239
self.steps_per_window = round(analysis_window/self.delta_t) + 1
@@ -256,13 +284,13 @@ def cycle(self,
256284
xj.from_xarray(input_state),
257285
all_filtered_padded)
258286

259-
all_vals_xr = xr.Dataset(
287+
all_vals_ds = xr.Dataset(
260288
{var: (('cycle',) + tuple(all_values[var].dims),
261289
all_values[var].data)
262290
for var in all_values.data_vars}
263291
).rename_dims({'time': 'cycle_timestep'})
264292

265293
if return_forecast:
266-
return all_vals_xr.drop_isel(cycle_timestep=-1)
294+
return all_vals_ds.drop_isel(cycle_timestep=-1)
267295
else:
268-
return all_vals_xr.isel(cycle_timestep=0)
296+
return all_vals_ds.isel(cycle_timestep=0)

0 commit comments

Comments
 (0)