Skip to content

Commit 34c7601

Browse files
Added instruvtions for adding custom model
1 parent 6b3d14f commit 34c7601

2 files changed

Lines changed: 246 additions & 0 deletions

File tree

docs/custom_model.md

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
# Implement a Wrapper/Custom Model
2+
3+
If you would like to use the convenience of Turftopic, including pretty printing and visualization utilities,
4+
but the model you would like to use is either implemented in another library, or not yet implemented,
5+
you might want to consider writing a Turftopic wrapper/custom model.
6+
7+
The primary interface of Turftopic models is implemented in the `ContextualModel` class, your topic model will have to inherit from this class:
8+
9+
```python
10+
from turftopic.base import ContextualModel
11+
```
12+
13+
## Minimal interface
14+
15+
For a ContextualModel to work you will have to implement the following:
16+
17+
### `__init__`
18+
19+
Implement an `__init__` method that takes and assigns the most basic attributes of your topic model (`n_components`, `encoder`, `vectorizer`).
20+
Some of these attributes are optional, and to align your model's behaviour with the rest of Turftopic, here's some minimal boilerplate.
21+
22+
!!! note
23+
Your model might not need an `n_components` attribute if it always discovers the number of topics automatically.
24+
25+
26+
```python
27+
from typing import Optional, Union
28+
29+
import numpy as np
30+
from sentence_transformers import SentenceTransformer
31+
from sklearn.feature_extraction.text import CountVectorizer
32+
from rich.console import Console
33+
34+
from turftopic.base import ContextualModel, Encoder
35+
from turftopic.vectorizers.default import default_vectorizer
36+
37+
class CustomModel(ContextualModel):
38+
def __init__(
39+
self,
40+
n_components: int,
41+
# You could of course change this to a
42+
encoder: Union[
43+
Encoder, str
44+
] = "sentence-transformers/all-MiniLM-L6-v2",
45+
vectorizer: Optional[CountVectorizer] = None,
46+
random_state: Optional[int] = None,
47+
):
48+
self.n_components = n_components
49+
self.encoder = encoder
50+
self.random_state = random_state
51+
if isinstance(encoder, str):
52+
# Note that we assign the actual encoder to encoder_
53+
# This is because scikit-learn requires that the attributes and init parameters match
54+
self.encoder_ = SentenceTransformer(encoder)
55+
else:
56+
self.encoder_ = encoder
57+
if vectorizer is None:
58+
# Assign the default vectorizer from Turftopic
59+
self.vectorizer = default_vectorizer()
60+
else:
61+
self.vectorizer = vectorizer
62+
```
63+
64+
### `fit_transform`
65+
66+
You will also have to implement a `fit_transform` method. This method does the following things:
67+
68+
1. Learns the vocabulary by training the vectorizer.
69+
2. Encodes the documents using the encoder if the embeddings are not provided.
70+
3. Fits the topic model, then assigns the topic-term-matrix to `components_`.
71+
4. Returns the document-topic matrix.
72+
73+
!!! tip
74+
Turftopic models also use the `rich` Python library for progress tracking during model fitting.
75+
Note that this is entirely optional, but it makes your model more streamlined with the rest of the library.
76+
77+
Here's a minimal example with some boilerplate code:
78+
79+
```python
80+
def fit_transform(
81+
self, raw_documents, y=None, embeddings: Optional[np.ndarray] = None
82+
) -> np.ndarray:
83+
# We track progress with a rich console
84+
console = Console()
85+
with console.status("Fitting model") as status:
86+
if embeddings is None:
87+
status.update("Encoding documents")
88+
# The encode_documents method is implemented in ContextualModel
89+
embeddings = self.encode_documents(raw_documents)
90+
console.log("Documents encoded.")
91+
status.update("Extracting terms.")
92+
# It is very important that you fit your vectorizer on the corpus
93+
# this is how you get vocabulary items by calling get_vocab()
94+
document_term_matrix = self.vectorizer.fit_transform(raw_documents)
95+
console.log("Term extraction done.")
96+
status.update("Fitting model")
97+
# ===========HERE COMES ALL YOUR MODEL FITTING CODE================
98+
# I'm assigning None to these since we are not implementing a model here,
99+
# but in your implementation these should be assigned a numpy array
100+
self.components_ = None
101+
document_topic_matrix = None
102+
console.log("Model fitting done.")
103+
return document_topic_matrix
104+
```
105+
106+
## Example: Wrapper for Latent Dirichlet Allocation
107+
108+
LDA is the most well known and historically the most used topic model.
109+
It is not implemented in Turftopic, since it is not a contextual topic model, but you might need to compare it with Turftopic models,
110+
and it might be convenient to have access to the same interface.
111+
112+
Here's a minimal wrapper for LDA in Turftopic using the boilerplate code above.
113+
I will not implement most hyperparameters, since this would be trivial to do but takes more code.
114+
115+
```python
116+
from sklearn.decomposition import LatentDirichletAllocation
117+
118+
class LDA(ContextualModel):
119+
"""Latent Dirichlet Allocation model wrapper in Turftopic."""
120+
def __init__(
121+
self,
122+
n_components: int,
123+
encoder: Union[
124+
Encoder, str
125+
] = "sentence-transformers/all-MiniLM-L6-v2",
126+
vectorizer: Optional[CountVectorizer] = None,
127+
random_state: Optional[int] = None,
128+
):
129+
self.n_components = n_components
130+
self.encoder = encoder
131+
self.random_state = random_state
132+
# Since LDA only uses bag-of-words, we do not load an encoder
133+
self.encoder_ = None
134+
if vectorizer is None:
135+
self.vectorizer = default_vectorizer()
136+
else:
137+
self.vectorizer = vectorizer
138+
139+
def fit_transform(
140+
self, raw_documents, y=None, embeddings: Optional[np.ndarray] = None
141+
) -> np.ndarray:
142+
console = Console()
143+
with console.status("Fitting model") as status:
144+
status.update("Extracting terms.")
145+
document_term_matrix = self.vectorizer.fit_transform(raw_documents)
146+
console.log("Term extraction done.")
147+
status.update("Fitting model")
148+
self._lda = LatentDirichletAllocation(self.n_components, random_state=self.random_state)
149+
document_topic_matrix = self._lda.fit_transform(document_term_matrix)
150+
# Since the scikit-learn API matches perfectly, we won't have to do much.
151+
self.components_ = self._lda.components_
152+
console.log("Model fitting done.")
153+
return document_topic_matrix
154+
```
155+
156+
This model can now be used the same way you would use any other Turftopic model:
157+
158+
```python
159+
from sklearn.datasets import fetch_20newsgroups
160+
161+
ds = fetch_20newsgroups(
162+
subset="all",
163+
remove=("headers", "footers", "quotes"),
164+
)
165+
corpus = ds.data
166+
167+
model = LDA(10, random_state=42)
168+
doc_topic = model.fit_transform(corpus)
169+
170+
model.print_topics()
171+
```
172+
173+
```
174+
[10:08:44] Term extraction done. lda_mess.py:43
175+
[10:09:48] Model fitting done. lda_mess.py:53
176+
┏━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
177+
┃ Topic ID ┃ Highest Ranking ┃
178+
┡━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
179+
│ 0 │ game, team, year, games, play, don, think, season, good, hockey │
180+
├──────────┼─────────────────────────────────────────────────────────────────────────────────┤
181+
│ 1 │ drive, use, like, just, used, scsi, don, time, card, power │
182+
├──────────┼─────────────────────────────────────────────────────────────────────────────────┤
183+
│ 2 │ 00, 10, 25, 20, 15, 11, 12, 14, 16, space │
184+
├──────────┼─────────────────────────────────────────────────────────────────────────────────┤
185+
│ 3 │ people, god, don, think, just, know, like, say, does, time │
186+
├──────────┼─────────────────────────────────────────────────────────────────────────────────┤
187+
│ 4 │ ax, max, g9v, b8f, a86, pl, 145, 1d9, 34u, 1t │
188+
├──────────┼─────────────────────────────────────────────────────────────────────────────────┤
189+
│ 5 │ know, thanks, like, car, mail, edu, db, just, good, new │
190+
├──────────┼─────────────────────────────────────────────────────────────────────────────────┤
191+
│ 6 │ gun, people, government, law, right, use, fbi, don, guns, control │
192+
├──────────┼─────────────────────────────────────────────────────────────────────────────────┤
193+
│ 7 │ windows, dos, use, file, window, program, using, problem, does, like │
194+
├──────────┼─────────────────────────────────────────────────────────────────────────────────┤
195+
│ 8 │ edu, image, available, com, data, ftp, information, file, graphics, mail │
196+
├──────────┼─────────────────────────────────────────────────────────────────────────────────┤
197+
│ 9 │ people, said, president, armenian, israel, government, armenians, mr, jews, war │
198+
└──────────┴─────────────────────────────────────────────────────────────────────────────────┘
199+
```
200+
201+
202+
## Dynamic Topic Models
203+
204+
If you want to implement dynamic functionality you will have to inherit from `DynamicTopicModel`,
205+
and implement a `fit_transform_dynamic` method, that also takes timestamps.
206+
207+
!!! note
208+
You will have to implement the rest of the methods too, as outlined before.
209+
210+
Here's some boilerplate code:
211+
212+
```python
213+
from turftopic.dynamic import DynamicTopicModel
214+
215+
class CustomDynamic(ContextualModel, DynamicTopicModel):
216+
...
217+
218+
def fit_transform_dynamic(
219+
self,
220+
raw_documents,
221+
timestamps: list[datetime],
222+
embeddings: Optional[np.ndarray] = None,
223+
bins: Union[int, list[datetime]] = 10,
224+
):
225+
# bin_timestamps sorts your data into time bins based on the bins the user provided
226+
# time_labels will be time bin indices for each document
227+
time_labels, self.time_bin_edges = self.bin_timestamps(
228+
timestamps, bins
229+
)
230+
n_bins = len(self.time_bin_edges) - 1
231+
# fit vectorizer as before
232+
doc_term_matrix = self.vectorizer.transform(raw_documents)
233+
# Overall topics, non-time sensitive
234+
self.components_ = ...
235+
# Here I'm assigning zeros,
236+
# but this attribute should contain the topic-word distributions for each time bin
237+
self.temporal_components_ = np.zeros(
238+
(n_bins, self.n_components, len(self.get_vocab))
239+
)
240+
# This attribute should contain the importance of a topic for each time bin
241+
self.temporal_importance_ = np.zeros((n_bins, n_comp))
242+
# You should of course assign this too
243+
doc_topic_matrix = ...
244+
return doc_topic_matrix
245+
```

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ nav:
1616
- Modifying and Finetuning Models: finetuning.md
1717
- Saving and Loading: persistence.md
1818
- Using TopicData: topic_data.md
19+
- Implement a Wrapper/Custom Model: custom_model.md
1920
- Tutorials:
2021
- Tutorial Overview: tutorials/overview.md
2122
- Analyzing the Landscape of Machine Learning Research: tutorials/arxiv_ml.md

0 commit comments

Comments
 (0)