88from sentence_transformers import SentenceTransformer
99from sklearn .datasets import fetch_20newsgroups
1010
11- from turftopic import (GMM , AutoEncodingTopicModel , ClusteringTopicModel ,
12- KeyNMF , SemanticSignalSeparation )
11+ from turftopic import (
12+ GMM ,
13+ AutoEncodingTopicModel ,
14+ ClusteringTopicModel ,
15+ KeyNMF ,
16+ SemanticSignalSeparation ,
17+ )
18+
19+
20+ def batched (iterable , n : int ):
21+ "Batch data into tuples of length n. The last batch may be shorter."
22+ # batched('ABCDEFG', 3) --> ABC DEF G
23+ if n < 1 :
24+ raise ValueError ("n must be at least one" )
25+ it = iter (iterable )
26+ while batch := list (itertools .islice (it , n )):
27+ yield batch
1328
1429
1530def generate_dates (
@@ -75,6 +90,8 @@ def generate_dates(
7590 KeyNMF (5 , encoder = trf ),
7691]
7792
93+ online_models = [KeyNMF (5 , encoder = trf )]
94+
7895
7996@pytest .mark .parametrize ("model" , models )
8097def test_fit_export_table (model ):
@@ -100,3 +117,16 @@ def test_fit_dynamic(model):
100117 with out_path .open ("w" ) as out_file :
101118 out_file .write (table )
102119 df = pd .read_csv (out_path )
120+
121+
122+ @pytest .mark .parametrize ("model" , online_models )
123+ def test_fit_online (model ):
124+ for epoch in range (5 ):
125+ for batch in batched (texts , 50 ):
126+ model .partial_fit (texts )
127+ table = model .export_topics (format = "csv" )
128+ with tempfile .TemporaryDirectory () as tmpdirname :
129+ out_path = Path (tmpdirname ).joinpath ("topics.csv" )
130+ with out_path .open ("w" ) as out_file :
131+ out_file .write (table )
132+ df = pd .read_csv (out_path )
0 commit comments