Skip to content

Commit b47eda8

Browse files
Started working on contextual encoding and ragged array manipulation
1 parent e707316 commit b47eda8

2 files changed

Lines changed: 362 additions & 0 deletions

File tree

turftopic/encoders/contextual.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
import itertools
2+
import warnings
3+
from typing import Iterable, Union
4+
5+
import numpy as np
6+
import torch
7+
from sentence_transformers import SentenceTransformer
8+
from sklearn.preprocessing import normalize
9+
from tokenizers import Tokenizer
10+
from tqdm import trange
11+
12+
13+
def is_contextual(encoder):
14+
return hasattr(encoder, "encode_tokens")
15+
16+
17+
Offsets = list[tuple[int, int]]
18+
Lengths = list[int]
19+
20+
21+
def flatten_embeddings(
22+
embeddings: list[np.ndarray],
23+
) -> tuple[np.ndarray, Lengths]:
24+
"""Flattens ragged array to normal array.
25+
26+
Parameters
27+
----------
28+
embeddings: list[ndarray]
29+
Ragged embedding array.
30+
31+
Returns
32+
-------
33+
flat_embeddings: ndarray
34+
Flattened embedding array.
35+
lengths: list[int]
36+
Length of each document in the corpus.
37+
"""
38+
lengths = [emb.shape[0] for emb in embeddings]
39+
return np.concatenate(embeddings, axis=0), lengths
40+
41+
42+
def unflatten_embeddings(
43+
flat_embeddings: np.ndarray, lengths: Lengths
44+
) -> list[np.ndarray]:
45+
"""Unflattens flat array to ragged array.
46+
47+
Parameters
48+
----------
49+
flat_embeddings: ndarray
50+
Flattened embedding array.
51+
lengths: list[int]
52+
Length of each document in the corpus.
53+
54+
Returns
55+
-------
56+
embeddings: list[ndarray]
57+
Ragged embedding array.
58+
59+
"""
60+
embeddings = []
61+
start_index = 0
62+
for length in lengths:
63+
embeddings.append(flat_embeddings[start_index:length])
64+
start_index += length
65+
return embeddings
66+
67+
68+
class ContextTransformer(SentenceTransformer):
69+
def encode(
70+
self, sentences: Union[str, list[str], np.ndarray], *args, **kwargs
71+
):
72+
warnings.warn(
73+
"Encoder is contextual but topic model is not using contextual embeddings. Perhaps you wanted to use another topic model."
74+
)
75+
return super().encode(sentences, *args, **kwargs)
76+
77+
def _encode_tokens(
78+
self,
79+
texts,
80+
batch_size=32,
81+
show_progress_bar=True,
82+
) -> tuple[list[np.ndarray], list[Offsets]]:
83+
"""
84+
Returns
85+
-------
86+
token_embeddings: list[np.ndarray]
87+
Embedding matrix of tokens for each document.
88+
offsets: list[list[tuple[int, int]]]
89+
Start and end character of each token in each document.
90+
"""
91+
token_embeddings = []
92+
offsets = []
93+
tokenizer = Tokenizer.from_pretrained(self.model_card_data.base_model)
94+
for start_index in trange(
95+
0,
96+
len(texts),
97+
batch_size,
98+
disable=not show_progress_bar,
99+
desc="Encoding tokens...",
100+
):
101+
batch = texts[start_index : start_index + batch_size]
102+
features = self.tokenize(batch)
103+
with torch.no_grad():
104+
output_features = self.forward(features)
105+
n_tokens = output_features["attention_mask"].sum(axis=1)
106+
# Find first nonzero elements in each document
107+
# The document could be padded from the left, so we have to watch out for this.
108+
start_token = torch.argmax(
109+
(output_features["attention_mask"] > 0).to(torch.long), axis=1
110+
)
111+
end_token = start_token + n_tokens
112+
for i_doc in range(len(batch)):
113+
_token_embeddings = output_features["token_embeddings"][
114+
i_doc, start_token[i_doc] : end_token[i_doc], :
115+
].numpy(force=True)
116+
_offsets = tokenizer.encode(batch[i_doc]).offsets
117+
token_embeddings.append(_token_embeddings)
118+
offsets.append(_offsets)
119+
return token_embeddings, offsets
120+
121+
def encode_tokens(
122+
self,
123+
sentences: list[str],
124+
batch_size: int = 32,
125+
show_progress_bar: bool = True,
126+
):
127+
"""Produces contextual token embeddings over all documents.
128+
129+
Parameters
130+
----------
131+
sentences: list[str]
132+
Documents to encode contextually.
133+
batch_size: int, default 32
134+
Size of the batch of document to encode at once.
135+
show_progress_bar: bool, default True
136+
Indicates whether a progress bar should be displayed when encoding.
137+
138+
Returns
139+
-------
140+
token_embeddings: list[np.ndarray]
141+
Embedding matrix of tokens for each document.
142+
offsets: list[list[tuple[int, int]]]
143+
Start and end character of each token in each document.
144+
"""
145+
# This is needed because the above implementation does not normalize embeddings,
146+
# which normally happens to document embeddings.
147+
token_embeddings, offsets = self._encode_tokens(
148+
sentences,
149+
batch_size=batch_size,
150+
show_progress_bar=show_progress_bar,
151+
)
152+
token_embeddings = [normalize(emb) for emb in token_embeddings]
153+
return token_embeddings, offsets
154+
155+
def encode_windows(
156+
self,
157+
sentences: list[str],
158+
batch_size: int = 32,
159+
window_size: int = 50,
160+
step_size: int = 40,
161+
show_progress_bar: bool = True,
162+
):
163+
"""Produces contextual embeddings for a sliding window of tokens similar to C-Top2Vec.
164+
165+
Parameters
166+
----------
167+
sentences: list[str]
168+
Documents to encode contextually.
169+
batch_size: int, default 32
170+
Size of the batch of document to encode at once.
171+
window_size: int, default 50
172+
Size of the sliding window.
173+
step_size: int, default 40
174+
Step size of the window.
175+
If step_size < window_size, windows will overlap.
176+
If step_size == window_size, then windows are separate.
177+
If step_size > window_size, there will be gaps between the windows.
178+
In this case, we throw a warning, as this is probably unintended behaviour.
179+
show_progress_bar: bool, default True
180+
Indicates whether a progress bar should be displayed when encoding.
181+
182+
Returns
183+
-------
184+
window_embeddings: list[np.ndarray]
185+
Embedding matrix of windows for each document.
186+
offsets: list[list[tuple[int, int]]]
187+
Start and end character of each token in each document.
188+
"""
189+
token_embeddings, token_offsets = self._encode_tokens(
190+
sentences,
191+
batch_size=batch_size,
192+
show_progress_bar=show_progress_bar,
193+
)
194+
window_embeddings = []
195+
window_offsets = []
196+
for emb, offs in zip(token_embeddings, token_offsets):
197+
_offsets = []
198+
_embeddings = []
199+
for start_index in trange(0, len(emb), step_size):
200+
end_index = start_index + window_size
201+
window_emb = np.mean(emb[start_index:end_index], axis=0)
202+
_embeddings.append(window_emb)
203+
_offsets.append((offs[start_index][0], offs[end_index][1]))
204+
window_embeddings.append(normalize(np.stack(_embeddings)))
205+
window_offsets.append(_offsets)
206+
return window_embeddings, window_offsets

turftopic/ragged.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
from typing import Callable, Optional
2+
3+
import numpy as np
4+
from sklearn.base import TransformerMixin
5+
6+
Lengths = list[int]
7+
8+
9+
def flatten_repr(
10+
repr: list[np.ndarray],
11+
) -> tuple[np.ndarray, Lengths]:
12+
"""Flattens ragged array to normal array.
13+
14+
Parameters
15+
----------
16+
repr: list[ndarray]
17+
Ragged representation array.
18+
19+
Returns
20+
-------
21+
flat_repr: ndarray
22+
Flattened representation array.
23+
lengths: list[int]
24+
Length of each document in the corpus.
25+
"""
26+
lengths = [r.shape[0] for r in repr]
27+
return np.concatenate(repr, axis=0), lengths
28+
29+
30+
def unflatten_repr(
31+
flat_repr: np.ndarray, lengths: Lengths
32+
) -> list[np.ndarray]:
33+
"""Unflattens flat array to ragged array.
34+
35+
Parameters
36+
----------
37+
flat_repr: ndarray
38+
Flattened representation array.
39+
lengths: list[int]
40+
Length of each document in the corpus.
41+
42+
Returns
43+
-------
44+
repr: list[ndarray]
45+
Ragged representation array.
46+
47+
"""
48+
repr = []
49+
start_index = 0
50+
for length in lengths:
51+
repr.append(flat_repr[start_index:length])
52+
start_index += length
53+
return repr
54+
55+
56+
def pool_flat(flat_repr: np.ndarray, lengths: Lengths, agg=np.mean):
57+
pooled = []
58+
start_index = 0
59+
for length in lengths:
60+
pooled.append(agg(flat_repr[start_index:length], axis=0))
61+
start_index += length
62+
return np.stack(pooled)
63+
64+
65+
class TokenLevel(TransformerMixin):
66+
def __init__(
67+
self,
68+
model: TransformerMixin,
69+
batch_size: int = 32,
70+
pooling: Optional[Callable] = None,
71+
):
72+
self.model = model
73+
self.batch_size = batch_size
74+
self.pooling = pooling
75+
76+
def transform(
77+
self, raw_documents: list[str], embeddings: list[np.ndarray] = None
78+
):
79+
if embeddings is None:
80+
embeddings = self.model.encoder.encode_tokens(
81+
raw_documents, batch_size=self.batch_size
82+
)
83+
flat_embeddings, lengths = flatten_repr(embeddings)
84+
out_array = self.model.transform(
85+
raw_documents, embeddings=flat_embeddings
86+
)
87+
if self.pooling is None:
88+
return unflatten_repr(out_array, lengths)
89+
else:
90+
return pool_flat(out_array, lengths)
91+
92+
def fit_transform(
93+
self,
94+
raw_documents: list[str],
95+
y=None,
96+
embeddings: list[np.ndarray] = None,
97+
):
98+
if embeddings is None:
99+
embeddings = self.model.encoder.encode_tokens(
100+
raw_documents, batch_size=self.batch_size
101+
)
102+
flat_embeddings, lengths = flatten_repr(embeddings)
103+
out_array = self.model.fit_transform(
104+
raw_documents, y, embeddings=flat_embeddings
105+
)
106+
if self.pooling is None:
107+
return unflatten_repr(out_array, lengths)
108+
else:
109+
return pool_flat(out_array, lengths)
110+
111+
112+
class Windowed(TransformerMixin):
113+
def __init__(
114+
self,
115+
model: TransformerMixin,
116+
batch_size: int = 32,
117+
pooling: Optional[Callable] = None,
118+
):
119+
self.model = model
120+
self.batch_size = batch_size
121+
self.pooling = pooling
122+
123+
def transform(
124+
self, raw_documents: list[str], embeddings: list[np.ndarray] = None
125+
):
126+
if embeddings is None:
127+
embeddings = self.model.encoder.encode_tokens(
128+
raw_documents, batch_size=self.batch_size
129+
)
130+
flat_embeddings, lengths = flatten_repr(embeddings)
131+
out_array = self.model.transform(
132+
raw_documents, embeddings=flat_embeddings
133+
)
134+
if self.pooling is None:
135+
return unflatten_repr(out_array, lengths)
136+
else:
137+
return pool_flat(out_array, lengths)
138+
139+
def fit_transform(
140+
self,
141+
raw_documents: list[str],
142+
y=None,
143+
embeddings: list[np.ndarray] = None,
144+
):
145+
if embeddings is None:
146+
embeddings = self.model.encoder.encode_tokens(
147+
raw_documents, batch_size=self.batch_size
148+
)
149+
flat_embeddings, lengths = flatten_repr(embeddings)
150+
out_array = self.model.fit_transform(
151+
raw_documents, y, embeddings=flat_embeddings
152+
)
153+
if self.pooling is None:
154+
return unflatten_repr(out_array, lengths)
155+
else:
156+
return pool_flat(out_array, lengths)

0 commit comments

Comments
 (0)