Skip to content

Commit 49020d2

Browse files
Merge branch 'main' into bayes_rule
2 parents 7ee21fd + 21fd848 commit 49020d2

3 files changed

Lines changed: 57 additions & 19 deletions

File tree

README.md

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,20 +34,17 @@ model.hierarchy.divide_children(n_subtopics=3)
3434
print(model.hierarchy)
3535
```
3636

37-
<div style="background-color: #F5F5F5; padding: 10px; padding-left: 20px; padding-right: 20px;">
38-
<tt style="font-size: 11pt">
39-
<b>Root </b><br>
40-
├── <b style="color: blue">0</b>: windows, dos, os, disk, card, drivers, file, pc, files, microsoft <br>
41-
│ ├── <b style="color: magenta">0.0</b>: dos, file, disk, files, program, windows, disks, shareware, norton, memory <br>
42-
│ ├── <b style="color: magenta">0.1</b>: os, unix, windows, microsoft, apps, nt, ibm, ms, os2, platform <br>
43-
│ └── <b style="color: magenta">0.2</b>: card, drivers, monitor, driver, vga, ram, motherboard, cards, graphics, ati <br>
44-
└── <b style="color: blue">1</b>: atheism, atheist, atheists, religion, christians, religious, belief, christian, god, beliefs <br>
45-
. ├── <b style="color: magenta">1.0</b>: atheism, alt, newsgroup, reading, faq, islam, questions, read, newsgroups, readers <br>
46-
. ├── <b style="color: magenta">1.1</b>: atheists, atheist, belief, theists, beliefs, religious, religion, agnostic, gods, religions <br>
47-
. └── <b style="color: magenta">1.2</b>: morality, bible, christian, christians, moral, christianity, biblical, immoral, god, religion <br>
48-
</tt>
49-
</div>
50-
37+
```
38+
Root
39+
├── windows, dos, os, disk, card, drivers, file, pc, files, microsoft
40+
│ ├── 0.0: dos, file, disk, files, program, windows, disks, shareware, norton, memory
41+
│ ├── 0.1: os, unix, windows, microsoft, apps, nt, ibm, ms, os2, platform
42+
│ └── 0.2: card, drivers, monitor, driver, vga, ram, motherboard, cards, graphics, ati
43+
└── 1: atheism, atheist, atheists, religion, christians, religious, belief, christian, god, beliefs
44+
. ├── 1.0: atheism, alt, newsgroup, reading, faq, islam, questions, read, newsgroups, readers
45+
. ├── 1.1: atheists, atheist, belief, theists, beliefs, religious, religion, agnostic, gods, religions
46+
. └── 1.2: morality, bible, christian, christians, moral, christianity, biblical, immoral, god, religion
47+
```
5148

5249
#### FASTopic *(Experimental)*
5350

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ line-length=79
66

77
[tool.poetry]
88
name = "turftopic"
9-
version = "0.5.1"
9+
version = "0.5.4"
1010
description = "Topic modeling with contextual representations from sentence transformers."
1111
authors = ["Márton Kardos <power.up1163@gmail.com>"]
1212
license = "MIT"

turftopic/models/decomp.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from sklearn.base import TransformerMixin
77
from sklearn.decomposition import FastICA
88
from sklearn.feature_extraction.text import CountVectorizer
9-
from sklearn.metrics.pairwise import euclidean_distances
9+
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
1010

1111
from turftopic.base import ContextualModel, Encoder
1212
from turftopic.vectorizer import default_vectorizer
@@ -42,6 +42,9 @@ class SemanticSignalSeparation(ContextualModel):
4242
If not specified, FastICA is used.
4343
max_iter: int, default 200
4444
Maximum number of iterations for ICA.
45+
feature_importance: "axial", "angular" or "combined", default "combined"
46+
Defines whether the word's position on an axis ('axial'), it's angle to the axis ('angular')
47+
or their combination ('combined') should determine the word's importance for a topic.
4548
random_state: int, default None
4649
Random state to use so that results are exactly reproducible.
4750
"""
@@ -55,10 +58,14 @@ def __init__(
5558
vectorizer: Optional[CountVectorizer] = None,
5659
decomposition: Optional[TransformerMixin] = None,
5760
max_iter: int = 200,
61+
feature_importance: Literal[
62+
"axial", "angular", "combined"
63+
] = "combined",
5864
random_state: Optional[int] = None,
5965
):
6066
self.n_components = n_components
6167
self.encoder = encoder
68+
self.feature_importance = feature_importance
6269
if isinstance(encoder, str):
6370
self.encoder_ = SentenceTransformer(encoder)
6471
else:
@@ -76,6 +83,20 @@ def __init__(
7683
else:
7784
self.decomposition = decomposition
7885

86+
def estimate_components(
87+
self, feature_importance: Literal["axial", "angular", "combined"]
88+
) -> np.ndarray:
89+
"""Reestimates components based on the chosen feature_importance method."""
90+
if feature_importance == "axial":
91+
self.components_ = self.axial_components_
92+
elif feature_importance == "angular":
93+
self.components_ = self.angular_components_
94+
elif feature_importance == "combined":
95+
self.components_ = (
96+
np.square(self.axial_components_) * self.angular_components_
97+
)
98+
return self.components_
99+
79100
def fit_transform(
80101
self, raw_documents, y=None, embeddings: Optional[np.ndarray] = None
81102
) -> np.ndarray:
@@ -96,10 +117,30 @@ def fit_transform(
96117
console.log("Vocabulary encoded.")
97118
status.update("Estimating term importances")
98119
vocab_topic = self.decomposition.transform(vocab_embeddings)
99-
self.components_ = vocab_topic.T
120+
self.axial_components_ = vocab_topic.T
121+
if self.feature_importance == "axial":
122+
self.components_ = self.axial_components_
123+
elif self.feature_importance == "angular":
124+
self.components_ = self.angular_components_
125+
elif self.feature_importance == "combined":
126+
self.components_ = (
127+
np.square(self.axial_components_)
128+
* self.angular_components_
129+
)
100130
console.log("Model fitting done.")
101131
return doc_topic
102132

133+
@property
134+
def angular_components_(self):
135+
"""Reweights words based on their angle in ICA-space to the axis
136+
base vectors.
137+
"""
138+
word_vectors = self.axial_components_.T
139+
n_topics = self.axial_components_.shape[0]
140+
axis_vectors = np.eye(n_topics)
141+
cosine_components = cosine_similarity(axis_vectors, word_vectors)
142+
return cosine_components
143+
103144
def transform(
104145
self, raw_documents, embeddings: Optional[np.ndarray] = None
105146
) -> np.ndarray:
@@ -211,8 +252,8 @@ def concept_compass(
211252
raise ValueError(
212253
f"{topic_y} is not a valid topic name or index."
213254
) from e
214-
x = self.components_[topic_x]
215-
y = self.components_[topic_y]
255+
x = self.axial_components_[topic_x]
256+
y = self.axial_components_[topic_y]
216257
vocab = self.get_vocab()
217258
points = np.array(list(zip(x, y)))
218259
xx, yy = np.meshgrid(

0 commit comments

Comments
 (0)