|
| 1 | +# Late Interaction Topic Models |
| 2 | + |
| 3 | +Late interaction, or multi-vector models use token representations from a Sentence Transformer before pooling them all together into a single document embedding. |
| 4 | +This can be particularly useful for clustering models, as they, by default assign one topic to a single document, but when accessing token representations, can assign topics on a per-token basis. |
| 5 | + |
| 6 | +!!! info |
| 7 | + There are currently no native late-interaction models in Turftopic, meaning models that explicitly model token representations in the context of a document. |
| 8 | + We are currently working on implementing such models, but for the time being, wrappers are included, that can force regular models to use embeddings of higher granularity. |
| 9 | + **Visualization utilities** are also on the way. |
| 10 | + |
| 11 | +## Encoding Tokens, and Ragged Array Manipulation |
| 12 | + |
| 13 | +Turftopic provides a convenience class for encoding documents on a token-level using Sentence Transformers instead of pooling them together into document embeddings. |
| 14 | +In order to initialize an encoder, load `LateSentenceTransformer`, and specify which model you would like to use: |
| 15 | + |
| 16 | +!!! tip |
| 17 | + While you could use any encoder model with `LateSentenceTransformer`, we recommend that you stick to ones that have mean pooling, and normalize embeddings. |
| 18 | + This is because in these models, you can be sure that the pooled document embeddings and the token embeddings will be in the same semantic space. |
| 19 | + |
| 20 | +### Token Embeddings |
| 21 | + |
| 22 | +```python |
| 23 | +from turftopic.late import LateSentenceTransformer |
| 24 | + |
| 25 | +documents = ["This is a text", "This is another but slightly longer text"] |
| 26 | + |
| 27 | +encoder = LateSentenceTransformer("all-MiniLM-L6-v2") |
| 28 | +token_embeddings, offsets = encoder.encode_tokens(documents) |
| 29 | +print(token_embeddings) |
| 30 | +print(offsets) |
| 31 | +``` |
| 32 | + |
| 33 | +```python |
| 34 | +[ |
| 35 | + array([[-0.01135089, 0.04170538, 0.00379963, ..., 0.01383126, |
| 36 | + -0.00274855, -0.05360783], |
| 37 | + ... |
| 38 | + [ 0.05069249, 0.03840942, -0.03545087, ..., 0.03142243, |
| 39 | + 0.01929936, -0.09216172]], |
| 40 | + shape=(6, 384), dtype=float32), |
| 41 | + array([[-0.00047079, 0.03402771, 0.00037086, ..., 0.0228903 , |
| 42 | + -0.01734272, -0.04073172], |
| 43 | + ..., |
| 44 | + [-0.02586325, 0.03737643, 0.02260585, ..., 0.05613737, |
| 45 | + -0.01032581, -0.03799873]], shape=(9, 384), dtype=float32) |
| 46 | +] |
| 47 | +[[(0, 0), (0, 4), (5, 7), (8, 9), (10, 14), (0, 0)], [(0, 0), (0, 4), (5, 7), (8, 15), (16, 19), (20, 28), (29, 35), (36, 40), (0, 0)]] |
| 48 | +``` |
| 49 | + |
| 50 | +As you can see, `encode_tokens` returns two arrays, one of them being the token embeddings. This is a ragged array, where longer document can have more embeddings. |
| 51 | +`offsets` contains a list of tuples for each document, where the first element of the tuple is the start character of the given token, and the second element is the end character. |
| 52 | + |
| 53 | +### Rolling Window Embeddings |
| 54 | + |
| 55 | +You can also pool these embeddings over a rolling window of tokens. |
| 56 | +This way, you still represent your document with multiple vectors, but don't need to model each token individually: |
| 57 | + |
| 58 | +```python |
| 59 | +window_embeddings, window_offsets = encoder.encode_windows(documents, window_size=5, step_size=4) |
| 60 | +for doc_emb, doc_off in zip(window_embeddings, window_offsets): |
| 61 | + print(doc_emb.shape, doc_off) |
| 62 | +``` |
| 63 | + |
| 64 | +```python |
| 65 | +(2, 384) [(0, 14), (10, 0)] |
| 66 | +(3, 384) [(0, 19), (16, 0), (0, 0)] |
| 67 | +``` |
| 68 | + |
| 69 | +### Ragged array manipulation |
| 70 | + |
| 71 | +These ragged datastructures are hard to deal with, especially when using array operations, so we include convenience functions for manipulating them: |
| 72 | +**`flatten_repr`** flattens the ragged array into a single large array, and returns the length of each sub-array: |
| 73 | + |
| 74 | +```python |
| 75 | +from turftopic.late import flatten_repr, unflatten_repr |
| 76 | + |
| 77 | +flat_token_embeddings, lengths = flatten_repr(token_embeddings) |
| 78 | +print(flat_token_embeddings.shape) |
| 79 | +# (15, 384) |
| 80 | +``` |
| 81 | + |
| 82 | +**`unflatten_repr`** will turn a flattened representation array into a ragged array: |
| 83 | +```python |
| 84 | +token_embeddings = unflatten_repr(flat_token_embeddings, lengths) |
| 85 | +``` |
| 86 | + |
| 87 | +**`pool_flat`** will pool a document representations in a flattened array using a given aggregation function: |
| 88 | +```python |
| 89 | +import numpy as np |
| 90 | +from turftopic.late import pool_flat |
| 91 | + |
| 92 | +pooled = pool_flat(flat_token_embeddings, lengths, agg=np.nanmean) |
| 93 | +print(pooled.shape) |
| 94 | +# (2, 384) |
| 95 | +``` |
| 96 | + |
| 97 | +## Turning Regular Models into Multi-Vector Models |
| 98 | + |
| 99 | +The `LateWrapper` class can turn your regular topic models into ones that can utilize windowed or token-level embeddings. |
| 100 | +Here's how `LateWrapper` works: |
| 101 | + |
| 102 | + 1. It encodes documents at a token or window-level based on its parameters. |
| 103 | + 2. It flattens the embedding array, and feeds the this into the topic model, along with the token/window text. |
| 104 | + 3. It unflattens the output of the topic model (`doc_topic_matrix`) into a ragged array, where you get topic importance for each token. |
| 105 | + 4. *\[OPTIONAL\]* It pools token-level topic content on the document level, so that you get one document-topic vector for each document instead of each token. |
| 106 | + |
| 107 | +Let's see how this works in practice, and create a [Topeax](Topeax.md) model that uses windowed embeddings instead of document-level embeddings: |
| 108 | + |
| 109 | +```python |
| 110 | +from sklearn.datasets import fetch_20newsgroups |
| 111 | +from turftopic import Topeax |
| 112 | +from turftopic.late import LateWrapper, LateSentenceTransformer |
| 113 | + |
| 114 | +corpus = fetch_20newsgroups(subset="all", categories=["alt.atheism"]).data |
| 115 | + |
| 116 | +model = LateWrapper( |
| 117 | + Topeax(encoder=LateSentenceTransformer("all-MiniLM-L6-v2")), |
| 118 | + window_size=50, # If we don't specify window size, it will use token-level embeddings |
| 119 | + step_size=40, # Since the step size is smaller than the window, we will get overlapping windows |
| 120 | +) |
| 121 | +doc_topic_matrix, offsets = model.fit_transform(corpus) |
| 122 | +model.print_topics() |
| 123 | +``` |
| 124 | + |
| 125 | +| Topic ID | Highest Ranking | |
| 126 | +| - | - | |
| 127 | +| 0 | morality, moral, morals, immoral, objective, behavior, instinctive, species, inherent, animals | |
| 128 | +| 1 | matthew, luke, bible, text, passages, mormon, texts, translations, copy, john | |
| 129 | +| 2 | atheism, agnostics, atheist, beliefs, belief, faith, contradictory, believers, contradictions, theists | |
| 130 | +| 3 | punishment, cruel, abortion, penalty, death, constitution, homosexuality, painless, capital, punish | |
| 131 | +| 4 | war, arms, invaded, gulf, hussein, civilians, military, kuwait, peace, sell | |
| 132 | +| 5 | islam, islamic, muslim, qur, muslims, imams, rushdie, quran, koran, khomeini | |
| 133 | + |
| 134 | +The document-topic matrix, we created, is now a ragged array and contains document-topic proportions for each window in a document. |
| 135 | +Let's see what this means in practice for the first document in our corpus: |
| 136 | +```python |
| 137 | +import pandas as pd |
| 138 | + |
| 139 | +# We select document 0, then collect all information into a dataframe: |
| 140 | +window_topic_matrix = doc_topic_matrix[0] |
| 141 | +window_offs = offsets[0] |
| 142 | +document = corpus[0] |
| 143 | +# We extract the text for each window based on the offsets |
| 144 | +window_text = [document[window_start: window_end] for window_start, window_end in window_offs] |
| 145 | +df = pd.DataFrame(window_topic_matrix, index=window_text, columns=model.topic_names) |
| 146 | +print(df) |
| 147 | +``` |
| 148 | + |
| 149 | +```python |
| 150 | + 0_morality_moral_morals_immoral 1_matthew_luke_bible_text ... 4_war_arms_invaded_gulf 5_islam_islamic_muslim_qur |
| 151 | +From: acooper@mac.cc.macalstr.edu (Turin Turamb... 0.334267 1.287207e-13 ... 2.626869e-26 1.459101e-04 |
| 152 | +alester College\nLines: 55\n\nIn article <C5sA2... 0.360400 8.898302e-14 ... 3.290858e-26 1.382718e-04 |
| 153 | +u (Mike Cobb) writes:\n> I guess I'm delving in... 0.847002 5.002921e-22 ... 4.852574e-41 3.141366e-07 |
| 154 | +this you just have a spiral. What\nwould then ... 0.848413 5.819050e-22 ... 8.139559e-41 3.286224e-07 |
| 155 | +, even though this would hardly seem moral. Fo... 0.863685 1.272204e-21 ... 2.823941e-41 2.815930e-07 |
| 156 | +whatever helps this goal is\n"moral", whatever ... 0.864913 1.584558e-21 ... 5.780971e-41 3.003952e-07 |
| 157 | +a "hyper-morality" to apply to just the methods... 0.865558 1.919885e-21 ... 1.251694e-40 3.231265e-07 |
| 158 | +not doing something because it is\n> a personal... 0.868360 2.951441e-21 ... 3.085662e-40 3.494368e-07 |
| 159 | +we only consider something moral or immoral if ... 0.872827 5.444738e-21 ... 4.708349e-40 3.580695e-07 |
| 160 | +here we have a way to discriminate\nmorals. I ... 0.876951 1.021014e-20 ... 3.486096e-40 3.411401e-07 |
| 161 | +enough and\nlistened to the arguments, I could ... 0.878680 2.302363e-20 ... 5.866410e-40 3.565728e-07 |
| 162 | +. Or, as you brought out,\n> if whatever is ri... 0.878953 3.004052e-20 ... 5.977738e-40 3.566668e-07 |
| 163 | +> ******************************* 0.647793 5.664651e-17 ... 1.805073e-19 4.612731e-04 |
| 164 | +``` |
| 165 | + |
| 166 | +## C-Top2Vec |
| 167 | + |
| 168 | +Contextual Top2Vec [(Angelov and Inkpen, 2024)](https://aclanthology.org/2024.findings-emnlp.790/) is a late-interaction topic model, that uses windowed representations. |
| 169 | +The model is essentially the same as wrapping a regular Top2vec model in `LateWrapper`, but we provide a convenience class in Turftopic, so that it's easy for you to initialize this model. |
| 170 | +It comes pre-loaded with the following features: |
| 171 | + |
| 172 | + - Same hyperparameters as in Angelov and Inkpen (2024) |
| 173 | + - Phrase-vectorizer that finds regular phrases based on PMI |
| 174 | + - `LateSentenceTransformer` by default, you can specify any model. |
| 175 | + |
| 176 | +Our implementation is much more flexible than the original top2vec package, and you might be able to use much more powerful or novel embedding models. |
| 177 | + |
| 178 | +```python |
| 179 | +from turftopic import CTop2Vec |
| 180 | + |
| 181 | +model = CTop2Vec(n_reduce_to=5) |
| 182 | +doc_topic_matrix = model.fit_transform(corpus) |
| 183 | + |
| 184 | +model.print_topics() |
| 185 | +``` |
| 186 | + |
| 187 | + |
| 188 | +| Topic ID | Highest Ranking | |
| 189 | +| - | - | |
| 190 | +| -1 | caused atheism organization, genocide caused atheism, atheism organization, atheism, subject political atheists, alt atheism, caused atheism, political atheists organization, subject amusing atheists, amusing atheists | |
| 191 | +| 166 | atheists organization, political atheists organization, christian morality organization, caused atheism organization, morality organization, atheism organization, atheists organization california, subject amusing atheists, cwru edu article, alt atheism | |
| 192 | +| 172 | biblical, read bible, caused atheism, agnostics, caused atheism organization, atheists agnostics, christianity, alt atheism, atheism, christian morality organization | |
| 193 | +| 173 | objective morality, morality, subject christian morality, christian morality, natural morality, say christian morality, morality organization, christian morality organization, behavior moral, moral | |
| 194 | +| 175 | atheism, atheism organization, caused atheism organization, atheists agnostics, caused atheism, subject political atheists, alt atheism, genocide caused atheism, subject amusing atheists, amusing atheists | |
| 195 | +| 176 | rushdie islamic law, subject rushdie islamic, islamic genocide, islamic law, genocide caused atheism, subject islamic, islamic law organization, islamic genocide organization, rushdie islamic, islamic authority | |
| 196 | + |
| 197 | +You might also observe that the output of this model is a regular document-topic matrix, and isn't ragged. |
| 198 | +```python |
| 199 | +print(doc_topic_matrix.shape) |
| 200 | +# (1024, 6) |
| 201 | +``` |
| 202 | + |
| 203 | +This is because this way the model has the same API, as other Turftopic models, and works the same way as the top2vec package, making migration easier. |
| 204 | + |
| 205 | +## API Reference |
| 206 | + |
| 207 | +### Encoder |
| 208 | + |
| 209 | +::: turftopic.late.LateSentenceTransformer |
| 210 | + |
| 211 | +### Wrapper |
| 212 | + |
| 213 | +::: turftopic.late.LateWrapper |
| 214 | + |
| 215 | +### Utility functions |
| 216 | + |
| 217 | +::: turftopic.late.flatten_repr |
| 218 | + |
| 219 | +::: turftopic.late.unflatten_repr |
| 220 | + |
| 221 | +::: turftopic.late.pool_flat |
| 222 | + |
| 223 | +::: turftopic.late.get_document_chunks |
| 224 | + |
0 commit comments