1+ from datetime import datetime
12import tempfile
23from pathlib import Path
34
1213 AutoEncodingTopicModel ,
1314 ClusteringTopicModel ,
1415 KeyNMF ,
15- SemanticSignalSeparation ,
16+ SemanticSignalSeparation
1617)
1718
19+
20+ def generate_dates (
21+ n_dates : int ,
22+ ) -> list [datetime ]:
23+ """ Generate random dates to test dynamic models """
24+ dates = []
25+ for n in range (n_dates ):
26+ d = np .random .randint (low = 1 , high = 29 )
27+ m = np .random .randint (low = 1 , high = 13 )
28+ y = np .random .randint (low = 2000 , high = 2020 )
29+ date = datetime (year = y , month = m , day = d )
30+ dates .append (date )
31+ return dates
32+
33+
1834newsgroups = fetch_20newsgroups (
1935 subset = "all" ,
2036 categories = [
2541texts = newsgroups .data
2642trf = SentenceTransformer ("all-MiniLM-L6-v2" )
2743embeddings = np .asarray (trf .encode (texts ))
44+ timestamps = generate_dates (n_dates = len (texts ))
2845
2946models = [
3047 GMM (5 , encoder = trf ),
4663 AutoEncodingTopicModel (5 , combined = True ),
4764]
4865
66+ dynamic_models = [
67+ GMM (5 , encoder = trf ),
68+ ClusteringTopicModel (
69+ n_reduce_to = 5 ,
70+ feature_importance = "centroid" ,
71+ encoder = trf ,
72+ reduction_method = "smallest"
73+ ),
74+ ClusteringTopicModel (
75+ n_reduce_to = 5 ,
76+ feature_importance = "soft-c-tf-idf" ,
77+ encoder = trf ,
78+ reduction_method = "smallest"
79+ )
80+ ]
81+
4982
5083@pytest .mark .parametrize ("model" , models )
5184def test_fit_export_table (model ):
@@ -56,3 +89,16 @@ def test_fit_export_table(model):
5689 with out_path .open ("w" ) as out_file :
5790 out_file .write (table )
5891 df = pd .read_csv (out_path )
92+
93+
94+ @pytest .mark .parametrize ("model" , dynamic_models )
95+ def test_fit_dynamic (model ):
96+ doc_topic_matrix = model .fit_transform_dynamic (
97+ texts , embeddings = embeddings , timestamps = timestamps ,
98+ )
99+ table = model .export_topics (format = "csv" )
100+ with tempfile .TemporaryDirectory () as tmpdirname :
101+ out_path = Path (tmpdirname ).joinpath ("topics.csv" )
102+ with out_path .open ("w" ) as out_file :
103+ out_file .write (table )
104+ df = pd .read_csv (out_path )
0 commit comments