Skip to content

Commit f87feb0

Browse files
committed
Updating rc to return full n_step forecast and new index dimension name
1 parent dfe4148 commit f87feb0

1 file changed

Lines changed: 3 additions & 6 deletions

File tree

dabench/model/_rc.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ class RCModel(model.Model):
2121
system_dim (int): Dimension of reservoir output.
2222
input_dim (int): Dimension of reservoir input signal.
2323
reservoir_dim (int): Dimension of reservoir state. Default: 512.
24-
2524
sparsity (float): the percentage of zero-valued entries in the
2625
adjacency matrix (A). Default: 0.99.
2726
sparse_adj_matrix (bool): If True, A is computed using scipy sparse.
@@ -37,10 +36,8 @@ class RCModel(model.Model):
3736
readout_method (str): How to handle reservoir state elements during
3837
readout. One of 'linear', 'biased', or 'quadratic'.
3938
Default: 'linear'.
40-
4139
random_seed (int): Random seed for random number generation. Default
4240
is 1.
43-
4441
s (ndarray): Model states over entire time series.
4542
s_last (ndarray): Last
4643
ybar (ndarray): y.T @ st, set in _compute_Wout.
@@ -386,7 +383,7 @@ def train(self, data_obj, update_Wout=True):
386383

387384
r = self.generate(data_obj)['r'].data
388385
# u = data_obj.to_array().transpose(..., 'variable').data.reshape(data_obj.sizes['time'], -1)
389-
u = data_obj.to_array().stack(system=['variable','i']).data
386+
u = data_obj.to_array().stack(system=['variable','index']).data
390387
self.Wout = self._compute_Wout(r, u, update_Wout=update_Wout, u=u.T)
391388

392389
def _compute_Wout(self, rt, y, update_Wout=True, u=None):
@@ -496,13 +493,13 @@ def forecast(self, state_vec, n_steps=1):
496493
r_full = jnp.zeros((n_steps, self.reservoir_dim))
497494
for i in range(n_steps):
498495
r_full = r_full.at[i].set(r)
499-
if i < n_steps-1:
496+
if i < n_steps:
500497
r = self.update(r, self.readout(r))
501498

502499
new_vec = xr.Dataset(
503500
{'r':(('time','reservoir'), r_full)}
504501
)
505-
return new_vec.isel(time=-1), new_vec.drop_isel(time=-1)
502+
return new_vec.isel(time=-1), new_vec
506503

507504
def save_weights(self, pkl_path):
508505
"""Save RC reservoir weights as pkl file.

0 commit comments

Comments
 (0)