@@ -29,9 +29,7 @@ def get_topics(
2929 """Returns high-level topic representations in form of the top K words
3030 in each topic.
3131
32- Parameters
33- ----------
34- top_k: int, default 10
32+ Parameters ---------- top_k: int, default 10
3533 Number of top words to return for each topic.
3634
3735 Returns
@@ -62,22 +60,57 @@ def get_topics(
6260 return topics
6361
6462 def _topics_table (
65- self , top_k : int = 10 , show_scores : bool = False
63+ self ,
64+ top_k : int = 10 ,
65+ show_scores : bool = False ,
66+ show_negative : bool = False ,
6667 ) -> list [list [str ]]:
67- topics = self .get_topics (top_k )
68- columns = ["Topic ID" , f"Top { top_k } Words" ]
68+ columns = ["Topic ID" , "Highest Ranking" ]
69+ if show_negative :
70+ columns .append ("Lowest Ranking" )
6971 rows = []
70- for topic_id , terms in topics :
72+ try :
73+ classes = self .classes_
74+ except AttributeError :
75+ classes = list (range (self .components_ .shape [0 ]))
76+ vocab = self .get_vocab ()
77+ for topic_id , component in zip (classes , self .components_ ):
78+ highest = np .argpartition (- component , top_k )[:top_k ]
79+ highest = highest [np .argsort (- component [highest ])]
80+ lowest = np .argpartition (component , top_k )[:top_k ]
81+ lowest = lowest [np .argsort (component [lowest ])]
7182 if show_scores :
72- concat_words = ", " .join (
73- [f"{ word } ({ importance :.2f} )" for word , importance in terms ]
83+ concat_positive = ", " .join (
84+ [
85+ f"{ word } ({ importance :.2f} )"
86+ for word , importance in zip (
87+ vocab [highest ], component [highest ]
88+ )
89+ ]
90+ )
91+ concat_negative = ", " .join (
92+ [
93+ f"{ word } ({ importance :.2f} )"
94+ for word , importance in zip (
95+ vocab [lowest ], component [lowest ]
96+ )
97+ ]
7498 )
7599 else :
76- concat_words = ", " .join ([word for word , importance in terms ])
77- rows .append ([f"{ topic_id } " , f"{ concat_words } " ])
100+ concat_positive = ", " .join ([word for word in vocab [highest ]])
101+ concat_negative = ", " .join ([word for word in vocab [lowest ]])
102+ row = [f"{ topic_id } " , f"{ concat_positive } " ]
103+ if show_negative :
104+ row .append (concat_negative )
105+ rows .append (row )
78106 return [columns , * rows ]
79107
80- def print_topics (self , top_k : int = 10 , show_scores : bool = False ):
108+ def print_topics (
109+ self ,
110+ top_k : int = 10 ,
111+ show_scores : bool = False ,
112+ show_negative : bool = False ,
113+ ):
81114 """Pretty prints topics in the model in a table.
82115
83116 Parameters
@@ -86,23 +119,36 @@ def print_topics(self, top_k: int = 10, show_scores: bool = False):
86119 Number of top words to return for each topic.
87120 show_scores: bool, default False
88121 Indicates whether to show importance scores for each word.
122+ show_negative: bool, default False
123+ Indicates whether the most negative terms should also be displayed.
89124 """
90- columns , * rows = self ._topics_table (top_k , show_scores )
125+ columns , * rows = self ._topics_table (top_k , show_scores , show_negative )
91126 table = Table (show_lines = True )
92- table .add_column (columns [ 0 ] , style = "blue" , justify = "right" )
127+ table .add_column ("Topic ID" , style = "blue" , justify = "right" )
93128 table .add_column (
94- columns [ 1 ] ,
129+ "Highest Ranking" ,
95130 justify = "left" ,
96131 style = "magenta" ,
97132 max_width = 100 ,
98133 )
134+ if show_negative :
135+ table .add_column (
136+ "Lowest Ranking" ,
137+ justify = "left" ,
138+ style = "red" ,
139+ max_width = 100 ,
140+ )
99141 for row in rows :
100142 table .add_row (* row )
101143 console = Console ()
102144 console .print (table )
103145
104146 def export_topics (
105- self , top_k : int = 10 , show_scores : bool = False , format : str = "csv"
147+ self ,
148+ top_k : int = 10 ,
149+ show_scores : bool = False ,
150+ show_negative : bool = False ,
151+ format : str = "csv" ,
106152 ) -> str :
107153 """Exports top K words from topics in a table in a given format.
108154 Returns table as a pure string.
@@ -113,15 +159,24 @@ def export_topics(
113159 Number of top words to return for each topic.
114160 show_scores: bool, default False
115161 Indicates whether to show importance scores for each word.
162+ show_negative: bool, default False
163+ Indicates whether the most negative terms should also be displayed.
116164 format: 'csv', 'latex' or 'markdown'
117165 Specifies which format should be used.
118166 'csv', 'latex' and 'markdown' are supported.
119167 """
120- table = self ._topics_table (top_k , show_scores )
168+ table = self ._topics_table (
169+ top_k , show_scores , show_negative = show_negative
170+ )
121171 return export_table (table , format = format )
122172
123- def _highest_ranking_docs (
124- self , topic_id , raw_documents , document_topic_matrix = None , top_k = 5
173+ def _representative_docs (
174+ self ,
175+ topic_id ,
176+ raw_documents ,
177+ document_topic_matrix = None ,
178+ top_k = 5 ,
179+ show_negative : bool = False ,
125180 ) -> list [list [str ]]:
126181 if document_topic_matrix is None :
127182 try :
@@ -154,10 +209,30 @@ def _highest_ranking_docs(
154209 if len (doc ) > 300 :
155210 doc = doc [:300 ] + "..."
156211 rows .append ([doc , f"{ score :.2f} " ])
212+ if show_negative :
213+ rows .append (["..." , "" ])
214+ lowest = np .argpartition (document_topic_matrix [:, topic_id ], kth )[
215+ :kth
216+ ]
217+ lowest = lowest [
218+ np .argsort (document_topic_matrix [lowest , topic_id ])
219+ ]
220+ scores = document_topic_matrix [lowest , topic_id ]
221+ for document_id , score in zip (lowest , scores ):
222+ doc = raw_documents [document_id ]
223+ doc = remove_whitespace (doc )
224+ if len (doc ) > 300 :
225+ doc = doc [:300 ] + "..."
226+ rows .append ([doc , f"{ score :.2f} " ])
157227 return [columns , * rows ]
158228
159- def print_highest_ranking_documents (
160- self , topic_id , raw_documents , document_topic_matrix = None , top_k = 5
229+ def print_representative_documents (
230+ self ,
231+ topic_id ,
232+ raw_documents ,
233+ document_topic_matrix = None ,
234+ top_k = 5 ,
235+ show_negative : bool = False ,
161236 ):
162237 """Pretty prints the highest ranking documents in a topic.
163238
@@ -172,9 +247,15 @@ def print_highest_ranking_documents(
172247 as they cannot infer topics from text.
173248 top_k: int, default 5
174249 Top K documents to show.
250+ show_negative: bool, default False
251+ Indicates whether lowest ranking documents should also be shown.
175252 """
176- columns , * rows = self ._highest_ranking_docs (
177- topic_id , raw_documents , document_topic_matrix , top_k
253+ columns , * rows = self ._representative_docs (
254+ topic_id ,
255+ raw_documents ,
256+ document_topic_matrix ,
257+ top_k ,
258+ show_negative ,
178259 )
179260 table = Table (show_lines = True )
180261 table .add_column (
@@ -186,12 +267,13 @@ def print_highest_ranking_documents(
186267 console = Console ()
187268 console .print (table )
188269
189- def export_highest_ranking_documents (
270+ def export_representative_documents (
190271 self ,
191272 topic_id ,
192273 raw_documents ,
193274 document_topic_matrix = None ,
194275 top_k = 5 ,
276+ show_negative : bool = False ,
195277 format : str = "csv" ,
196278 ):
197279 """Exports the highest ranking documents in a topic as a text table.
@@ -207,12 +289,18 @@ def export_highest_ranking_documents(
207289 as they cannot infer topics from text.
208290 top_k: int, default 5
209291 Top K documents to show.
292+ show_negative: bool, default False
293+ Indicates whether lowest ranking documents should also be shown.
210294 format: 'csv', 'latex' or 'markdown'
211295 Specifies which format should be used.
212296 'csv', 'latex' and 'markdown' are supported.
213297 """
214298 table = self ._highest_ranking_docs (
215- topic_id , raw_documents , document_topic_matrix , top_k
299+ topic_id ,
300+ raw_documents ,
301+ document_topic_matrix ,
302+ top_k ,
303+ show_negative ,
216304 )
217305 return export_table (table , format = format )
218306
0 commit comments