@@ -29,12 +29,12 @@ class BandedTRF(BaseEstimator):
2929 basis_dict : dict, optional
3030 Dictionary mapping feature names to basis objects.
3131 """
32- def __init__ (self , tmin , tmax , sfreq , alphas = None , basis_dict = {} ):
32+ def __init__ (self , tmin , tmax , sfreq , alphas = None , basis_dict = None ):
3333 self .tmin = tmin
3434 self .tmax = tmax
3535 self .sfreq = sfreq
3636 self .alphas = alphas if alphas is not None else np .logspace (- 2 , 5 , 8 )
37- self .basis_dict = basis_dict
37+ self .basis_dict = basis_dict if basis_dict is not None else {}
3838 self .feature_alphas_ = []
3939 self .alpha_paths_ = []
4040 self .feature_order_ = []
@@ -86,6 +86,7 @@ def _prepare_matrix(self, X_list, alphas_list):
8686 if x .ndim == 1 :
8787 x = x [:, np .newaxis ]
8888
89+ name = self .feature_order_ [i ]
8990 if name in self .basis_dict :
9091 x = self .basis_dict [name ].transform (x )
9192
@@ -144,7 +145,7 @@ def fit(self, data=None, X=['aud'], y='resp'):
144145 else :
145146 self .feature_order_ = [chr (i + 65 ) for i in range (len (X ))]
146147 if isinstance (y , str ):
147- self .target_ = target
148+ self .target_ = y
148149 else :
149150 self .target_ = 'target'
150151
@@ -156,7 +157,7 @@ def fit(self, data=None, X=['aud'], y='resp'):
156157 self .n_targets_ = y [0 ].shape [1 ]
157158
158159 self .scores_ = np .zeros ((n_trials , self .n_targets_ , len (X )))
159- self .feature_alphas_ = np . zeros (( len ( X ), ))
160+ self .feature_alphas_ = []
160161 self .alpha_paths = np .zeros ((len (X ), len (self .alphas )))
161162
162163 for i , current_feat in enumerate (X ):
@@ -165,7 +166,7 @@ def fit(self, data=None, X=['aud'], y='resp'):
165166 r_history = []
166167 best_r_per_trial_ch = None
167168
168- for alpha in tqdm (self .alphas , desc = f"Optimizing { current_feat } " , leave = False ):
169+ for alpha in tqdm (self .alphas , desc = f"Optimizing { self . feature_order_ [ i ] } " , leave = False ):
169170 temp_alphas = self .feature_alphas_ + [alpha ]
170171 X_mats = self ._prepare_matrix (X [:i + 1 ], temp_alphas )
171172
@@ -197,15 +198,15 @@ def fit(self, data=None, X=['aud'], y='resp'):
197198 best_r_per_trial_ch = current_alpha_trial_r
198199
199200 self .feature_alphas_ .append (best_alpha )
200- self .alpha_paths_ [current_feat ] = np .array (r_history )
201+ self .alpha_paths_ [i , : ] = np .array (r_history )
201202 self .scores_ [:, :, i ] = best_r_per_trial_ch
202203
203204 # Final fit on each trial separately
204205 final_X = self ._prepare_matrix (X , self .feature_alphas_ )
205206 self .model_ = [Ridge (alpha = 1.0 ).fit (tx , ty ) for tx , ty in zip (final_X , y )]
206207
207208 self .feat_dims_ = []
208- for i , name in enumerate (feature_order ):
209+ for i , name in enumerate (self . feature_order_ ):
209210 x_sample = X [i ][0 ]
210211 if isinstance (x_sample , list ): x_sample = x_sample [0 ]
211212 if x_sample .ndim == 1 : x_sample = x_sample [:, None ]
@@ -257,15 +258,9 @@ def predict(self, data=None, X=['aud']):
257258 if self .model_ is None :
258259 raise ValueError ("Model must be fitted before calling predict." )
259260
260- requested_features = feature_names if feature_names else self .feature_order_
261-
262- # Standardize feature data to list of trial-lists
263- feat_data_list = []
264- for f in requested_features :
265- f_data = _parse_outstruct_args (data , f )
266- feat_data_list .append (f_data if isinstance (f_data , list ) else [f_data ])
261+ X = [_parse_outstruct_args (data , x ) for x in X ]
267262
268- X_mats = self ._prepare_matrix (feat_data_list , requested_features , self .feature_alphas_ )
263+ X_mats = self ._prepare_matrix (X , self .feature_alphas_ )
269264 n_trials = len (X_mats )
270265
271266 if n_trials != len (self .model_ ):
@@ -279,17 +274,6 @@ def predict(self, data=None, X=['aud']):
279274 # Expand (trials, features) -> (trials, 1_target, features)
280275 all_coefs = all_coefs [:, np .newaxis , :]
281276
282- # Handle feature masking if a subset is requested
283- mask = np .ones (all_coefs .shape [2 ], dtype = bool )
284- if feature_names is not None :
285- mask = np .zeros (all_coefs .shape [2 ], dtype = bool )
286- current_col = 0
287- for i , f in enumerate (self .feature_order_ ):
288- num_cols = self .feat_dims_ [i ] * self ._ndelays
289- if f in requested_features :
290- mask [current_col : current_col + num_cols ] = True
291- current_col += num_cols
292-
293277 preds = []
294278 for i in range (n_trials ):
295279 # Indices for all trials except the current one
@@ -298,8 +282,8 @@ def predict(self, data=None, X=['aud']):
298282 # Average coefficients and intercepts from the other trials
299283 loto_coef = np .mean (all_coefs [loto_indices ], axis = 0 )
300284
301- # Apply feature mask
302- sliced_coef = loto_coef [:, mask ]
285+ # Only use first coef features if using feature subset
286+ sliced_coef = loto_coef [:, : X_mats [ i ]. shape [ 1 ] ]
303287
304288 # Predict for the current trial
305289 preds .append (X_mats [i ] @ sliced_coef .T )
@@ -367,7 +351,7 @@ def summary(self, channel=None):
367351 'Feature' : feat ,
368352 'Total R' : np .nanmean (r_report [:, f_idx ]),
369353 'Delta R' : np .nanmean (dr_report [:, f_idx ]),
370- 'Alpha' : self .feature_alphas_ [feat ],
354+ 'Alpha' : self .feature_alphas_ [f_idx ],
371355 't-value' : t_val ,
372356 'p-value' : p_val ,
373357 })
0 commit comments