Skip to content

Commit 0100d47

Browse files
Added more advanced plotting for Topeax
1 parent f52dff4 commit 0100d47

2 files changed

Lines changed: 244 additions & 13 deletions

File tree

turftopic/models/gmm.py

Lines changed: 212 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
MultimodalModel,
2929
)
3030
from turftopic.optimization import optimize_n_components
31+
from turftopic.utils import confidence_ellipse
3132
from turftopic.vectorizers.default import default_vectorizer
3233

3334
FEATURE_IMPORTANCE_METHODS = {
@@ -411,7 +412,11 @@ def plot_components_datamapplot(
411412
return plot
412413

413414
def plot_density(
414-
self, hover_text: list[str] = None, show_points=False, light_mode=False
415+
self,
416+
hover_text: list[str] = None,
417+
show_keywords=True,
418+
show_points=False,
419+
light_mode=False,
415420
):
416421
try:
417422
import plotly.graph_objects as go
@@ -428,9 +433,9 @@ def plot_density(
428433
warnings.warn(
429434
"Embeddings are not in 2d space, only using first 2 dimensions"
430435
)
431-
432-
coord_min, coord_max = np.min(self.reduced_embeddings), np.max(
433-
self.reduced_embeddings
436+
reduced_embeddings = self.reduced_embeddings[:, :2]
437+
coord_min, coord_max = np.min(reduced_embeddings), np.max(
438+
reduced_embeddings
434439
)
435440
coord_spread = coord_max - coord_min
436441
coord_min = coord_min - coord_spread * 0.05
@@ -464,8 +469,8 @@ def plot_density(
464469
]
465470
if show_points:
466471
scatter = go.Scatter(
467-
x=self.reduced_embeddings[:, 0],
468-
y=self.reduced_embeddings[:, 1],
472+
x=reduced_embeddings[:, 0],
473+
y=reduced_embeddings[:, 1],
469474
mode="markers",
470475
showlegend=False,
471476
text=hover_text,
@@ -488,13 +493,14 @@ def plot_density(
488493
self.gmm_.means_, self.topic_names, self.get_top_words()
489494
):
490495
_keys = ""
491-
for i, key in enumerate(keywords):
492-
if (i % 5) == 0:
493-
_keys += "<br> "
494-
_keys += key
495-
if i < (len(keywords) - 1):
496-
_keys += ","
497-
_keys += " "
496+
if show_keywords:
497+
for i, key in enumerate(keywords):
498+
if (i % 5) == 0:
499+
_keys += "<br> "
500+
_keys += key
501+
if i < (len(keywords) - 1):
502+
_keys += ","
503+
_keys += " "
498504
text = f"<b>{name}</b> <i>{_keys}</i> "
499505
fig.add_annotation(
500506
text=text,
@@ -510,3 +516,196 @@ def plot_density(
510516
borderwidth=2,
511517
)
512518
return fig
519+
520+
def plot_density_3d(self, show_keywords=False):
521+
try:
522+
import plotly.graph_objects as go
523+
except (ImportError, ModuleNotFoundError) as e:
524+
raise ModuleNotFoundError(
525+
"Please install plotly if you intend to use plots in Turftopic."
526+
) from e
527+
528+
if not hasattr(self, "reduced_embeddings"):
529+
raise ValueError(
530+
"No reduced embeddings found, can't display in 2d space."
531+
)
532+
if self.reduced_embeddings.shape[1] != 2:
533+
warnings.warn(
534+
"Embeddings are not in 2d space, only using first 2 dimensions"
535+
)
536+
reduced_embeddings = self.reduced_embeddings[:, :2]
537+
coord_min, coord_max = np.min(reduced_embeddings), np.max(
538+
reduced_embeddings
539+
)
540+
coord_spread = coord_max - coord_min
541+
coord_min = coord_min - coord_spread * 0.05
542+
coord_max = coord_max + coord_spread * 0.05
543+
coord = np.linspace(coord_min, coord_max, num=100)
544+
z = []
545+
for yval in coord:
546+
points = np.stack([coord, np.full(coord.shape, yval)]).T
547+
prob = np.exp(self.gmm_.score_samples(points))
548+
z.append(prob)
549+
z = np.stack(z)
550+
means = self.gmm_.means_
551+
means_z = np.exp(self.gmm_.score_samples(means))
552+
annotations = []
553+
for (x_mean, y_mean), z_mean, name, keywords in zip(
554+
means, means_z, self.topic_names, self.get_top_words()
555+
):
556+
_keys = ""
557+
if show_keywords:
558+
for i, key in enumerate(keywords):
559+
if (i % 5) == 0:
560+
_keys += "<br> "
561+
_keys += key
562+
if i < (len(keywords) - 1):
563+
_keys += ","
564+
_keys += " "
565+
text = f"<b>{name}</b> <i>{_keys}</i> "
566+
annotations.append(
567+
dict(
568+
showarrow=True,
569+
x=x_mean,
570+
y=y_mean,
571+
z=z_mean,
572+
text=text,
573+
font=dict(family="Roboto Mono", size=18, color="black"),
574+
bgcolor="rgba(255,255,255,0.9)",
575+
bordercolor="black",
576+
borderwidth=2,
577+
)
578+
)
579+
color_grid = [0.0, 0.25, 0.5, 0.75, 1.0]
580+
colorscale = [
581+
"#01014B",
582+
"#000080",
583+
"#5D5DEF",
584+
"#B7B7FF",
585+
"#ffffff",
586+
]
587+
fig = go.Figure(
588+
data=[
589+
go.Surface(
590+
z=z,
591+
x=coord,
592+
y=coord,
593+
colorscale=list(zip(color_grid, colorscale)),
594+
)
595+
]
596+
)
597+
fig = fig.update_layout(
598+
margin=dict(l=0, r=0, b=0, t=0),
599+
template="plotly_white",
600+
scene=dict(annotations=annotations),
601+
)
602+
return fig
603+
604+
def plot_components(
605+
self,
606+
show_points=False,
607+
show_keywords=True,
608+
hover_text: Optional[list[str]] = None,
609+
):
610+
try:
611+
import plotly.express as px
612+
import plotly.graph_objects as go
613+
except (ImportError, ModuleNotFoundError) as e:
614+
raise ModuleNotFoundError(
615+
"Please install plotly if you intend to use plots in Turftopic."
616+
) from e
617+
618+
if not hasattr(self, "reduced_embeddings"):
619+
raise ValueError(
620+
"No reduced embeddings found, can't display in 2d space."
621+
)
622+
if self.reduced_embeddings.shape[1] != 2:
623+
warnings.warn(
624+
"Embeddings are not in 2d space, only using first 2 dimensions"
625+
)
626+
reduced_embeddings = self.reduced_embeddings[:, :2]
627+
coord_min, coord_max = np.min(reduced_embeddings), np.max(
628+
reduced_embeddings
629+
)
630+
coord_spread = coord_max - coord_min
631+
coord_min = coord_min - coord_spread * 0.05
632+
coord_max = coord_max + coord_spread * 0.05
633+
coord = np.linspace(coord_min, coord_max, num=100)
634+
z = []
635+
for yval in coord:
636+
points = np.stack([coord, np.full(coord.shape, yval)]).T
637+
prob = np.exp(self.gmm_.score_samples(points))
638+
z.append(prob)
639+
z = np.stack(z)
640+
fig = go.Figure(
641+
[
642+
go.Contour(
643+
z=z,
644+
x=coord,
645+
y=coord,
646+
colorscale="Greys",
647+
opacity=0.25,
648+
hoverinfo="skip",
649+
showscale=False,
650+
),
651+
]
652+
)
653+
gmm_colors = px.colors.qualitative.Antique
654+
for i_std, n_std in enumerate(np.linspace(0.1, 3.0, num=5)):
655+
for color, mean, cov in zip(
656+
gmm_colors, self.gmm_.means_, self.gmm_.covariances_
657+
):
658+
fig.add_shape(
659+
type="path",
660+
path=confidence_ellipse(mean, cov, n_std=n_std),
661+
fillcolor=color,
662+
opacity=0.2,
663+
)
664+
for mean, name, keywords in zip(
665+
self.gmm_.means_, self.topic_names, self.get_top_words()
666+
):
667+
_keys = ""
668+
if show_keywords:
669+
for i, key in enumerate(keywords):
670+
if (i % 5) == 0:
671+
_keys += "<br> "
672+
_keys += key
673+
if i < (len(keywords) - 1):
674+
_keys += ","
675+
_keys += " "
676+
text = f"<b>{name}</b> <i>{_keys}</i> "
677+
fig.add_annotation(
678+
text=text,
679+
x=mean[0],
680+
y=mean[1],
681+
align="left",
682+
showarrow=False,
683+
xshift=0,
684+
yshift=50,
685+
font=dict(family="Roboto Mono", size=18, color="black"),
686+
bgcolor="rgba(255,255,255,0.9)",
687+
bordercolor="black",
688+
borderwidth=2,
689+
)
690+
fig = fig.update_layout(
691+
margin=dict(l=0, r=0, b=0, t=0),
692+
template="plotly_white",
693+
)
694+
if show_points:
695+
scatter = go.Scatter(
696+
x=reduced_embeddings[:, 0],
697+
y=reduced_embeddings[:, 1],
698+
mode="markers",
699+
showlegend=False,
700+
text=hover_text,
701+
marker=dict(
702+
symbol="circle",
703+
opacity=0.5,
704+
color="white",
705+
size=8,
706+
line=dict(width=1),
707+
),
708+
)
709+
fig.add_trace(scatter)
710+
fig = fig.update_layout(coloraxis=dict(showscale=False))
711+
fig.show()

turftopic/utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,35 @@ def sanitize_for_html(text: str) -> str:
7171
# Removing unnecessary whitespace
7272
text = " ".join(text.split())
7373
return text
74+
75+
76+
def confidence_ellipse(mean, cov, n_std=1, size=100):
77+
pearson = cov[0, 1] / np.sqrt(cov[0, 0] * cov[1, 1])
78+
ell_radius_x = np.sqrt(1 + pearson)
79+
ell_radius_y = np.sqrt(1 - pearson)
80+
theta = np.linspace(0, 2 * np.pi, size)
81+
ellipse_coords = np.column_stack(
82+
[ell_radius_x * np.cos(theta), ell_radius_y * np.sin(theta)]
83+
)
84+
x_scale = np.sqrt(cov[0, 0]) * n_std
85+
y_scale = np.sqrt(cov[1, 1]) * n_std
86+
x_mean, y_mean = mean
87+
translation_matrix = np.tile(
88+
[x_mean, y_mean], (ellipse_coords.shape[0], 1)
89+
)
90+
rotation_matrix = np.array(
91+
[
92+
[np.cos(np.pi / 4), np.sin(np.pi / 4)],
93+
[-np.sin(np.pi / 4), np.cos(np.pi / 4)],
94+
]
95+
)
96+
scale_matrix = np.array([[x_scale, 0], [0, y_scale]])
97+
ellipse_coords = (
98+
ellipse_coords.dot(rotation_matrix).dot(scale_matrix)
99+
+ translation_matrix
100+
)
101+
path = f"M {ellipse_coords[0, 0]}, {ellipse_coords[0, 1]}"
102+
for k in range(1, len(ellipse_coords)):
103+
path += f"L{ellipse_coords[k, 0]}, {ellipse_coords[k, 1]}"
104+
path += " Z"
105+
return path

0 commit comments

Comments
 (0)