1616ArrayLike = np .ndarray | jax .Array
1717
1818class Data ():
19- """Generic class for data generator objects.
19+ """Base for all data generator objects.
2020
21- Attributes :
21+ Args :
2222 system_dim: system dimension
23- time_dim: total time steps
2423 original_dim: dimensions in original space, e.g. could be 3x3
2524 for a 2d system with system_dim = 9. Defaults to (system_dim),
2625 i.e. 1d.
@@ -32,7 +31,6 @@ class Data():
3231
3332 def __init__ (self ,
3433 system_dim : int = 3 ,
35- time_dim : int = 1 ,
3634 original_dim : tuple [int , ...] | None = None ,
3735 random_seed : int = 37 ,
3836 delta_t : float = 0.01 ,
@@ -42,7 +40,6 @@ def __init__(self,
4240 """Initializes the base data object"""
4341
4442 self .system_dim = system_dim
45- self .time_dim = time_dim
4643 self .random_seed = random_seed
4744 self .delta_t = delta_t
4845 self .store_as_jax = store_as_jax
@@ -98,8 +95,7 @@ def generate(self,
9895
9996 Notes:
10097 Either provide n_steps or t_final in order to indicate the length
101- of the forecast. These are used to set the values, times, and
102- time_dim attributes.
98+ of the forecast.
10399
104100 Args:
105101 n_steps: Number of timesteps. One of n_steps OR
@@ -118,8 +114,8 @@ def generate(self,
118114 convergence tolerance, etc.).
119115
120116 Returns:
121- Xarray Dataset of output vector and ( if return_tlm=True)
122- Xarray DataArray of TLMs corresponding to the system trajectory.
117+ Xarray Dataset of output vector, and if return_tlm=True
118+ Xarray DataArray of TLMs corresponding to the system trajectory.
123119 """
124120
125121 # Check that n_steps or t_final is supplied
@@ -172,8 +168,8 @@ def generate(self,
172168 ** kwargs )
173169
174170 # Convert to JAX if necessary
175- self . time_dim = t .shape [0 ]
176- out_dim = (self . time_dim ,) + self .original_dim
171+ time_dim = t .shape [0 ]
172+ out_dim = (time_dim ,) + self .original_dim
177173 if self .store_as_jax :
178174 y_out = jnp .array (y [:,:self .system_dim ].reshape (out_dim ))
179175 else :
@@ -197,13 +193,13 @@ def generate(self,
197193 # Reshape M matrix
198194 if self .store_as_jax :
199195 M = jnp .reshape (y [:, self .system_dim :],
200- (self . time_dim ,
196+ (time_dim ,
201197 self .system_dim ,
202198 self .system_dim )
203199 )
204200 else :
205201 M = np .reshape (y [:, self .system_dim :],
206- (self . time_dim ,
202+ (time_dim ,
207203 self .system_dim ,
208204 self .system_dim )
209205 )
@@ -283,7 +279,7 @@ def calc_lyapunov_exponents_series(
283279
284280 Returns:
285281 Lyapunov exponents for all timesteps, array of size
286- (total_time/rescale_time - 1, system_dim)
282+ (total_time/rescale_time - 1, system_dim)
287283 """
288284
289285 # Set total_time
0 commit comments