Skip to content

Commit 1639f28

Browse files
Merge pull request #31 from x-tabdeveloping/positive_negative
WIP: Positive/negative highest ranking terms
2 parents 1da5425 + 99e9744 commit 1639f28

5 files changed

Lines changed: 171 additions & 29 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ model.print_topics()
9393

9494
```python
9595
# Print highest ranking documents for topic 0
96-
model.print_highest_ranking_documents(0, corpus, document_topic_matrix)
96+
model.print_representative_documents(0, corpus, document_topic_matrix)
9797
```
9898

9999
<center>

docs/index.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ model.print_topics()
179179

180180
```python
181181
# Print highest ranking documents for topic 0
182-
model.print_highest_ranking_documents(0, corpus, document_topic_matrix)
182+
model.print_representative_documents(0, corpus, document_topic_matrix)
183183
```
184184

185185
<center>
@@ -217,7 +217,7 @@ csv_table: str = model.export_topic_distribution("something something", format="
217217

218218
latex_table: str = model.export_topics(format="latex")
219219

220-
md_table: str = model.export_highest_ranking_documents(0, corpus, document_topic_matrix, format="markdown")
220+
md_table: str = model.export_representative_documents(0, corpus, document_topic_matrix, format="markdown")
221221
```
222222

223223
### Visualization

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
[tool.black]
22
line-length=79
33

4+
[tool.ruff]
5+
line-length=79
6+
47
[tool.poetry]
58
name = "turftopic"
6-
version = "0.2.12"
9+
version = "0.2.13"
710
description = "Topic modeling with contextual representations from sentence transformers."
811
authors = ["Márton Kardos <power.up1163@gmail.com>"]
912
license = "MIT"

turftopic/base.py

Lines changed: 113 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@ def get_topics(
2929
"""Returns high-level topic representations in form of the top K words
3030
in each topic.
3131
32-
Parameters
33-
----------
34-
top_k: int, default 10
32+
Parameters ---------- top_k: int, default 10
3533
Number of top words to return for each topic.
3634
3735
Returns
@@ -62,22 +60,57 @@ def get_topics(
6260
return topics
6361

6462
def _topics_table(
65-
self, top_k: int = 10, show_scores: bool = False
63+
self,
64+
top_k: int = 10,
65+
show_scores: bool = False,
66+
show_negative: bool = False,
6667
) -> list[list[str]]:
67-
topics = self.get_topics(top_k)
68-
columns = ["Topic ID", f"Top {top_k} Words"]
68+
columns = ["Topic ID", "Highest Ranking"]
69+
if show_negative:
70+
columns.append("Lowest Ranking")
6971
rows = []
70-
for topic_id, terms in topics:
72+
try:
73+
classes = self.classes_
74+
except AttributeError:
75+
classes = list(range(self.components_.shape[0]))
76+
vocab = self.get_vocab()
77+
for topic_id, component in zip(classes, self.components_):
78+
highest = np.argpartition(-component, top_k)[:top_k]
79+
highest = highest[np.argsort(-component[highest])]
80+
lowest = np.argpartition(component, top_k)[:top_k]
81+
lowest = lowest[np.argsort(component[lowest])]
7182
if show_scores:
72-
concat_words = ", ".join(
73-
[f"{word}({importance:.2f})" for word, importance in terms]
83+
concat_positive = ", ".join(
84+
[
85+
f"{word}({importance:.2f})"
86+
for word, importance in zip(
87+
vocab[highest], component[highest]
88+
)
89+
]
90+
)
91+
concat_negative = ", ".join(
92+
[
93+
f"{word}({importance:.2f})"
94+
for word, importance in zip(
95+
vocab[lowest], component[lowest]
96+
)
97+
]
7498
)
7599
else:
76-
concat_words = ", ".join([word for word, importance in terms])
77-
rows.append([f"{topic_id}", f"{concat_words}"])
100+
concat_positive = ", ".join([word for word in vocab[highest]])
101+
concat_negative = ", ".join([word for word in vocab[lowest]])
102+
row = [f"{topic_id}", f"{concat_positive}"]
103+
if show_negative:
104+
row.append(concat_negative)
105+
rows.append(row)
78106
return [columns, *rows]
79107

80-
def print_topics(self, top_k: int = 10, show_scores: bool = False):
108+
def print_topics(
109+
self,
110+
top_k: int = 10,
111+
show_scores: bool = False,
112+
show_negative: bool = False,
113+
):
81114
"""Pretty prints topics in the model in a table.
82115
83116
Parameters
@@ -86,23 +119,36 @@ def print_topics(self, top_k: int = 10, show_scores: bool = False):
86119
Number of top words to return for each topic.
87120
show_scores: bool, default False
88121
Indicates whether to show importance scores for each word.
122+
show_negative: bool, default False
123+
Indicates whether the most negative terms should also be displayed.
89124
"""
90-
columns, *rows = self._topics_table(top_k, show_scores)
125+
columns, *rows = self._topics_table(top_k, show_scores, show_negative)
91126
table = Table(show_lines=True)
92-
table.add_column(columns[0], style="blue", justify="right")
127+
table.add_column("Topic ID", style="blue", justify="right")
93128
table.add_column(
94-
columns[1],
129+
"Highest Ranking",
95130
justify="left",
96131
style="magenta",
97132
max_width=100,
98133
)
134+
if show_negative:
135+
table.add_column(
136+
"Lowest Ranking",
137+
justify="left",
138+
style="red",
139+
max_width=100,
140+
)
99141
for row in rows:
100142
table.add_row(*row)
101143
console = Console()
102144
console.print(table)
103145

104146
def export_topics(
105-
self, top_k: int = 10, show_scores: bool = False, format: str = "csv"
147+
self,
148+
top_k: int = 10,
149+
show_scores: bool = False,
150+
show_negative: bool = False,
151+
format: str = "csv",
106152
) -> str:
107153
"""Exports top K words from topics in a table in a given format.
108154
Returns table as a pure string.
@@ -113,15 +159,24 @@ def export_topics(
113159
Number of top words to return for each topic.
114160
show_scores: bool, default False
115161
Indicates whether to show importance scores for each word.
162+
show_negative: bool, default False
163+
Indicates whether the most negative terms should also be displayed.
116164
format: 'csv', 'latex' or 'markdown'
117165
Specifies which format should be used.
118166
'csv', 'latex' and 'markdown' are supported.
119167
"""
120-
table = self._topics_table(top_k, show_scores)
168+
table = self._topics_table(
169+
top_k, show_scores, show_negative=show_negative
170+
)
121171
return export_table(table, format=format)
122172

123-
def _highest_ranking_docs(
124-
self, topic_id, raw_documents, document_topic_matrix=None, top_k=5
173+
def _representative_docs(
174+
self,
175+
topic_id,
176+
raw_documents,
177+
document_topic_matrix=None,
178+
top_k=5,
179+
show_negative: bool = False,
125180
) -> list[list[str]]:
126181
if document_topic_matrix is None:
127182
try:
@@ -154,10 +209,30 @@ def _highest_ranking_docs(
154209
if len(doc) > 300:
155210
doc = doc[:300] + "..."
156211
rows.append([doc, f"{score:.2f}"])
212+
if show_negative:
213+
rows.append(["...", ""])
214+
lowest = np.argpartition(document_topic_matrix[:, topic_id], kth)[
215+
:kth
216+
]
217+
lowest = lowest[
218+
np.argsort(document_topic_matrix[lowest, topic_id])
219+
]
220+
scores = document_topic_matrix[lowest, topic_id]
221+
for document_id, score in zip(lowest, scores):
222+
doc = raw_documents[document_id]
223+
doc = remove_whitespace(doc)
224+
if len(doc) > 300:
225+
doc = doc[:300] + "..."
226+
rows.append([doc, f"{score:.2f}"])
157227
return [columns, *rows]
158228

159-
def print_highest_ranking_documents(
160-
self, topic_id, raw_documents, document_topic_matrix=None, top_k=5
229+
def print_representative_documents(
230+
self,
231+
topic_id,
232+
raw_documents,
233+
document_topic_matrix=None,
234+
top_k=5,
235+
show_negative: bool = False,
161236
):
162237
"""Pretty prints the highest ranking documents in a topic.
163238
@@ -172,9 +247,15 @@ def print_highest_ranking_documents(
172247
as they cannot infer topics from text.
173248
top_k: int, default 5
174249
Top K documents to show.
250+
show_negative: bool, default False
251+
Indicates whether lowest ranking documents should also be shown.
175252
"""
176-
columns, *rows = self._highest_ranking_docs(
177-
topic_id, raw_documents, document_topic_matrix, top_k
253+
columns, *rows = self._representative_docs(
254+
topic_id,
255+
raw_documents,
256+
document_topic_matrix,
257+
top_k,
258+
show_negative,
178259
)
179260
table = Table(show_lines=True)
180261
table.add_column(
@@ -186,12 +267,13 @@ def print_highest_ranking_documents(
186267
console = Console()
187268
console.print(table)
188269

189-
def export_highest_ranking_documents(
270+
def export_representative_documents(
190271
self,
191272
topic_id,
192273
raw_documents,
193274
document_topic_matrix=None,
194275
top_k=5,
276+
show_negative: bool = False,
195277
format: str = "csv",
196278
):
197279
"""Exports the highest ranking documents in a topic as a text table.
@@ -207,12 +289,18 @@ def export_highest_ranking_documents(
207289
as they cannot infer topics from text.
208290
top_k: int, default 5
209291
Top K documents to show.
292+
show_negative: bool, default False
293+
Indicates whether lowest ranking documents should also be shown.
210294
format: 'csv', 'latex' or 'markdown'
211295
Specifies which format should be used.
212296
'csv', 'latex' and 'markdown' are supported.
213297
"""
214298
table = self._highest_ranking_docs(
215-
topic_id, raw_documents, document_topic_matrix, top_k
299+
topic_id,
300+
raw_documents,
301+
document_topic_matrix,
302+
top_k,
303+
show_negative,
216304
)
217305
return export_table(table, format=format)
218306

turftopic/models/decomp.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,54 @@ def transform(
108108
if embeddings is None:
109109
embeddings = self.encoder_.encode(raw_documents)
110110
return self.decomposition.transform(embeddings)
111+
112+
def print_topics(
113+
self,
114+
top_k: int = 5,
115+
show_scores: bool = False,
116+
show_negative: bool = True,
117+
):
118+
super().print_topics(top_k, show_scores, show_negative)
119+
120+
def export_topics(
121+
self,
122+
top_k: int = 5,
123+
show_scores: bool = False,
124+
show_negative: bool = True,
125+
format: str = "csv",
126+
) -> str:
127+
return super().export_topics(top_k, show_scores, show_negative, format)
128+
129+
def print_representative_documents(
130+
self,
131+
topic_id,
132+
raw_documents,
133+
document_topic_matrix=None,
134+
top_k=5,
135+
show_negative: bool = True,
136+
):
137+
super().print_representative_documents(
138+
topic_id,
139+
raw_documents,
140+
document_topic_matrix,
141+
top_k,
142+
show_negative,
143+
)
144+
145+
def export_representative_documents(
146+
self,
147+
topic_id,
148+
raw_documents,
149+
document_topic_matrix=None,
150+
top_k=5,
151+
show_negative: bool = True,
152+
format: str = "csv",
153+
):
154+
return super().export_representative_documents(
155+
topic_id,
156+
raw_documents,
157+
document_topic_matrix,
158+
top_k,
159+
show_negative,
160+
format,
161+
)

0 commit comments

Comments
 (0)