Skip to content

Commit 443bf25

Browse files
Fixed some issues induced by sentence-transformers==5.4.0
1 parent 77eeab1 commit 443bf25

2 files changed

Lines changed: 15 additions & 35 deletions

File tree

turftopic/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def encode_documents(self, raw_documents: Iterable[str]) -> np.ndarray:
4141
"""
4242
if not hasattr(self.encoder_, "encode"):
4343
return self.encoder.get_text_embeddings(list(raw_documents))
44-
return self.encoder_.encode(raw_documents)
44+
return self.encoder_.encode(list(raw_documents))
4545

4646
@abstractmethod
4747
def fit_transform(

turftopic/late.py

Lines changed: 14 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -53,40 +53,20 @@ def _encode_tokens(
5353
Start and end character of each token in each document.
5454
"""
5555
self.has_used_token_level = True
56-
token_embeddings = []
57-
offsets = []
58-
for start_index in trange(
59-
0,
60-
len(texts),
61-
batch_size,
62-
desc="Encoding batches...",
63-
):
64-
batch = texts[start_index : start_index + batch_size]
65-
features = self.tokenize(batch)
66-
with torch.no_grad():
67-
output_features = self.forward(features)
68-
n_tokens = output_features["attention_mask"].sum(axis=1)
69-
# Find first nonzero elements in each document
70-
# The document could be padded from the left, so we have to watch out for this.
71-
start_token = torch.argmax(
72-
(output_features["attention_mask"] > 0).to(torch.long), axis=1
73-
)
74-
end_token = start_token + n_tokens
75-
for i_doc in range(len(batch)):
76-
_token_embeddings = (
77-
output_features["token_embeddings"][
78-
i_doc, start_token[i_doc] : end_token[i_doc], :
79-
]
80-
.float()
81-
.numpy(force=True)
82-
)
83-
_n = _token_embeddings.shape[0]
84-
# We extract the character offsets and prune it at the maximum context length
85-
_offsets = self.tokenizer(
86-
batch[i_doc], return_offsets_mapping=True, verbose=False
87-
)["offset_mapping"][:_n]
88-
token_embeddings.append(_token_embeddings)
89-
offsets.append(_offsets)
56+
token_embeddings = self.encode(
57+
texts, output_value="token_embeddings", batch_size=batch_size
58+
)
59+
offsets = self.tokenizer(
60+
texts, return_offsets_mapping=True, verbose=False
61+
)["offset_mapping"]
62+
offsets = [
63+
offs[: len(embs)] for offs, embs in zip(offsets, token_embeddings)
64+
]
65+
token_embeddings = [
66+
embs.numpy(force=True)
67+
for embs in token_embeddings
68+
if torch.is_tensor(embs)
69+
]
9070
return token_embeddings, offsets
9171

9272
def encode_tokens(

0 commit comments

Comments
 (0)