Skip to content

Commit 4a8487e

Browse files
Added option to use any decomposition method in S3
1 parent 450184b commit 4a8487e

1 file changed

Lines changed: 14 additions & 4 deletions

File tree

turftopic/models/decomp.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import numpy as np
44
from rich.console import Console
55
from sentence_transformers import SentenceTransformer
6-
from sklearn.decomposition import PCA, FastICA
6+
from sklearn.base import TransformerMixin
7+
from sklearn.decomposition import FastICA
78
from sklearn.feature_extraction.text import CountVectorizer
89

910
from turftopic.base import ContextualModel, Encoder
@@ -33,6 +34,11 @@ class SemanticSignalSeparation(ContextualModel):
3334
vectorizer: CountVectorizer, default None
3435
Vectorizer used for term extraction.
3536
Can be used to prune or filter the vocabulary.
37+
decomposition: TransformerMixin, default None
38+
Custom decomposition method to use.
39+
Can be an instance of FastICA or PCA, or basically any dimensionality
40+
reduction method. Has to have `fit_transform` and `fit` methods.
41+
If not specified, FastICA is used.
3642
max_iter: int, default 200
3743
Maximum number of iterations for ICA.
3844
random_state: int, default None
@@ -46,6 +52,7 @@ def __init__(
4652
Encoder, str
4753
] = "sentence-transformers/all-MiniLM-L6-v2",
4854
vectorizer: Optional[CountVectorizer] = None,
55+
decomposition: Optional[TransformerMixin] = None,
4956
max_iter: int = 200,
5057
random_state: Optional[int] = None,
5158
):
@@ -61,9 +68,12 @@ def __init__(
6168
self.vectorizer = vectorizer
6269
self.max_iter = max_iter
6370
self.random_state = random_state
64-
self.decomposition = FastICA(
65-
n_components, max_iter=max_iter, random_state=random_state
66-
)
71+
if decomposition is None:
72+
self.decomposition = FastICA(
73+
n_components, max_iter=max_iter, random_state=random_state
74+
)
75+
else:
76+
self.decomposition = decomposition
6777

6878
def fit_transform(
6979
self, raw_documents, y=None, embeddings: Optional[np.ndarray] = None

0 commit comments

Comments
 (0)