Skip to content

Commit 91a769d

Browse files
authored
Merge branch 'main' into external-encoder-patch
2 parents 321655a + c9b3147 commit 91a769d

7 files changed

Lines changed: 174 additions & 76 deletions

File tree

.github/workflows/documentation.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@ jobs:
1515
deploy:
1616
runs-on: ubuntu-latest
1717
steps:
18-
- uses: actions/checkout@v3
18+
- uses: actions/checkout@v4
1919
- uses: actions/setup-python@v4
2020
with:
2121
python-version: '3.10'
2222

2323
- name: Dependencies
2424
run: |
2525
python -m pip install --upgrade pip
26-
pip install turftopic[pyro-ppl,docs]
26+
pip install "turftopic[pyro-ppl,docs]"
2727
2828
- name: Build and Deploy
2929
if: github.event_name == 'push'

.github/workflows/static.yml

Lines changed: 0 additions & 43 deletions
This file was deleted.

README.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@ I DO NOT recommend using this package for production and academic use!!!
2020
- [x] Pretty Printing
2121
- [x] Implement visualization utilites for these models in topicwizard
2222
- [x] Thorough documentation
23+
- [x] Dynamic modeling (currently `GMM` and `ClusteringTopicModel` others might follow)
2324
- [ ] Publish papers :hourglass_flowing_sand: (in progress..)
24-
- [ ] Dynamic modeling :hourglass_flowing_sand:
25-
- [ ] More Tutorials
2625
- [ ] High-level topic descriptions with LLMs.
2726
- [ ] Contextualized evaluation metrics.
2827

@@ -94,7 +93,7 @@ model.print_topics()
9493

9594
```python
9695
# Print highest ranking documents for topic 0
97-
model.print_highest_ranking_documents(0, corpus, document_topic_matrix)
96+
model.print_representative_documents(0, corpus, document_topic_matrix)
9897
```
9998

10099
<center>

docs/index.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ model.print_topics()
179179

180180
```python
181181
# Print highest ranking documents for topic 0
182-
model.print_highest_ranking_documents(0, corpus, document_topic_matrix)
182+
model.print_representative_documents(0, corpus, document_topic_matrix)
183183
```
184184

185185
<center>
@@ -217,7 +217,7 @@ csv_table: str = model.export_topic_distribution("something something", format="
217217

218218
latex_table: str = model.export_topics(format="latex")
219219

220-
md_table: str = model.export_highest_ranking_documents(0, corpus, document_topic_matrix, format="markdown")
220+
md_table: str = model.export_representative_documents(0, corpus, document_topic_matrix, format="markdown")
221221
```
222222

223223
### Visualization

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
[tool.black]
22
line-length=79
33

4+
[tool.ruff]
5+
line-length=79
6+
47
[tool.poetry]
58
name = "turftopic"
6-
version = "0.2.12"
9+
version = "0.2.13"
710
description = "Topic modeling with contextual representations from sentence transformers."
811
authors = ["Márton Kardos <power.up1163@gmail.com>"]
912
license = "MIT"

turftopic/base.py

Lines changed: 113 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@ def get_topics(
2929
"""Returns high-level topic representations in form of the top K words
3030
in each topic.
3131
32-
Parameters
33-
----------
34-
top_k: int, default 10
32+
Parameters ---------- top_k: int, default 10
3533
Number of top words to return for each topic.
3634
3735
Returns
@@ -62,22 +60,57 @@ def get_topics(
6260
return topics
6361

6462
def _topics_table(
65-
self, top_k: int = 10, show_scores: bool = False
63+
self,
64+
top_k: int = 10,
65+
show_scores: bool = False,
66+
show_negative: bool = False,
6667
) -> list[list[str]]:
67-
topics = self.get_topics(top_k)
68-
columns = ["Topic ID", f"Top {top_k} Words"]
68+
columns = ["Topic ID", "Highest Ranking"]
69+
if show_negative:
70+
columns.append("Lowest Ranking")
6971
rows = []
70-
for topic_id, terms in topics:
72+
try:
73+
classes = self.classes_
74+
except AttributeError:
75+
classes = list(range(self.components_.shape[0]))
76+
vocab = self.get_vocab()
77+
for topic_id, component in zip(classes, self.components_):
78+
highest = np.argpartition(-component, top_k)[:top_k]
79+
highest = highest[np.argsort(-component[highest])]
80+
lowest = np.argpartition(component, top_k)[:top_k]
81+
lowest = lowest[np.argsort(component[lowest])]
7182
if show_scores:
72-
concat_words = ", ".join(
73-
[f"{word}({importance:.2f})" for word, importance in terms]
83+
concat_positive = ", ".join(
84+
[
85+
f"{word}({importance:.2f})"
86+
for word, importance in zip(
87+
vocab[highest], component[highest]
88+
)
89+
]
90+
)
91+
concat_negative = ", ".join(
92+
[
93+
f"{word}({importance:.2f})"
94+
for word, importance in zip(
95+
vocab[lowest], component[lowest]
96+
)
97+
]
7498
)
7599
else:
76-
concat_words = ", ".join([word for word, importance in terms])
77-
rows.append([f"{topic_id}", f"{concat_words}"])
100+
concat_positive = ", ".join([word for word in vocab[highest]])
101+
concat_negative = ", ".join([word for word in vocab[lowest]])
102+
row = [f"{topic_id}", f"{concat_positive}"]
103+
if show_negative:
104+
row.append(concat_negative)
105+
rows.append(row)
78106
return [columns, *rows]
79107

80-
def print_topics(self, top_k: int = 10, show_scores: bool = False):
108+
def print_topics(
109+
self,
110+
top_k: int = 10,
111+
show_scores: bool = False,
112+
show_negative: bool = False,
113+
):
81114
"""Pretty prints topics in the model in a table.
82115
83116
Parameters
@@ -86,23 +119,36 @@ def print_topics(self, top_k: int = 10, show_scores: bool = False):
86119
Number of top words to return for each topic.
87120
show_scores: bool, default False
88121
Indicates whether to show importance scores for each word.
122+
show_negative: bool, default False
123+
Indicates whether the most negative terms should also be displayed.
89124
"""
90-
columns, *rows = self._topics_table(top_k, show_scores)
125+
columns, *rows = self._topics_table(top_k, show_scores, show_negative)
91126
table = Table(show_lines=True)
92-
table.add_column(columns[0], style="blue", justify="right")
127+
table.add_column("Topic ID", style="blue", justify="right")
93128
table.add_column(
94-
columns[1],
129+
"Highest Ranking",
95130
justify="left",
96131
style="magenta",
97132
max_width=100,
98133
)
134+
if show_negative:
135+
table.add_column(
136+
"Lowest Ranking",
137+
justify="left",
138+
style="red",
139+
max_width=100,
140+
)
99141
for row in rows:
100142
table.add_row(*row)
101143
console = Console()
102144
console.print(table)
103145

104146
def export_topics(
105-
self, top_k: int = 10, show_scores: bool = False, format: str = "csv"
147+
self,
148+
top_k: int = 10,
149+
show_scores: bool = False,
150+
show_negative: bool = False,
151+
format: str = "csv",
106152
) -> str:
107153
"""Exports top K words from topics in a table in a given format.
108154
Returns table as a pure string.
@@ -113,15 +159,24 @@ def export_topics(
113159
Number of top words to return for each topic.
114160
show_scores: bool, default False
115161
Indicates whether to show importance scores for each word.
162+
show_negative: bool, default False
163+
Indicates whether the most negative terms should also be displayed.
116164
format: 'csv', 'latex' or 'markdown'
117165
Specifies which format should be used.
118166
'csv', 'latex' and 'markdown' are supported.
119167
"""
120-
table = self._topics_table(top_k, show_scores)
168+
table = self._topics_table(
169+
top_k, show_scores, show_negative=show_negative
170+
)
121171
return export_table(table, format=format)
122172

123-
def _highest_ranking_docs(
124-
self, topic_id, raw_documents, document_topic_matrix=None, top_k=5
173+
def _representative_docs(
174+
self,
175+
topic_id,
176+
raw_documents,
177+
document_topic_matrix=None,
178+
top_k=5,
179+
show_negative: bool = False,
125180
) -> list[list[str]]:
126181
if document_topic_matrix is None:
127182
try:
@@ -154,10 +209,30 @@ def _highest_ranking_docs(
154209
if len(doc) > 300:
155210
doc = doc[:300] + "..."
156211
rows.append([doc, f"{score:.2f}"])
212+
if show_negative:
213+
rows.append(["...", ""])
214+
lowest = np.argpartition(document_topic_matrix[:, topic_id], kth)[
215+
:kth
216+
]
217+
lowest = lowest[
218+
np.argsort(document_topic_matrix[lowest, topic_id])
219+
]
220+
scores = document_topic_matrix[lowest, topic_id]
221+
for document_id, score in zip(lowest, scores):
222+
doc = raw_documents[document_id]
223+
doc = remove_whitespace(doc)
224+
if len(doc) > 300:
225+
doc = doc[:300] + "..."
226+
rows.append([doc, f"{score:.2f}"])
157227
return [columns, *rows]
158228

159-
def print_highest_ranking_documents(
160-
self, topic_id, raw_documents, document_topic_matrix=None, top_k=5
229+
def print_representative_documents(
230+
self,
231+
topic_id,
232+
raw_documents,
233+
document_topic_matrix=None,
234+
top_k=5,
235+
show_negative: bool = False,
161236
):
162237
"""Pretty prints the highest ranking documents in a topic.
163238
@@ -172,9 +247,15 @@ def print_highest_ranking_documents(
172247
as they cannot infer topics from text.
173248
top_k: int, default 5
174249
Top K documents to show.
250+
show_negative: bool, default False
251+
Indicates whether lowest ranking documents should also be shown.
175252
"""
176-
columns, *rows = self._highest_ranking_docs(
177-
topic_id, raw_documents, document_topic_matrix, top_k
253+
columns, *rows = self._representative_docs(
254+
topic_id,
255+
raw_documents,
256+
document_topic_matrix,
257+
top_k,
258+
show_negative,
178259
)
179260
table = Table(show_lines=True)
180261
table.add_column(
@@ -186,12 +267,13 @@ def print_highest_ranking_documents(
186267
console = Console()
187268
console.print(table)
188269

189-
def export_highest_ranking_documents(
270+
def export_representative_documents(
190271
self,
191272
topic_id,
192273
raw_documents,
193274
document_topic_matrix=None,
194275
top_k=5,
276+
show_negative: bool = False,
195277
format: str = "csv",
196278
):
197279
"""Exports the highest ranking documents in a topic as a text table.
@@ -207,12 +289,18 @@ def export_highest_ranking_documents(
207289
as they cannot infer topics from text.
208290
top_k: int, default 5
209291
Top K documents to show.
292+
show_negative: bool, default False
293+
Indicates whether lowest ranking documents should also be shown.
210294
format: 'csv', 'latex' or 'markdown'
211295
Specifies which format should be used.
212296
'csv', 'latex' and 'markdown' are supported.
213297
"""
214298
table = self._highest_ranking_docs(
215-
topic_id, raw_documents, document_topic_matrix, top_k
299+
topic_id,
300+
raw_documents,
301+
document_topic_matrix,
302+
top_k,
303+
show_negative,
216304
)
217305
return export_table(table, format=format)
218306

0 commit comments

Comments
 (0)