11import warnings
22from datetime import datetime
3- from typing import Iterable , Literal , Optional , Union
3+ from typing import Literal , Optional , Union
44
55import numpy as np
66from rich .console import Console
1818from turftopic .feature_importance import (bayes_rule ,
1919 cluster_centroid_distance , ctf_idf ,
2020 soft_ctf_idf )
21- from turftopic .hierarchical import TopicNode
2221from turftopic .vectorizer import default_vectorizer
2322
2423integer_message = """
@@ -231,12 +230,11 @@ def _merge_agglomerative(self, n_reduce_to: int) -> np.ndarray:
231230 ]
232231 )
233232 old_labels = [label for label in self .classes_ if label != - 1 ]
234- clustering = AgglomerativeClustering (
233+ new_labels = AgglomerativeClustering (
235234 n_clusters = n_reduce_to ,
236235 metric = "cosine" ,
237236 linkage = "average" ,
238- )
239- new_labels = clustering .fit_predict (interesting_topic_vectors )
237+ ).fit_predict (interesting_topic_vectors )
240238 res = {}
241239 if - 1 in self .classes_ :
242240 res [- 1 ] = - 1
@@ -256,58 +254,6 @@ def _merge_smallest(self, n_reduce_to: int):
256254 labels [labels == from_topic ] = to_topic
257255 return labels
258256
259- def join_subtopics (
260- self , subtopics : Iterable [int ], hierarchy : Optional [TopicNode ] = None
261- ) -> TopicNode :
262- """Joins subtopics in a topic hierarchy and returns the joint TopicNode.
263- > Note that this method does not alter the underlying hierarchy!
264- > You will need to use the join() method of a hierarchy for that.
265-
266- Parameters
267- ----------
268- subtopics: iterable of int
269- Indices of subtopics to be joint.
270- hierarchy: TopicNode, default None
271- Hierarchy to join subtopics in, defaults to the root hierarchy of the model.
272-
273- Returns
274- -------
275- TopicNode
276- New topic made up of the joint subtopics.
277- """
278- if hierarchy is None :
279- hierarchy = self .hierarchy
280- subtopics = list (set (subtopics ))
281- slot = min (subtopics )
282- max_subtopics = max (subtopics )
283- if len (self .children ) < (max_subtopics - 1 ):
284- raise ValueError (
285- "These subtopics don't exist on the current node."
286- )
287- if slot < 0 :
288- raise ValueError (
289- "Outlier topics (-1) cannot be merged with other topics."
290- )
291- if self .children is None :
292- raise ValueError (
293- "Current Node is a leaf, children can't be joined."
294- )
295- path = (* hierarchy .path , slot )
296- children = [self .hierarchy [sub ] for sub in subtopics ]
297- doc_topic_vector = self .hierarchy .doc_topic_matrix [:, subtopics ].sum (
298- axis = 1
299- )
300- rest = [
301- doc_topic_vector
302- for i_topic , doc_topic_vector in enumerate (
303- self .hierarchy .doc_topic_matrix .T
304- )
305- if i_topic not in subtopics
306- ]
307- doc_topic_matrix = np .stack ([doc_topic_vector , rest ]).T
308- # TODO
309- pass
310-
311257 def reduce_topics (
312258 self ,
313259 n_reduce_to : int ,
@@ -340,7 +286,6 @@ def reduce_topics(
340286 self .labels_ = self ._merge_smallest (n_reduce_to )
341287 elif reduction_method == "agglomerative" :
342288 self .labels_ = self ._merge_agglomerative (n_reduce_to )
343- self .estimate_components (self .feature_importance )
344289 return self .labels_
345290
346291 def reset_reduction (self ):
@@ -381,10 +326,6 @@ def estimate_components(
381326 )
382327 clusters = np .unique (self .labels_ )
383328 self .classes_ = np .sort (clusters )
384- if - 1 in self .classes_ :
385- # Putting outliers in the last position, so that when you index things,
386- # it works.
387- self .classes_ = np .array ([* self .classes_ [1 :], - 1 ])
388329 self .topic_sizes_ = np .array (
389330 [np .sum (self .labels_ == label ) for label in self .classes_ ]
390331 )
0 commit comments