Skip to content

Commit c6ac90b

Browse files
Started implementing hierarchical topic joining in clustering models
1 parent b8688e1 commit c6ac90b

2 files changed

Lines changed: 110 additions & 3 deletions

File tree

turftopic/hierarchical.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,20 @@ class TopicNode:
116116
document_topic_vector: Optional[np.ndarray] = None
117117
children: Optional[list[TopicNode]] = None
118118

119+
@property
120+
def components_(self) -> np.ndarray:
121+
if self.children is None:
122+
raise ValueError("Current node is a leaf, no components.")
123+
return np.stack([child.word_importance for child in self.children])
124+
125+
@property
126+
def doc_topic_matrix(self) -> np.ndarray:
127+
if self.children is None:
128+
raise ValueError("Current node is a leaf, no doc_topic_matrix.")
129+
return np.stack(
130+
[child.document_topic_vector for child in self.children]
131+
).T
132+
119133
@classmethod
120134
def create_root(
121135
cls,
@@ -146,6 +160,14 @@ def create_root(
146160
children=children,
147161
)
148162

163+
def set_path(self, path: tuple[int]):
164+
"""Sets path for current node and all children accordingly."""
165+
self.path = path
166+
if self.children is None:
167+
return
168+
for i_child, child in enumerate(self.children):
169+
child.set_path((*self.path, i_child))
170+
149171
@property
150172
def level(self) -> int:
151173
"""Indicates how deep down the hierarchy the topic is."""
@@ -275,3 +297,29 @@ def divide_children(self, n_subtopics: int, **kwargs):
275297
def plot_tree(self):
276298
"""Plots hierarchy as an interactive tree in Plotly."""
277299
return _tree_plot(self)
300+
301+
def join(self, *subtopics: int, **kwargs):
302+
slot = min(subtopics)
303+
max_subtopics = max(subtopics)
304+
if len(self.children) < (max_subtopics - 1):
305+
raise ValueError(
306+
"These subtopics don't exist on the current node."
307+
)
308+
if slot < 0:
309+
raise ValueError(
310+
"Outlier topics (-1) cannot be merged with other topics."
311+
)
312+
if self.children is None:
313+
raise ValueError(
314+
"Current Node is a leaf, children can't be joined."
315+
)
316+
try:
317+
self.children[slot] = self.model.join_subtopics(
318+
subtopics, self, **kwargs
319+
)
320+
self.set_path(self.path)
321+
except AttributeError as e:
322+
raise AttributeError(
323+
"Looks like your model is not an agglomerative hierarchical model."
324+
) from e
325+
return self

turftopic/models/cluster.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import warnings
22
from datetime import datetime
3-
from typing import Literal, Optional, Union
3+
from typing import Iterable, Literal, Optional, Union
44

55
import numpy as np
66
from rich.console import Console
@@ -18,6 +18,7 @@
1818
from turftopic.feature_importance import (bayes_rule,
1919
cluster_centroid_distance, ctf_idf,
2020
soft_ctf_idf)
21+
from turftopic.hierarchical import TopicNode
2122
from turftopic.vectorizer import default_vectorizer
2223

2324
integer_message = """
@@ -230,11 +231,12 @@ def _merge_agglomerative(self, n_reduce_to: int) -> np.ndarray:
230231
]
231232
)
232233
old_labels = [label for label in self.classes_ if label != -1]
233-
new_labels = AgglomerativeClustering(
234+
clustering = AgglomerativeClustering(
234235
n_clusters=n_reduce_to,
235236
metric="cosine",
236237
linkage="average",
237-
).fit_predict(interesting_topic_vectors)
238+
)
239+
new_labels = clustering.fit_predict(interesting_topic_vectors)
238240
res = {}
239241
if -1 in self.classes_:
240242
res[-1] = -1
@@ -254,6 +256,58 @@ def _merge_smallest(self, n_reduce_to: int):
254256
labels[labels == from_topic] = to_topic
255257
return labels
256258

259+
def join_subtopics(
260+
self, subtopics: Iterable[int], hierarchy: Optional[TopicNode] = None
261+
) -> TopicNode:
262+
"""Joins subtopics in a topic hierarchy and returns the joint TopicNode.
263+
> Note that this method does not alter the underlying hierarchy!
264+
> You will need to use the join() method of a hierarchy for that.
265+
266+
Parameters
267+
----------
268+
subtopics: iterable of int
269+
Indices of subtopics to be joint.
270+
hierarchy: TopicNode, default None
271+
Hierarchy to join subtopics in, defaults to the root hierarchy of the model.
272+
273+
Returns
274+
-------
275+
TopicNode
276+
New topic made up of the joint subtopics.
277+
"""
278+
if hierarchy is None:
279+
hierarchy = self.hierarchy
280+
subtopics = list(set(subtopics))
281+
slot = min(subtopics)
282+
max_subtopics = max(subtopics)
283+
if len(self.children) < (max_subtopics - 1):
284+
raise ValueError(
285+
"These subtopics don't exist on the current node."
286+
)
287+
if slot < 0:
288+
raise ValueError(
289+
"Outlier topics (-1) cannot be merged with other topics."
290+
)
291+
if self.children is None:
292+
raise ValueError(
293+
"Current Node is a leaf, children can't be joined."
294+
)
295+
path = (*hierarchy.path, slot)
296+
children = [self.hierarchy[sub] for sub in subtopics]
297+
doc_topic_vector = self.hierarchy.doc_topic_matrix[:, subtopics].sum(
298+
axis=1
299+
)
300+
rest = [
301+
doc_topic_vector
302+
for i_topic, doc_topic_vector in enumerate(
303+
self.hierarchy.doc_topic_matrix.T
304+
)
305+
if i_topic not in subtopics
306+
]
307+
doc_topic_matrix = np.stack([doc_topic_vector, rest]).T
308+
# TODO
309+
pass
310+
257311
def reduce_topics(
258312
self,
259313
n_reduce_to: int,
@@ -286,6 +340,7 @@ def reduce_topics(
286340
self.labels_ = self._merge_smallest(n_reduce_to)
287341
elif reduction_method == "agglomerative":
288342
self.labels_ = self._merge_agglomerative(n_reduce_to)
343+
self.estimate_components(self.feature_importance)
289344
return self.labels_
290345

291346
def reset_reduction(self):
@@ -326,6 +381,10 @@ def estimate_components(
326381
)
327382
clusters = np.unique(self.labels_)
328383
self.classes_ = np.sort(clusters)
384+
if -1 in self.classes_:
385+
# Putting outliers in the last position, so that when you index things,
386+
# it works.
387+
self.classes_ = np.array([*self.classes_[1:], -1])
329388
self.topic_sizes_ = np.array(
330389
[np.sum(self.labels_ == label) for label in self.classes_]
331390
)

0 commit comments

Comments
 (0)