Skip to content

Commit c7b835c

Browse files
Added Dynamic KeyNMF
1 parent 36aba52 commit c7b835c

1 file changed

Lines changed: 75 additions & 1 deletion

File tree

turftopic/models/keynmf.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,67 @@
11
import itertools
22
import json
33
import random
4+
from datetime import datetime
45
from typing import Dict, Iterable, List, Optional, Union
56

67
import numpy as np
78
from rich.console import Console
89
from sentence_transformers import SentenceTransformer
910
from sklearn.decomposition import NMF, MiniBatchNMF
11+
from sklearn.decomposition._nmf import (_initialize_nmf,
12+
_update_coordinate_descent)
1013
from sklearn.exceptions import NotFittedError
1114
from sklearn.feature_extraction import DictVectorizer
1215
from sklearn.feature_extraction.text import CountVectorizer
1316
from sklearn.metrics.pairwise import cosine_similarity
17+
from sklearn.utils import check_array
1418

1519
from turftopic.base import ContextualModel, Encoder
1620
from turftopic.data import TopicData
21+
from turftopic.dynamic import DynamicTopicModel, bin_timestamps
1722
from turftopic.vectorizer import default_vectorizer
1823

1924

25+
def fit_timeslice(
26+
X,
27+
W,
28+
H,
29+
tol=1e-4,
30+
max_iter=200,
31+
l1_reg_W=0,
32+
l1_reg_H=0,
33+
l2_reg_W=0,
34+
l2_reg_H=0,
35+
verbose=0,
36+
shuffle=False,
37+
random_state=None,
38+
):
39+
"""Fits topic_term_matrix based on a precomputed document_topic_matrix.
40+
This is used to get temporal components in dynamic KeyNMF.
41+
"""
42+
Ht = check_array(H.T, order="C")
43+
if random_state is None:
44+
rng = np.random.mtrand._rand
45+
else:
46+
rng = np.random.RandomState(random_state)
47+
for n_iter in range(1, max_iter + 1):
48+
violation = 0.0
49+
violation += _update_coordinate_descent(
50+
X.T, Ht, W, l1_reg_H, l2_reg_H, shuffle, rng
51+
)
52+
if n_iter == 1:
53+
violation_init = violation
54+
if violation_init == 0:
55+
break
56+
if verbose:
57+
print("violation:", violation / violation_init)
58+
if violation / violation_init <= tol:
59+
if verbose:
60+
print("Converged at iteration", n_iter + 1)
61+
break
62+
return W, Ht.T, n_iter
63+
64+
2065
def batched(iterable, n: int) -> Iterable[List[str]]:
2166
"Batch data into tuples of length n. The last batch may be shorter."
2267
# batched('ABCDEFG', 3) --> ABC DEF G
@@ -48,7 +93,7 @@ def __iter__(self) -> Iterable[Dict[str, float]]:
4893
yield deserialize_keywords(line.strip())
4994

5095

51-
class KeyNMF(ContextualModel):
96+
class KeyNMF(ContextualModel, DynamicTopicModel):
5297
"""Extracts keywords from documents based on semantic similarity of
5398
term encodings to document encodings.
5499
Topics are then extracted with non-negative matrix factorization from
@@ -305,3 +350,32 @@ def prepare_topic_data(
305350
"topic_names": self.topic_names,
306351
}
307352
return res
353+
354+
def fit_transform_dynamic(
355+
self,
356+
raw_documents,
357+
timestamps: list[datetime],
358+
embeddings: Optional[np.ndarray] = None,
359+
bins: Union[int, list[datetime]] = 10,
360+
) -> np.ndarray:
361+
time_labels, self.time_bin_edges = bin_timestamps(timestamps, bins)
362+
topic_data = self.prepare_topic_data(
363+
raw_documents, embeddings=embeddings
364+
)
365+
n_bins = len(self.time_bin_edges) + 1
366+
n_comp, n_vocab = self.components_.shape
367+
self.temporal_components_ = np.zeros((n_bins, n_comp, n_vocab))
368+
self.temporal_importance_ = np.zeros((n_bins, n_comp))
369+
for label in np.unique(time_labels):
370+
idx = np.nonzero(time_labels == label)
371+
X = topic_data["document_term_matrix"][idx]
372+
W = topic_data["document_topic_matrix"][idx]
373+
_, H = _initialize_nmf(
374+
X, self.components_.shape[0], random_state=self.random_state
375+
)
376+
_, H, _ = fit_timeslice(X, W, H, random_state=self.random_state)
377+
self.temporal_components_[label] = H
378+
topic_importances = np.squeeze(np.asarray(W.sum(axis=0)))
379+
topic_importances = topic_importances / topic_importances.sum()
380+
self.temporal_importance_[label] = topic_importances
381+
return topic_data["document_topic_matrix"]

0 commit comments

Comments
 (0)