Skip to content

Commit ecffc37

Browse files
Fixed integration test for online models
1 parent d532403 commit ecffc37

1 file changed

Lines changed: 7 additions & 4 deletions

File tree

tests/test_integration.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import itertools
12
import tempfile
23
from datetime import datetime
34
from pathlib import Path
@@ -56,8 +57,7 @@ def generate_dates(
5657
models = [
5758
GMM(5, encoder=trf),
5859
SemanticSignalSeparation(5, encoder=trf),
59-
KeyNMF(5, encoder=trf, keyword_scope="document"),
60-
KeyNMF(5, encoder=trf, keyword_scope="corpus"),
60+
KeyNMF(5, encoder=trf),
6161
ClusteringTopicModel(
6262
n_reduce_to=5,
6363
feature_importance="c-tf-idf",
@@ -122,8 +122,11 @@ def test_fit_dynamic(model):
122122
@pytest.mark.parametrize("model", online_models)
123123
def test_fit_online(model):
124124
for epoch in range(5):
125-
for batch in batched(texts, 50):
126-
model.partial_fit(texts)
125+
for batch in batched(zip(texts, embeddings), 50):
126+
batch_text, batch_embedding = zip(*batch)
127+
batch_text = list(batch_text)
128+
batch_embedding = np.stack(batch_embedding)
129+
model.partial_fit(batch_text, embeddings=batch_embedding)
127130
table = model.export_topics(format="csv")
128131
with tempfile.TemporaryDirectory() as tmpdirname:
129132
out_path = Path(tmpdirname).joinpath("topics.csv")

0 commit comments

Comments
 (0)