Skip to content

Commit a607aa1

Browse files
KeyNMF now prunes keywords if there are more than top_n
1 parent c69bcad commit a607aa1

2 files changed

Lines changed: 20 additions & 2 deletions

File tree

turftopic/models/_keynmf.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,16 +139,31 @@ def batch_extract_keywords(
139139

140140

141141
class KeywordNMF:
142-
def __init__(self, n_components: int, seed: Optional[int] = None):
142+
def __init__(
143+
self,
144+
n_components: int,
145+
seed: Optional[int] = None,
146+
top_n: Optional[int] = None,
147+
):
143148
self.n_components = n_components
144149
self.key_to_index: dict[str, int] = {}
145150
self.index_to_key: list[str] = []
151+
self.top_n = top_n
146152
# n_components * n_vocab
147153
self.components: Optional[np.ndarray] = None
148154
self.seed = seed
149155
self.temporal_components: Optional[np.ndarray] = None
150156
self.temporal_importance_: Optional[np.ndarray] = None
151157

158+
def prune_keywords(self, keywords: dict[str, float]) -> dict[str, float]:
159+
"""If there are more keywords than allowed, this prunes them."""
160+
if (self.top_n is None) or (self.top_n >= len(keywords)):
161+
return keywords
162+
words, similarities = zip(*keywords.items())
163+
selected = np.argsort(similarities)[: self.top_n]
164+
items = [(words[i], similarities[i]) for i in selected]
165+
return dict(items)
166+
152167
@property
153168
def n_vocab(self) -> int:
154169
return len(self.index_to_key)
@@ -183,6 +198,7 @@ def vectorize(
183198
indptr = [0]
184199
values = []
185200
for k in keywords:
201+
k = self.prune_keywords(k)
186202
for w, v in k.items():
187203
# Adding vocab item if missing
188204
if (w not in self.key_to_index) and fitting:

turftopic/models/keynmf.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ def __init__(
9090
self.vectorizer = CountVectorizer()
9191
else:
9292
self.vectorizer = vectorizer
93-
self.model = KeywordNMF(n_components=n_components, seed=random_state)
93+
self.model = KeywordNMF(
94+
n_components=n_components, seed=random_state, top_n=self.top_n
95+
)
9496
self.extractor = KeywordExtractor(
9597
top_n=self.top_n, vectorizer=self.vectorizer, encoder=self.encoder_
9698
)

0 commit comments

Comments
 (0)