Skip to content

Commit 55dee2d

Browse files
Fixed import
1 parent a98d5e1 commit 55dee2d

1 file changed

Lines changed: 14 additions & 11 deletions

File tree

turftopic/models/cluster.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@
4444
)
4545
from turftopic.types import VALID_DISTANCE_METRICS, DistanceMetric
4646
from turftopic.utils import safe_binarize
47-
from turftopic.vectorizers import PhraseVectorizer
4847
from turftopic.vectorizers.default import default_vectorizer
48+
from turftopic.vectorizers.phrases import PhraseVectorizer
4949

5050
integer_message = """
5151
You tried to pass an integer to ClusteringTopicModel as its first argument.
@@ -719,12 +719,12 @@ def transform(
719719
X = self.vectorizer.transform(raw_documents)
720720
X = normalize(X, axis=1, norm="l1", copy=False)
721721
X = X * idf_diag
722-
doc_topic_matrix = np.exp(cosine_similarity(X, self.components_))
722+
doc_topic_matrix = cosine_similarity(X, self.components_)
723723
elif self.feature_importance == "centroid":
724724
if embeddings is None:
725725
embeddings = self.encode_documents(raw_documents)
726-
doc_topic_matrix = np.exp(
727-
cosine_similarity(embeddings, self._calculate_topic_vectors())
726+
doc_topic_matrix = cosine_similarity(
727+
embeddings, self._calculate_topic_vectors()
728728
)
729729
else:
730730
doc_topic_matrix = safe_binarize(
@@ -909,7 +909,7 @@ def __init__(
909909
reduction_topic_representation: TopicRepresentation = "centroid",
910910
window_size: Optional[int] = 50,
911911
step_size: Optional[int] = 40,
912-
pooling: Optional[Callable] = np.mean,
912+
pooling: Optional[Callable] = np.nanmean,
913913
random_state: Optional[int] = None,
914914
):
915915
if dimensionality_reduction is None:
@@ -933,7 +933,10 @@ def __init__(
933933
cluster_selection_method="eom",
934934
)
935935
self.encoder = encoder
936-
self.vectorizer = vectorizer
936+
if isinstance(encoder, str):
937+
encoder = LateSentenceTransformer(encoder)
938+
if vectorizer is None:
939+
vectorizer = PhraseVectorizer()
937940
self.dimensionality_reduction = dimensionality_reduction
938941
self.clustering = clustering
939942
self.feature_importance = feature_importance
@@ -942,7 +945,7 @@ def __init__(
942945
self.reduction_distance_metric = reduction_distance_metric
943946
self.reduction_topic_representation = reduction_topic_representation
944947
self.random_state = random_state
945-
self.model = ClusteringTopicModel(
948+
model = ClusteringTopicModel(
946949
encoder=encoder,
947950
vectorizer=vectorizer,
948951
dimensionality_reduction=dimensionality_reduction,
@@ -955,8 +958,8 @@ def __init__(
955958
reduction_topic_representation=reduction_topic_representation,
956959
)
957960
super().__init__(
958-
self.model,
959-
window_size=self.window_size,
960-
step_size=self.step_size,
961-
pooling=self.pooling,
961+
model,
962+
window_size=window_size,
963+
step_size=step_size,
964+
pooling=pooling,
962965
)

0 commit comments

Comments
 (0)