Skip to content

Commit 5219ac3

Browse files
authored
Merge pull request #97 from MunchLab/update-matisse
add dist matrix function and update tutorial
2 parents 79e0278 + e084362 commit 5219ac3

5 files changed

Lines changed: 383 additions & 361 deletions

File tree

doc_source/notebooks/Matisse/example_matisse.ipynb

Lines changed: 328 additions & 351 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "ect"
3-
version = "1.2.3"
3+
version = "1.2.4"
44
authors = [
55
{ name="Liz Munch", email="muncheli@msu.edu" },
66
]

src/ect/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .ect import ECT
1414
from .embed_complex import EmbeddedComplex, EmbeddedGraph, EmbeddedCW
1515
from .directions import Directions
16+
from .results import ECTResult
1617
from .sect import SECT
1718
from .dect import DECT
1819
from .utils import examples
@@ -25,5 +26,6 @@
2526
"EmbeddedGraph",
2627
"EmbeddedCW",
2728
"Directions",
29+
"ECTResult",
2830
"examples",
2931
]

src/ect/results.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import matplotlib.pyplot as plt
22
import numpy as np
33
from ect.directions import Sampling
4-
from scipy.spatial.distance import cdist
4+
from scipy.spatial.distance import cdist, pdist, squareform
55
from typing import Union, List, Callable
66

77

@@ -319,7 +319,7 @@ def _plot_ecc(self, theta):
319319
def dist(
320320
self,
321321
other: Union["ECTResult", List["ECTResult"]],
322-
metric: Union[str, Callable] = "cityblock",
322+
metric: Union[str, Callable] = "frobenius",
323323
**kwargs,
324324
):
325325
"""
@@ -365,7 +365,15 @@ def dist(
365365
f"Shape mismatch at index {i}: {self.shape} vs {ect.shape}"
366366
)
367367

368-
# use ravel to avoid copying the data and compute distances
368+
if isinstance(metric, str) and metric.lower() in ("frobenius", "fro"):
369+
a = np.asarray(self, dtype=np.float64)
370+
if single:
371+
b = np.asarray(other, dtype=np.float64)
372+
return float(np.sqrt(np.sum((a - b) ** 2)))
373+
b = np.stack([np.asarray(ect, dtype=np.float64) for ect in others], axis=0)
374+
diff = b - a
375+
return np.sqrt(np.sum(diff * diff, axis=(1, 2)))
376+
369377
distances = cdist(
370378
self.ravel()[np.newaxis, :],
371379
np.vstack([ect.ravel() for ect in others]),
@@ -374,3 +382,30 @@ def dist(
374382
)[0]
375383

376384
return distances[0] if single else distances
385+
386+
@classmethod
387+
def dist_matrix(
388+
cls,
389+
results: List["ECTResult"],
390+
metric: Union[str, Callable] = "frobenius",
391+
**kwargs,
392+
) -> np.ndarray:
393+
if not results:
394+
return np.empty((0, 0), dtype=np.float64)
395+
396+
shape0 = results[0].shape
397+
for i, r in enumerate(results):
398+
if r.shape != shape0:
399+
raise ValueError(f"Shape mismatch at index {i}: {shape0} vs {r.shape}")
400+
401+
if isinstance(metric, str) and metric.lower() in ("frobenius", "fro"):
402+
return np.vstack([results[i].dist(results, metric="frobenius") for i in range(len(results))])
403+
404+
if isinstance(metric, str):
405+
X = np.stack([np.asarray(r, dtype=np.float64).ravel() for r in results], axis=0)
406+
try:
407+
return squareform(pdist(X, metric=metric, **kwargs))
408+
except TypeError:
409+
return cdist(X, X, metric=metric, **kwargs)
410+
411+
return np.vstack([results[i].dist(results, metric=metric, **kwargs) for i in range(len(results))])

tests/test_ect_result.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,19 @@ def test_dist_single_ectresult(self):
9090
result2_modified.directions = result2.directions
9191
result2_modified.thresholds = result2.thresholds
9292

93-
# Test L1 distance (default)
94-
dist_l1 = self.result.dist(result2_modified)
95-
expected_l1 = np.abs(self.result - result2_modified).sum()
96-
self.assertAlmostEqual(dist_l1, expected_l1)
97-
self.assertIsInstance(dist_l1, (float, np.floating))
93+
# Test frobenius distance (default)
94+
dist_frobenius = self.result.dist(result2_modified)
95+
expected_frobenius = np.sqrt(
96+
np.sum(
97+
(
98+
np.asarray(self.result, dtype=np.float64)
99+
- np.asarray(result2_modified, dtype=np.float64)
100+
).ravel()
101+
** 2
102+
)
103+
)
104+
self.assertAlmostEqual(dist_frobenius, expected_frobenius)
105+
self.assertIsInstance(dist_frobenius, (float, np.floating))
98106

99107
# Test L2 distance
100108
dist_l2 = self.result.dist(result2_modified, metric="euclidean")
@@ -119,7 +127,7 @@ def test_dist_list_of_ectresults(self):
119127
r.thresholds = self.result.thresholds
120128

121129
# Test batch distances
122-
distances = self.result.dist([result2, result3, result4])
130+
distances = self.result.dist([result2, result3, result4], metric="cityblock")
123131

124132
# Check return type is array
125133
self.assertIsInstance(distances, np.ndarray)

0 commit comments

Comments
 (0)