Skip to content

Commit aaa0e17

Browse files
Added properties and printing to LateWrapper
1 parent f50b983 commit aaa0e17

1 file changed

Lines changed: 44 additions & 16 deletions

File tree

turftopic/late.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
from sentence_transformers import SentenceTransformer
88
from sklearn.base import TransformerMixin
99
from sklearn.preprocessing import normalize
10-
from tokenizers import Tokenizer
1110
from tqdm import trange
1211

12+
from turftopic.base import ContextualModel
13+
1314
Offsets = list[tuple[int, int]]
1415
Lengths = list[int]
1516

@@ -39,13 +40,11 @@ def _encode_tokens(
3940
"""
4041
token_embeddings = []
4142
offsets = []
42-
tokenizer = Tokenizer.from_pretrained(self.model_card_data.base_model)
4343
for start_index in trange(
4444
0,
4545
len(texts),
4646
batch_size,
47-
disable=not show_progress_bar,
48-
desc="Encoding tokens...",
47+
desc="Encoding batches...",
4948
):
5049
batch = texts[start_index : start_index + batch_size]
5150
features = self.tokenize(batch)
@@ -59,10 +58,18 @@ def _encode_tokens(
5958
)
6059
end_token = start_token + n_tokens
6160
for i_doc in range(len(batch)):
62-
_token_embeddings = output_features["token_embeddings"][
63-
i_doc, start_token[i_doc] : end_token[i_doc], :
64-
].numpy(force=True)
65-
_offsets = tokenizer.encode(batch[i_doc]).offsets
61+
_token_embeddings = (
62+
output_features["token_embeddings"][
63+
i_doc, start_token[i_doc] : end_token[i_doc], :
64+
]
65+
.float()
66+
.numpy(force=True)
67+
)
68+
_n = _token_embeddings.shape[0]
69+
# We extract the character offsets and prune it at the maximum context length
70+
_offsets = self.tokenizer(
71+
batch[i_doc], return_offsets_mapping=True, verbose=False
72+
)["offset_mapping"][:_n]
6673
token_embeddings.append(_token_embeddings)
6774
offsets.append(_offsets)
6875
return token_embeddings, offsets
@@ -145,11 +152,12 @@ def encode_windows(
145152
for emb, offs in zip(token_embeddings, token_offsets):
146153
_offsets = []
147154
_embeddings = []
148-
for start_index in trange(0, len(emb), step_size):
155+
for start_index in range(0, len(emb), step_size):
149156
end_index = start_index + window_size
150157
window_emb = np.mean(emb[start_index:end_index], axis=0)
158+
off = offs[start_index:end_index]
151159
_embeddings.append(window_emb)
152-
_offsets.append((offs[start_index][0], offs[end_index][1]))
160+
_offsets.append((off[0][0], off[-1][1]))
153161
window_embeddings.append(normalize(np.stack(_embeddings)))
154162
window_offsets.append(_offsets)
155163
return window_embeddings, window_offsets
@@ -197,7 +205,7 @@ def unflatten_repr(
197205
repr = []
198206
start_index = 0
199207
for length in lengths:
200-
repr.append(flat_repr[start_index:length])
208+
repr.append(flat_repr[start_index : start_index + length])
201209
start_index += length
202210
return repr
203211

@@ -217,11 +225,11 @@ def get_document_chunks(
217225
chunks = []
218226
for doc, _offs in zip(raw_documents, offsets):
219227
for start_char, end_char in _offs:
220-
chunks.append(raw_documents[start_char, end_char])
228+
chunks.append(doc[start_char:end_char])
221229
return chunks
222230

223231

224-
class LateModel(TransformerMixin):
232+
class LateWrapper(ContextualModel, TransformerMixin):
225233
def __init__(
226234
self,
227235
model: TransformerMixin,
@@ -236,7 +244,7 @@ def __init__(
236244
self.window_size = window_size
237245
self.step_size = step_size
238246

239-
def encode_documents(
247+
def encode_late(
240248
self, raw_documents: list[str]
241249
) -> tuple[np.ndarray, list[Offsets]]:
242250
if self.window_size is None:
@@ -264,7 +272,7 @@ def transform(
264272
offsets: list[Offsets] = None,
265273
):
266274
if (embeddings is None) or (offsets is None):
267-
embeddings, offsets = self.encode_documents(raw_documents)
275+
embeddings, offsets = self.encode_late(raw_documents)
268276
flat_embeddings, lengths = flatten_repr(embeddings)
269277
chunks = get_document_chunks(raw_documents, offsets)
270278
out_array = self.model.transform(chunks, embeddings=flat_embeddings)
@@ -281,7 +289,7 @@ def fit_transform(
281289
offsets: list[Offsets] = None,
282290
):
283291
if (embeddings is None) or (offsets is None):
284-
embeddings, offsets = self.encode_documents(raw_documents)
292+
embeddings, offsets = self.encode_late(raw_documents)
285293
flat_embeddings, lengths = flatten_repr(embeddings)
286294
chunks = get_document_chunks(raw_documents, offsets)
287295
out_array = self.model.fit_transform(
@@ -291,3 +299,23 @@ def fit_transform(
291299
return unflatten_repr(out_array, lengths)
292300
else:
293301
return pool_flat(out_array, lengths)
302+
303+
@property
304+
def components_(self):
305+
return self.model.components_
306+
307+
@property
308+
def hierarchy(self):
309+
return self.model.hierarchy
310+
311+
@property
312+
def topic_names(self):
313+
return self.model.topic_names
314+
315+
@property
316+
def classes_(self):
317+
return self.model.classes_
318+
319+
@property
320+
def vectorizer(self):
321+
return self.model.vectorizer

0 commit comments

Comments
 (0)