Skip to content

Commit 2460eb0

Browse files
Merge pull request #46 from x-tabdeveloping/partial_keynmf
Online KeyNMF
2 parents ef5fd2f + 4491342 commit 2460eb0

3 files changed

Lines changed: 554 additions & 255 deletions

File tree

tests/test_integration.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import itertools
12
import tempfile
23
from datetime import datetime
34
from pathlib import Path
@@ -8,8 +9,23 @@
89
from sentence_transformers import SentenceTransformer
910
from sklearn.datasets import fetch_20newsgroups
1011

11-
from turftopic import (GMM, AutoEncodingTopicModel, ClusteringTopicModel,
12-
KeyNMF, SemanticSignalSeparation)
12+
from turftopic import (
13+
GMM,
14+
AutoEncodingTopicModel,
15+
ClusteringTopicModel,
16+
KeyNMF,
17+
SemanticSignalSeparation,
18+
)
19+
20+
21+
def batched(iterable, n: int):
22+
"Batch data into tuples of length n. The last batch may be shorter."
23+
# batched('ABCDEFG', 3) --> ABC DEF G
24+
if n < 1:
25+
raise ValueError("n must be at least one")
26+
it = iter(iterable)
27+
while batch := list(itertools.islice(it, n)):
28+
yield batch
1329

1430

1531
def generate_dates(
@@ -41,8 +57,7 @@ def generate_dates(
4157
models = [
4258
GMM(5, encoder=trf),
4359
SemanticSignalSeparation(5, encoder=trf),
44-
KeyNMF(5, encoder=trf, keyword_scope="document"),
45-
KeyNMF(5, encoder=trf, keyword_scope="corpus"),
60+
KeyNMF(5, encoder=trf),
4661
ClusteringTopicModel(
4762
n_reduce_to=5,
4863
feature_importance="c-tf-idf",
@@ -75,6 +90,8 @@ def generate_dates(
7590
KeyNMF(5, encoder=trf),
7691
]
7792

93+
online_models = [KeyNMF(5, encoder=trf)]
94+
7895

7996
@pytest.mark.parametrize("model", models)
8097
def test_fit_export_table(model):
@@ -100,3 +117,19 @@ def test_fit_dynamic(model):
100117
with out_path.open("w") as out_file:
101118
out_file.write(table)
102119
df = pd.read_csv(out_path)
120+
121+
122+
@pytest.mark.parametrize("model", online_models)
123+
def test_fit_online(model):
124+
for epoch in range(5):
125+
for batch in batched(zip(texts, embeddings), 50):
126+
batch_text, batch_embedding = zip(*batch)
127+
batch_text = list(batch_text)
128+
batch_embedding = np.stack(batch_embedding)
129+
model.partial_fit(batch_text, embeddings=batch_embedding)
130+
table = model.export_topics(format="csv")
131+
with tempfile.TemporaryDirectory() as tmpdirname:
132+
out_path = Path(tmpdirname).joinpath("topics.csv")
133+
with out_path.open("w") as out_file:
134+
out_file.write(table)
135+
df = pd.read_csv(out_path)

turftopic/models/_keynmf.py

Lines changed: 328 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,328 @@
1+
import itertools
2+
from datetime import datetime
3+
from typing import Iterable, Optional
4+
5+
import numpy as np
6+
import scipy.sparse as spr
7+
from sklearn.base import clone
8+
from sklearn.decomposition._nmf import (
9+
NMF,
10+
MiniBatchNMF,
11+
_initialize_nmf,
12+
_update_coordinate_descent,
13+
)
14+
from sklearn.exceptions import NotFittedError
15+
from sklearn.feature_extraction.text import CountVectorizer
16+
from sklearn.metrics.pairwise import cosine_similarity
17+
from sklearn.utils import check_array
18+
from sklearn.utils.validation import check_non_negative
19+
20+
from turftopic.base import Encoder
21+
22+
23+
def batched(iterable, n: int) -> Iterable[list[str]]:
24+
"Batch data into tuples of length n. The last batch may be shorter."
25+
# batched('ABCDEFG', 3) --> ABC DEF G
26+
if n < 1:
27+
raise ValueError("n must be at least one")
28+
it = iter(iterable)
29+
while batch := list(itertools.islice(it, n)):
30+
yield batch
31+
32+
33+
def fit_timeslice(
34+
X,
35+
W,
36+
H,
37+
tol=1e-4,
38+
max_iter=200,
39+
l1_reg_W=0,
40+
l1_reg_H=0,
41+
l2_reg_W=0,
42+
l2_reg_H=0,
43+
verbose=0,
44+
shuffle=False,
45+
random_state=None,
46+
):
47+
"""Fits topic_term_matrix based on a precomputed document_topic_matrix.
48+
This is used to get temporal components in dynamic KeyNMF.
49+
"""
50+
Ht = check_array(H.T, order="C")
51+
if random_state is None:
52+
rng = np.random.mtrand._rand
53+
else:
54+
rng = np.random.RandomState(random_state)
55+
for n_iter in range(1, max_iter + 1):
56+
violation = 0.0
57+
violation += _update_coordinate_descent(
58+
X.T, Ht, W, l1_reg_H, l2_reg_H, shuffle, rng
59+
)
60+
if n_iter == 1:
61+
violation_init = violation
62+
if violation_init == 0:
63+
break
64+
if verbose:
65+
print("violation:", violation / violation_init)
66+
if violation / violation_init <= tol:
67+
if verbose:
68+
print("Converged at iteration", n_iter + 1)
69+
break
70+
return W, Ht.T, n_iter
71+
72+
73+
class KeywordExtractor:
74+
def __init__(
75+
self, top_n: int, encoder: Encoder, vectorizer: CountVectorizer
76+
):
77+
self.top_n = top_n
78+
self.encoder = encoder
79+
self.vectorizer = vectorizer
80+
self.key_to_index: dict[str, int] = {}
81+
self.term_embeddings: Optional[np.ndarray] = None
82+
83+
@property
84+
def n_vocab(self) -> int:
85+
return len(self.key_to_index)
86+
87+
def _add_terms(self, new_terms: list[str]):
88+
for term in new_terms:
89+
self.key_to_index[term] = self.n_vocab
90+
term_encodings = self.encoder.encode(new_terms)
91+
if self.term_embeddings is not None:
92+
self.term_embeddings = np.concatenate(
93+
(self.term_embeddings, term_encodings), axis=0
94+
)
95+
else:
96+
self.term_embeddings = term_encodings
97+
98+
def batch_extract_keywords(
99+
self,
100+
documents: list[str],
101+
embeddings: Optional[np.ndarray] = None,
102+
) -> list[dict[str, float]]:
103+
if not len(documents):
104+
return []
105+
if embeddings is None:
106+
embeddings = self.encoder.encode(documents)
107+
if len(embeddings) != len(documents):
108+
raise ValueError(
109+
"Number of documents doesn't match number of embeddings."
110+
)
111+
keywords = []
112+
vectorizer = clone(self.vectorizer)
113+
document_term_matrix = vectorizer.fit_transform(documents)
114+
batch_vocab = vectorizer.get_feature_names_out()
115+
new_terms = list(set(batch_vocab) - set(self.key_to_index.keys()))
116+
if len(new_terms):
117+
self._add_terms(new_terms)
118+
total = embeddings.shape[0]
119+
for i in range(total):
120+
terms = document_term_matrix[i, :].todense()
121+
embedding = embeddings[i].reshape(1, -1)
122+
mask = terms > 0
123+
if not np.any(mask):
124+
keywords.append(dict())
125+
continue
126+
important_terms = np.squeeze(np.asarray(mask))
127+
word_embeddings = [
128+
self.term_embeddings[self.key_to_index[term]]
129+
for term in batch_vocab[important_terms]
130+
]
131+
sim = cosine_similarity(embedding, word_embeddings)
132+
sim = np.ravel(sim)
133+
kth = min(self.top_n, len(sim) - 1)
134+
top = np.argpartition(-sim, kth)[:kth]
135+
top_words = batch_vocab[important_terms][top]
136+
top_sims = [sim for sim in sim[top] if sim > 0]
137+
keywords.append(dict(zip(top_words, top_sims)))
138+
return keywords
139+
140+
141+
class KeywordNMF:
142+
def __init__(
143+
self,
144+
n_components: int,
145+
seed: Optional[int] = None,
146+
top_n: Optional[int] = None,
147+
):
148+
self.n_components = n_components
149+
self.key_to_index: dict[str, int] = {}
150+
self.index_to_key: list[str] = []
151+
self.top_n = top_n
152+
# n_components * n_vocab
153+
self.components: Optional[np.ndarray] = None
154+
self.seed = seed
155+
self.temporal_components: Optional[np.ndarray] = None
156+
self.temporal_importance_: Optional[np.ndarray] = None
157+
158+
def prune_keywords(self, keywords: dict[str, float]) -> dict[str, float]:
159+
"""If there are more keywords than allowed, this prunes them."""
160+
if (self.top_n is None) or (self.top_n >= len(keywords)):
161+
return keywords
162+
words, similarities = zip(*keywords.items())
163+
selected = np.argsort(similarities)[: self.top_n]
164+
items = [(words[i], similarities[i]) for i in selected]
165+
return dict(items)
166+
167+
@property
168+
def n_vocab(self) -> int:
169+
return len(self.index_to_key)
170+
171+
def _add_word_components(self, X: spr.csr_matrix):
172+
"""Initializes components for novel vocabulary."""
173+
_, H = _initialize_nmf(X, self.n_components, random_state=self.seed)
174+
if self.components is None:
175+
self.components = H
176+
else:
177+
n_new = X.shape[1] - self.components.shape[1]
178+
if n_new:
179+
self.components = np.concatenate(
180+
(self.components, H[:, -n_new:]), axis=1
181+
)
182+
if self.temporal_components is not None:
183+
n_new = X.shape[1] - self.temporal_components.shape[-1]
184+
if n_new:
185+
new_comps = H[:, -n_new:]
186+
new_comps = np.broadcast_to(
187+
new_comps,
188+
(self.temporal_components.shape[0], *new_comps.shape),
189+
)
190+
self.temporal_components = np.concatenate(
191+
(self.temporal_components, new_comps), axis=-1
192+
)
193+
194+
def vectorize(
195+
self, keywords: list[dict[str, float]], fitting: bool = False
196+
) -> spr.csr_array:
197+
indices = []
198+
indptr = [0]
199+
values = []
200+
for k in keywords:
201+
k = self.prune_keywords(k)
202+
for w, v in k.items():
203+
# Adding vocab item if missing
204+
if (w not in self.key_to_index) and fitting:
205+
self.key_to_index[w] = self.n_vocab
206+
self.index_to_key.append(w)
207+
if w in self.key_to_index:
208+
indices.append(self.key_to_index[w])
209+
values.append(v)
210+
indptr.append(len(indices))
211+
shape = (len(indptr) - 1, self.n_vocab)
212+
document_term_matrix = spr.csr_matrix(
213+
(values, indices, indptr), shape=shape
214+
)
215+
return document_term_matrix
216+
217+
def fit_transform(self, keywords: list[dict[str, float]]) -> np.ndarray:
218+
X = self.vectorize(keywords, fitting=True)
219+
check_non_negative(X, "NMF (input X)")
220+
W, H = _initialize_nmf(X, self.n_components, random_state=self.seed)
221+
W, H, self.n_iter = NMF(
222+
self.n_components, init="custom", random_state=self.seed
223+
)._fit_transform(X, W=W, H=H, update_H=True)
224+
self.components = H.astype(X.dtype)
225+
return W
226+
227+
def transform(self, keywords: list[dict[str, float]]):
228+
if self.components is None:
229+
raise NotFittedError(
230+
"Can't transform() if the model has not been fitted."
231+
)
232+
X = self.vectorize(keywords, fitting=False)
233+
check_non_negative(X, "NMF (input X)")
234+
W, _, _ = NMF(
235+
self.n_components, init="custom", random_state=self.seed
236+
)._fit_transform(X, W=None, H=self.components, update_H=False)
237+
return W.astype(X.dtype)
238+
239+
def partial_fit(self, keyword_batch: list[dict[str, float]]):
240+
X = self.vectorize(keyword_batch, fitting=True)
241+
check_non_negative(X, "NMF (input X)")
242+
self._add_word_components(X)
243+
W, _ = _initialize_nmf(X, self.n_components, random_state=self.seed)
244+
_minibatchnmf = MiniBatchNMF(
245+
self.n_components, init="custom", random_state=self.seed
246+
).partial_fit(X, W=W, H=self.components)
247+
self.components = _minibatchnmf.components_.astype(X.dtype)
248+
return self
249+
250+
def fit_transform_dynamic(
251+
self,
252+
keywords: list[dict[str, float]],
253+
time_labels: np.ndarray,
254+
time_bin_edges: list[datetime],
255+
) -> np.ndarray:
256+
self.time_bin_edges = time_bin_edges
257+
n_bins = len(time_bin_edges) + 1
258+
document_term_matrix = self.vectorize(keywords, fitting=True)
259+
check_non_negative(document_term_matrix, "NMF (input X)")
260+
document_topic_matrix, H = _initialize_nmf(
261+
document_term_matrix,
262+
self.n_components,
263+
random_state=self.seed,
264+
)
265+
document_topic_matrix, H, self.n_iter = NMF(
266+
self.n_components, init="custom", random_state=self.seed
267+
)._fit_transform(
268+
document_term_matrix, W=document_topic_matrix, H=H, update_H=True
269+
)
270+
self.components = H.astype(document_term_matrix.dtype)
271+
n_comp, n_vocab = self.components.shape
272+
self.temporal_components = np.zeros(
273+
(n_bins, n_comp, n_vocab), dtype=document_term_matrix.dtype
274+
)
275+
self.temporal_importance_ = np.zeros((n_bins, n_comp))
276+
for label in np.unique(time_labels):
277+
idx = np.nonzero(time_labels == label)
278+
X = document_term_matrix[idx]
279+
W = document_topic_matrix[idx]
280+
_, H = _initialize_nmf(
281+
X, self.components.shape[0], random_state=self.seed
282+
)
283+
_, H, _ = fit_timeslice(X, W, H, random_state=self.seed)
284+
self.temporal_components[label] = H
285+
topic_importances = np.squeeze(np.asarray(W.sum(axis=0)))
286+
self.temporal_importance_[label] = topic_importances
287+
return document_topic_matrix
288+
289+
def partial_fit_dynamic(
290+
self,
291+
keyword_batch: list[dict[str, float]],
292+
time_labels: np.ndarray,
293+
time_bin_edges: list[datetime],
294+
) -> np.ndarray:
295+
if self.temporal_components is None:
296+
self.fit_transform_dynamic(
297+
keyword_batch, time_labels, time_bin_edges
298+
)
299+
else:
300+
document_term_matrix = self.vectorize(keyword_batch, fitting=True)
301+
check_non_negative(document_term_matrix, "NMF (input X)")
302+
self._add_word_components(document_term_matrix)
303+
document_topic_matrix = self.transform(keyword_batch)
304+
_minibatchnmf = MiniBatchNMF(
305+
self.n_components, init="custom", random_state=self.seed
306+
).partial_fit(
307+
document_term_matrix,
308+
W=document_topic_matrix,
309+
H=self.components,
310+
)
311+
self.components = _minibatchnmf.components_.astype(
312+
document_term_matrix.dtype
313+
)
314+
document_topic_matrix = self.transform(keyword_batch)
315+
for label in np.unique(time_labels):
316+
idx = np.nonzero(time_labels == label)
317+
X = document_term_matrix[idx]
318+
W = document_topic_matrix[idx]
319+
_minibatchnmf = MiniBatchNMF(
320+
self.n_components, init="custom", random_state=self.seed
321+
).partial_fit(
322+
X,
323+
W=W,
324+
H=self.temporal_components[label],
325+
)
326+
self.temporal_components[label] = _minibatchnmf.components_
327+
topic_importances = np.squeeze(np.asarray(W.sum(axis=0)))
328+
self.temporal_importance_[label] += topic_importances

0 commit comments

Comments
 (0)