@@ -139,16 +139,31 @@ def batch_extract_keywords(
139139
140140
141141class KeywordNMF :
142- def __init__ (self , n_components : int , seed : Optional [int ] = None ):
142+ def __init__ (
143+ self ,
144+ n_components : int ,
145+ seed : Optional [int ] = None ,
146+ top_n : Optional [int ] = None ,
147+ ):
143148 self .n_components = n_components
144149 self .key_to_index : dict [str , int ] = {}
145150 self .index_to_key : list [str ] = []
151+ self .top_n = top_n
146152 # n_components * n_vocab
147153 self .components : Optional [np .ndarray ] = None
148154 self .seed = seed
149155 self .temporal_components : Optional [np .ndarray ] = None
150156 self .temporal_importance_ : Optional [np .ndarray ] = None
151157
158+ def prune_keywords (self , keywords : dict [str , float ]) -> dict [str , float ]:
159+ """If there are more keywords than allowed, this prunes them."""
160+ if (self .top_n is None ) or (self .top_n >= len (keywords )):
161+ return keywords
162+ words , similarities = zip (* keywords .items ())
163+ selected = np .argsort (similarities )[: self .top_n ]
164+ items = [(words [i ], similarities [i ]) for i in selected ]
165+ return dict (items )
166+
152167 @property
153168 def n_vocab (self ) -> int :
154169 return len (self .index_to_key )
@@ -183,6 +198,7 @@ def vectorize(
183198 indptr = [0 ]
184199 values = []
185200 for k in keywords :
201+ k = self .prune_keywords (k )
186202 for w , v in k .items ():
187203 # Adding vocab item if missing
188204 if (w not in self .key_to_index ) and fitting :
0 commit comments