Skip to content

Commit 5a4403b

Browse files
Fix banded tests
1 parent 63e2b65 commit 5a4403b

2 files changed

Lines changed: 13 additions & 9 deletions

File tree

naplib/encoding/banded_trf.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def fit(self, data=None, X=['aud'], y='resp'):
160160
self.feature_alphas_ = []
161161
self.alpha_paths = np.zeros((len(X), len(self.alphas)))
162162

163-
for i, current_feat in enumerate(X):
163+
for i in range(len(X)):
164164
best_alpha = None
165165
max_r = -np.inf
166166
r_history = []
@@ -262,6 +262,7 @@ def predict(self, data=None, X=['aud']):
262262

263263
X_mats = self._prepare_matrix(X, self.feature_alphas_)
264264
n_trials = len(X_mats)
265+
n_feat_dim = X_mats[0].shape[1]
265266

266267
if n_trials != len(self.model_):
267268
raise ValueError(
@@ -274,6 +275,12 @@ def predict(self, data=None, X=['aud']):
274275
# Expand (trials, features) -> (trials, 1_target, features)
275276
all_coefs = all_coefs[:, np.newaxis, :]
276277

278+
if n_feat_dim < all_coefs.shape[2]:
279+
all_coefs = all_coefs[:, :, :n_feat_dim]
280+
print('Using reduced TRF')
281+
elif n_feat_dim > all_coefs.shape[2]:
282+
raise ValueError('Too many features for trained model.')
283+
277284
preds = []
278285
for i in range(n_trials):
279286
# Indices for all trials except the current one
@@ -282,11 +289,8 @@ def predict(self, data=None, X=['aud']):
282289
# Average coefficients and intercepts from the other trials
283290
loto_coef = np.mean(all_coefs[loto_indices], axis=0)
284291

285-
# Only use first coef features if using feature subset
286-
sliced_coef = loto_coef[:, :X_mats[i].shape[1]]
287-
288292
# Predict for the current trial
289-
preds.append(X_mats[i] @ sliced_coef.T)
293+
preds.append(X_mats[i] @ loto_coef.T)
290294

291295
return preds
292296

tests/encoding/test_banded_trf.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def test_banded_trf_loto_consistency(synth_data):
3838
"""Test that coef_ property handles the 4D reshape correctly."""
3939
model = BandedTRF(tmin=synth_data['tmin'], tmax=synth_data['tmax'],
4040
sfreq=synth_data['sfreq'], alphas=[0.1, 10.0])
41-
model.fit(data=synth_data['data'], feature_order=synth_data['feature_order'], target='resp')
41+
model.fit(data=synth_data['data'], X=synth_data['feature_order'], target='resp')
4242

4343
# Shape calculation: 2 targets, 2 features, 4 delays, 3 trials.
4444
# ndelays = (0.03 * 100) - (0 * 100) + 1 = 4.
@@ -47,7 +47,7 @@ def test_banded_trf_loto_consistency(synth_data):
4747
def test_predict_masking_logic(synth_data):
4848
"""Verify that partial feature prediction works with multi-channel targets."""
4949
model = BandedTRF(tmin=synth_data['tmin'], tmax=synth_data['tmax'], sfreq=synth_data['sfreq'])
50-
model.fit(data=synth_data['data'], feature_order=synth_data['feature_order'], target='resp')
50+
model.fit(data=synth_data['data'], X=synth_data['feature_order'], target='resp')
5151

5252
# Full prediction: should match target shape (samples, channels)
5353
preds_all = model.predict(synth_data['data'])
@@ -63,7 +63,7 @@ def test_predict_masking_logic(synth_data):
6363
def test_summary_p_values(synth_data):
6464
"""Verify summary table computes stats across channels correctly."""
6565
model = BandedTRF(tmin=synth_data['tmin'], tmax=synth_data['tmax'], sfreq=synth_data['sfreq'])
66-
model.fit(data=synth_data['data'], feature_order=synth_data['feature_order'], target='resp')
66+
model.fit(data=synth_data['data'], X=synth_data['feature_order'], target='resp')
6767

6868
df = model.summary()
6969
assert isinstance(df, pd.DataFrame)
@@ -81,7 +81,7 @@ def test_unfitted_attribute_error():
8181
def test_predict_trial_mismatch(synth_data):
8282
"""LOTO requires the same number of trials for predict as fit."""
8383
model = BandedTRF(tmin=synth_data['tmin'], tmax=synth_data['tmax'], sfreq=synth_data['sfreq'])
84-
model.fit(data=synth_data['data'], feature_order=synth_data['feature_order'], target='resp')
84+
model.fit(data=synth_data['data'], X=synth_data['feature_order'], target='resp')
8585

8686
# Try predicting with only 2 trials instead of 3
8787
short_data = synth_data['data'][:2]

0 commit comments

Comments
 (0)