Skip to content

Commit aaea063

Browse files
Simplify BandedTRF
Remove basis_dict, can be easily implemented externally
1 parent a4ffc49 commit aaea063

1 file changed

Lines changed: 2 additions & 10 deletions

File tree

naplib/encoding/banded_trf.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,8 @@ class BandedTRF(BaseEstimator):
2626
Sampling frequency (Hz).
2727
alphas : np.ndarray, optional
2828
Alphas to sweep for each feature. Default is np.logspace(-2, 5, 8).
29-
basis_dict : dict, optional
30-
Dictionary mapping feature names to basis objects.
3129
"""
32-
def __init__(self, tmin, tmax, sfreq, alphas=None, basis_dict=None):
30+
def __init__(self, tmin, tmax, sfreq, alphas=None):
3331
self.tmin = tmin
3432
self.tmax = tmax
3533
self.sfreq = sfreq
@@ -86,10 +84,6 @@ def _prepare_matrix(self, X_list, alphas_list):
8684
if x.ndim == 1:
8785
x = x[:, np.newaxis]
8886

89-
name = self.feature_order_[i]
90-
if name in self.basis_dict:
91-
x = self.basis_dict[name].transform(x)
92-
9387
alpha = alphas_list[i]
9488
mats.append(x / np.sqrt(alpha))
9589

@@ -206,12 +200,10 @@ def fit(self, data=None, X=['aud'], y='resp'):
206200
self.model_ = [Ridge(alpha=1.0).fit(tx, ty) for tx, ty in zip(final_X, y)]
207201

208202
self.feat_dims_ = []
209-
for i, name in enumerate(self.feature_order_):
203+
for i in range(len(X)):
210204
x_sample = X[i][0]
211205
if isinstance(x_sample, list): x_sample = x_sample[0]
212206
if x_sample.ndim == 1: x_sample = x_sample[:, None]
213-
if name in self.basis_dict:
214-
x_sample = self.basis_dict[name].transform(x_sample)
215207
self.feat_dims_.append(x_sample.shape[1])
216208

217209
return self

0 commit comments

Comments
 (0)