@@ -26,10 +26,8 @@ class BandedTRF(BaseEstimator):
2626 Sampling frequency (Hz).
2727 alphas : np.ndarray, optional
2828 Alphas to sweep for each feature. Default is np.logspace(-2, 5, 8).
29- basis_dict : dict, optional
30- Dictionary mapping feature names to basis objects.
3129 """
32- def __init__ (self , tmin , tmax , sfreq , alphas = None , basis_dict = None ):
30+ def __init__ (self , tmin , tmax , sfreq , alphas = None ):
3331 self .tmin = tmin
3432 self .tmax = tmax
3533 self .sfreq = sfreq
@@ -86,10 +84,6 @@ def _prepare_matrix(self, X_list, alphas_list):
8684 if x .ndim == 1 :
8785 x = x [:, np .newaxis ]
8886
89- name = self .feature_order_ [i ]
90- if name in self .basis_dict :
91- x = self .basis_dict [name ].transform (x )
92-
9387 alpha = alphas_list [i ]
9488 mats .append (x / np .sqrt (alpha ))
9589
@@ -206,12 +200,10 @@ def fit(self, data=None, X=['aud'], y='resp'):
206200 self .model_ = [Ridge (alpha = 1.0 ).fit (tx , ty ) for tx , ty in zip (final_X , y )]
207201
208202 self .feat_dims_ = []
209- for i , name in enumerate ( self . feature_order_ ):
203+ for i in range ( len ( X ) ):
210204 x_sample = X [i ][0 ]
211205 if isinstance (x_sample , list ): x_sample = x_sample [0 ]
212206 if x_sample .ndim == 1 : x_sample = x_sample [:, None ]
213- if name in self .basis_dict :
214- x_sample = self .basis_dict [name ].transform (x_sample )
215207 self .feat_dims_ .append (x_sample .shape [1 ])
216208
217209 return self
0 commit comments