Skip to content

Commit 7ac9b1b

Browse files
committed
Remove self.time_dim attribute from data generators, doesn't make sense since we now produce xarray objects which have that information (used to save values in the data generator object instead)
1 parent 9a3bcdc commit 7ac9b1b

7 files changed

Lines changed: 17 additions & 42 deletions

File tree

dabench/data/_data.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ class Data():
2020
2121
Args:
2222
system_dim: system dimension
23-
time_dim: total time steps
2423
original_dim: dimensions in original space, e.g. could be 3x3
2524
for a 2d system with system_dim = 9. Defaults to (system_dim),
2625
i.e. 1d.
@@ -32,7 +31,6 @@ class Data():
3231

3332
def __init__(self,
3433
system_dim: int = 3,
35-
time_dim: int = 1,
3634
original_dim: tuple[int, ...] | None = None,
3735
random_seed: int = 37,
3836
delta_t: float = 0.01,
@@ -42,7 +40,6 @@ def __init__(self,
4240
"""Initializes the base data object"""
4341

4442
self.system_dim = system_dim
45-
self.time_dim = time_dim
4643
self.random_seed = random_seed
4744
self.delta_t = delta_t
4845
self.store_as_jax = store_as_jax
@@ -98,8 +95,7 @@ def generate(self,
9895
9996
Notes:
10097
Either provide n_steps or t_final in order to indicate the length
101-
of the forecast. These are used to set the values, times, and
102-
time_dim attributes.
98+
of the forecast.
10399
104100
Args:
105101
n_steps: Number of timesteps. One of n_steps OR
@@ -172,8 +168,8 @@ def generate(self,
172168
**kwargs)
173169

174170
# Convert to JAX if necessary
175-
self.time_dim = t.shape[0]
176-
out_dim = (self.time_dim,) + self.original_dim
171+
time_dim = t.shape[0]
172+
out_dim = (time_dim,) + self.original_dim
177173
if self.store_as_jax:
178174
y_out = jnp.array(y[:,:self.system_dim].reshape(out_dim))
179175
else:
@@ -197,13 +193,13 @@ def generate(self,
197193
# Reshape M matrix
198194
if self.store_as_jax:
199195
M = jnp.reshape(y[:, self.system_dim:],
200-
(self.time_dim,
196+
(time_dim,
201197
self.system_dim,
202198
self.system_dim)
203199
)
204200
else:
205201
M = np.reshape(y[:, self.system_dim:],
206-
(self.time_dim,
202+
(time_dim,
207203
self.system_dim,
208204
self.system_dim)
209205
)

dabench/data/_enso_indices.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ class ENSOIndices(_data.Data):
2121
2222
Args:
2323
system_dim: system dimension
24-
time_dim: total time steps
2524
store_as_jax: Store values as jax array instead of numpy array.
2625
Default is False (store as numpy).
2726
file_dict: Lists of files to get. Dict keys are type of data:
@@ -58,15 +57,14 @@ def __init__(self,
5857
file_dict: dict | None = None,
5958
var_types: dict | None = None,
6059
system_dim: int | None = None,
61-
time_dim: int | None = None,
6260
store_as_jax: bool = False,
6361
**kwargs):
6462

6563
"""Initialize ENSOIndices object, subclass of Base"""
6664

6765
self.file_dict = file_dict
6866
self.var_types = var_types
69-
super().__init__(system_dim=system_dim, time_dim=time_dim,
67+
super().__init__(system_dim=system_dim,
7068
values=None, delta_t=None, **kwargs,
7169
store_as_jax=store_as_jax)
7270

dabench/data/_lorenz63.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ class Lorenz63(_data.Data):
3030
and initial conditions [0., 1., 0.], a spinup which replicates
3131
the simulation described in Lorenz, 1963.
3232
system_dim: system dimension. Must be 3 for Lorenz63.
33-
time_dim: total time steps
3433
store_as_jax: Store values as jax array instead of numpy array.
3534
Default is False (store as numpy).
3635
"""
@@ -42,7 +41,6 @@ def __init__(self,
4241
delta_t: float = 0.01,
4342
x0: ArrayLike | None = jnp.array([-10.0, -15.0, 21.3]),
4443
system_dim: int = 3,
45-
time_dim: int | None = None,
4644
values: ArrayLike | None = None,
4745
store_as_jax: bool = False,
4846
**kwargs):
@@ -57,7 +55,7 @@ def __init__(self,
5755
print('Assigning system_dim to 3.')
5856
system_dim = 3
5957

60-
super().__init__(system_dim=system_dim, time_dim=time_dim,
58+
super().__init__(system_dim=system_dim,
6159
values=values, delta_t=delta_t,
6260
store_as_jax=store_as_jax, **kwargs)
6361

dabench/data/_lorenz96.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ class Lorenz96(_data.Data):
3333
which is set to 0.01.
3434
system_dim: System dimension, must be between 4 and 40.
3535
Default is 36.
36-
time_dim: Total time steps
3736
delta_t: Length of one time step. Default is 0.05 from
3837
Lorenz, 1996, but on modern computers 0.01 is often used.
3938
store_as_jax: Store values as jax array instead of numpy array.
@@ -45,13 +44,12 @@ def __init__(self,
4544
delta_t: float = 0.05,
4645
x0: ArrayLike | None = None,
4746
system_dim: int = 36,
48-
time_dim: int | None = None,
4947
values: ArrayLike | None = None,
5048
store_as_jax: bool = False,
5149
**kwargs):
5250
"""Initialize Lorenz96 object, subclass of Base"""
5351

54-
super().__init__(system_dim=system_dim, time_dim=time_dim,
52+
super().__init__(system_dim=system_dim,
5553
values=values, delta_t=delta_t,
5654
store_as_jax=store_as_jax, **kwargs)
5755

dabench/data/_pyqg_jax.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ def __init__(self,
7272
ny: int | None = None,
7373
delta_t: float = 7200,
7474
random_seed: int = 37,
75-
time_dim: int | None = None,
7675
store_as_jax: bool = False,
7776
**kwargs):
7877
""" Initialize PyQGJax QGModel object, subclass of Base
@@ -110,7 +109,7 @@ def __init__(self,
110109
jax.random.PRNGKey(0)
111110
)
112111
super().__init__(system_dim=system_dim, original_dim=original_dim,
113-
time_dim=time_dim, delta_t=delta_t,
112+
delta_t=delta_t,
114113
store_as_jax=store_as_jax, x0=x0,
115114
**kwargs)
116115

@@ -157,8 +156,7 @@ def generate(self,
157156
158157
Notes:
159158
Either provide n_steps or t_final in order to indicate the length
160-
of the forecast. These are used to set the values, times, and
161-
time_dim attributes.
159+
of the forecast.
162160
163161
Args:
164162
n_steps: Number of timesteps. One of n_steps OR

dabench/data/_qgs.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ def __init__(self,
5353
x0: ArrayLike | None = None,
5454
delta_t: ArrayLike | None = 0.1,
5555
system_dim: int | None = None,
56-
time_dim: int | None = None,
5756
store_as_jax: bool = False,
5857
random_seed: int = 37,
5958
**kwargs):
@@ -86,7 +85,7 @@ def __init__(self,
8685
if x0 is None:
8786
x0 = self._rng.random(system_dim)*0.001
8887

89-
super().__init__(system_dim=system_dim, time_dim=time_dim,
88+
super().__init__(system_dim=system_dim,
9089
delta_t=delta_t, store_as_jax=store_as_jax, x0=x0,
9190
**kwargs)
9291

@@ -169,8 +168,7 @@ def generate(self,
169168
170169
Notes:
171170
Either provide n_steps or t_final in order to indicate the length
172-
of the forecast. These are used to set the values, times, and
173-
time_dim attributes.
171+
of the forecast.
174172
175173
Args:
176174
n_steps (int): Number of timesteps. One of n_steps OR
@@ -243,8 +241,8 @@ def generate(self,
243241
**kwargs)
244242

245243
# Convert to JAX if necessary
246-
self.time_dim = t.shape[0]
247-
out_dim = (self.time_dim,) + self.original_dim
244+
time_dim = t.shape[0]
245+
out_dim = (time_dim,) + self.original_dim
248246
if self.store_as_jax:
249247
y_out = jnp.array(y[:,:self.system_dim].reshape(out_dim))
250248
else:
@@ -268,13 +266,13 @@ def generate(self,
268266
# Reshape M matrix
269267
if self.store_as_jax:
270268
M = jnp.reshape(y[:, self.system_dim:],
271-
(self.time_dim,
269+
(time_dim,
272270
self.system_dim,
273271
self.system_dim)
274272
)
275273
else:
276274
M = np.reshape(y[:, self.system_dim:],
277-
(self.time_dim,
275+
(time_dim,
278276
self.system_dim,
279277
self.system_dim)
280278
)

dabench/data/_sqgturb.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ class SQGTurb(_data.Data):
5959
https://github.com/jswhit/sqgturb. 57600 steps matches the
6060
"nature run" spin up in that repository.
6161
system_dim: The dimension of the system state
62-
time_dim: The dimension of the timeseries (not used)
6362
delta_t: model time step (seconds)
6463
x0: Initial state, array of floats of size
6564
(system_dim).
@@ -499,7 +498,6 @@ def integrate(self,
499498
if include_x0:
500499
n_steps = n_steps + 1
501500

502-
self.time_dim = n_steps
503501
times = t + jnp.arange(n_steps)*delta_t
504502

505503
# Integrate in spectral spacestep_n
@@ -548,13 +546,4 @@ def rhs(self,
548546
# save wind field
549547
self.u = -psiy
550548
self.v = psix
551-
return dpvspecdt
552-
553-
def _to_original_dim(self) -> np.ndarray:
554-
"""Going back to 2D is a bit trickier for sqgturb"""
555-
gridded_vals = np.zeros((self.time_dim, self.Nv, self.Nx, self.Nx))
556-
557-
for t in np.arange(self.time_dim):
558-
gridded_vals[t] = self.map1dto2d_ifft2(self.values[t])
559-
560-
return gridded_vals
549+
return dpvspecdt

0 commit comments

Comments
 (0)