|
6 | 6 | from sklearn.base import TransformerMixin |
7 | 7 | from sklearn.decomposition import FastICA |
8 | 8 | from sklearn.feature_extraction.text import CountVectorizer |
| 9 | +from sklearn.metrics.pairwise import euclidean_distances |
9 | 10 |
|
10 | 11 | from turftopic.base import ContextualModel, Encoder |
11 | 12 | from turftopic.vectorizer import default_vectorizer |
@@ -170,3 +171,76 @@ def export_representative_documents( |
170 | 171 | show_negative, |
171 | 172 | format, |
172 | 173 | ) |
| 174 | + |
| 175 | + def concept_compass( |
| 176 | + self, topic_x: Union[int, str], topic_y: Union[str, int] |
| 177 | + ): |
| 178 | + """Display a compass of concepts along two semantic axes. |
| 179 | +
|
| 180 | + Parameters |
| 181 | + ---------- |
| 182 | + topic_x: int or str |
| 183 | + Index or name of the topic to display on the X axis. |
| 184 | + topic_y: int or str |
| 185 | + Index or name of the topic to display on the Y axis. |
| 186 | +
|
| 187 | + Returns |
| 188 | + ------- |
| 189 | + go.Figure |
| 190 | + Plotly interactive plot of the concept compass. |
| 191 | + """ |
| 192 | + try: |
| 193 | + import plotly.express as px |
| 194 | + except (ImportError, ModuleNotFoundError) as e: |
| 195 | + raise ModuleNotFoundError( |
| 196 | + "Please install plotly if you intend to use plots in Turftopic." |
| 197 | + ) from e |
| 198 | + if isinstance(topic_x, str): |
| 199 | + try: |
| 200 | + topic_x = list(self.topic_names).index(topic_x) |
| 201 | + except ValueError as e: |
| 202 | + raise ValueError( |
| 203 | + f"{topic_x} is not a valid topic name or index." |
| 204 | + ) from e |
| 205 | + if isinstance(topic_y, str): |
| 206 | + try: |
| 207 | + topic_y = list(self.topic_names).index(topic_y) |
| 208 | + except ValueError as e: |
| 209 | + raise ValueError( |
| 210 | + f"{topic_y} is not a valid topic name or index." |
| 211 | + ) from e |
| 212 | + x = self.components_[topic_x] |
| 213 | + y = self.components_[topic_y] |
| 214 | + vocab = self.get_vocab() |
| 215 | + points = np.array(list(zip(x, y))) |
| 216 | + xx, yy = np.meshgrid( |
| 217 | + np.linspace(np.min(x), np.max(x), 20), |
| 218 | + np.linspace(np.min(y), np.max(y), 20), |
| 219 | + ) |
| 220 | + coords = np.array(list(zip(np.ravel(xx), np.ravel(yy)))) |
| 221 | + coords = coords + np.random.default_rng(0).normal( |
| 222 | + [0, 0], [0.1, 0.1], size=coords.shape |
| 223 | + ) |
| 224 | + dist = euclidean_distances(coords, points) |
| 225 | + idxs = np.argmin(dist, axis=1) |
| 226 | + fig = px.scatter( |
| 227 | + x=x[idxs], |
| 228 | + y=y[idxs], |
| 229 | + text=vocab[idxs], |
| 230 | + template="plotly_white", |
| 231 | + ) |
| 232 | + fig = fig.update_traces( |
| 233 | + mode="text", textfont_color="black", marker=dict(color="black") |
| 234 | + ).update_layout( |
| 235 | + xaxis_title=f"{self.topic_names[topic_x]}", |
| 236 | + yaxis_title=f"{self.topic_names[topic_y]}", |
| 237 | + ) |
| 238 | + fig = fig.update_layout( |
| 239 | + width=1000, |
| 240 | + height=1000, |
| 241 | + font=dict(family="Times New Roman", color="black", size=21), |
| 242 | + margin=dict(l=5, r=5, t=5, b=5), |
| 243 | + ) |
| 244 | + fig = fig.add_hline(y=0, line_color="black", line_width=4) |
| 245 | + fig = fig.add_vline(x=0, line_color="black", line_width=4) |
| 246 | + return fig |
0 commit comments