Skip to content

Commit f362feb

Browse files
Added concept compass plot to S3
1 parent e9bbe37 commit f362feb

1 file changed

Lines changed: 74 additions & 0 deletions

File tree

turftopic/models/decomp.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from sklearn.base import TransformerMixin
77
from sklearn.decomposition import FastICA
88
from sklearn.feature_extraction.text import CountVectorizer
9+
from sklearn.metrics.pairwise import euclidean_distances
910

1011
from turftopic.base import ContextualModel, Encoder
1112
from turftopic.vectorizer import default_vectorizer
@@ -170,3 +171,76 @@ def export_representative_documents(
170171
show_negative,
171172
format,
172173
)
174+
175+
def concept_compass(
176+
self, topic_x: Union[int, str], topic_y: Union[str, int]
177+
):
178+
"""Display a compass of concepts along two semantic axes.
179+
180+
Parameters
181+
----------
182+
topic_x: int or str
183+
Index or name of the topic to display on the X axis.
184+
topic_y: int or str
185+
Index or name of the topic to display on the Y axis.
186+
187+
Returns
188+
-------
189+
go.Figure
190+
Plotly interactive plot of the concept compass.
191+
"""
192+
try:
193+
import plotly.express as px
194+
except (ImportError, ModuleNotFoundError) as e:
195+
raise ModuleNotFoundError(
196+
"Please install plotly if you intend to use plots in Turftopic."
197+
) from e
198+
if isinstance(topic_x, str):
199+
try:
200+
topic_x = list(self.topic_names).index(topic_x)
201+
except ValueError as e:
202+
raise ValueError(
203+
f"{topic_x} is not a valid topic name or index."
204+
) from e
205+
if isinstance(topic_y, str):
206+
try:
207+
topic_y = list(self.topic_names).index(topic_y)
208+
except ValueError as e:
209+
raise ValueError(
210+
f"{topic_y} is not a valid topic name or index."
211+
) from e
212+
x = self.components_[topic_x]
213+
y = self.components_[topic_y]
214+
vocab = self.get_vocab()
215+
points = np.array(list(zip(x, y)))
216+
xx, yy = np.meshgrid(
217+
np.linspace(np.min(x), np.max(x), 20),
218+
np.linspace(np.min(y), np.max(y), 20),
219+
)
220+
coords = np.array(list(zip(np.ravel(xx), np.ravel(yy))))
221+
coords = coords + np.random.default_rng(0).normal(
222+
[0, 0], [0.1, 0.1], size=coords.shape
223+
)
224+
dist = euclidean_distances(coords, points)
225+
idxs = np.argmin(dist, axis=1)
226+
fig = px.scatter(
227+
x=x[idxs],
228+
y=y[idxs],
229+
text=vocab[idxs],
230+
template="plotly_white",
231+
)
232+
fig = fig.update_traces(
233+
mode="text", textfont_color="black", marker=dict(color="black")
234+
).update_layout(
235+
xaxis_title=f"{self.topic_names[topic_x]}",
236+
yaxis_title=f"{self.topic_names[topic_y]}",
237+
)
238+
fig = fig.update_layout(
239+
width=1000,
240+
height=1000,
241+
font=dict(family="Times New Roman", color="black", size=21),
242+
margin=dict(l=5, r=5, t=5, b=5),
243+
)
244+
fig = fig.add_hline(y=0, line_color="black", line_width=4)
245+
fig = fig.add_vline(x=0, line_color="black", line_width=4)
246+
return fig

0 commit comments

Comments
 (0)