Skip to content

Commit c9765c3

Browse files
committed
first stab at conditional keyword extraction logic
1 parent b408540 commit c9765c3

1 file changed

Lines changed: 19 additions & 7 deletions

File tree

turftopic/models/keynmf.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,11 @@ class KeyNMF(ContextualModel):
7474
Can be used to prune or filter the vocabulary.
7575
top_n: int, default 25
7676
Number of keywords to extract for each document.
77+
keyword_scope: str, default 'document'
78+
Specifies whether keyword extraction for each document
79+
is performed on the whole vocabulary ('corpus') or only
80+
using words that are included in the document ('document').
81+
Setting this to 'corpus' allows for multilingual topics.
7782
"""
7883

7984
def __init__(
@@ -84,7 +89,10 @@ def __init__(
8489
] = "sentence-transformers/all-MiniLM-L6-v2",
8590
vectorizer: Optional[CountVectorizer] = None,
8691
top_n: int = 25,
92+
keyword_scope: str = 'document',
8793
):
94+
if keyword_scope not in ['document', 'corpus']:
95+
raise ValueError("keyword_scope must be 'document' or 'corpus'")
8896
self.n_components = n_components
8997
self.top_n = top_n
9098
self.encoder = encoder
@@ -98,6 +106,7 @@ def __init__(
98106
self.vectorizer = vectorizer
99107
self.dict_vectorizer_ = DictVectorizer()
100108
self.nmf_ = NMF(n_components)
109+
self.keyword_scope = keyword_scope
101110

102111
def extract_keywords(
103112
self,
@@ -114,13 +123,16 @@ def extract_keywords(
114123
for i in range(total):
115124
terms = document_term_matrix[i, :].todense()
116125
embedding = embeddings[i].reshape(1, -1)
117-
nonzero = terms > 0
118-
if not np.any(nonzero):
119-
keywords.append(dict())
120-
continue
121-
important_terms = np.squeeze(np.asarray(nonzero))
122-
word_embeddings = self.vocab_embeddings[important_terms]
123-
sim = cosine_similarity(embedding, word_embeddings)
126+
if self.keyword_scope == 'document':
127+
nonzero = terms > 0
128+
if not np.any(nonzero):
129+
keywords.append(dict())
130+
continue
131+
important_terms = np.squeeze(np.asarray(nonzero))
132+
word_embeddings = self.vocab_embeddings[important_terms]
133+
sim = cosine_similarity(embedding, word_embeddings)
134+
else:
135+
sim = cosine_similarity(embedding, self.vocab_embeddings)
124136
sim = np.ravel(sim)
125137
kth = min(self.top_n, len(sim) - 1)
126138
top = np.argpartition(-sim, kth)[:kth]

0 commit comments

Comments
 (0)