Skip to content

Commit 63e2b65

Browse files
Standarize Banded
1 parent 9323d3b commit 63e2b65

1 file changed

Lines changed: 13 additions & 29 deletions

File tree

naplib/encoding/banded_trf.py

Lines changed: 13 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@ class BandedTRF(BaseEstimator):
2929
basis_dict : dict, optional
3030
Dictionary mapping feature names to basis objects.
3131
"""
32-
def __init__(self, tmin, tmax, sfreq, alphas=None, basis_dict={}):
32+
def __init__(self, tmin, tmax, sfreq, alphas=None, basis_dict=None):
3333
self.tmin = tmin
3434
self.tmax = tmax
3535
self.sfreq = sfreq
3636
self.alphas = alphas if alphas is not None else np.logspace(-2, 5, 8)
37-
self.basis_dict = basis_dict
37+
self.basis_dict = basis_dict if basis_dict is not None else {}
3838
self.feature_alphas_ = []
3939
self.alpha_paths_ = []
4040
self.feature_order_ = []
@@ -86,6 +86,7 @@ def _prepare_matrix(self, X_list, alphas_list):
8686
if x.ndim == 1:
8787
x = x[:, np.newaxis]
8888

89+
name = self.feature_order_[i]
8990
if name in self.basis_dict:
9091
x = self.basis_dict[name].transform(x)
9192

@@ -144,7 +145,7 @@ def fit(self, data=None, X=['aud'], y='resp'):
144145
else:
145146
self.feature_order_ = [chr(i+65) for i in range(len(X))]
146147
if isinstance(y, str):
147-
self.target_ = target
148+
self.target_ = y
148149
else:
149150
self.target_ = 'target'
150151

@@ -156,7 +157,7 @@ def fit(self, data=None, X=['aud'], y='resp'):
156157
self.n_targets_ = y[0].shape[1]
157158

158159
self.scores_ = np.zeros((n_trials, self.n_targets_, len(X)))
159-
self.feature_alphas_ = np.zeros((len(X), ))
160+
self.feature_alphas_ = []
160161
self.alpha_paths = np.zeros((len(X), len(self.alphas)))
161162

162163
for i, current_feat in enumerate(X):
@@ -165,7 +166,7 @@ def fit(self, data=None, X=['aud'], y='resp'):
165166
r_history = []
166167
best_r_per_trial_ch = None
167168

168-
for alpha in tqdm(self.alphas, desc=f"Optimizing {current_feat}", leave=False):
169+
for alpha in tqdm(self.alphas, desc=f"Optimizing {self.feature_order_[i]}", leave=False):
169170
temp_alphas = self.feature_alphas_ + [alpha]
170171
X_mats = self._prepare_matrix(X[:i+1], temp_alphas)
171172

@@ -197,15 +198,15 @@ def fit(self, data=None, X=['aud'], y='resp'):
197198
best_r_per_trial_ch = current_alpha_trial_r
198199

199200
self.feature_alphas_.append(best_alpha)
200-
self.alpha_paths_[current_feat] = np.array(r_history)
201+
self.alpha_paths_[i, :] = np.array(r_history)
201202
self.scores_[:, :, i] = best_r_per_trial_ch
202203

203204
# Final fit on each trial separately
204205
final_X = self._prepare_matrix(X, self.feature_alphas_)
205206
self.model_ = [Ridge(alpha=1.0).fit(tx, ty) for tx, ty in zip(final_X, y)]
206207

207208
self.feat_dims_ = []
208-
for i, name in enumerate(feature_order):
209+
for i, name in enumerate(self.feature_order_):
209210
x_sample = X[i][0]
210211
if isinstance(x_sample, list): x_sample = x_sample[0]
211212
if x_sample.ndim == 1: x_sample = x_sample[:, None]
@@ -257,15 +258,9 @@ def predict(self, data=None, X=['aud']):
257258
if self.model_ is None:
258259
raise ValueError("Model must be fitted before calling predict.")
259260

260-
requested_features = feature_names if feature_names else self.feature_order_
261-
262-
# Standardize feature data to list of trial-lists
263-
feat_data_list = []
264-
for f in requested_features:
265-
f_data = _parse_outstruct_args(data, f)
266-
feat_data_list.append(f_data if isinstance(f_data, list) else [f_data])
261+
X = [_parse_outstruct_args(data, x) for x in X]
267262

268-
X_mats = self._prepare_matrix(feat_data_list, requested_features, self.feature_alphas_)
263+
X_mats = self._prepare_matrix(X, self.feature_alphas_)
269264
n_trials = len(X_mats)
270265

271266
if n_trials != len(self.model_):
@@ -279,17 +274,6 @@ def predict(self, data=None, X=['aud']):
279274
# Expand (trials, features) -> (trials, 1_target, features)
280275
all_coefs = all_coefs[:, np.newaxis, :]
281276

282-
# Handle feature masking if a subset is requested
283-
mask = np.ones(all_coefs.shape[2], dtype=bool)
284-
if feature_names is not None:
285-
mask = np.zeros(all_coefs.shape[2], dtype=bool)
286-
current_col = 0
287-
for i, f in enumerate(self.feature_order_):
288-
num_cols = self.feat_dims_[i] * self._ndelays
289-
if f in requested_features:
290-
mask[current_col : current_col + num_cols] = True
291-
current_col += num_cols
292-
293277
preds = []
294278
for i in range(n_trials):
295279
# Indices for all trials except the current one
@@ -298,8 +282,8 @@ def predict(self, data=None, X=['aud']):
298282
# Average coefficients and intercepts from the other trials
299283
loto_coef = np.mean(all_coefs[loto_indices], axis=0)
300284

301-
# Apply feature mask
302-
sliced_coef = loto_coef[:, mask]
285+
# Only use first coef features if using feature subset
286+
sliced_coef = loto_coef[:, :X_mats[i].shape[1]]
303287

304288
# Predict for the current trial
305289
preds.append(X_mats[i] @ sliced_coef.T)
@@ -367,7 +351,7 @@ def summary(self, channel=None):
367351
'Feature': feat,
368352
'Total R': np.nanmean(r_report[:, f_idx]),
369353
'Delta R': np.nanmean(dr_report[:, f_idx]),
370-
'Alpha': self.feature_alphas_[feat],
354+
'Alpha': self.feature_alphas_[f_idx],
371355
't-value': t_val,
372356
'p-value': p_val,
373357
})

0 commit comments

Comments
 (0)