Skip to content

Commit 450184b

Browse files
Added random_state argument to all models so results are exactly reproducible.
1 parent 6d5ac33 commit 450184b

4 files changed

Lines changed: 45 additions & 10 deletions

File tree

turftopic/models/cluster.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,14 @@ class ClusteringTopicModel(ContextualModel, ClusterMixin, DynamicTopicModel):
137137
The specified reduction method will be used to merge them.
138138
By default, topics are not merged.
139139
reduction_method: 'agglomerative', 'smallest'
140+
Method used to reduce the number of topics post-hoc.
141+
When 'agglomerative', BERTopic's topic reduction method is used,
142+
where topic vectors are hierarchically clustered.
143+
When 'smallest', the smallest topic gets merged into the closest
144+
non-outlier cluster until the desired number
145+
is achieved similarly to Top2Vec.
146+
random_state: int, default None
147+
Random state to use so that results are exactly reproducible.
140148
"""
141149

142150
def __init__(
@@ -154,8 +162,10 @@ def __init__(
154162
reduction_method: Literal[
155163
"agglomerative", "smallest"
156164
] = "agglomerative",
165+
random_state: Optional[int] = None,
157166
):
158167
self.encoder = encoder
168+
self.random_state = random_state
159169
if feature_importance not in ["c-tf-idf", "soft-c-tf-idf", "centroid"]:
160170
raise ValueError(feature_message)
161171
if isinstance(encoder, int):
@@ -174,7 +184,7 @@ def __init__(
174184
self.clustering = clustering
175185
if dimensionality_reduction is None:
176186
self.dimensionality_reduction = TSNE(
177-
n_components=2, metric="cosine"
187+
n_components=2, metric="cosine", random_state=random_state
178188
)
179189
else:
180190
self.dimensionality_reduction = dimensionality_reduction
@@ -196,7 +206,9 @@ def _merge_agglomerative(self, n_reduce_to: int) -> np.ndarray:
196206
)
197207
old_labels = [label for label in self.classes_ if label != -1]
198208
new_labels = AgglomerativeClustering(
199-
n_clusters=n_reduce_to, metric="cosine", linkage="average"
209+
n_clusters=n_reduce_to,
210+
metric="cosine",
211+
linkage="average",
200212
).fit_predict(interesting_topic_vectors)
201213
res = {}
202214
if -1 in self.classes_:
@@ -235,7 +247,9 @@ def _estimate_parameters(
235247
self.labels_, classes=self.classes_
236248
)
237249
if self.feature_importance == "soft-c-tf-idf":
238-
self.components_ = soft_ctf_idf(document_topic_matrix, doc_term_matrix) # type: ignore
250+
self.components_ = soft_ctf_idf(
251+
document_topic_matrix, doc_term_matrix
252+
) # type: ignore
239253
elif self.feature_importance == "centroid":
240254
self.components_ = cluster_centroid_distance(
241255
self.topic_vectors_,
@@ -327,7 +341,7 @@ def fit_transform_dynamic(
327341
if embeddings is None:
328342
embeddings = self.encoder_.encode(raw_documents)
329343
for i_timebin in np.arange(len(self.time_bin_edges) - 1):
330-
if hasattr(self, 'components_'):
344+
if hasattr(self, "components_"):
331345
doc_topic_matrix = label_binarize(
332346
self.labels_, classes=self.classes_
333347
)

turftopic/models/ctm.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import math
2+
import random
3+
import sys
24
from typing import Optional, Union
35

46
import numpy as np
@@ -129,6 +131,8 @@ class AutoEncodingTopicModel(ContextualModel):
129131
Learning rate for the optimizer.
130132
n_epochs: int, default 50
131133
Number of epochs to run during training.
134+
random_state: int, default None
135+
Random state to use so that results are exactly reproducible.
132136
"""
133137

134138
def __init__(
@@ -144,8 +148,10 @@ def __init__(
144148
batch_size: int = 42,
145149
learning_rate: float = 1e-2,
146150
n_epochs: int = 50,
151+
random_state: Optional[int] = None,
147152
):
148153
self.n_components = n_components
154+
self.random_state = random_state
149155
self.encoder = encoder
150156
if isinstance(encoder, str):
151157
self.encoder_ = SentenceTransformer(encoder)
@@ -205,7 +211,7 @@ def fit(
205211
status.update("Extracting terms.")
206212
document_term_matrix = self.vectorizer.fit_transform(raw_documents)
207213
console.log("Term extraction done.")
208-
seed = 0
214+
seed = self.random_state or random.randint(0, sys.maxint - 1)
209215
torch.manual_seed(seed)
210216
pyro.set_rng_seed(seed)
211217
device = torch.device(

turftopic/models/gmm.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ class GMM(ContextualModel, DynamicTopicModel):
5454
result in Gaussian components.
5555
For even larger datasets you can use IncrementalPCA to reduce
5656
memory load.
57+
random_state: int, default None
58+
Random state to use so that results are exactly reproducible.
5759
5860
Attributes
5961
----------
@@ -71,11 +73,13 @@ def __init__(
7173
dimensionality_reduction: Optional[TransformerMixin] = None,
7274
weight_prior: Literal["dirichlet", "dirichlet_process", None] = None,
7375
gamma: Optional[float] = None,
76+
random_state: Optional[int] = None,
7477
):
7578
self.n_components = n_components
7679
self.encoder = encoder
7780
self.weight_prior = weight_prior
7881
self.gamma = gamma
82+
self.random_state = random_state
7983
if isinstance(encoder, str):
8084
self.encoder_ = SentenceTransformer(encoder)
8185
else:
@@ -94,9 +98,12 @@ def __init__(
9498
else "dirichlet_process"
9599
),
96100
weight_concentration_prior=gamma,
101+
random_state=self.random_state,
97102
)
98103
else:
99-
mixture = GaussianMixture(n_components)
104+
mixture = GaussianMixture(
105+
n_components, random_state=self.random_state
106+
)
100107
if dimensionality_reduction is not None:
101108
self.gmm_ = make_pipeline(dimensionality_reduction, mixture)
102109
else:
@@ -162,7 +169,7 @@ def fit_transform_dynamic(
162169
bins: Union[int, list[datetime]] = 10,
163170
):
164171
time_labels, self.time_bin_edges = bin_timestamps(timestamps, bins)
165-
if hasattr(self, 'components_'):
172+
if hasattr(self, "components_"):
166173
doc_topic_matrix = self.transform(
167174
raw_documents, embeddings=embeddings
168175
)

turftopic/models/keynmf.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ class KeyNMF(ContextualModel):
7979
is performed on the whole vocabulary ('corpus') or only
8080
using words that are included in the document ('document').
8181
Setting this to 'corpus' allows for multilingual topics.
82+
random_state: int, default None
83+
Random state to use so that results are exactly reproducible.
8284
"""
8385

8486
def __init__(
@@ -90,7 +92,9 @@ def __init__(
9092
vectorizer: Optional[CountVectorizer] = None,
9193
top_n: int = 25,
9294
keyword_scope: str = "document",
95+
random_state: Optional[int] = None,
9396
):
97+
self.random_state = random_state
9498
if keyword_scope not in ["document", "corpus"]:
9599
raise ValueError("keyword_scope must be 'document' or 'corpus'")
96100
self.n_components = n_components
@@ -105,7 +109,7 @@ def __init__(
105109
else:
106110
self.vectorizer = vectorizer
107111
self.dict_vectorizer_ = DictVectorizer()
108-
self.nmf_ = NMF(n_components)
112+
self.nmf_ = NMF(n_components, random_state=self.random_state)
109113
self.keyword_scope = keyword_scope
110114

111115
def extract_keywords(
@@ -172,7 +176,9 @@ def minibatch_train(
172176
console=None,
173177
):
174178
self.dict_vectorizer_.fit(keywords)
175-
self.nmf_ = MiniBatchNMF(self.n_components)
179+
self.nmf_ = MiniBatchNMF(
180+
self.n_components, random_state=self.random_state
181+
)
176182
epoch_costs = []
177183
for i_epoch in range(max_epochs):
178184
epoch_cost = 0
@@ -220,7 +226,9 @@ def big_fit(
220226
console.log("Keywords extracted.")
221227
keywords = KeywordIterator(keyword_file)
222228
status.update("Fitting NMF.")
223-
self.minibatch_train(keywords, max_epochs, batch_size, console=console) # type: ignore
229+
self.minibatch_train(
230+
keywords, max_epochs, batch_size, console=console
231+
) # type: ignore
224232
console.log("NMF fitted.")
225233
return self
226234

0 commit comments

Comments
 (0)