Skip to content

Commit 9323d3b

Browse files
Standardize BandedTRF with Data
1 parent 137eb2a commit 9323d3b

2 files changed

Lines changed: 44 additions & 40 deletions

File tree

naplib/encoding/banded_trf.py

Lines changed: 43 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@ class BandedTRF(BaseEstimator):
2929
basis_dict : dict, optional
3030
Dictionary mapping feature names to basis objects.
3131
"""
32-
def __init__(self, tmin, tmax, sfreq, alphas=None, basis_dict=None):
32+
def __init__(self, tmin, tmax, sfreq, alphas=None, basis_dict={}):
3333
self.tmin = tmin
3434
self.tmax = tmax
3535
self.sfreq = sfreq
3636
self.alphas = alphas if alphas is not None else np.logspace(-2, 5, 8)
37-
self.basis_dict = basis_dict if basis_dict is not None else {}
38-
self.feature_alphas_ = {}
39-
self.alpha_paths_ = {}
37+
self.basis_dict = basis_dict
38+
self.feature_alphas_ = []
39+
self.alpha_paths_ = []
4040
self.feature_order_ = []
4141
self.model_ = None # Will store a list of fitted Ridge models (one per trial)
4242
self.target_ = None
@@ -69,13 +69,13 @@ def coef_(self):
6969

7070
return all_coefs.reshape(n_targets, n_feat_dim, self._ndelays, n_trials)
7171

72-
def _prepare_matrix(self, X_list, feature_names, alphas_dict):
72+
def _prepare_matrix(self, X_list, alphas_list):
7373
processed_trials = []
7474
n_trials = len(X_list[0])
7575

7676
for trl in range(n_trials):
7777
mats = []
78-
for i, name in enumerate(feature_names):
78+
for i in range(len(X_list)):
7979
x = X_list[i][trl]
8080

8181
if isinstance(x, list) and len(x) == 1:
@@ -89,7 +89,7 @@ def _prepare_matrix(self, X_list, feature_names, alphas_dict):
8989
if name in self.basis_dict:
9090
x = self.basis_dict[name].transform(x)
9191

92-
alpha = alphas_dict.get(name, 1.0)
92+
alpha = alphas_list[i]
9393
mats.append(x / np.sqrt(alpha))
9494

9595
if not mats:
@@ -117,16 +117,15 @@ def fit(self, data=None, X=['aud'], y='resp'):
117117
and ``y`` arguments.
118118
X : list of str | list of list of np.ndarrays
119119
Data to be used as predictor in the regression. Should be a list,
120-
in which each element is feature, corresponding to a list of trials,
121-
each of shape (time, num_features).
120+
in which each element is a feature, corresponding to a list of trials,
121+
each of which is a numpy array of shape (time, num_features).
122122
If a string, it must specify a list of fields of the Data
123123
provided in the first argument.
124-
y : str | list of np.ndarrays or a multidimensional np.ndarray
124+
y : str | list of np.ndarrays
125125
Data to be used as target(s) in the regression. Once arranged,
126-
should be of shape (time, num_targets[, num_features_y]).
126+
should be of shape (time, num_targets).
127127
If a string, it must specify one of the fields of the Data
128-
provided in the first argument. If a multidimensional array, first dimension
129-
indicates the trial/instances.
128+
provided in the first argument.
130129
131130
Returns
132131
-------
@@ -140,31 +139,35 @@ def fit(self, data=None, X=['aud'], y='resp'):
140139
The prediction for a held-out trial $i$ is generated using the mean
141140
coefficients of all trials $j \neq i$.
142141
"""
143-
self.feature_order_ = feature_order
144-
self.target_ = target
142+
if isinstance(X[0], str):
143+
self.feature_order_ = X
144+
else:
145+
self.feature_order_ = [chr(i+65) for i in range(len(X))]
146+
if isinstance(y, str):
147+
self.target_ = target
148+
else:
149+
self.target_ = 'target'
145150

146-
y = _parse_outstruct_args(data, target)
151+
y = _parse_outstruct_args(data, y)
147152
if not isinstance(y, list): y = [y]
153+
X = [_parse_outstruct_args(data, x) for x in X]
148154

149155
n_trials = len(y)
150156
self.n_targets_ = y[0].shape[1]
151-
152-
all_features_data = []
153-
for f in feature_order:
154-
f_data = _parse_outstruct_args(data, f)
155-
all_features_data.append(f_data if isinstance(f_data, list) else [f_data])
156157

157-
self.scores_ = np.zeros((n_trials, self.n_targets_, len(feature_order)))
158+
self.scores_ = np.zeros((n_trials, self.n_targets_, len(X)))
159+
self.feature_alphas_ = np.zeros((len(X), ))
160+
self.alpha_paths = np.zeros((len(X), len(self.alphas)))
158161

159-
for i, current_feat in enumerate(feature_order):
162+
for i, current_feat in enumerate(X):
160163
best_alpha = None
161164
max_r = -np.inf
162165
r_history = []
163166
best_r_per_trial_ch = None
164167

165168
for alpha in tqdm(self.alphas, desc=f"Optimizing {current_feat}", leave=False):
166-
temp_alphas = {**self.feature_alphas_, current_feat: alpha}
167-
X_mats = self._prepare_matrix(all_features_data[:i+1], feature_order[:i+1], temp_alphas)
169+
temp_alphas = self.feature_alphas_ + [alpha]
170+
X_mats = self._prepare_matrix(X[:i+1], temp_alphas)
168171

169172
trial_betas = [Ridge(alpha=1.0).fit(tx, ty.reshape(-1, self.n_targets_)).coef_ for tx, ty in zip(X_mats, y)]
170173

@@ -193,17 +196,17 @@ def fit(self, data=None, X=['aud'], y='resp'):
193196
max_r, best_alpha = avg_r, alpha
194197
best_r_per_trial_ch = current_alpha_trial_r
195198

196-
self.feature_alphas_[current_feat] = best_alpha
199+
self.feature_alphas_.append(best_alpha)
197200
self.alpha_paths_[current_feat] = np.array(r_history)
198201
self.scores_[:, :, i] = best_r_per_trial_ch
199202

200203
# Final fit on each trial separately
201-
final_X = self._prepare_matrix(all_features_data, feature_order, self.feature_alphas_)
204+
final_X = self._prepare_matrix(X, self.feature_alphas_)
202205
self.model_ = [Ridge(alpha=1.0).fit(tx, ty) for tx, ty in zip(final_X, y)]
203206

204207
self.feat_dims_ = []
205208
for i, name in enumerate(feature_order):
206-
x_sample = all_features_data[i][0]
209+
x_sample = X[i][0]
207210
if isinstance(x_sample, list): x_sample = x_sample[0]
208211
if x_sample.ndim == 1: x_sample = x_sample[:, None]
209212
if name in self.basis_dict:
@@ -212,25 +215,27 @@ def fit(self, data=None, X=['aud'], y='resp'):
212215

213216
return self
214217

215-
def predict(self, data, feature_names=None):
218+
def predict(self, data=None, X=['aud']):
216219
"""
217220
Predict target responses using the fitted Banded Ridge model.
218221
219222
This method performs Leave-One-Trial-Out (LOTO) prediction. For each
220223
trial in the input data, it averages the regression coefficients
221224
from all *other* trials (fitted during training) to generate the
222225
prediction for the current trial.
223-
226+
224227
Parameters
225228
----------
226-
data : naplib.OutStruct or list of dict
227-
The data containing the features to predict from. Must contain
228-
the same number of trials as used during `fit`.
229-
feature_names : list of str, optional
230-
The subset of features to use for prediction. If None (default),
231-
uses all features specified in the `feature_order` during `fit`.
232-
This allows for isolating the contribution of specific bands.
233-
229+
data : naplib.Data object, optional
230+
Data object containing data to predict from in one of the fields.
231+
If not given, must give the X data directly.
232+
X : list of str | list of list of np.ndarrays
233+
Data to be used as predictor in the regression. Should be a list,
234+
in which each element is a feature, corresponding to a list of trials,
235+
each of which is a numpy array of shape (time, num_features).
236+
If a list of strings, it must specify a list of fields of the Data
237+
provided in the first argument.
238+
234239
Returns
235240
-------
236241
preds : list of np.ndarray

naplib/encoding/trf.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,7 @@ def predict(self, data=None, X='aud'):
195195
----------
196196
data : naplib.Data object, optional
197197
Data object containing data to be normalized in one of the field.
198-
If not given, must give the X and y data directly as the ``X``
199-
and ``y`` arguments.
198+
If not given, must give the X data directly as the ``X`` argument.
200199
X : str | list of np.ndarrays or a multidimensional np.ndarray
201200
Data to be used as predictor in the regression. Once arranged,
202201
should be of shape (time, num_features).

0 commit comments

Comments
 (0)