Skip to content

Commit f02c4bd

Browse files
Fixed token encoding for GPUs
1 parent 77eeab1 commit f02c4bd

1 file changed

Lines changed: 3 additions & 0 deletions

File tree

turftopic/late.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ def _encode_tokens(
6363
):
6464
batch = texts[start_index : start_index + batch_size]
6565
features = self.tokenize(batch)
66+
features = {
67+
key: value.to(self.device) for key, value in features.items()
68+
}
6669
with torch.no_grad():
6770
output_features = self.forward(features)
6871
n_tokens = output_features["attention_mask"].sum(axis=1)

0 commit comments

Comments
 (0)