Skip to content

Commit f50b983

Browse files
Moved late interaction code to one file
1 parent 3ececcc commit f50b983

2 files changed

Lines changed: 178 additions & 169 deletions

File tree

turftopic/encoders/late_interaction.py

Lines changed: 0 additions & 159 deletions
This file was deleted.

turftopic/late.py

Lines changed: 178 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,160 @@
1-
from typing import Callable, Optional
1+
import itertools
2+
import warnings
3+
from typing import Callable, Iterable, Optional, Union
24

35
import numpy as np
6+
import torch
7+
from sentence_transformers import SentenceTransformer
48
from sklearn.base import TransformerMixin
9+
from sklearn.preprocessing import normalize
10+
from tokenizers import Tokenizer
11+
from tqdm import trange
512

6-
from turftopic.encoders.contextual import Offsets
7-
13+
Offsets = list[tuple[int, int]]
814
Lengths = list[int]
915

1016

17+
class LateSentenceTransformer(SentenceTransformer):
18+
def encode(
19+
self, sentences: Union[str, list[str], np.ndarray], *args, **kwargs
20+
):
21+
warnings.warn(
22+
"Encoder is contextual but topic model is not using contextual embeddings. Perhaps you wanted to use another topic model."
23+
)
24+
return super().encode(sentences, *args, **kwargs)
25+
26+
def _encode_tokens(
27+
self,
28+
texts,
29+
batch_size=32,
30+
show_progress_bar=True,
31+
) -> tuple[list[np.ndarray], list[Offsets]]:
32+
"""
33+
Returns
34+
-------
35+
token_embeddings: list[np.ndarray]
36+
Embedding matrix of tokens for each document.
37+
offsets: list[list[tuple[int, int]]]
38+
Start and end character of each token in each document.
39+
"""
40+
token_embeddings = []
41+
offsets = []
42+
tokenizer = Tokenizer.from_pretrained(self.model_card_data.base_model)
43+
for start_index in trange(
44+
0,
45+
len(texts),
46+
batch_size,
47+
disable=not show_progress_bar,
48+
desc="Encoding tokens...",
49+
):
50+
batch = texts[start_index : start_index + batch_size]
51+
features = self.tokenize(batch)
52+
with torch.no_grad():
53+
output_features = self.forward(features)
54+
n_tokens = output_features["attention_mask"].sum(axis=1)
55+
# Find first nonzero elements in each document
56+
# The document could be padded from the left, so we have to watch out for this.
57+
start_token = torch.argmax(
58+
(output_features["attention_mask"] > 0).to(torch.long), axis=1
59+
)
60+
end_token = start_token + n_tokens
61+
for i_doc in range(len(batch)):
62+
_token_embeddings = output_features["token_embeddings"][
63+
i_doc, start_token[i_doc] : end_token[i_doc], :
64+
].numpy(force=True)
65+
_offsets = tokenizer.encode(batch[i_doc]).offsets
66+
token_embeddings.append(_token_embeddings)
67+
offsets.append(_offsets)
68+
return token_embeddings, offsets
69+
70+
def encode_tokens(
71+
self,
72+
sentences: list[str],
73+
batch_size: int = 32,
74+
show_progress_bar: bool = True,
75+
):
76+
"""Produces contextual token embeddings over all documents.
77+
78+
Parameters
79+
----------
80+
sentences: list[str]
81+
Documents to encode contextually.
82+
batch_size: int, default 32
83+
Size of the batch of document to encode at once.
84+
show_progress_bar: bool, default True
85+
Indicates whether a progress bar should be displayed when encoding.
86+
87+
Returns
88+
-------
89+
token_embeddings: list[np.ndarray]
90+
Embedding matrix of tokens for each document.
91+
offsets: list[list[tuple[int, int]]]
92+
Start and end character of each token in each document.
93+
"""
94+
# This is needed because the above implementation does not normalize embeddings,
95+
# which normally happens to document embeddings.
96+
token_embeddings, offsets = self._encode_tokens(
97+
sentences,
98+
batch_size=batch_size,
99+
show_progress_bar=show_progress_bar,
100+
)
101+
token_embeddings = [normalize(emb) for emb in token_embeddings]
102+
return token_embeddings, offsets
103+
104+
def encode_windows(
105+
self,
106+
sentences: list[str],
107+
batch_size: int = 32,
108+
window_size: int = 50,
109+
step_size: int = 40,
110+
show_progress_bar: bool = True,
111+
):
112+
"""Produces contextual embeddings for a sliding window of tokens similar to C-Top2Vec.
113+
114+
Parameters
115+
----------
116+
sentences: list[str]
117+
Documents to encode contextually.
118+
batch_size: int, default 32
119+
Size of the batch of document to encode at once.
120+
window_size: int, default 50
121+
Size of the sliding window.
122+
step_size: int, default 40
123+
Step size of the window.
124+
If step_size < window_size, windows will overlap.
125+
If step_size == window_size, then windows are separate.
126+
If step_size > window_size, there will be gaps between the windows.
127+
In this case, we throw a warning, as this is probably unintended behaviour.
128+
show_progress_bar: bool, default True
129+
Indicates whether a progress bar should be displayed when encoding.
130+
131+
Returns
132+
-------
133+
window_embeddings: list[np.ndarray]
134+
Embedding matrix of windows for each document.
135+
offsets: list[list[tuple[int, int]]]
136+
Start and end character of each token in each document.
137+
"""
138+
token_embeddings, token_offsets = self._encode_tokens(
139+
sentences,
140+
batch_size=batch_size,
141+
show_progress_bar=show_progress_bar,
142+
)
143+
window_embeddings = []
144+
window_offsets = []
145+
for emb, offs in zip(token_embeddings, token_offsets):
146+
_offsets = []
147+
_embeddings = []
148+
for start_index in trange(0, len(emb), step_size):
149+
end_index = start_index + window_size
150+
window_emb = np.mean(emb[start_index:end_index], axis=0)
151+
_embeddings.append(window_emb)
152+
_offsets.append((offs[start_index][0], offs[end_index][1]))
153+
window_embeddings.append(normalize(np.stack(_embeddings)))
154+
window_offsets.append(_offsets)
155+
return window_embeddings, window_offsets
156+
157+
11158
def flatten_repr(
12159
repr: list[np.ndarray],
13160
) -> tuple[np.ndarray, Lengths]:
@@ -78,12 +225,37 @@ class LateModel(TransformerMixin):
78225
def __init__(
79226
self,
80227
model: TransformerMixin,
81-
batch_size: int = 32,
228+
batch_size: Optional[int] = 32,
229+
window_size: Optional[int] = None,
230+
step_size: Optional[int] = None,
82231
pooling: Optional[Callable] = None,
83232
):
84233
self.model = model
85234
self.batch_size = batch_size
86235
self.pooling = pooling
236+
self.window_size = window_size
237+
self.step_size = step_size
238+
239+
def encode_documents(
240+
self, raw_documents: list[str]
241+
) -> tuple[np.ndarray, list[Offsets]]:
242+
if self.window_size is None:
243+
embeddings, offsets = self.model.encoder.encode_tokens(
244+
raw_documents, batch_size=self.batch_size
245+
)
246+
return embeddings, offsets
247+
# If the window_size is specified, but not step_size, we set the step size to the window size
248+
# Thereby getting non-overlapping windows
249+
step_size = (
250+
self.window_size if self.step_size is None else self.step_size
251+
)
252+
embeddings, offsets = self.model.encoder.encode_windows(
253+
raw_documents,
254+
batch_size=self.batch_size,
255+
window_size=self.window_size,
256+
step_size=step_size,
257+
)
258+
return embeddings, offsets
87259

88260
def transform(
89261
self,
@@ -92,9 +264,7 @@ def transform(
92264
offsets: list[Offsets] = None,
93265
):
94266
if (embeddings is None) or (offsets is None):
95-
embeddings, offsets = self.model.encoder.encode_tokens(
96-
raw_documents, batch_size=self.batch_size
97-
)
267+
embeddings, offsets = self.encode_documents(raw_documents)
98268
flat_embeddings, lengths = flatten_repr(embeddings)
99269
chunks = get_document_chunks(raw_documents, offsets)
100270
out_array = self.model.transform(chunks, embeddings=flat_embeddings)
@@ -111,9 +281,7 @@ def fit_transform(
111281
offsets: list[Offsets] = None,
112282
):
113283
if (embeddings is None) or (offsets is None):
114-
embeddings, offsets = self.model.encoder.encode_tokens(
115-
raw_documents, batch_size=self.batch_size
116-
)
284+
embeddings, offsets = self.encode_documents(raw_documents)
117285
flat_embeddings, lengths = flatten_repr(embeddings)
118286
chunks = get_document_chunks(raw_documents, offsets)
119287
out_array = self.model.fit_transform(

0 commit comments

Comments
 (0)