|
8 | 8 | from .results import ECTResult |
9 | 9 |
|
10 | 10 |
|
| 11 | +def _thresholds_are_uniform(thresholds: np.ndarray) -> bool: |
| 12 | + thresholds = np.asarray(thresholds, dtype=float) |
| 13 | + if thresholds.ndim != 1: |
| 14 | + raise ValueError("thresholds must be a 1-dimensional array") |
| 15 | + n = thresholds.size |
| 16 | + if n <= 1: |
| 17 | + return True |
| 18 | + diffs = np.diff(thresholds) |
| 19 | + first = diffs[0] |
| 20 | + if first == 0.0: |
| 21 | + return bool(np.all(diffs == 0.0)) |
| 22 | + tol = 1e-12 * max(1.0, abs(first)) |
| 23 | + return bool(np.all(np.abs(diffs - first) <= tol)) |
| 24 | + |
| 25 | + |
11 | 26 | class ECT: |
12 | | - """ |
| 27 | + r""" |
13 | 28 | A class to calculate the Euler Characteristic Transform (ECT) from an input :class:`ect.embed_complex.EmbeddedComplex`, |
14 | 29 | using a set of directions to project the complex onto and thresholds to filter the projections. |
15 | 30 |
|
@@ -55,6 +70,22 @@ def __init__( |
55 | 70 | self.bound_radius = bound_radius |
56 | 71 | self.thresholds = thresholds |
57 | 72 | self.dtype = dtype |
| 73 | + self._thresholds_validated = False |
| 74 | + if self.thresholds is not None: |
| 75 | + self.thresholds = np.asarray(self.thresholds, dtype=float) |
| 76 | + if self.thresholds.ndim != 1: |
| 77 | + raise ValueError("thresholds must be a 1-dimensional array") |
| 78 | + self._thresholds_validated = True |
| 79 | + if num_thresh is not None: |
| 80 | + self.is_uniform = True |
| 81 | + elif self.thresholds is not None: |
| 82 | + self.is_uniform = False |
| 83 | + if not _thresholds_are_uniform(self.thresholds): |
| 84 | + raise ValueError( |
| 85 | + "thresholds must be uniform if num_thresh is not provided" |
| 86 | + ) |
| 87 | + else: |
| 88 | + self.is_uniform = True |
58 | 89 |
|
59 | 90 | def _ensure_directions(self, graph_dim, theta=None): |
60 | 91 | """Ensures directions is a valid Directions object of correct dimension""" |
@@ -97,11 +128,14 @@ def _ensure_thresholds(self, graph, override_bound_radius=None): |
97 | 128 | or graph.get_bounding_radius() |
98 | 129 | ) |
99 | 130 | self.thresholds = np.linspace(-radius, radius, self.num_thresh, dtype=float) |
| 131 | + self.is_uniform = True |
| 132 | + self._thresholds_validated = True |
100 | 133 | else: |
101 | | - # validate and convert existing thresholds |
102 | | - self.thresholds = np.asarray(self.thresholds, dtype=float) |
103 | | - if self.thresholds.ndim != 1: |
104 | | - raise ValueError("thresholds must be a 1-dimensional array") |
| 134 | + if not self._thresholds_validated: |
| 135 | + self.thresholds = np.asarray(self.thresholds, dtype=float) |
| 136 | + if self.thresholds.ndim != 1: |
| 137 | + raise ValueError("thresholds must be a 1-dimensional array") |
| 138 | + self._thresholds_validated = True |
105 | 139 |
|
106 | 140 | def calculate( |
107 | 141 | self, |
@@ -132,14 +166,25 @@ def _compute_ect( |
132 | 166 | H = X @ V.T # (N, m) |
133 | 167 | H_T = np.ascontiguousarray(H.T) # (m, N) for contiguous per-direction rows |
134 | 168 |
|
135 | | - out64 = _ect_all_dirs( |
136 | | - H_T, |
137 | | - cell_vertex_pointers, |
138 | | - cell_vertex_indices_flat, |
139 | | - cell_euler_signs, |
140 | | - thresholds, |
141 | | - N, |
142 | | - ) |
| 169 | + is_uniform = bool(self.is_uniform) and thresholds[0] != thresholds[-1] |
| 170 | + if is_uniform: |
| 171 | + out64 = _ect_all_dirs_uniform( |
| 172 | + H_T, |
| 173 | + cell_vertex_pointers, |
| 174 | + cell_vertex_indices_flat, |
| 175 | + cell_euler_signs, |
| 176 | + thresholds, |
| 177 | + N, |
| 178 | + ) |
| 179 | + else: |
| 180 | + out64 = _ect_all_dirs_search( |
| 181 | + H_T, |
| 182 | + cell_vertex_pointers, |
| 183 | + cell_vertex_indices_flat, |
| 184 | + cell_euler_signs, |
| 185 | + thresholds, |
| 186 | + N, |
| 187 | + ) |
143 | 188 | if dtype == np.int32: |
144 | 189 | return out64.astype(np.int32) |
145 | 190 | return out64 |
@@ -176,74 +221,123 @@ def _compute_simplex_projections(self, graph: EmbeddedComplex, directions): |
176 | 221 |
|
177 | 222 |
|
178 | 223 | @njit(cache=True, parallel=True) |
179 | | -def _ect_all_dirs( |
180 | | - heights_by_direction, # shape (num_directions, num_vertices) |
181 | | - cell_vertex_pointers, # shape (num_cells + 1,) |
182 | | - cell_vertex_indices_flat, # concatenated vertex indices for all cells |
183 | | - cell_euler_signs, # per-cell sign: (+1) for even-dim, (-1) for odd-dim |
184 | | - threshold_values, # shape (num_thresholds,), assumed nondecreasing |
| 224 | +def _ect_all_dirs_uniform( |
| 225 | + heights_by_direction, |
| 226 | + cell_vertex_pointers, |
| 227 | + cell_vertex_indices_flat, |
| 228 | + cell_euler_signs, |
| 229 | + threshold_values, |
| 230 | + num_vertices, |
| 231 | +): |
| 232 | + num_directions = heights_by_direction.shape[0] |
| 233 | + num_thresholds = threshold_values.shape[0] |
| 234 | + t_min = threshold_values[0] if num_thresholds > 0 else 0.0 |
| 235 | + t_max = threshold_values[-1] if num_thresholds > 0 else 0.0 |
| 236 | + span = t_max - t_min |
| 237 | + inv_span = 1.0 / span |
| 238 | + n_minus_1 = num_thresholds - 1 |
| 239 | + |
| 240 | + ect_values = np.empty((num_directions, num_thresholds), dtype=np.int64) |
| 241 | + |
| 242 | + for dir_idx in prange(num_directions): |
| 243 | + heights = heights_by_direction[dir_idx] |
| 244 | + |
| 245 | + diff = np.zeros(num_thresholds, dtype=np.int64) |
| 246 | + vertex_thresh_index = np.empty(num_vertices, dtype=np.int64) |
| 247 | + |
| 248 | + for v in range(num_vertices): |
| 249 | + h = heights[v] |
| 250 | + u = (h - t_min) * inv_span |
| 251 | + idx = int(np.ceil(u * n_minus_1)) |
| 252 | + if idx < 0: |
| 253 | + idx = 0 |
| 254 | + elif idx >= num_thresholds: |
| 255 | + idx = num_thresholds |
| 256 | + |
| 257 | + vertex_thresh_index[v] = idx |
| 258 | + if idx < num_thresholds: |
| 259 | + diff[idx] += 1 |
| 260 | + |
| 261 | + num_cells = cell_vertex_pointers.shape[0] - 1 |
| 262 | + |
| 263 | + for cell_idx in range(num_cells): |
| 264 | + start = cell_vertex_pointers[cell_idx] |
| 265 | + end = cell_vertex_pointers[cell_idx + 1] |
| 266 | + |
| 267 | + entrance_idx = -1 |
| 268 | + for k in range(start, end): |
| 269 | + v = cell_vertex_indices_flat[k] |
| 270 | + t_idx = vertex_thresh_index[v] |
| 271 | + if t_idx > entrance_idx: |
| 272 | + entrance_idx = t_idx |
| 273 | + |
| 274 | + if 0 <= entrance_idx < num_thresholds: |
| 275 | + diff[entrance_idx] += cell_euler_signs[cell_idx] |
| 276 | + |
| 277 | + running = 0 |
| 278 | + for j in range(num_thresholds): |
| 279 | + running += diff[j] |
| 280 | + ect_values[dir_idx, j] = running |
| 281 | + |
| 282 | + return ect_values |
| 283 | + |
| 284 | + |
| 285 | +@njit(cache=True, parallel=True) |
| 286 | +def _ect_all_dirs_search( |
| 287 | + heights_by_direction, |
| 288 | + cell_vertex_pointers, |
| 289 | + cell_vertex_indices_flat, |
| 290 | + cell_euler_signs, |
| 291 | + threshold_values, |
185 | 292 | num_vertices, |
186 | 293 | ): |
187 | | - """ |
188 | | - Calculate the Euler Characteristic Transform (ECT) for a given direction and thresholds. |
189 | | -
|
190 | | - Args: |
191 | | - heights_by_direction: The heights of the vertices for each direction |
192 | | - cell_vertex_pointers: The pointers to the vertices for each cell |
193 | | - cell_vertex_indices_flat: The indices of the vertices for each cell |
194 | | - cell_euler_signs: The signs of the cells |
195 | | - threshold_values: The thresholds to calculate the ECT for |
196 | | - num_vertices: The number of vertices in the graph |
197 | | - """ |
198 | 294 | num_directions = heights_by_direction.shape[0] |
199 | 295 | num_thresholds = threshold_values.shape[0] |
| 296 | + |
200 | 297 | ect_values = np.empty((num_directions, num_thresholds), dtype=np.int64) |
201 | 298 |
|
202 | 299 | for dir_idx in prange(num_directions): |
203 | 300 | heights = heights_by_direction[dir_idx] |
204 | 301 |
|
205 | | - sort_order = np.argsort(heights) |
| 302 | + diff = np.zeros(num_thresholds, dtype=np.int64) |
| 303 | + vertex_thresh_index = np.empty(num_vertices, dtype=np.int64) |
| 304 | + |
| 305 | + for v in range(num_vertices): |
| 306 | + h = heights[v] |
206 | 307 |
|
207 | | - # calculate what position each vertex is in the sorted heights starting from 1 (the rank) |
208 | | - vertex_rank_1based = np.empty(num_vertices, dtype=np.int32) |
209 | | - for rnk in range(num_vertices): |
210 | | - vertex_index = sort_order[rnk] |
211 | | - vertex_rank_1based[vertex_index] = rnk + 1 |
| 308 | + left = 0 |
| 309 | + right = num_thresholds |
| 310 | + while left < right: |
| 311 | + mid = (left + right) // 2 |
| 312 | + if threshold_values[mid] >= h: |
| 313 | + right = mid |
| 314 | + else: |
| 315 | + left = mid + 1 |
| 316 | + idx = left |
212 | 317 |
|
213 | | - # euler char can only jump at each vertex value |
214 | | - # we know vertices add +1 so wait until end to add |
215 | | - # |
216 | | - jump_amount = np.zeros(num_vertices + 1, dtype=np.int64) |
| 318 | + vertex_thresh_index[v] = idx |
| 319 | + if idx < num_thresholds: |
| 320 | + diff[idx] += 1 |
217 | 321 |
|
218 | | - # each pair of pointers defines a cell, so we iterate over them |
219 | 322 | num_cells = cell_vertex_pointers.shape[0] - 1 |
| 323 | + |
220 | 324 | for cell_idx in range(num_cells): |
221 | 325 | start = cell_vertex_pointers[cell_idx] |
222 | 326 | end = cell_vertex_pointers[cell_idx + 1] |
223 | | - # cells come in when the highest vertex enters |
224 | | - entrance_rank = 0 |
| 327 | + |
| 328 | + entrance_idx = -1 |
225 | 329 | for k in range(start, end): |
226 | 330 | v = cell_vertex_indices_flat[k] |
227 | | - rnk = vertex_rank_1based[v] |
228 | | - if rnk > entrance_rank: |
229 | | - entrance_rank = rnk |
230 | | - # record at what rank the cell enters and how much the euler char changes |
231 | | - jump_amount[entrance_rank] += cell_euler_signs[cell_idx] |
232 | | - |
233 | | - # calculate euler char at the moment each vertex enters |
234 | | - euler_prefix = np.empty(num_vertices + 1, dtype=np.int64) |
235 | | - running_sum = 0 |
236 | | - for r in range(num_vertices + 1): |
237 | | - running_sum += jump_amount[r] |
238 | | - euler_prefix[r] = running_sum + r # +r because vertices add +1 |
239 | | - |
240 | | - # now find euler char at each threshold wrt the sorted heights |
241 | | - sorted_heights = heights[sort_order] |
242 | | - rank_cursor = 0 # equals r(t) = # { i : h_i <= t } |
243 | | - for thresh_idx in range(num_thresholds): |
244 | | - t = threshold_values[thresh_idx] |
245 | | - while rank_cursor < num_vertices and sorted_heights[rank_cursor] <= t: |
246 | | - rank_cursor += 1 |
247 | | - ect_values[dir_idx, thresh_idx] = euler_prefix[rank_cursor] |
| 331 | + t_idx = vertex_thresh_index[v] |
| 332 | + if t_idx > entrance_idx: |
| 333 | + entrance_idx = t_idx |
| 334 | + |
| 335 | + if 0 <= entrance_idx < num_thresholds: |
| 336 | + diff[entrance_idx] += cell_euler_signs[cell_idx] |
| 337 | + |
| 338 | + running = 0 |
| 339 | + for j in range(num_thresholds): |
| 340 | + running += diff[j] |
| 341 | + ect_values[dir_idx, j] = running |
248 | 342 |
|
249 | 343 | return ect_values |
0 commit comments