Skip to content

Commit 923196b

Browse files
Merge pull request #24 from x-tabdeveloping/tests
Added Integration tests for all models
2 parents f13a4ff + 165ee45 commit 923196b

2 files changed

Lines changed: 93 additions & 0 deletions

File tree

.github/workflows/tests.yml

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
name: Tests
2+
on:
3+
push:
4+
branches: [main]
5+
pull_request:
6+
branches: [main]
7+
8+
jobs:
9+
pytest:
10+
runs-on: ubuntu-latest
11+
strategy:
12+
matrix:
13+
python-version: ["3.9"]
14+
#
15+
# This allows a subsequently queued workflow run to interrupt previous runs
16+
concurrency:
17+
group: "${{ github.workflow }}-${{ matrix.python-version}}-${{ matrix.os }} @ ${{ github.ref }}"
18+
cancel-in-progress: true
19+
20+
steps:
21+
- uses: actions/checkout@v4
22+
- name: Set up Python ${{ matrix.python-version }}
23+
uses: actions/setup-python@v4
24+
with:
25+
python-version: ${{ matrix.python-version }}
26+
cache: "pip"
27+
# You can test your matrix by printing the current Python version
28+
- name: Display Python version
29+
run: python3 -c "import sys; print(sys.version)"
30+
31+
- name: Install dependencies
32+
run: python3 -m pip install --upgrade turftopic[pyro-ppl] pandas pytest
33+
34+
- name: Run tests
35+
run: python3 -m pytest tests/
36+

tests/test_integration.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import tempfile
2+
from pathlib import Path
3+
4+
import numpy as np
5+
import pandas as pd
6+
import pytest
7+
from sentence_transformers import SentenceTransformer
8+
from sklearn.datasets import fetch_20newsgroups
9+
10+
from turftopic import (
11+
GMM,
12+
AutoEncodingTopicModel,
13+
ClusteringTopicModel,
14+
KeyNMF,
15+
SemanticSignalSeparation,
16+
)
17+
18+
newsgroups = fetch_20newsgroups(
19+
subset="all",
20+
categories=[
21+
"misc.forsale",
22+
],
23+
remove=("headers", "footers", "quotes"),
24+
)
25+
texts = newsgroups.data
26+
trf = SentenceTransformer("all-MiniLM-L6-v2")
27+
embeddings = np.asarray(trf.encode(texts))
28+
29+
models = [
30+
GMM(5, encoder=trf),
31+
SemanticSignalSeparation(5, encoder=trf),
32+
KeyNMF(5, encoder=trf),
33+
ClusteringTopicModel(
34+
n_reduce_to=5,
35+
feature_importance="c-tf-idf",
36+
encoder=trf,
37+
reduction_method="agglomerative",
38+
),
39+
ClusteringTopicModel(
40+
n_reduce_to=5,
41+
feature_importance="centroid",
42+
encoder=trf,
43+
reduction_method="smallest",
44+
),
45+
AutoEncodingTopicModel(5, combined=True),
46+
]
47+
48+
49+
@pytest.mark.parametrize("model", models)
50+
def test_fit_export_table(model):
51+
doc_topic_matrix = model.fit_transform(texts, embeddings=embeddings)
52+
table = model.export_topics(format="csv")
53+
with tempfile.TemporaryDirectory() as tmpdirname:
54+
out_path = Path(tmpdirname).joinpath("topics.csv")
55+
with out_path.open("w") as out_file:
56+
out_file.write(table)
57+
df = pd.read_csv(out_path)

0 commit comments

Comments
 (0)