Skip to content

Commit 9755d9f

Browse files
Merge pull request #60 from x-tabdeveloping/bayes_rule
Term importance estimation with Bayes' rule.
2 parents 34c3761 + 81edd97 commit 9755d9f

2 files changed

Lines changed: 249 additions & 83 deletions

File tree

turftopic/feature_importance.py

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import numpy as np
22
import scipy.sparse as spr
3-
from sklearn.metrics import pairwise_distances
3+
from sklearn.metrics.pairwise import cosine_similarity
44

55

66
def cluster_centroid_distance(
77
cluster_centroids: np.ndarray,
88
vocab_embeddings: np.ndarray,
9-
metric="cosine",
109
) -> np.ndarray:
1110
"""Computes feature importances based on distances between
1211
topic vectors (cluster centroids) and term embeddings
@@ -17,25 +16,21 @@ def cluster_centroid_distance(
1716
Coordinates of cluster centroids of shape (n_topics, embedding_size)
1817
vocab_embeddings: np.ndarray
1918
Term embeddings of shape (vocab_size, embedding_size)
20-
metric: str, defaul 'cosine'
21-
Metric used to compute distance from centroid.
22-
See documentation for sklearn.metrics.pairwise.distance_metrics
23-
for valid values.
2419
2520
Returns
2621
-------
2722
ndarray of shape (n_topics, vocab_size)
2823
Term importance matrix.
2924
"""
30-
distances = pairwise_distances(
31-
cluster_centroids, vocab_embeddings, metric=metric
25+
n_components = cluster_centroids.shape[0]
26+
n_vocab = vocab_embeddings.shape[0]
27+
components = np.full((n_components, n_vocab), np.nan)
28+
valid_centroids = np.all(np.isfinite(cluster_centroids), axis=1)
29+
similarities = cosine_similarity(
30+
cluster_centroids[valid_centroids], vocab_embeddings
3231
)
33-
similarities = -distances / np.max(distances)
34-
# Z-score transformation
35-
similarities = (similarities - np.mean(similarities)) / np.std(
36-
similarities
37-
)
38-
return similarities
32+
components[valid_centroids, :] = similarities
33+
return components
3934

4035

4136
def soft_ctf_idf(
@@ -87,10 +82,47 @@ def ctf_idf(
8782
components = []
8883
overall_freq = np.ravel(np.asarray(doc_term_matrix.sum(axis=0)))
8984
average = overall_freq.sum() / n_topics
85+
overall_freq[overall_freq == 0] = np.finfo(float).eps
9086
for i_topic in range(n_topics):
9187
freq = np.ravel(
9288
np.asarray(doc_term_matrix[labels == i_topic].sum(axis=0))
9389
)
9490
component = freq * np.log(1 + average / overall_freq)
9591
components.append(component)
9692
return np.stack(components)
93+
94+
95+
def bayes_rule(
96+
doc_topic_matrix: np.ndarray, doc_term_matrix: spr.csr_matrix
97+
) -> np.ndarray:
98+
"""Computes feature importance based on Bayes' rule.
99+
The importance of a word for a topic is the probability of the topic conditional on the word.
100+
101+
$$p(t|w) = \\frac{p(w|t) * p(t)}{p(w)}$$
102+
103+
Parameters
104+
----------
105+
doc_topic_matrix: np.ndarray
106+
Document-topic matrix of shape (n_documents, n_topics)
107+
doc_term_matrix: np.ndarray
108+
Document-term matrix of shape (n_documents, vocab_size)
109+
110+
Returns
111+
-------
112+
ndarray of shape (n_topics, vocab_size)
113+
Term importance matrix.
114+
"""
115+
eps = np.finfo(float).eps
116+
p_w = np.squeeze(np.asarray(doc_term_matrix.sum(axis=0)))
117+
p_w = p_w / p_w.sum()
118+
p_w[p_w <= 0] = eps
119+
p_t = doc_topic_matrix.sum(axis=0)
120+
p_t = p_t / p_t.sum()
121+
term_importance = doc_topic_matrix.T @ doc_term_matrix
122+
overall_in_topic = np.abs(term_importance).sum(axis=1)
123+
overall_in_topic[overall_in_topic <= 0] = eps
124+
p_wt = (term_importance.T / (overall_in_topic)).T
125+
p_wt /= p_wt.sum(axis=1)[:, None]
126+
p_tw = (p_wt.T * p_t).T / p_w
127+
p_tw /= np.nansum(p_tw, axis=0)
128+
return p_tw

0 commit comments

Comments
 (0)