Skip to content

Commit e79d836

Browse files
committed
add tests for dynamic models
1 parent a9a015e commit e79d836

2 files changed

Lines changed: 59 additions & 1 deletion

File tree

tests/test_integration.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from datetime import datetime
12
import tempfile
23
from pathlib import Path
34

@@ -12,9 +13,24 @@
1213
AutoEncodingTopicModel,
1314
ClusteringTopicModel,
1415
KeyNMF,
15-
SemanticSignalSeparation,
16+
SemanticSignalSeparation
1617
)
1718

19+
20+
def generate_dates(
21+
n_dates: int,
22+
) -> list[datetime]:
23+
""" Generate random dates to test dynamic models """
24+
dates = []
25+
for n in range(n_dates):
26+
d = np.random.randint(low=1, high=29)
27+
m = np.random.randint(low=1, high=13)
28+
y = np.random.randint(low=2000, high=2020)
29+
date = datetime(year=y, month=m, day=d)
30+
dates.append(date)
31+
return dates
32+
33+
1834
newsgroups = fetch_20newsgroups(
1935
subset="all",
2036
categories=[
@@ -25,6 +41,7 @@
2541
texts = newsgroups.data
2642
trf = SentenceTransformer("all-MiniLM-L6-v2")
2743
embeddings = np.asarray(trf.encode(texts))
44+
timestamps = generate_dates(n_dates=len(texts))
2845

2946
models = [
3047
GMM(5, encoder=trf),
@@ -46,6 +63,22 @@
4663
AutoEncodingTopicModel(5, combined=True),
4764
]
4865

66+
dynamic_models = [
67+
GMM(5, encoder=trf),
68+
ClusteringTopicModel(
69+
n_reduce_to=5,
70+
feature_importance="centroid",
71+
encoder=trf,
72+
reduction_method="smallest"
73+
),
74+
ClusteringTopicModel(
75+
n_reduce_to=5,
76+
feature_importance="soft-c-tf-idf",
77+
encoder=trf,
78+
reduction_method="smallest"
79+
)
80+
]
81+
4982

5083
@pytest.mark.parametrize("model", models)
5184
def test_fit_export_table(model):
@@ -56,3 +89,16 @@ def test_fit_export_table(model):
5689
with out_path.open("w") as out_file:
5790
out_file.write(table)
5891
df = pd.read_csv(out_path)
92+
93+
94+
@pytest.mark.parametrize("model", dynamic_models)
95+
def test_fit_dynamic(model):
96+
doc_topic_matrix = model.fit_transform_dynamic(
97+
texts, embeddings=embeddings, timestamps=timestamps,
98+
)
99+
table = model.export_topics(format="csv")
100+
with tempfile.TemporaryDirectory() as tmpdirname:
101+
out_path = Path(tmpdirname).joinpath("topics.csv")
102+
with out_path.open("w") as out_file:
103+
out_file.write(table)
104+
df = pd.read_csv(out_path)

turftopic/encoders/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import itertools
2+
from typing import Iterable, List
3+
4+
5+
def batched(iterable, n: int) -> Iterable[List[str]]:
6+
"Batch data into tuples of length n. The last batch may be shorter."
7+
# batched('ABCDEFG', 3) --> ABC DEF G
8+
if n < 1:
9+
raise ValueError("n must be at least one")
10+
it = iter(iterable)
11+
while batch := list(itertools.islice(it, n)):
12+
yield batch

0 commit comments

Comments
 (0)