2828 MultimodalModel ,
2929)
3030from turftopic .optimization import optimize_n_components
31+ from turftopic .utils import confidence_ellipse
3132from turftopic .vectorizers .default import default_vectorizer
3233
3334FEATURE_IMPORTANCE_METHODS = {
@@ -411,7 +412,11 @@ def plot_components_datamapplot(
411412 return plot
412413
413414 def plot_density (
414- self , hover_text : list [str ] = None , show_points = False , light_mode = False
415+ self ,
416+ hover_text : list [str ] = None ,
417+ show_keywords = True ,
418+ show_points = False ,
419+ light_mode = False ,
415420 ):
416421 try :
417422 import plotly .graph_objects as go
@@ -428,9 +433,9 @@ def plot_density(
428433 warnings .warn (
429434 "Embeddings are not in 2d space, only using first 2 dimensions"
430435 )
431-
432- coord_min , coord_max = np .min (self . reduced_embeddings ), np .max (
433- self . reduced_embeddings
436+ reduced_embeddings = self . reduced_embeddings [:, : 2 ]
437+ coord_min , coord_max = np .min (reduced_embeddings ), np .max (
438+ reduced_embeddings
434439 )
435440 coord_spread = coord_max - coord_min
436441 coord_min = coord_min - coord_spread * 0.05
@@ -464,8 +469,8 @@ def plot_density(
464469 ]
465470 if show_points :
466471 scatter = go .Scatter (
467- x = self . reduced_embeddings [:, 0 ],
468- y = self . reduced_embeddings [:, 1 ],
472+ x = reduced_embeddings [:, 0 ],
473+ y = reduced_embeddings [:, 1 ],
469474 mode = "markers" ,
470475 showlegend = False ,
471476 text = hover_text ,
@@ -488,13 +493,14 @@ def plot_density(
488493 self .gmm_ .means_ , self .topic_names , self .get_top_words ()
489494 ):
490495 _keys = ""
491- for i , key in enumerate (keywords ):
492- if (i % 5 ) == 0 :
493- _keys += "<br> "
494- _keys += key
495- if i < (len (keywords ) - 1 ):
496- _keys += ","
497- _keys += " "
496+ if show_keywords :
497+ for i , key in enumerate (keywords ):
498+ if (i % 5 ) == 0 :
499+ _keys += "<br> "
500+ _keys += key
501+ if i < (len (keywords ) - 1 ):
502+ _keys += ","
503+ _keys += " "
498504 text = f"<b>{ name } </b> <i>{ _keys } </i> "
499505 fig .add_annotation (
500506 text = text ,
@@ -510,3 +516,196 @@ def plot_density(
510516 borderwidth = 2 ,
511517 )
512518 return fig
519+
520+ def plot_density_3d (self , show_keywords = False ):
521+ try :
522+ import plotly .graph_objects as go
523+ except (ImportError , ModuleNotFoundError ) as e :
524+ raise ModuleNotFoundError (
525+ "Please install plotly if you intend to use plots in Turftopic."
526+ ) from e
527+
528+ if not hasattr (self , "reduced_embeddings" ):
529+ raise ValueError (
530+ "No reduced embeddings found, can't display in 2d space."
531+ )
532+ if self .reduced_embeddings .shape [1 ] != 2 :
533+ warnings .warn (
534+ "Embeddings are not in 2d space, only using first 2 dimensions"
535+ )
536+ reduced_embeddings = self .reduced_embeddings [:, :2 ]
537+ coord_min , coord_max = np .min (reduced_embeddings ), np .max (
538+ reduced_embeddings
539+ )
540+ coord_spread = coord_max - coord_min
541+ coord_min = coord_min - coord_spread * 0.05
542+ coord_max = coord_max + coord_spread * 0.05
543+ coord = np .linspace (coord_min , coord_max , num = 100 )
544+ z = []
545+ for yval in coord :
546+ points = np .stack ([coord , np .full (coord .shape , yval )]).T
547+ prob = np .exp (self .gmm_ .score_samples (points ))
548+ z .append (prob )
549+ z = np .stack (z )
550+ means = self .gmm_ .means_
551+ means_z = np .exp (self .gmm_ .score_samples (means ))
552+ annotations = []
553+ for (x_mean , y_mean ), z_mean , name , keywords in zip (
554+ means , means_z , self .topic_names , self .get_top_words ()
555+ ):
556+ _keys = ""
557+ if show_keywords :
558+ for i , key in enumerate (keywords ):
559+ if (i % 5 ) == 0 :
560+ _keys += "<br> "
561+ _keys += key
562+ if i < (len (keywords ) - 1 ):
563+ _keys += ","
564+ _keys += " "
565+ text = f"<b>{ name } </b> <i>{ _keys } </i> "
566+ annotations .append (
567+ dict (
568+ showarrow = True ,
569+ x = x_mean ,
570+ y = y_mean ,
571+ z = z_mean ,
572+ text = text ,
573+ font = dict (family = "Roboto Mono" , size = 18 , color = "black" ),
574+ bgcolor = "rgba(255,255,255,0.9)" ,
575+ bordercolor = "black" ,
576+ borderwidth = 2 ,
577+ )
578+ )
579+ color_grid = [0.0 , 0.25 , 0.5 , 0.75 , 1.0 ]
580+ colorscale = [
581+ "#01014B" ,
582+ "#000080" ,
583+ "#5D5DEF" ,
584+ "#B7B7FF" ,
585+ "#ffffff" ,
586+ ]
587+ fig = go .Figure (
588+ data = [
589+ go .Surface (
590+ z = z ,
591+ x = coord ,
592+ y = coord ,
593+ colorscale = list (zip (color_grid , colorscale )),
594+ )
595+ ]
596+ )
597+ fig = fig .update_layout (
598+ margin = dict (l = 0 , r = 0 , b = 0 , t = 0 ),
599+ template = "plotly_white" ,
600+ scene = dict (annotations = annotations ),
601+ )
602+ return fig
603+
604+ def plot_components (
605+ self ,
606+ show_points = False ,
607+ show_keywords = True ,
608+ hover_text : Optional [list [str ]] = None ,
609+ ):
610+ try :
611+ import plotly .express as px
612+ import plotly .graph_objects as go
613+ except (ImportError , ModuleNotFoundError ) as e :
614+ raise ModuleNotFoundError (
615+ "Please install plotly if you intend to use plots in Turftopic."
616+ ) from e
617+
618+ if not hasattr (self , "reduced_embeddings" ):
619+ raise ValueError (
620+ "No reduced embeddings found, can't display in 2d space."
621+ )
622+ if self .reduced_embeddings .shape [1 ] != 2 :
623+ warnings .warn (
624+ "Embeddings are not in 2d space, only using first 2 dimensions"
625+ )
626+ reduced_embeddings = self .reduced_embeddings [:, :2 ]
627+ coord_min , coord_max = np .min (reduced_embeddings ), np .max (
628+ reduced_embeddings
629+ )
630+ coord_spread = coord_max - coord_min
631+ coord_min = coord_min - coord_spread * 0.05
632+ coord_max = coord_max + coord_spread * 0.05
633+ coord = np .linspace (coord_min , coord_max , num = 100 )
634+ z = []
635+ for yval in coord :
636+ points = np .stack ([coord , np .full (coord .shape , yval )]).T
637+ prob = np .exp (self .gmm_ .score_samples (points ))
638+ z .append (prob )
639+ z = np .stack (z )
640+ fig = go .Figure (
641+ [
642+ go .Contour (
643+ z = z ,
644+ x = coord ,
645+ y = coord ,
646+ colorscale = "Greys" ,
647+ opacity = 0.25 ,
648+ hoverinfo = "skip" ,
649+ showscale = False ,
650+ ),
651+ ]
652+ )
653+ gmm_colors = px .colors .qualitative .Antique
654+ for i_std , n_std in enumerate (np .linspace (0.1 , 3.0 , num = 5 )):
655+ for color , mean , cov in zip (
656+ gmm_colors , self .gmm_ .means_ , self .gmm_ .covariances_
657+ ):
658+ fig .add_shape (
659+ type = "path" ,
660+ path = confidence_ellipse (mean , cov , n_std = n_std ),
661+ fillcolor = color ,
662+ opacity = 0.2 ,
663+ )
664+ for mean , name , keywords in zip (
665+ self .gmm_ .means_ , self .topic_names , self .get_top_words ()
666+ ):
667+ _keys = ""
668+ if show_keywords :
669+ for i , key in enumerate (keywords ):
670+ if (i % 5 ) == 0 :
671+ _keys += "<br> "
672+ _keys += key
673+ if i < (len (keywords ) - 1 ):
674+ _keys += ","
675+ _keys += " "
676+ text = f"<b>{ name } </b> <i>{ _keys } </i> "
677+ fig .add_annotation (
678+ text = text ,
679+ x = mean [0 ],
680+ y = mean [1 ],
681+ align = "left" ,
682+ showarrow = False ,
683+ xshift = 0 ,
684+ yshift = 50 ,
685+ font = dict (family = "Roboto Mono" , size = 18 , color = "black" ),
686+ bgcolor = "rgba(255,255,255,0.9)" ,
687+ bordercolor = "black" ,
688+ borderwidth = 2 ,
689+ )
690+ fig = fig .update_layout (
691+ margin = dict (l = 0 , r = 0 , b = 0 , t = 0 ),
692+ template = "plotly_white" ,
693+ )
694+ if show_points :
695+ scatter = go .Scatter (
696+ x = reduced_embeddings [:, 0 ],
697+ y = reduced_embeddings [:, 1 ],
698+ mode = "markers" ,
699+ showlegend = False ,
700+ text = hover_text ,
701+ marker = dict (
702+ symbol = "circle" ,
703+ opacity = 0.5 ,
704+ color = "white" ,
705+ size = 8 ,
706+ line = dict (width = 1 ),
707+ ),
708+ )
709+ fig .add_trace (scatter )
710+ fig = fig .update_layout (coloraxis = dict (showscale = False ))
711+ fig .show ()
0 commit comments