11import warnings
22from datetime import datetime
3- from typing import Literal , Optional , Union
3+ from typing import Iterable , Literal , Optional , Union
44
55import numpy as np
66from rich .console import Console
1818from turftopic .feature_importance import (bayes_rule ,
1919 cluster_centroid_distance , ctf_idf ,
2020 soft_ctf_idf )
21+ from turftopic .hierarchical import TopicNode
2122from turftopic .vectorizer import default_vectorizer
2223
2324integer_message = """
@@ -230,11 +231,12 @@ def _merge_agglomerative(self, n_reduce_to: int) -> np.ndarray:
230231 ]
231232 )
232233 old_labels = [label for label in self .classes_ if label != - 1 ]
233- new_labels = AgglomerativeClustering (
234+ clustering = AgglomerativeClustering (
234235 n_clusters = n_reduce_to ,
235236 metric = "cosine" ,
236237 linkage = "average" ,
237- ).fit_predict (interesting_topic_vectors )
238+ )
239+ new_labels = clustering .fit_predict (interesting_topic_vectors )
238240 res = {}
239241 if - 1 in self .classes_ :
240242 res [- 1 ] = - 1
@@ -254,6 +256,58 @@ def _merge_smallest(self, n_reduce_to: int):
254256 labels [labels == from_topic ] = to_topic
255257 return labels
256258
259+ def join_subtopics (
260+ self , subtopics : Iterable [int ], hierarchy : Optional [TopicNode ] = None
261+ ) -> TopicNode :
262+ """Joins subtopics in a topic hierarchy and returns the joint TopicNode.
263+ > Note that this method does not alter the underlying hierarchy!
264+ > You will need to use the join() method of a hierarchy for that.
265+
266+ Parameters
267+ ----------
268+ subtopics: iterable of int
269+ Indices of subtopics to be joint.
270+ hierarchy: TopicNode, default None
271+ Hierarchy to join subtopics in, defaults to the root hierarchy of the model.
272+
273+ Returns
274+ -------
275+ TopicNode
276+ New topic made up of the joint subtopics.
277+ """
278+ if hierarchy is None :
279+ hierarchy = self .hierarchy
280+ subtopics = list (set (subtopics ))
281+ slot = min (subtopics )
282+ max_subtopics = max (subtopics )
283+ if len (self .children ) < (max_subtopics - 1 ):
284+ raise ValueError (
285+ "These subtopics don't exist on the current node."
286+ )
287+ if slot < 0 :
288+ raise ValueError (
289+ "Outlier topics (-1) cannot be merged with other topics."
290+ )
291+ if self .children is None :
292+ raise ValueError (
293+ "Current Node is a leaf, children can't be joined."
294+ )
295+ path = (* hierarchy .path , slot )
296+ children = [self .hierarchy [sub ] for sub in subtopics ]
297+ doc_topic_vector = self .hierarchy .doc_topic_matrix [:, subtopics ].sum (
298+ axis = 1
299+ )
300+ rest = [
301+ doc_topic_vector
302+ for i_topic , doc_topic_vector in enumerate (
303+ self .hierarchy .doc_topic_matrix .T
304+ )
305+ if i_topic not in subtopics
306+ ]
307+ doc_topic_matrix = np .stack ([doc_topic_vector , rest ]).T
308+ # TODO
309+ pass
310+
257311 def reduce_topics (
258312 self ,
259313 n_reduce_to : int ,
@@ -286,6 +340,7 @@ def reduce_topics(
286340 self .labels_ = self ._merge_smallest (n_reduce_to )
287341 elif reduction_method == "agglomerative" :
288342 self .labels_ = self ._merge_agglomerative (n_reduce_to )
343+ self .estimate_components (self .feature_importance )
289344 return self .labels_
290345
291346 def reset_reduction (self ):
@@ -326,6 +381,10 @@ def estimate_components(
326381 )
327382 clusters = np .unique (self .labels_ )
328383 self .classes_ = np .sort (clusters )
384+ if - 1 in self .classes_ :
385+ # Putting outliers in the last position, so that when you index things,
386+ # it works.
387+ self .classes_ = np .array ([* self .classes_ [1 :], - 1 ])
329388 self .topic_sizes_ = np .array (
330389 [np .sum (self .labels_ == label ) for label in self .classes_ ]
331390 )
0 commit comments