|
| 1 | +# Implement a Wrapper/Custom Model |
| 2 | + |
| 3 | +If you would like to use the convenience of Turftopic, including pretty printing and visualization utilities, |
| 4 | +but the model you would like to use is either implemented in another library, or not yet implemented, |
| 5 | +you might want to consider writing a Turftopic wrapper/custom model. |
| 6 | + |
| 7 | +The primary interface of Turftopic models is implemented in the `ContextualModel` class, your topic model will have to inherit from this class: |
| 8 | + |
| 9 | +```python |
| 10 | +from turftopic.base import ContextualModel |
| 11 | +``` |
| 12 | + |
| 13 | +## Minimal interface |
| 14 | + |
| 15 | +For a ContextualModel to work you will have to implement the following: |
| 16 | + |
| 17 | +### `__init__` |
| 18 | + |
| 19 | +Implement an `__init__` method that takes and assigns the most basic attributes of your topic model (`n_components`, `encoder`, `vectorizer`). |
| 20 | +Some of these attributes are optional, and to align your model's behaviour with the rest of Turftopic, here's some minimal boilerplate. |
| 21 | + |
| 22 | +!!! note |
| 23 | + Your model might not need an `n_components` attribute if it always discovers the number of topics automatically. |
| 24 | + |
| 25 | + |
| 26 | +```python |
| 27 | +from typing import Optional, Union |
| 28 | + |
| 29 | +import numpy as np |
| 30 | +from sentence_transformers import SentenceTransformer |
| 31 | +from sklearn.feature_extraction.text import CountVectorizer |
| 32 | +from rich.console import Console |
| 33 | + |
| 34 | +from turftopic.base import ContextualModel, Encoder |
| 35 | +from turftopic.vectorizers.default import default_vectorizer |
| 36 | + |
| 37 | +class CustomModel(ContextualModel): |
| 38 | + def __init__( |
| 39 | + self, |
| 40 | + n_components: int, |
| 41 | + # You could of course change this to a |
| 42 | + encoder: Union[ |
| 43 | + Encoder, str |
| 44 | + ] = "sentence-transformers/all-MiniLM-L6-v2", |
| 45 | + vectorizer: Optional[CountVectorizer] = None, |
| 46 | + random_state: Optional[int] = None, |
| 47 | + ): |
| 48 | + self.n_components = n_components |
| 49 | + self.encoder = encoder |
| 50 | + self.random_state = random_state |
| 51 | + if isinstance(encoder, str): |
| 52 | + # Note that we assign the actual encoder to encoder_ |
| 53 | + # This is because scikit-learn requires that the attributes and init parameters match |
| 54 | + self.encoder_ = SentenceTransformer(encoder) |
| 55 | + else: |
| 56 | + self.encoder_ = encoder |
| 57 | + if vectorizer is None: |
| 58 | + # Assign the default vectorizer from Turftopic |
| 59 | + self.vectorizer = default_vectorizer() |
| 60 | + else: |
| 61 | + self.vectorizer = vectorizer |
| 62 | +``` |
| 63 | + |
| 64 | +### `fit_transform` |
| 65 | + |
| 66 | +You will also have to implement a `fit_transform` method. This method does the following things: |
| 67 | + |
| 68 | +1. Learns the vocabulary by training the vectorizer. |
| 69 | +2. Encodes the documents using the encoder if the embeddings are not provided. |
| 70 | +3. Fits the topic model, then assigns the topic-term-matrix to `components_`. |
| 71 | +4. Returns the document-topic matrix. |
| 72 | + |
| 73 | +!!! tip |
| 74 | + Turftopic models also use the `rich` Python library for progress tracking during model fitting. |
| 75 | + Note that this is entirely optional, but it makes your model more streamlined with the rest of the library. |
| 76 | + |
| 77 | +Here's a minimal example with some boilerplate code: |
| 78 | + |
| 79 | +```python |
| 80 | + def fit_transform( |
| 81 | + self, raw_documents, y=None, embeddings: Optional[np.ndarray] = None |
| 82 | + ) -> np.ndarray: |
| 83 | + # We track progress with a rich console |
| 84 | + console = Console() |
| 85 | + with console.status("Fitting model") as status: |
| 86 | + if embeddings is None: |
| 87 | + status.update("Encoding documents") |
| 88 | + # The encode_documents method is implemented in ContextualModel |
| 89 | + embeddings = self.encode_documents(raw_documents) |
| 90 | + console.log("Documents encoded.") |
| 91 | + status.update("Extracting terms.") |
| 92 | + # It is very important that you fit your vectorizer on the corpus |
| 93 | + # this is how you get vocabulary items by calling get_vocab() |
| 94 | + document_term_matrix = self.vectorizer.fit_transform(raw_documents) |
| 95 | + console.log("Term extraction done.") |
| 96 | + status.update("Fitting model") |
| 97 | + # ===========HERE COMES ALL YOUR MODEL FITTING CODE================ |
| 98 | + # I'm assigning None to these since we are not implementing a model here, |
| 99 | + # but in your implementation these should be assigned a numpy array |
| 100 | + self.components_ = None |
| 101 | + document_topic_matrix = None |
| 102 | + console.log("Model fitting done.") |
| 103 | + return document_topic_matrix |
| 104 | +``` |
| 105 | + |
| 106 | +## Example: Wrapper for Latent Dirichlet Allocation |
| 107 | + |
| 108 | +LDA is the most well known and historically the most used topic model. |
| 109 | +It is not implemented in Turftopic, since it is not a contextual topic model, but you might need to compare it with Turftopic models, |
| 110 | +and it might be convenient to have access to the same interface. |
| 111 | + |
| 112 | +Here's a minimal wrapper for LDA in Turftopic using the boilerplate code above. |
| 113 | +I will not implement most hyperparameters, since this would be trivial to do but takes more code. |
| 114 | + |
| 115 | +```python |
| 116 | +from sklearn.decomposition import LatentDirichletAllocation |
| 117 | + |
| 118 | +class LDA(ContextualModel): |
| 119 | + """Latent Dirichlet Allocation model wrapper in Turftopic.""" |
| 120 | + def __init__( |
| 121 | + self, |
| 122 | + n_components: int, |
| 123 | + encoder: Union[ |
| 124 | + Encoder, str |
| 125 | + ] = "sentence-transformers/all-MiniLM-L6-v2", |
| 126 | + vectorizer: Optional[CountVectorizer] = None, |
| 127 | + random_state: Optional[int] = None, |
| 128 | + ): |
| 129 | + self.n_components = n_components |
| 130 | + self.encoder = encoder |
| 131 | + self.random_state = random_state |
| 132 | + # Since LDA only uses bag-of-words, we do not load an encoder |
| 133 | + self.encoder_ = None |
| 134 | + if vectorizer is None: |
| 135 | + self.vectorizer = default_vectorizer() |
| 136 | + else: |
| 137 | + self.vectorizer = vectorizer |
| 138 | + |
| 139 | + def fit_transform( |
| 140 | + self, raw_documents, y=None, embeddings: Optional[np.ndarray] = None |
| 141 | + ) -> np.ndarray: |
| 142 | + console = Console() |
| 143 | + with console.status("Fitting model") as status: |
| 144 | + status.update("Extracting terms.") |
| 145 | + document_term_matrix = self.vectorizer.fit_transform(raw_documents) |
| 146 | + console.log("Term extraction done.") |
| 147 | + status.update("Fitting model") |
| 148 | + self._lda = LatentDirichletAllocation(self.n_components, random_state=self.random_state) |
| 149 | + document_topic_matrix = self._lda.fit_transform(document_term_matrix) |
| 150 | + # Since the scikit-learn API matches perfectly, we won't have to do much. |
| 151 | + self.components_ = self._lda.components_ |
| 152 | + console.log("Model fitting done.") |
| 153 | + return document_topic_matrix |
| 154 | +``` |
| 155 | + |
| 156 | +This model can now be used the same way you would use any other Turftopic model: |
| 157 | + |
| 158 | +```python |
| 159 | +from sklearn.datasets import fetch_20newsgroups |
| 160 | + |
| 161 | +ds = fetch_20newsgroups( |
| 162 | + subset="all", |
| 163 | + remove=("headers", "footers", "quotes"), |
| 164 | +) |
| 165 | +corpus = ds.data |
| 166 | + |
| 167 | +model = LDA(10, random_state=42) |
| 168 | +doc_topic = model.fit_transform(corpus) |
| 169 | + |
| 170 | +model.print_topics() |
| 171 | +``` |
| 172 | + |
| 173 | +``` |
| 174 | +[10:08:44] Term extraction done. lda_mess.py:43 |
| 175 | +[10:09:48] Model fitting done. lda_mess.py:53 |
| 176 | +┏━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ |
| 177 | +┃ Topic ID ┃ Highest Ranking ┃ |
| 178 | +┡━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ |
| 179 | +│ 0 │ game, team, year, games, play, don, think, season, good, hockey │ |
| 180 | +├──────────┼─────────────────────────────────────────────────────────────────────────────────┤ |
| 181 | +│ 1 │ drive, use, like, just, used, scsi, don, time, card, power │ |
| 182 | +├──────────┼─────────────────────────────────────────────────────────────────────────────────┤ |
| 183 | +│ 2 │ 00, 10, 25, 20, 15, 11, 12, 14, 16, space │ |
| 184 | +├──────────┼─────────────────────────────────────────────────────────────────────────────────┤ |
| 185 | +│ 3 │ people, god, don, think, just, know, like, say, does, time │ |
| 186 | +├──────────┼─────────────────────────────────────────────────────────────────────────────────┤ |
| 187 | +│ 4 │ ax, max, g9v, b8f, a86, pl, 145, 1d9, 34u, 1t │ |
| 188 | +├──────────┼─────────────────────────────────────────────────────────────────────────────────┤ |
| 189 | +│ 5 │ know, thanks, like, car, mail, edu, db, just, good, new │ |
| 190 | +├──────────┼─────────────────────────────────────────────────────────────────────────────────┤ |
| 191 | +│ 6 │ gun, people, government, law, right, use, fbi, don, guns, control │ |
| 192 | +├──────────┼─────────────────────────────────────────────────────────────────────────────────┤ |
| 193 | +│ 7 │ windows, dos, use, file, window, program, using, problem, does, like │ |
| 194 | +├──────────┼─────────────────────────────────────────────────────────────────────────────────┤ |
| 195 | +│ 8 │ edu, image, available, com, data, ftp, information, file, graphics, mail │ |
| 196 | +├──────────┼─────────────────────────────────────────────────────────────────────────────────┤ |
| 197 | +│ 9 │ people, said, president, armenian, israel, government, armenians, mr, jews, war │ |
| 198 | +└──────────┴─────────────────────────────────────────────────────────────────────────────────┘ |
| 199 | +``` |
| 200 | + |
| 201 | + |
| 202 | +## Dynamic Topic Models |
| 203 | + |
| 204 | +If you want to implement dynamic functionality you will have to inherit from `DynamicTopicModel`, |
| 205 | +and implement a `fit_transform_dynamic` method, that also takes timestamps. |
| 206 | + |
| 207 | +!!! note |
| 208 | + You will have to implement the rest of the methods too, as outlined before. |
| 209 | + |
| 210 | +Here's some boilerplate code: |
| 211 | + |
| 212 | +```python |
| 213 | +from turftopic.dynamic import DynamicTopicModel |
| 214 | + |
| 215 | +class CustomDynamic(ContextualModel, DynamicTopicModel): |
| 216 | + ... |
| 217 | + |
| 218 | + def fit_transform_dynamic( |
| 219 | + self, |
| 220 | + raw_documents, |
| 221 | + timestamps: list[datetime], |
| 222 | + embeddings: Optional[np.ndarray] = None, |
| 223 | + bins: Union[int, list[datetime]] = 10, |
| 224 | + ): |
| 225 | + # bin_timestamps sorts your data into time bins based on the bins the user provided |
| 226 | + # time_labels will be time bin indices for each document |
| 227 | + time_labels, self.time_bin_edges = self.bin_timestamps( |
| 228 | + timestamps, bins |
| 229 | + ) |
| 230 | + n_bins = len(self.time_bin_edges) - 1 |
| 231 | + # fit vectorizer as before |
| 232 | + doc_term_matrix = self.vectorizer.transform(raw_documents) |
| 233 | + # Overall topics, non-time sensitive |
| 234 | + self.components_ = ... |
| 235 | + # Here I'm assigning zeros, |
| 236 | + # but this attribute should contain the topic-word distributions for each time bin |
| 237 | + self.temporal_components_ = np.zeros( |
| 238 | + (n_bins, self.n_components, len(self.get_vocab)) |
| 239 | + ) |
| 240 | + # This attribute should contain the importance of a topic for each time bin |
| 241 | + self.temporal_importance_ = np.zeros((n_bins, n_comp)) |
| 242 | + # You should of course assign this too |
| 243 | + doc_topic_matrix = ... |
| 244 | + return doc_topic_matrix |
| 245 | +``` |
0 commit comments