Skip to content

Commit 74aef42

Browse files
authored
Merge pull request #100 from MunchLab/uniform-threshold-speedup
Add threshold validation and search support for ECT calculation
2 parents 8c4bbc9 + e1db4ba commit 74aef42

2 files changed

Lines changed: 159 additions & 65 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "ect"
3-
version = "1.2.4"
3+
version = "1.3.0"
44
authors = [
55
{ name="Liz Munch", email="muncheli@msu.edu" },
66
]

src/ect/ect.py

Lines changed: 158 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,23 @@
88
from .results import ECTResult
99

1010

11+
def _thresholds_are_uniform(thresholds: np.ndarray) -> bool:
12+
thresholds = np.asarray(thresholds, dtype=float)
13+
if thresholds.ndim != 1:
14+
raise ValueError("thresholds must be a 1-dimensional array")
15+
n = thresholds.size
16+
if n <= 1:
17+
return True
18+
diffs = np.diff(thresholds)
19+
first = diffs[0]
20+
if first == 0.0:
21+
return bool(np.all(diffs == 0.0))
22+
tol = 1e-12 * max(1.0, abs(first))
23+
return bool(np.all(np.abs(diffs - first) <= tol))
24+
25+
1126
class ECT:
12-
"""
27+
r"""
1328
A class to calculate the Euler Characteristic Transform (ECT) from an input :class:`ect.embed_complex.EmbeddedComplex`,
1429
using a set of directions to project the complex onto and thresholds to filter the projections.
1530
@@ -55,6 +70,22 @@ def __init__(
5570
self.bound_radius = bound_radius
5671
self.thresholds = thresholds
5772
self.dtype = dtype
73+
self._thresholds_validated = False
74+
if self.thresholds is not None:
75+
self.thresholds = np.asarray(self.thresholds, dtype=float)
76+
if self.thresholds.ndim != 1:
77+
raise ValueError("thresholds must be a 1-dimensional array")
78+
self._thresholds_validated = True
79+
if num_thresh is not None:
80+
self.is_uniform = True
81+
elif self.thresholds is not None:
82+
self.is_uniform = False
83+
if not _thresholds_are_uniform(self.thresholds):
84+
raise ValueError(
85+
"thresholds must be uniform if num_thresh is not provided"
86+
)
87+
else:
88+
self.is_uniform = True
5889

5990
def _ensure_directions(self, graph_dim, theta=None):
6091
"""Ensures directions is a valid Directions object of correct dimension"""
@@ -97,11 +128,14 @@ def _ensure_thresholds(self, graph, override_bound_radius=None):
97128
or graph.get_bounding_radius()
98129
)
99130
self.thresholds = np.linspace(-radius, radius, self.num_thresh, dtype=float)
131+
self.is_uniform = True
132+
self._thresholds_validated = True
100133
else:
101-
# validate and convert existing thresholds
102-
self.thresholds = np.asarray(self.thresholds, dtype=float)
103-
if self.thresholds.ndim != 1:
104-
raise ValueError("thresholds must be a 1-dimensional array")
134+
if not self._thresholds_validated:
135+
self.thresholds = np.asarray(self.thresholds, dtype=float)
136+
if self.thresholds.ndim != 1:
137+
raise ValueError("thresholds must be a 1-dimensional array")
138+
self._thresholds_validated = True
105139

106140
def calculate(
107141
self,
@@ -132,14 +166,25 @@ def _compute_ect(
132166
H = X @ V.T # (N, m)
133167
H_T = np.ascontiguousarray(H.T) # (m, N) for contiguous per-direction rows
134168

135-
out64 = _ect_all_dirs(
136-
H_T,
137-
cell_vertex_pointers,
138-
cell_vertex_indices_flat,
139-
cell_euler_signs,
140-
thresholds,
141-
N,
142-
)
169+
is_uniform = bool(self.is_uniform) and thresholds[0] != thresholds[-1]
170+
if is_uniform:
171+
out64 = _ect_all_dirs_uniform(
172+
H_T,
173+
cell_vertex_pointers,
174+
cell_vertex_indices_flat,
175+
cell_euler_signs,
176+
thresholds,
177+
N,
178+
)
179+
else:
180+
out64 = _ect_all_dirs_search(
181+
H_T,
182+
cell_vertex_pointers,
183+
cell_vertex_indices_flat,
184+
cell_euler_signs,
185+
thresholds,
186+
N,
187+
)
143188
if dtype == np.int32:
144189
return out64.astype(np.int32)
145190
return out64
@@ -176,74 +221,123 @@ def _compute_simplex_projections(self, graph: EmbeddedComplex, directions):
176221

177222

178223
@njit(cache=True, parallel=True)
179-
def _ect_all_dirs(
180-
heights_by_direction, # shape (num_directions, num_vertices)
181-
cell_vertex_pointers, # shape (num_cells + 1,)
182-
cell_vertex_indices_flat, # concatenated vertex indices for all cells
183-
cell_euler_signs, # per-cell sign: (+1) for even-dim, (-1) for odd-dim
184-
threshold_values, # shape (num_thresholds,), assumed nondecreasing
224+
def _ect_all_dirs_uniform(
225+
heights_by_direction,
226+
cell_vertex_pointers,
227+
cell_vertex_indices_flat,
228+
cell_euler_signs,
229+
threshold_values,
230+
num_vertices,
231+
):
232+
num_directions = heights_by_direction.shape[0]
233+
num_thresholds = threshold_values.shape[0]
234+
t_min = threshold_values[0] if num_thresholds > 0 else 0.0
235+
t_max = threshold_values[-1] if num_thresholds > 0 else 0.0
236+
span = t_max - t_min
237+
inv_span = 1.0 / span
238+
n_minus_1 = num_thresholds - 1
239+
240+
ect_values = np.empty((num_directions, num_thresholds), dtype=np.int64)
241+
242+
for dir_idx in prange(num_directions):
243+
heights = heights_by_direction[dir_idx]
244+
245+
diff = np.zeros(num_thresholds, dtype=np.int64)
246+
vertex_thresh_index = np.empty(num_vertices, dtype=np.int64)
247+
248+
for v in range(num_vertices):
249+
h = heights[v]
250+
u = (h - t_min) * inv_span
251+
idx = int(np.ceil(u * n_minus_1))
252+
if idx < 0:
253+
idx = 0
254+
elif idx >= num_thresholds:
255+
idx = num_thresholds
256+
257+
vertex_thresh_index[v] = idx
258+
if idx < num_thresholds:
259+
diff[idx] += 1
260+
261+
num_cells = cell_vertex_pointers.shape[0] - 1
262+
263+
for cell_idx in range(num_cells):
264+
start = cell_vertex_pointers[cell_idx]
265+
end = cell_vertex_pointers[cell_idx + 1]
266+
267+
entrance_idx = -1
268+
for k in range(start, end):
269+
v = cell_vertex_indices_flat[k]
270+
t_idx = vertex_thresh_index[v]
271+
if t_idx > entrance_idx:
272+
entrance_idx = t_idx
273+
274+
if 0 <= entrance_idx < num_thresholds:
275+
diff[entrance_idx] += cell_euler_signs[cell_idx]
276+
277+
running = 0
278+
for j in range(num_thresholds):
279+
running += diff[j]
280+
ect_values[dir_idx, j] = running
281+
282+
return ect_values
283+
284+
285+
@njit(cache=True, parallel=True)
286+
def _ect_all_dirs_search(
287+
heights_by_direction,
288+
cell_vertex_pointers,
289+
cell_vertex_indices_flat,
290+
cell_euler_signs,
291+
threshold_values,
185292
num_vertices,
186293
):
187-
"""
188-
Calculate the Euler Characteristic Transform (ECT) for a given direction and thresholds.
189-
190-
Args:
191-
heights_by_direction: The heights of the vertices for each direction
192-
cell_vertex_pointers: The pointers to the vertices for each cell
193-
cell_vertex_indices_flat: The indices of the vertices for each cell
194-
cell_euler_signs: The signs of the cells
195-
threshold_values: The thresholds to calculate the ECT for
196-
num_vertices: The number of vertices in the graph
197-
"""
198294
num_directions = heights_by_direction.shape[0]
199295
num_thresholds = threshold_values.shape[0]
296+
200297
ect_values = np.empty((num_directions, num_thresholds), dtype=np.int64)
201298

202299
for dir_idx in prange(num_directions):
203300
heights = heights_by_direction[dir_idx]
204301

205-
sort_order = np.argsort(heights)
302+
diff = np.zeros(num_thresholds, dtype=np.int64)
303+
vertex_thresh_index = np.empty(num_vertices, dtype=np.int64)
304+
305+
for v in range(num_vertices):
306+
h = heights[v]
206307

207-
# calculate what position each vertex is in the sorted heights starting from 1 (the rank)
208-
vertex_rank_1based = np.empty(num_vertices, dtype=np.int32)
209-
for rnk in range(num_vertices):
210-
vertex_index = sort_order[rnk]
211-
vertex_rank_1based[vertex_index] = rnk + 1
308+
left = 0
309+
right = num_thresholds
310+
while left < right:
311+
mid = (left + right) // 2
312+
if threshold_values[mid] >= h:
313+
right = mid
314+
else:
315+
left = mid + 1
316+
idx = left
212317

213-
# euler char can only jump at each vertex value
214-
# we know vertices add +1 so wait until end to add
215-
#
216-
jump_amount = np.zeros(num_vertices + 1, dtype=np.int64)
318+
vertex_thresh_index[v] = idx
319+
if idx < num_thresholds:
320+
diff[idx] += 1
217321

218-
# each pair of pointers defines a cell, so we iterate over them
219322
num_cells = cell_vertex_pointers.shape[0] - 1
323+
220324
for cell_idx in range(num_cells):
221325
start = cell_vertex_pointers[cell_idx]
222326
end = cell_vertex_pointers[cell_idx + 1]
223-
# cells come in when the highest vertex enters
224-
entrance_rank = 0
327+
328+
entrance_idx = -1
225329
for k in range(start, end):
226330
v = cell_vertex_indices_flat[k]
227-
rnk = vertex_rank_1based[v]
228-
if rnk > entrance_rank:
229-
entrance_rank = rnk
230-
# record at what rank the cell enters and how much the euler char changes
231-
jump_amount[entrance_rank] += cell_euler_signs[cell_idx]
232-
233-
# calculate euler char at the moment each vertex enters
234-
euler_prefix = np.empty(num_vertices + 1, dtype=np.int64)
235-
running_sum = 0
236-
for r in range(num_vertices + 1):
237-
running_sum += jump_amount[r]
238-
euler_prefix[r] = running_sum + r # +r because vertices add +1
239-
240-
# now find euler char at each threshold wrt the sorted heights
241-
sorted_heights = heights[sort_order]
242-
rank_cursor = 0 # equals r(t) = # { i : h_i <= t }
243-
for thresh_idx in range(num_thresholds):
244-
t = threshold_values[thresh_idx]
245-
while rank_cursor < num_vertices and sorted_heights[rank_cursor] <= t:
246-
rank_cursor += 1
247-
ect_values[dir_idx, thresh_idx] = euler_prefix[rank_cursor]
331+
t_idx = vertex_thresh_index[v]
332+
if t_idx > entrance_idx:
333+
entrance_idx = t_idx
334+
335+
if 0 <= entrance_idx < num_thresholds:
336+
diff[entrance_idx] += cell_euler_signs[cell_idx]
337+
338+
running = 0
339+
for j in range(num_thresholds):
340+
running += diff[j]
341+
ect_values[dir_idx, j] = running
248342

249343
return ect_values

0 commit comments

Comments
 (0)