22import numpy as np
33from ect .directions import Sampling
44from scipy .spatial .distance import cdist , pdist , squareform
5- from typing import Union , List , Callable
5+ from typing import Union , List , Callable , cast
66
77
88# ---------- CSR <-> Dense helpers (prefix-difference over thresholds) ----------
@@ -352,14 +352,15 @@ def dist(
352352 >>> # Batch distances with custom function
353353 >>> dists = ect1.dist([ect2, ect3, ect4], metric=my_distance)
354354 """
355- # normalize input to list
356355 single = isinstance (other , ECTResult )
357- others = [other ] if single else other
356+ others_list : List ["ECTResult" ] = cast (
357+ List ["ECTResult" ], [other ] if single else other
358+ )
358359
359- if not others :
360+ if not others_list :
360361 return np .array ([])
361362
362- for i , ect in enumerate (others ):
363+ for i , ect in enumerate (others_list ):
363364 if ect .shape != self .shape :
364365 raise ValueError (
365366 f"Shape mismatch at index { i } : { self .shape } vs { ect .shape } "
@@ -370,13 +371,15 @@ def dist(
370371 if single :
371372 b = np .asarray (other , dtype = np .float64 )
372373 return float (np .sqrt (np .sum ((a - b ) ** 2 )))
373- b = np .stack ([np .asarray (ect , dtype = np .float64 ) for ect in others ], axis = 0 )
374+ b = np .stack (
375+ [np .asarray (ect , dtype = np .float64 ) for ect in others_list ], axis = 0
376+ )
374377 diff = b - a
375378 return np .sqrt (np .sum (diff * diff , axis = (1 , 2 )))
376379
377380 distances = cdist (
378381 self .ravel ()[np .newaxis , :],
379- np .vstack ([ect .ravel () for ect in others ]),
382+ np .vstack ([ect .ravel () for ect in others_list ]),
380383 metric = metric ,
381384 ** kwargs ,
382385 )[0 ]
@@ -399,13 +402,25 @@ def dist_matrix(
399402 raise ValueError (f"Shape mismatch at index { i } : { shape0 } vs { r .shape } " )
400403
401404 if isinstance (metric , str ) and metric .lower () in ("frobenius" , "fro" ):
402- return np .vstack ([results [i ].dist (results , metric = "frobenius" ) for i in range (len (results ))])
405+ return np .vstack (
406+ [
407+ results [i ].dist (results , metric = "frobenius" )
408+ for i in range (len (results ))
409+ ]
410+ )
403411
404412 if isinstance (metric , str ):
405- X = np .stack ([np .asarray (r , dtype = np .float64 ).ravel () for r in results ], axis = 0 )
413+ X = np .stack (
414+ [np .asarray (r , dtype = np .float64 ).ravel () for r in results ], axis = 0
415+ )
406416 try :
407417 return squareform (pdist (X , metric = metric , ** kwargs ))
408418 except TypeError :
409419 return cdist (X , X , metric = metric , ** kwargs )
410420
411- return np .vstack ([results [i ].dist (results , metric = metric , ** kwargs ) for i in range (len (results ))])
421+ return np .vstack (
422+ [
423+ results [i ].dist (results , metric = metric , ** kwargs )
424+ for i in range (len (results ))
425+ ]
426+ )
0 commit comments