Skip to content

Commit badeeec

Browse files
Fixed pooling
1 parent 2146f68 commit badeeec

1 file changed

Lines changed: 4 additions & 2 deletions

File tree

turftopic/late.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,11 +210,13 @@ def unflatten_repr(
210210
return repr
211211

212212

213-
def pool_flat(flat_repr: np.ndarray, lengths: Lengths, agg=np.mean):
213+
def pool_flat(flat_repr: np.ndarray, lengths: Lengths, agg=np.nanmean):
214214
pooled = []
215215
start_index = 0
216216
for length in lengths:
217-
pooled.append(agg(flat_repr[start_index:length], axis=0))
217+
pooled.append(
218+
agg(flat_repr[start_index : start_index + length], axis=0)
219+
)
218220
start_index += length
219221
return np.stack(pooled)
220222

0 commit comments

Comments
 (0)