Skip to content

Commit 4070e99

Browse files
authored
Merge pull request #70 from StevePny/docs/rtd-autoapi-fixes
ReadTheDocs autoapi fixes
2 parents 8e787f5 + 560d575 commit 4070e99

25 files changed

Lines changed: 155 additions & 162 deletions

dabench/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1+
"""DataAssimBench"""
12
from . import data, model, observer, obsop, dacycler, _suppl_data

dabench/dacycler/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Data Assimilation cyclers"""
2+
13
from ._dacycler import DACycler
24
from ._var3d import Var3D
35
from ._etkf import ETKF

dabench/dacycler/_dacycler.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,12 @@
1616
XarrayDatasetLike = xr.Dataset | xj.XjDataset
1717

1818
class DACycler():
19-
"""Base class for DACycler object
19+
"""Base for all DACyclers
2020
21-
Attributes:
21+
Args:
2222
system_dim: System dimension
2323
delta_t: The timestep of the model (assumed uniform)
2424
model_obj: Forecast model object.
25-
in_4d: True for 4D data assimilation techniques (e.g. 4DVar).
26-
Default is False.
27-
ensemble: True for ensemble-based data assimilation techniques
28-
(ETKF). Default is False
2925
B: Initial / static background error covariance. Shape:
3026
(system_dim, system_dim). If not provided, will be calculated
3127
automatically.
@@ -37,13 +33,13 @@ class DACycler():
3733
h: Optional observation operator as function. More flexible
3834
(allows for more complex observation operator). Default is None.
3935
"""
36+
_in_4d: bool = False
37+
_uses_ensemble: bool = False
4038

4139
def __init__(self,
4240
system_dim: int,
4341
delta_t: float,
4442
model_obj: Model,
45-
in_4d: bool = False,
46-
ensemble: bool = False,
4743
B: ArrayLike | None = None,
4844
R: ArrayLike | None = None,
4945
H: ArrayLike | None = None,
@@ -54,8 +50,6 @@ def __init__(self,
5450
self.H = H
5551
self.R = R
5652
self.B = B
57-
self.in_4d = in_4d
58-
self.ensemble = ensemble
5953
self.system_dim = system_dim
6054
self.delta_t = delta_t
6155
self.model_obj = model_obj
@@ -230,7 +224,7 @@ def cycle(self,
230224

231225
# If don't specify analysis_time_in_window, is assumed to be middle
232226
if analysis_time_in_window is None:
233-
if self.in_4d:
227+
if self._in_4d:
234228
analysis_time_in_window = 0
235229
else:
236230
analysis_time_in_window = self.analysis_window/2
@@ -257,7 +251,7 @@ def cycle(self,
257251
obs_times=jnp.array(obs_vector.time.values),
258252
analysis_times=all_times+_time_offset,
259253
start_inclusive=True,
260-
end_inclusive=self.in_4d,
254+
end_inclusive=self._in_4d,
261255
analysis_window=analysis_window
262256
)
263257
input_state = input_state.assign(_cur_time=start_time)
@@ -273,7 +267,7 @@ def cycle(self,
273267
obs_vector[self._observed_vars].to_array().data)
274268
self._obs_vector=self._obs_vector.fillna(0)
275269

276-
if self.in_4d:
270+
if self._in_4d:
277271
cur_state, all_values = jax.lax.scan(
278272
self._cycle_and_forecast_4d,
279273
xj.from_xarray(input_state),

dabench/dacycler/_etkf.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
XarrayDatasetLike = xr.Dataset | xj.XjDataset
1818

1919
class ETKF(dacycler.DACycler):
20-
"""Class for building ETKF DA Cycler
20+
"""Ensemble transform Kalman filter DA Cycler
2121
22-
Attributes:
22+
Args:
2323
system_dim: System dimension.
2424
delta_t: The timestep of the model (assumed uniform)
2525
model_obj: Forecast model object.
@@ -38,6 +38,8 @@ class ETKF(dacycler.DACycler):
3838
multiplicative_inflation: Scaling factor by which to multiply ensemble
3939
deviation. Default is 1.0 (no inflation).
4040
"""
41+
_in_4d: bool = False
42+
_uses_ensemble: bool = True
4143

4244
def __init__(self,
4345
system_dim: int,
@@ -57,8 +59,6 @@ def __init__(self,
5759
super().__init__(system_dim=system_dim,
5860
delta_t=delta_t,
5961
model_obj=model_obj,
60-
in_4d=False,
61-
ensemble=True,
6262
B=B, R=R, H=H, h=h)
6363

6464
def _step_forecast(self,

dabench/dacycler/_var3d.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
XarrayDatasetLike = xr.Dataset | xj.XjDataset
1717

1818
class Var3D(dacycler.DACycler):
19-
"""Class for building 3DVar DA Cycler
19+
"""3D-Var DA Cycler
2020
21-
Attributes:
21+
Args:
2222
system_dim: System dimension.
2323
delta_t: The timestep of the model (assumed uniform)
2424
model_obj: Forecast model object.
@@ -33,6 +33,8 @@ class Var3D(dacycler.DACycler):
3333
h: Optional observation operator as function. More flexible
3434
(allows for more complex observation operator). Default is None.
3535
"""
36+
_in_4d: bool = False
37+
_uses_ensemble: bool = False
3638

3739
def __init__(self,
3840
system_dim: int,
@@ -47,8 +49,6 @@ def __init__(self,
4749
super().__init__(system_dim=system_dim,
4850
delta_t=delta_t,
4951
model_obj=model_obj,
50-
in_4d=False,
51-
ensemble=False,
5252
B=B, R=R, H=H, h=h)
5353

5454
def _cycle_obsop(self,

dabench/dacycler/_var4d.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,12 @@
2525
XarrayDatasetLike = xr.Dataset | xj.XjDataset
2626

2727
class Var4D(dacycler.DACycler):
28-
"""Class for building 4D DA Cycler
28+
"""4D-Var DA Cycler
2929
30-
Attributes:
30+
Args:
3131
system_dim: System dimension.
3232
delta_t: The timestep of the model (assumed uniform)
3333
model_obj: Forecast model object.
34-
in_4d: True for 4D data assimilation techniques (e.g. 4DVar).
35-
Always True for Var4D.
36-
ensemble: True for ensemble-based data assimilation techniques
37-
(ETKF). Always False for Var4D.
3834
B: Initial / static background error covariance. Shape:
3935
(system_dim, system_dim). If not provided, will be calculated
4036
automatically.
@@ -59,6 +55,8 @@ class Var4D(dacycler.DACycler):
5955
[0, 1, 2, 3, 4, 5]. If None (default), will calculate
6056
automatically.
6157
"""
58+
_in_4d: bool = True
59+
_uses_ensemble: bool = False
6260

6361
def __init__(self,
6462
system_dim: int,
@@ -87,8 +85,6 @@ def __init__(self,
8785
super().__init__(system_dim=system_dim,
8886
delta_t=delta_t,
8987
model_obj=model_obj,
90-
in_4d=True,
91-
ensemble=False,
9288
B=B, R=R, H=H, h=h)
9389

9490
def _calc_default_H(self,

dabench/dacycler/_var4d_backprop.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,12 @@
2525
ScheduleState = Any
2626

2727
class Var4DBackprop(dacycler.DACycler):
28-
"""Class for building Backpropagation 4D DA Cycler
28+
"""Backpropagation 4D-Var DA Cycler
2929
30-
Attributes:
30+
Args:
3131
system_dim: System dimension.
3232
delta_t: The timestep of the model (assumed uniform)
3333
model_obj: Forecast model object.
34-
in_4d: True for 4D data assimilation techniques (e.g. 4DVar).
35-
Always True for Var4DBackprop.
36-
ensemble: True for ensemble-based data assimilation techniques
37-
(ETKF). Always False for Var4DBackprop.
3834
B: Initial / static background error covariance. Shape:
3935
(system_dim, system_dim). If not provided, will be calculated
4036
automatically.
@@ -65,6 +61,8 @@ class Var4DBackprop(dacycler.DACycler):
6561
return an error. This prevents it from hanging indefinitely
6662
when loss grows exponentionally. Default is 10.
6763
"""
64+
_in_4d: bool = True
65+
_uses_ensemble: bool = False
6866

6967
def __init__(self,
7068
system_dim: int,
@@ -97,8 +95,6 @@ def __init__(self,
9795
super().__init__(system_dim=system_dim,
9896
delta_t=delta_t,
9997
model_obj=model_obj,
100-
in_4d=True,
101-
ensemble=False,
10298
B=B, R=R, H=H, h=h)
10399

104100
def _calc_default_H(self,

dabench/data/__init__.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1+
"""Data generators and downloaders"""
12
from ._data import Data
23

3-
from .lorenz63 import Lorenz63
4-
from .lorenz96 import Lorenz96
5-
from .sqgturb import SQGTurb
6-
from .gcp import GCP
7-
from .pyqg import PyQG
8-
from .pyqg_jax import PyQGJax
9-
from .barotropic import Barotropic
10-
from .enso_indices import ENSOIndices
11-
from .qgs import QGS
4+
from ._lorenz63 import Lorenz63
5+
from ._lorenz96 import Lorenz96
6+
from ._sqgturb import SQGTurb
7+
from ._gcp import GCP
8+
from ._pyqg import PyQG
9+
from ._pyqg_jax import PyQGJax
10+
from ._barotropic import Barotropic
11+
from ._enso_indices import ENSOIndices
12+
from ._qgs import QGS
1213
from ._xarray_accessor import DABenchDatasetAccessor, DABenchDataArrayAccessor
1314

1415
__all__ = [
Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,21 @@
3030

3131

3232
class Barotropic(_data.Data):
33-
""" Class to set up barotropic model
33+
"""Barotropic model data generator based on pyqg
3434
35-
The data class is a wrapper of a "optional" pyqg package.
35+
This data class is a wrapper of a "optional" pyqg package.
3636
See https://pyqg.readthedocs.io
3737
3838
Notes:
39+
DEPRECATED
3940
Uses default attribute values from pyqg.BTModel:
4041
https://pyqg.readthedocs.io/en/latest/api.html#pyqg.BTModel
4142
Those values originally come from Mcwilliams 1984:
4243
J. C. Mcwilliams (1984). The emergence of isolated coherent
4344
vortices in turbulent flow. Journal of Fluid Mechanics, 146,
4445
pp 21-43 doi:10.1017/S0022112084001750.
4546
46-
Attributes:
47+
Args:
4748
system_dim: system dimension
4849
beta: Gradient of coriolis parameter. Units: meters^-1 *
4950
seconds^-1. Default is 0.
@@ -207,8 +208,8 @@ def __advance__(self,):
207208
"""Advances the QG model according to set attributes
208209
209210
Returns:
210-
qs (array_like): absolute potential vorticity (relative potential
211-
vorticity + background vorticity).
211+
Array of absolute potential vorticity (relative potential
212+
vorticity + background vorticity).
212213
"""
213214
qs = []
214215
for _ in self.m.run_with_snapshots(tsnapstart=0, tsnapint=self.m.dt):

dabench/data/_data.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,10 @@
1616
ArrayLike = np.ndarray | jax.Array
1717

1818
class Data():
19-
"""Generic class for data generator objects.
19+
"""Base for all data generator objects.
2020
21-
Attributes:
21+
Args:
2222
system_dim: system dimension
23-
time_dim: total time steps
2423
original_dim: dimensions in original space, e.g. could be 3x3
2524
for a 2d system with system_dim = 9. Defaults to (system_dim),
2625
i.e. 1d.
@@ -32,7 +31,6 @@ class Data():
3231

3332
def __init__(self,
3433
system_dim: int = 3,
35-
time_dim: int = 1,
3634
original_dim: tuple[int, ...] | None = None,
3735
random_seed: int = 37,
3836
delta_t: float = 0.01,
@@ -42,7 +40,6 @@ def __init__(self,
4240
"""Initializes the base data object"""
4341

4442
self.system_dim = system_dim
45-
self.time_dim = time_dim
4643
self.random_seed = random_seed
4744
self.delta_t = delta_t
4845
self.store_as_jax = store_as_jax
@@ -98,8 +95,7 @@ def generate(self,
9895
9996
Notes:
10097
Either provide n_steps or t_final in order to indicate the length
101-
of the forecast. These are used to set the values, times, and
102-
time_dim attributes.
98+
of the forecast.
10399
104100
Args:
105101
n_steps: Number of timesteps. One of n_steps OR
@@ -118,8 +114,8 @@ def generate(self,
118114
convergence tolerance, etc.).
119115
120116
Returns:
121-
Xarray Dataset of output vector and (if return_tlm=True)
122-
Xarray DataArray of TLMs corresponding to the system trajectory.
117+
Xarray Dataset of output vector, and if return_tlm=True
118+
Xarray DataArray of TLMs corresponding to the system trajectory.
123119
"""
124120

125121
# Check that n_steps or t_final is supplied
@@ -172,8 +168,8 @@ def generate(self,
172168
**kwargs)
173169

174170
# Convert to JAX if necessary
175-
self.time_dim = t.shape[0]
176-
out_dim = (self.time_dim,) + self.original_dim
171+
time_dim = t.shape[0]
172+
out_dim = (time_dim,) + self.original_dim
177173
if self.store_as_jax:
178174
y_out = jnp.array(y[:,:self.system_dim].reshape(out_dim))
179175
else:
@@ -197,13 +193,13 @@ def generate(self,
197193
# Reshape M matrix
198194
if self.store_as_jax:
199195
M = jnp.reshape(y[:, self.system_dim:],
200-
(self.time_dim,
196+
(time_dim,
201197
self.system_dim,
202198
self.system_dim)
203199
)
204200
else:
205201
M = np.reshape(y[:, self.system_dim:],
206-
(self.time_dim,
202+
(time_dim,
207203
self.system_dim,
208204
self.system_dim)
209205
)
@@ -283,7 +279,7 @@ def calc_lyapunov_exponents_series(
283279
284280
Returns:
285281
Lyapunov exponents for all timesteps, array of size
286-
(total_time/rescale_time - 1, system_dim)
282+
(total_time/rescale_time - 1, system_dim)
287283
"""
288284

289285
# Set total_time

0 commit comments

Comments
 (0)