@@ -74,6 +74,11 @@ class KeyNMF(ContextualModel):
7474 Can be used to prune or filter the vocabulary.
7575 top_n: int, default 25
7676 Number of keywords to extract for each document.
77+ keyword_scope: str, default 'document'
78+ Specifies whether keyword extraction for each document
79+ is performed on the whole vocabulary ('corpus') or only
80+ using words that are included in the document ('document').
81+ Setting this to 'corpus' allows for multilingual topics.
7782 """
7883
7984 def __init__ (
@@ -84,7 +89,10 @@ def __init__(
8489 ] = "sentence-transformers/all-MiniLM-L6-v2" ,
8590 vectorizer : Optional [CountVectorizer ] = None ,
8691 top_n : int = 25 ,
92+ keyword_scope : str = 'document' ,
8793 ):
94+ if keyword_scope not in ['document' , 'corpus' ]:
95+ raise ValueError ("keyword_scope must be 'document' or 'corpus'" )
8896 self .n_components = n_components
8997 self .top_n = top_n
9098 self .encoder = encoder
@@ -98,6 +106,7 @@ def __init__(
98106 self .vectorizer = vectorizer
99107 self .dict_vectorizer_ = DictVectorizer ()
100108 self .nmf_ = NMF (n_components )
109+ self .keyword_scope = keyword_scope
101110
102111 def extract_keywords (
103112 self ,
@@ -114,13 +123,16 @@ def extract_keywords(
114123 for i in range (total ):
115124 terms = document_term_matrix [i , :].todense ()
116125 embedding = embeddings [i ].reshape (1 , - 1 )
117- nonzero = terms > 0
118- if not np .any (nonzero ):
119- keywords .append (dict ())
120- continue
121- important_terms = np .squeeze (np .asarray (nonzero ))
122- word_embeddings = self .vocab_embeddings [important_terms ]
123- sim = cosine_similarity (embedding , word_embeddings )
126+ if self .keyword_scope == 'document' :
127+ nonzero = terms > 0
128+ if not np .any (nonzero ):
129+ keywords .append (dict ())
130+ continue
131+ important_terms = np .squeeze (np .asarray (nonzero ))
132+ word_embeddings = self .vocab_embeddings [important_terms ]
133+ sim = cosine_similarity (embedding , word_embeddings )
134+ else :
135+ sim = cosine_similarity (embedding , self .vocab_embeddings )
124136 sim = np .ravel (sim )
125137 kth = min (self .top_n , len (sim ) - 1 )
126138 top = np .argpartition (- sim , kth )[:kth ]
0 commit comments