Skip to content

Commit 4b39434

Browse files
committed
Fixed #299 Refactored compute_P_ABBA
1 parent 9b69a75 commit 4b39434

2 files changed

Lines changed: 45 additions & 17 deletions

File tree

stumpy/core.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Copyright 2019 TD Ameritrade. Released under the terms of the 3-Clause BSD license. # noqa: E501
33
# STUMPY is a trademark of TD Ameritrade IP Company, Inc. All rights reserved.
44

5+
from functools import partial
56
import numpy as np
67
from numba import njit, prange
78
from scipy.signal import convolve
@@ -1484,3 +1485,42 @@ def rolling_isfinite(a, w):
14841485
a_subseq_isfinite[~a_isfinite[w - 1 :]] = False
14851486

14861487
return a_isfinite[: a_isfinite.shape[0] - w + 1]
1488+
1489+
1490+
def _get_partial_mp_func(mp_func, dask_client=None, device_id=None):
1491+
"""
1492+
A convenience function for creating a `functools.partial` matrix profile function
1493+
for single server (parallel CPU), multi-server with Dask distributed (parallel CPU),
1494+
and multi-GPU implementations.
1495+
1496+
Parameters
1497+
----------
1498+
mp_func : object
1499+
The matrix profile function to be used for computing a matrix profile
1500+
1501+
dask_client : client, default None
1502+
A Dask Distributed client that is connected to a Dask scheduler and
1503+
Dask workers. Setting up a Dask distributed cluster is beyond the
1504+
scope of this library. Please refer to the Dask Distributed
1505+
documentation.
1506+
1507+
device_id : int or list, default None
1508+
The (GPU) device number to use. The default value is `0`. A list of
1509+
valid device ids (int) may also be provided for parallel GPU-STUMP
1510+
computation. A list of all valid device ids can be obtained by
1511+
executing `[device.id for device in numba.cuda.list_devices()]`.
1512+
1513+
Returns
1514+
-------
1515+
partial_mp_func : object
1516+
A generic matrix profile function that wraps the `dask_client` or GPU
1517+
`device_id` into `functools.partial` function where possible
1518+
"""
1519+
if dask_client is not None:
1520+
partial_mp_func = partial(mp_func, dask_client)
1521+
elif device_id is not None:
1522+
partial_mp_func = partial(mp_func, device_id=device_id)
1523+
else:
1524+
partial_mp_func = mp_func
1525+
1526+
return partial_mp_func

stumpy/mpdist.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -65,24 +65,12 @@ def _compute_P_ABBA(
6565
See Section III
6666
"""
6767
n_A = T_A.shape[0]
68+
partial_mp_func = core._get_partial_mp_func(
69+
mp_func, dask_client=dask_client, device_id=device_id
70+
)
6871

69-
if dask_client is not None:
70-
P_ABBA[: n_A - m + 1] = mp_func(dask_client, T_A, m, T_B, ignore_trivial=False)[
71-
:, 0
72-
]
73-
P_ABBA[n_A - m + 1 :] = mp_func(dask_client, T_B, m, T_A, ignore_trivial=False)[
74-
:, 0
75-
]
76-
elif device_id is not None:
77-
P_ABBA[: n_A - m + 1] = mp_func(
78-
T_A, m, T_B, ignore_trivial=False, device_id=device_id
79-
)[:, 0]
80-
P_ABBA[n_A - m + 1 :] = mp_func(
81-
T_B, m, T_A, ignore_trivial=False, device_id=device_id
82-
)[:, 0]
83-
else:
84-
P_ABBA[: n_A - m + 1] = mp_func(T_A, m, T_B, ignore_trivial=False)[:, 0]
85-
P_ABBA[n_A - m + 1 :] = mp_func(T_B, m, T_A, ignore_trivial=False)[:, 0]
72+
P_ABBA[: n_A - m + 1] = partial_mp_func(T_A, m, T_B, ignore_trivial=False)[:, 0]
73+
P_ABBA[n_A - m + 1 :] = partial_mp_func(T_B, m, T_A, ignore_trivial=False)[:, 0]
8674

8775

8876
def _select_P_ABBA_value(P_ABBA, k, custom_func=None):

0 commit comments

Comments
 (0)