@@ -54,6 +54,7 @@ def update_G(X, G, F, sparsity=0):
5454 denominator = jnp .maximum (denominator , EPSILON )
5555 delta_G = jnp .sqrt (numerator / denominator )
5656 G *= delta_G
57+ G = G / jnp .linalg .norm (G )
5758 return G
5859
5960
@@ -128,7 +129,24 @@ def fit_timeslice(self, X_t: np.ndarray, G_t: np.ndarray):
128129 return F .T
129130
130131 def transform (self , X : np .ndarray ):
131- G = jnp .maximum (X @ jnp .linalg .pinv (self .components_ ), 0 )
132+ G = init_G (
133+ X .T ,
134+ n_components = self .n_components ,
135+ random_state = self .random_state ,
136+ )
137+ F = self .components_ .T
138+ update = jit (lambda G : update_G (X .T , G , F , sparsity = self .sparsity ))
139+ error_at_init = rec_err (X .T , F , G )
140+ prev_error = error_at_init
141+ for i in range (self .max_iter ):
142+ G = update (G )
143+ err = rec_err (X .T , F , G )
144+ if (err < error_at_init ) and (
145+ (prev_error - err ) / error_at_init
146+ ) < self .tol :
147+ if self .verbose :
148+ print (f"Converged after { i } iterations" )
149+ break
132150 return np .array (G )
133151
134152 def inverse_transform (self , X ):
0 commit comments