Skip to content

Commit d532403

Browse files
Added test for online fitting
1 parent 10d3cad commit d532403

1 file changed

Lines changed: 32 additions & 2 deletions

File tree

tests/test_integration.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,23 @@
88
from sentence_transformers import SentenceTransformer
99
from sklearn.datasets import fetch_20newsgroups
1010

11-
from turftopic import (GMM, AutoEncodingTopicModel, ClusteringTopicModel,
12-
KeyNMF, SemanticSignalSeparation)
11+
from turftopic import (
12+
GMM,
13+
AutoEncodingTopicModel,
14+
ClusteringTopicModel,
15+
KeyNMF,
16+
SemanticSignalSeparation,
17+
)
18+
19+
20+
def batched(iterable, n: int):
21+
"Batch data into tuples of length n. The last batch may be shorter."
22+
# batched('ABCDEFG', 3) --> ABC DEF G
23+
if n < 1:
24+
raise ValueError("n must be at least one")
25+
it = iter(iterable)
26+
while batch := list(itertools.islice(it, n)):
27+
yield batch
1328

1429

1530
def generate_dates(
@@ -75,6 +90,8 @@ def generate_dates(
7590
KeyNMF(5, encoder=trf),
7691
]
7792

93+
online_models = [KeyNMF(5, encoder=trf)]
94+
7895

7996
@pytest.mark.parametrize("model", models)
8097
def test_fit_export_table(model):
@@ -100,3 +117,16 @@ def test_fit_dynamic(model):
100117
with out_path.open("w") as out_file:
101118
out_file.write(table)
102119
df = pd.read_csv(out_path)
120+
121+
122+
@pytest.mark.parametrize("model", online_models)
123+
def test_fit_online(model):
124+
for epoch in range(5):
125+
for batch in batched(texts, 50):
126+
model.partial_fit(texts)
127+
table = model.export_topics(format="csv")
128+
with tempfile.TemporaryDirectory() as tmpdirname:
129+
out_path = Path(tmpdirname).joinpath("topics.csv")
130+
with out_path.open("w") as out_file:
131+
out_file.write(table)
132+
df = pd.read_csv(out_path)

0 commit comments

Comments
 (0)