11import matplotlib .pyplot as plt
22import numpy as np
33from ect .directions import Sampling
4- from scipy .spatial .distance import cdist
4+ from scipy .spatial .distance import cdist , pdist , squareform
55from typing import Union , List , Callable
66
77
@@ -319,7 +319,7 @@ def _plot_ecc(self, theta):
319319 def dist (
320320 self ,
321321 other : Union ["ECTResult" , List ["ECTResult" ]],
322- metric : Union [str , Callable ] = "cityblock " ,
322+ metric : Union [str , Callable ] = "frobenius " ,
323323 ** kwargs ,
324324 ):
325325 """
@@ -365,7 +365,15 @@ def dist(
365365 f"Shape mismatch at index { i } : { self .shape } vs { ect .shape } "
366366 )
367367
368- # use ravel to avoid copying the data and compute distances
368+ if isinstance (metric , str ) and metric .lower () in ("frobenius" , "fro" ):
369+ a = np .asarray (self , dtype = np .float64 )
370+ if single :
371+ b = np .asarray (other , dtype = np .float64 )
372+ return float (np .sqrt (np .sum ((a - b ) ** 2 )))
373+ b = np .stack ([np .asarray (ect , dtype = np .float64 ) for ect in others ], axis = 0 )
374+ diff = b - a
375+ return np .sqrt (np .sum (diff * diff , axis = (1 , 2 )))
376+
369377 distances = cdist (
370378 self .ravel ()[np .newaxis , :],
371379 np .vstack ([ect .ravel () for ect in others ]),
@@ -374,3 +382,30 @@ def dist(
374382 )[0 ]
375383
376384 return distances [0 ] if single else distances
385+
386+ @classmethod
387+ def dist_matrix (
388+ cls ,
389+ results : List ["ECTResult" ],
390+ metric : Union [str , Callable ] = "frobenius" ,
391+ ** kwargs ,
392+ ) -> np .ndarray :
393+ if not results :
394+ return np .empty ((0 , 0 ), dtype = np .float64 )
395+
396+ shape0 = results [0 ].shape
397+ for i , r in enumerate (results ):
398+ if r .shape != shape0 :
399+ raise ValueError (f"Shape mismatch at index { i } : { shape0 } vs { r .shape } " )
400+
401+ if isinstance (metric , str ) and metric .lower () in ("frobenius" , "fro" ):
402+ return np .vstack ([results [i ].dist (results , metric = "frobenius" ) for i in range (len (results ))])
403+
404+ if isinstance (metric , str ):
405+ X = np .stack ([np .asarray (r , dtype = np .float64 ).ravel () for r in results ], axis = 0 )
406+ try :
407+ return squareform (pdist (X , metric = metric , ** kwargs ))
408+ except TypeError :
409+ return cdist (X , X , metric = metric , ** kwargs )
410+
411+ return np .vstack ([results [i ].dist (results , metric = metric , ** kwargs ) for i in range (len (results ))])
0 commit comments