|
| 1 | +import itertools |
1 | 2 | import tempfile |
2 | 3 | from datetime import datetime |
3 | 4 | from pathlib import Path |
@@ -56,8 +57,7 @@ def generate_dates( |
56 | 57 | models = [ |
57 | 58 | GMM(5, encoder=trf), |
58 | 59 | SemanticSignalSeparation(5, encoder=trf), |
59 | | - KeyNMF(5, encoder=trf, keyword_scope="document"), |
60 | | - KeyNMF(5, encoder=trf, keyword_scope="corpus"), |
| 60 | + KeyNMF(5, encoder=trf), |
61 | 61 | ClusteringTopicModel( |
62 | 62 | n_reduce_to=5, |
63 | 63 | feature_importance="c-tf-idf", |
@@ -122,8 +122,11 @@ def test_fit_dynamic(model): |
122 | 122 | @pytest.mark.parametrize("model", online_models) |
123 | 123 | def test_fit_online(model): |
124 | 124 | for epoch in range(5): |
125 | | - for batch in batched(texts, 50): |
126 | | - model.partial_fit(texts) |
| 125 | + for batch in batched(zip(texts, embeddings), 50): |
| 126 | + batch_text, batch_embedding = zip(*batch) |
| 127 | + batch_text = list(batch_text) |
| 128 | + batch_embedding = np.stack(batch_embedding) |
| 129 | + model.partial_fit(batch_text, embeddings=batch_embedding) |
127 | 130 | table = model.export_topics(format="csv") |
128 | 131 | with tempfile.TemporaryDirectory() as tmpdirname: |
129 | 132 | out_path = Path(tmpdirname).joinpath("topics.csv") |
|
0 commit comments