1616
1717
1818class LateSentenceTransformer (SentenceTransformer ):
19+ """SentenceTransformer model that can produce token and window-level embeddings.
20+ Its output can be used by topic models that can use multi-vector document representations.
21+
22+ !!! warning
23+ This is not checked yet in the library,
24+ but we recommend that you use SentenceTransformers that are
25+ a) **Mean pooled**
26+ b) **L2 Normalized**
27+ This will guarrantee that the token/window embeddings are in the same embedding space as the documents.
28+ """
29+
1930 has_used_token_level = False
2031
2132 def encode (
@@ -215,6 +226,20 @@ def unflatten_repr(
215226
216227
217228def pool_flat (flat_repr : np .ndarray , lengths : Lengths , agg = np .nanmean ):
229+ """Pools vectors within documents using the agg function.
230+
231+ Parameters
232+ ----------
233+ flat_repr: ndarray of shape (n_total_tokens, n_dims)
234+ Flattened document representations.
235+ lengths: Lengths
236+ Number of tokens in each document.
237+
238+ Returns
239+ -------
240+ ndarray of shape (n_documents, n_dims)
241+ Pooled representation for each document.
242+ """
218243 pooled = []
219244 start_index = 0
220245 for length in lengths :
@@ -228,6 +253,20 @@ def pool_flat(flat_repr: np.ndarray, lengths: Lengths, agg=np.nanmean):
228253def get_document_chunks (
229254 raw_documents : list [str ], offsets : list [Offsets ]
230255) -> list [str ]:
256+ """Extracts text chunks from documents based on token/window offsets.
257+
258+ Parameters
259+ ----------
260+ raw_documents: list[str]
261+ Text documents.
262+ offsets: list[Offsets]
263+ Offsets returned when encoding.
264+
265+ Returns
266+ -------
267+ list[str]
268+ Text chunks of tokens/windows in the documents.
269+ """
231270 chunks = []
232271 for doc , _offs in zip (raw_documents , offsets ):
233272 for start_char , end_char in _offs :
@@ -236,6 +275,35 @@ def get_document_chunks(
236275
237276
238277class LateWrapper (ContextualModel , TransformerMixin ):
278+ """Wraps existing Turftopic model so that they can accept and create
279+ multi-vector document representations.
280+
281+ !!! warning
282+ The model HAS TO HAVE a late interaction encoder model
283+ (e.g. `LateSentenceTransformer`)
284+
285+ Parameters
286+ ----------
287+ model
288+ Turftopic model to turn into late-interaction model.
289+ batch_size: int, default 32
290+ Batch size of the transformer.
291+ window_size: int, default None
292+ Size of the sliding window to average tokens over.
293+ If None, documents will be represented at a token level.
294+ step_size: int, default None
295+ Step size of the window.
296+ If (step_size == None) or (step_size == window_size), then windows are separate.
297+ If step_size < window_size, windows will overlap.
298+ If step_size > window_size, there will be gaps between the windows.
299+ In this case, we throw a warning, as this is probably unintended behaviour.
300+ pooling: Callable, default None
301+ Indicates whether and how to pool document-topic matrices.
302+ If None, multi-vector topic proportions are returned in a ragged array.
303+ If Callable, multiple vectors are averaged with the callable in each document.
304+ You could for example take the mean by specifying `pooling=np.nanmean`.
305+ """
306+
239307 def __init__ (
240308 self ,
241309 model : TransformerMixin ,
0 commit comments