Skip to content

Commit e25f0cd

Browse files
Added float conversion and progress bar to encode_token()
1 parent 12e6e86 commit e25f0cd

1 file changed

Lines changed: 5 additions & 2 deletions

File tree

turftopic/late.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,10 @@ def _encode_tokens(
5454
"""
5555
self.has_used_token_level = True
5656
token_embeddings = self.encode(
57-
texts, output_value="token_embeddings", batch_size=batch_size
57+
texts,
58+
output_value="token_embeddings",
59+
batch_size=batch_size,
60+
show_progress_bar=show_progress_bar,
5861
)
5962
offsets = self.tokenizer(
6063
texts, return_offsets_mapping=True, verbose=False
@@ -63,7 +66,7 @@ def _encode_tokens(
6366
offs[: len(embs)] for offs, embs in zip(offsets, token_embeddings)
6467
]
6568
token_embeddings = [
66-
embs.numpy(force=True)
69+
embs.float().numpy(force=True)
6770
for embs in token_embeddings
6871
if torch.is_tensor(embs)
6972
]

0 commit comments

Comments
 (0)