Skip to content

Commit 8ac8a2a

Browse files
Added more advanced plotting for topeax
1 parent 2f12aa2 commit 8ac8a2a

2 files changed

Lines changed: 160 additions & 19 deletions

File tree

turftopic/models/gmm.py

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,7 @@ def plot_density_3d(self, show_keywords=False):
591591
x=coord,
592592
y=coord,
593593
colorscale=list(zip(color_grid, colorscale)),
594+
showscale=False,
594595
)
595596
]
596597
)
@@ -650,16 +651,24 @@ def plot_components(
650651
),
651652
]
652653
)
653-
gmm_colors = px.colors.qualitative.Antique
654+
gmm_colors = px.colors.qualitative.Dark24
654655
for i_std, n_std in enumerate(np.linspace(0.1, 3.0, num=5)):
655-
for color, mean, cov in zip(
656-
gmm_colors, self.gmm_.means_, self.gmm_.covariances_
656+
for name, color, mean, cov in zip(
657+
self.topic_names,
658+
gmm_colors,
659+
self.gmm_.means_,
660+
self.gmm_.covariances_,
657661
):
658662
fig.add_shape(
663+
legend="legend",
664+
showlegend=False,
659665
type="path",
660666
path=confidence_ellipse(mean, cov, n_std=n_std),
667+
legendgroup=name,
668+
name=0,
669+
legendwidth=0,
661670
fillcolor=color,
662-
opacity=0.2,
671+
opacity=0.1,
663672
)
664673
for mean, name, keywords in zip(
665674
self.gmm_.means_, self.topic_names, self.get_top_words()
@@ -692,20 +701,36 @@ def plot_components(
692701
template="plotly_white",
693702
)
694703
if show_points:
695-
scatter = go.Scatter(
696-
x=reduced_embeddings[:, 0],
697-
y=reduced_embeddings[:, 1],
698-
mode="markers",
699-
showlegend=False,
700-
text=hover_text,
701-
marker=dict(
702-
symbol="circle",
703-
opacity=0.5,
704-
color="white",
705-
size=8,
706-
line=dict(width=1),
707-
),
708-
)
709-
fig.add_trace(scatter)
704+
for i, (name, color) in enumerate(
705+
zip(self.topic_names, gmm_colors)
706+
):
707+
include = self.labels_ == i
708+
text = (
709+
None
710+
if hover_text is None
711+
else [
712+
text
713+
for text, in_cluster in zip(hover_text, include)
714+
if in_cluster
715+
]
716+
)
717+
scatter = go.Scatter(
718+
x=reduced_embeddings[:, 0][include],
719+
y=reduced_embeddings[:, 1][include],
720+
mode="markers",
721+
showlegend=False,
722+
text=text,
723+
name=name,
724+
legendgroup=name,
725+
hovertemplate=f"<b>{name}</b><br>%{{text}}",
726+
marker=dict(
727+
symbol="circle",
728+
opacity=0.5,
729+
color=color,
730+
size=6,
731+
line=dict(width=1),
732+
),
733+
)
734+
fig.add_trace(scatter)
710735
fig = fig.update_layout(coloraxis=dict(showscale=False))
711736
return fig

turftopic/models/topeax.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,3 +211,119 @@ def estimate_components(
211211
def _init_model(self, n_components: int):
212212
mixture = Peax()
213213
return mixture
214+
215+
def plot_steps(self, hover_text=None):
216+
try:
217+
import plotly.express as px
218+
from plotly.subplots import make_subplots
219+
except (ImportError, ModuleNotFoundError) as e:
220+
raise ModuleNotFoundError(
221+
"Please install plotly if you intend to use plots in Turftopic."
222+
) from e
223+
dens_3d = self.plot_density_3d()
224+
component_plot = self.plot_components(
225+
show_points=True, hover_text=hover_text
226+
)
227+
points_plot = px.scatter(
228+
x=self.reduced_embeddings[:, 0],
229+
y=self.reduced_embeddings[:, 1],
230+
template="plotly_white",
231+
)
232+
points_plot = points_plot.update_layout(
233+
margin=dict(l=0, r=0, b=0, t=0),
234+
)
235+
points_plot = points_plot.update_traces(
236+
marker=dict(
237+
color="#B7B7FF",
238+
size=6,
239+
opacity=0.5,
240+
line=dict(color="#01014B", width=2),
241+
)
242+
)
243+
colormap = {
244+
name: color
245+
for name, color in zip(
246+
self.topic_names, px.colors.qualitative.Dark24
247+
)
248+
}
249+
bar = px.bar(
250+
y=self.topic_names,
251+
x=self.weights_,
252+
template="plotly_white",
253+
color_discrete_map=colormap,
254+
color=self.topic_names,
255+
text=[f"{p:.2f}" for p in self.weights_],
256+
)
257+
bar = bar.update_traces(
258+
marker_line_color="black",
259+
marker_line_width=1.5,
260+
opacity=0.8,
261+
)
262+
263+
def update_annotation(a):
264+
name = a.text.removeprefix("<b>").split("<")[0]
265+
return a.update(
266+
# text=name,
267+
font=dict(size=8, color=colormap[name]),
268+
arrowsize=1,
269+
arrowhead=1,
270+
arrowwidth=1,
271+
bgcolor=None,
272+
opacity=0.7,
273+
# bgcolor=colormap[name],
274+
bordercolor=colormap[name],
275+
borderwidth=0,
276+
)
277+
278+
fig = make_subplots(
279+
horizontal_spacing=0.0,
280+
vertical_spacing=0.1,
281+
rows=2,
282+
cols=2,
283+
subplot_titles=[
284+
"t-SN Embeddings",
285+
"Peaks in Kernel Density Estimate",
286+
"Gaussian Mixture Approximation",
287+
"Component Probabilities",
288+
],
289+
specs=[
290+
[
291+
{"type": "xy"},
292+
{"type": "surface"},
293+
],
294+
[
295+
{"type": "xy"},
296+
{"type": "bar"},
297+
],
298+
],
299+
)
300+
for i, sub in enumerate([points_plot, dens_3d, component_plot, bar]):
301+
row = i // 2
302+
col = i % 2
303+
for trace in sub.data:
304+
fig.add_trace(trace, row=row + 1, col=col + 1)
305+
for shape in sub.layout.shapes:
306+
fig.add_shape(shape, row=row + 1, col=col + 1)
307+
fig = fig.update_layout(
308+
template="plotly_white",
309+
font=dict(family="Merriweather", size=14, color="black"),
310+
width=1200,
311+
height=800,
312+
autosize=False,
313+
margin=dict(r=0, l=0, t=40, b=0),
314+
legend=dict(yanchor="top", y=0.45),
315+
)
316+
fig = fig.update_scenes(
317+
annotations=[
318+
update_annotation(annotation)
319+
for annotation in dens_3d.layout.scene.annotations
320+
],
321+
col=2,
322+
row=1,
323+
)
324+
fig = fig.for_each_annotation(lambda a: a.update(yshift=0))
325+
fig = fig.update_yaxes(visible=False, row=2, col=2)
326+
fig = fig.update_xaxes(
327+
title=dict(text="$P(z)$", font=dict(size=16)), row=2, col=2
328+
)
329+
return fig

0 commit comments

Comments
 (0)