33import numpy as np
44from rich .console import Console
55from sentence_transformers import SentenceTransformer
6- from sklearn .decomposition import PCA , FastICA
6+ from sklearn .base import TransformerMixin
7+ from sklearn .decomposition import FastICA
78from sklearn .feature_extraction .text import CountVectorizer
89
910from turftopic .base import ContextualModel , Encoder
@@ -33,6 +34,11 @@ class SemanticSignalSeparation(ContextualModel):
3334 vectorizer: CountVectorizer, default None
3435 Vectorizer used for term extraction.
3536 Can be used to prune or filter the vocabulary.
37+ decomposition: TransformerMixin, default None
38+ Custom decomposition method to use.
39+ Can be an instance of FastICA or PCA, or basically any dimensionality
40+ reduction method. Has to have `fit_transform` and `fit` methods.
41+ If not specified, FastICA is used.
3642 max_iter: int, default 200
3743 Maximum number of iterations for ICA.
3844 random_state: int, default None
@@ -46,6 +52,7 @@ def __init__(
4652 Encoder , str
4753 ] = "sentence-transformers/all-MiniLM-L6-v2" ,
4854 vectorizer : Optional [CountVectorizer ] = None ,
55+ decomposition : Optional [TransformerMixin ] = None ,
4956 max_iter : int = 200 ,
5057 random_state : Optional [int ] = None ,
5158 ):
@@ -61,9 +68,12 @@ def __init__(
6168 self .vectorizer = vectorizer
6269 self .max_iter = max_iter
6370 self .random_state = random_state
64- self .decomposition = FastICA (
65- n_components , max_iter = max_iter , random_state = random_state
66- )
71+ if decomposition is None :
72+ self .decomposition = FastICA (
73+ n_components , max_iter = max_iter , random_state = random_state
74+ )
75+ else :
76+ self .decomposition = decomposition
6777
6878 def fit_transform (
6979 self , raw_documents , y = None , embeddings : Optional [np .ndarray ] = None
0 commit comments