Skip to content

Commit b2221dc

Browse files
Added docstrings to late.py
1 parent 55dee2d commit b2221dc

1 file changed

Lines changed: 68 additions & 0 deletions

File tree

turftopic/late.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,17 @@
1616

1717

1818
class LateSentenceTransformer(SentenceTransformer):
19+
"""SentenceTransformer model that can produce token and window-level embeddings.
20+
Its output can be used by topic models that can use multi-vector document representations.
21+
22+
!!! warning
23+
This is not checked yet in the library,
24+
but we recommend that you use SentenceTransformers that are
25+
a) **Mean pooled**
26+
b) **L2 Normalized**
27+
This will guarrantee that the token/window embeddings are in the same embedding space as the documents.
28+
"""
29+
1930
has_used_token_level = False
2031

2132
def encode(
@@ -215,6 +226,20 @@ def unflatten_repr(
215226

216227

217228
def pool_flat(flat_repr: np.ndarray, lengths: Lengths, agg=np.nanmean):
229+
"""Pools vectors within documents using the agg function.
230+
231+
Parameters
232+
----------
233+
flat_repr: ndarray of shape (n_total_tokens, n_dims)
234+
Flattened document representations.
235+
lengths: Lengths
236+
Number of tokens in each document.
237+
238+
Returns
239+
-------
240+
ndarray of shape (n_documents, n_dims)
241+
Pooled representation for each document.
242+
"""
218243
pooled = []
219244
start_index = 0
220245
for length in lengths:
@@ -228,6 +253,20 @@ def pool_flat(flat_repr: np.ndarray, lengths: Lengths, agg=np.nanmean):
228253
def get_document_chunks(
229254
raw_documents: list[str], offsets: list[Offsets]
230255
) -> list[str]:
256+
"""Extracts text chunks from documents based on token/window offsets.
257+
258+
Parameters
259+
----------
260+
raw_documents: list[str]
261+
Text documents.
262+
offsets: list[Offsets]
263+
Offsets returned when encoding.
264+
265+
Returns
266+
-------
267+
list[str]
268+
Text chunks of tokens/windows in the documents.
269+
"""
231270
chunks = []
232271
for doc, _offs in zip(raw_documents, offsets):
233272
for start_char, end_char in _offs:
@@ -236,6 +275,35 @@ def get_document_chunks(
236275

237276

238277
class LateWrapper(ContextualModel, TransformerMixin):
278+
"""Wraps existing Turftopic model so that they can accept and create
279+
multi-vector document representations.
280+
281+
!!! warning
282+
The model HAS TO HAVE a late interaction encoder model
283+
(e.g. `LateSentenceTransformer`)
284+
285+
Parameters
286+
----------
287+
model
288+
Turftopic model to turn into late-interaction model.
289+
batch_size: int, default 32
290+
Batch size of the transformer.
291+
window_size: int, default None
292+
Size of the sliding window to average tokens over.
293+
If None, documents will be represented at a token level.
294+
step_size: int, default None
295+
Step size of the window.
296+
If (step_size == None) or (step_size == window_size), then windows are separate.
297+
If step_size < window_size, windows will overlap.
298+
If step_size > window_size, there will be gaps between the windows.
299+
In this case, we throw a warning, as this is probably unintended behaviour.
300+
pooling: Callable, default None
301+
Indicates whether and how to pool document-topic matrices.
302+
If None, multi-vector topic proportions are returned in a ragged array.
303+
If Callable, multiple vectors are averaged with the callable in each document.
304+
You could for example take the mean by specifying `pooling=np.nanmean`.
305+
"""
306+
239307
def __init__(
240308
self,
241309
model: TransformerMixin,

0 commit comments

Comments
 (0)