|
2 | 2 | # Copyright 2019 TD Ameritrade. Released under the terms of the 3-Clause BSD license. # noqa: E501 |
3 | 3 | # STUMPY is a trademark of TD Ameritrade IP Company, Inc. All rights reserved. |
4 | 4 |
|
| 5 | +from functools import partial |
5 | 6 | import numpy as np |
6 | 7 | from numba import njit, prange |
7 | 8 | from scipy.signal import convolve |
@@ -1484,3 +1485,42 @@ def rolling_isfinite(a, w): |
1484 | 1485 | a_subseq_isfinite[~a_isfinite[w - 1 :]] = False |
1485 | 1486 |
|
1486 | 1487 | return a_isfinite[: a_isfinite.shape[0] - w + 1] |
| 1488 | + |
| 1489 | + |
| 1490 | +def _get_partial_mp_func(mp_func, dask_client=None, device_id=None): |
| 1491 | + """ |
| 1492 | + A convenience function for creating a `functools.partial` matrix profile function |
| 1493 | + for single server (parallel CPU), multi-server with Dask distributed (parallel CPU), |
| 1494 | + and multi-GPU implementations. |
| 1495 | +
|
| 1496 | + Parameters |
| 1497 | + ---------- |
| 1498 | + mp_func : object |
| 1499 | + The matrix profile function to be used for computing a matrix profile |
| 1500 | +
|
| 1501 | + dask_client : client, default None |
| 1502 | + A Dask Distributed client that is connected to a Dask scheduler and |
| 1503 | + Dask workers. Setting up a Dask distributed cluster is beyond the |
| 1504 | + scope of this library. Please refer to the Dask Distributed |
| 1505 | + documentation. |
| 1506 | +
|
| 1507 | + device_id : int or list, default None |
| 1508 | + The (GPU) device number to use. The default value is `0`. A list of |
| 1509 | + valid device ids (int) may also be provided for parallel GPU-STUMP |
| 1510 | + computation. A list of all valid device ids can be obtained by |
| 1511 | + executing `[device.id for device in numba.cuda.list_devices()]`. |
| 1512 | +
|
| 1513 | + Returns |
| 1514 | + ------- |
| 1515 | + partial_mp_func : object |
| 1516 | + A generic matrix profile function that wraps the `dask_client` or GPU |
| 1517 | + `device_id` into `functools.partial` function where possible |
| 1518 | + """ |
| 1519 | + if dask_client is not None: |
| 1520 | + partial_mp_func = partial(mp_func, dask_client) |
| 1521 | + elif device_id is not None: |
| 1522 | + partial_mp_func = partial(mp_func, device_id=device_id) |
| 1523 | + else: |
| 1524 | + partial_mp_func = mp_func |
| 1525 | + |
| 1526 | + return partial_mp_func |
0 commit comments