Skip to content

Commit e8e86a5

Browse files
Refactored encode_chunks to encode things in a smarter way
1 parent fe79139 commit e8e86a5

1 file changed

Lines changed: 41 additions & 21 deletions

File tree

turftopic/encoders/utils.py

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,43 +18,63 @@ def batched(iterable, n: int) -> Iterable[List[str]]:
1818

1919
def encode_chunks(
2020
encoder,
21-
sentences,
21+
texts,
2222
batch_size=64,
2323
window_size=50,
2424
step_size=40,
25-
return_chunks=False,
26-
show_progress_bar=False,
2725
):
28-
chunks = []
26+
"""
27+
Returns
28+
-------
29+
chunk_embeddings: list[np.ndarray]
30+
Embedding matrix of chunks for each document.
31+
chunk_positions: list[list[tuple[int, int]]]
32+
List of start and end character index of chunks for each document.
33+
"""
34+
chunk_positions = []
2935
chunk_embeddings = []
3036
for start_index in trange(
3137
0,
32-
len(sentences),
38+
len(texts),
3339
batch_size,
3440
desc="Encoding batches...",
35-
disable=not show_progress_bar,
3641
):
37-
batch = sentences[start_index : start_index + batch_size]
42+
batch = texts[start_index : start_index + batch_size]
3843
features = encoder.tokenize(batch)
3944
with torch.no_grad():
4045
output_features = encoder.forward(features)
4146
n_tokens = output_features["attention_mask"].sum(axis=1)
47+
# Find first nonzero elements in each document
48+
# The document could be padded from the left, so we have to watch out for this.
49+
start_token = torch.argmax(
50+
(output_features["attention_mask"] > 0).to(torch.long), axis=1
51+
)
52+
end_token = start_token + n_tokens
4253
for i_doc in range(len(batch)):
43-
for chunk_start in range(0, n_tokens[i_doc], step_size):
44-
chunk_end = min(chunk_start + window_size, n_tokens[i_doc])
54+
_chunk_embeddings = []
55+
_chunk_positions = []
56+
for chunk_start in range(
57+
start_token[i_doc], end_token[i_doc], step_size
58+
):
59+
chunk_end = min(chunk_start + window_size, end_token[i_doc])
4560
_emb = output_features["token_embeddings"][
4661
i_doc, chunk_start:chunk_end, :
4762
].mean(axis=0)
48-
chunk_embeddings.append(_emb)
49-
if return_chunks:
50-
chunks.append(
51-
encoder.tokenizer.decode(
52-
features["input_ids"][i_doc, chunk_start:chunk_end]
53-
)
54-
.replace("[CLS]", "")
55-
.replace("[SEP]", "")
63+
_chunk_embeddings.append(_emb)
64+
chunk_text = (
65+
encoder.tokenizer.decode(
66+
features["input_ids"][i_doc, chunk_start:chunk_end],
67+
skip_special_tokens=True,
5668
)
57-
if not return_chunks:
58-
chunks = None
59-
chunk_embeddings = np.stack(chunk_embeddings)
60-
return chunk_embeddings, chunks
69+
.replace("[CLS]", "")
70+
.replace("[SEP]", "")
71+
.strip()
72+
)
73+
doc_text = texts[start_index + i_doc]
74+
start_char = doc_text.find(chunk_text)
75+
end_char = start_char + len(chunk_text)
76+
_chunk_positions.append((start_char, end_char))
77+
_chunk_embeddings = np.stack(_chunk_embeddings)
78+
chunk_embeddings.append(_chunk_embeddings)
79+
chunk_positions.append(_chunk_positions)
80+
return chunk_embeddings, chunk_positions

0 commit comments

Comments
 (0)