Skip to content

Commit fe79139

Browse files
Merge pull request #123 from x-tabdeveloping/senstopic_imp
Improvements for SensTopic
2 parents c8fce20 + dc40a95 commit fe79139

3 files changed

Lines changed: 25 additions & 2 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ profile = "black"
99

1010
[project]
1111
name = "turftopic"
12-
version = "0.23.1"
12+
version = "0.23.2"
1313
description = "Topic modeling with contextual representations from sentence transformers."
1414
authors = [
1515
{ name = "Márton Kardos <power.up1163@gmail.com>", email = "martonkardos@cas.au.dk" }

turftopic/models/_snmf.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def update_G(X, G, F, sparsity=0):
5454
denominator = jnp.maximum(denominator, EPSILON)
5555
delta_G = jnp.sqrt(numerator / denominator)
5656
G *= delta_G
57+
G = G / jnp.linalg.norm(G)
5758
return G
5859

5960

@@ -128,7 +129,24 @@ def fit_timeslice(self, X_t: np.ndarray, G_t: np.ndarray):
128129
return F.T
129130

130131
def transform(self, X: np.ndarray):
131-
G = jnp.maximum(X @ jnp.linalg.pinv(self.components_), 0)
132+
G = init_G(
133+
X.T,
134+
n_components=self.n_components,
135+
random_state=self.random_state,
136+
)
137+
F = self.components_.T
138+
update = jit(lambda G: update_G(X.T, G, F, sparsity=self.sparsity))
139+
error_at_init = rec_err(X.T, F, G)
140+
prev_error = error_at_init
141+
for i in range(self.max_iter):
142+
G = update(G)
143+
err = rec_err(X.T, F, G)
144+
if (err < error_at_init) and (
145+
(prev_error - err) / error_at_init
146+
) < self.tol:
147+
if self.verbose:
148+
print(f"Converged after {i} iterations")
149+
break
132150
return np.array(G)
133151

134152
def inverse_transform(self, X):

turftopic/models/senstopic.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,11 @@ def fit_transform(
205205
console.log("Model fitting done.")
206206
return doc_topic
207207

208+
def transform(self, raw_documents, embeddings=None):
209+
if embeddings is None:
210+
embeddings = self.encoder_.encode(raw_documents)
211+
return self.decomposition.transform(embeddings)
212+
208213
def fit_transform_multimodal(
209214
self,
210215
raw_documents: list[str],

0 commit comments

Comments
 (0)