|
9 | 9 | from wordcloud import WordCloud |
10 | 10 | import random |
11 | 11 | import os |
| 12 | +from pathlib import Path |
| 13 | +from typing import Optional |
| 14 | +import matplotlib.pyplot as plt |
| 15 | +from matplotlib.cm import get_cmap |
| 16 | +from matplotlib.colors import to_hex |
| 17 | +from scipy.spatial import ConvexHull |
| 18 | +import logging |
| 19 | +log = logging.getLogger(__name__) |
12 | 20 | from .file_system import check_path |
13 | 21 | from .data_structures import sum_dicts |
14 | 22 | from .maps import get_id_to_name |
15 | 23 | from .graphs import create_authors_graph |
16 | | -import math |
| 24 | + |
17 | 25 |
|
18 | 26 | def plot_authors_graph(df, id_col='s2_author_ids', name_col='s2_authors', title='Co-Authors Graph', |
19 | 27 | width=900, height=900, max_node_size=50, min_node_size=3): |
@@ -478,3 +486,60 @@ def plot_H_clustering(H, name="filename"): |
478 | 486 | plt.close() |
479 | 487 |
|
480 | 488 | return fig |
| 489 | + |
| 490 | +def plot_umap( |
| 491 | + coords: np.ndarray, |
| 492 | + labels: list, |
| 493 | + output_path: Path, |
| 494 | + label_column: str, |
| 495 | + model_name: str, |
| 496 | + accepted_mask: Optional[np.ndarray] = None |
| 497 | +) -> None: |
| 498 | + """ |
| 499 | + Save a UMAP scatterplot with optional accepted-hull overlay. |
| 500 | +
|
| 501 | + Parameters |
| 502 | + ---------- |
| 503 | + coords : np.ndarray |
| 504 | + 2D UMAP coordinates (n_samples, 2). |
| 505 | + labels : list |
| 506 | + Original labels corresponding to each coordinate. |
| 507 | + output_path : Path |
| 508 | + Filepath to save the resulting plot. |
| 509 | + label_column : str |
| 510 | + Name of the label column for legend entries. |
| 511 | + model_name : str |
| 512 | + Embedding model identifier for plot title. |
| 513 | + accepted_mask : Optional[np.ndarray] |
| 514 | + Boolean mask for accepted points; if provided, draws convex hull. |
| 515 | + """ |
| 516 | + uniq = sorted(set(labels)) |
| 517 | + color_map = {v: to_hex(get_cmap("tab20")(i % 20)) for i, v in enumerate(uniq)} |
| 518 | + |
| 519 | + fig, ax = plt.subplots(figsize=(6, 6)) |
| 520 | + ax.scatter(coords[:, 0], coords[:, 1], |
| 521 | + c=[color_map[v] for v in labels], |
| 522 | + s=25, alpha=0.85) |
| 523 | + |
| 524 | + if accepted_mask is not None: |
| 525 | + accepted = coords[accepted_mask] |
| 526 | + if accepted.shape[0] >= 3: |
| 527 | + hull = ConvexHull(accepted) |
| 528 | + verts = accepted[hull.vertices] |
| 529 | + verts = np.vstack([verts, verts[0]]) |
| 530 | + ax.fill(verts[:, 0], verts[:, 1], |
| 531 | + facecolor="none", edgecolor="green", lw=2, alpha=0.8, |
| 532 | + label="accepted hull") |
| 533 | + |
| 534 | + handles = [plt.Line2D([], [], marker="o", ls="", color=color_map[v]) for v in uniq] |
| 535 | + labels_legend = [f"{label_column}={v}" for v in uniq] |
| 536 | + if accepted_mask is not None: |
| 537 | + handles.append(plt.Line2D([], [], color="green", lw=2)) |
| 538 | + labels_legend.append("accepted hull") |
| 539 | + |
| 540 | + ax.legend(handles, labels_legend, fontsize=8, loc="upper right") |
| 541 | + ax.set(xticks=[], yticks=[], title=f"UMAP – {model_name} embeddings") |
| 542 | + fig.tight_layout() |
| 543 | + fig.savefig(output_path, dpi=300) |
| 544 | + plt.close(fig) |
| 545 | + log.info("Saved UMAP plot to %s", output_path) |
0 commit comments