|
2 | 2 | # Copyright 2019 TD Ameritrade. Released under the terms of the 3-Clause BSD license. # noqa: E501 |
3 | 3 | # STUMPY is a trademark of TD Ameritrade IP Company, Inc. All rights reserved. |
4 | 4 |
|
5 | | -from functools import partial |
| 5 | +import logging |
| 6 | +import functools |
| 7 | +import inspect |
| 8 | + |
6 | 9 | import numpy as np |
7 | 10 | from numba import njit, prange |
8 | 11 | from scipy.signal import convolve |
|
17 | 20 | except ImportError: |
18 | 21 | pass |
19 | 22 |
|
| 23 | +logger = logging.getLogger(__name__) |
| 24 | + |
20 | 25 |
|
21 | 26 | def driver_not_found(*args, **kwargs): # pragma: no cover |
22 | 27 | """ |
@@ -1583,10 +1588,88 @@ def _get_partial_mp_func(mp_func, dask_client=None, device_id=None): |
1583 | 1588 | `device_id` into `functools.partial` function where possible |
1584 | 1589 | """ |
1585 | 1590 | if dask_client is not None: |
1586 | | - partial_mp_func = partial(mp_func, dask_client) |
| 1591 | + partial_mp_func = functools.partial(mp_func, dask_client) |
1587 | 1592 | elif device_id is not None: |
1588 | | - partial_mp_func = partial(mp_func, device_id=device_id) |
| 1593 | + partial_mp_func = functools.partial(mp_func, device_id=device_id) |
1589 | 1594 | else: |
1590 | 1595 | partial_mp_func = mp_func |
1591 | 1596 |
|
1592 | 1597 | return partial_mp_func |
| 1598 | + |
| 1599 | + |
| 1600 | +def compare_parameters(norm, non_norm, exclude=None): |
| 1601 | + """ |
| 1602 | + Compare if the parameters in `norm` and `non_norm` are the same |
| 1603 | +
|
| 1604 | + Parameters |
| 1605 | + ---------- |
| 1606 | + norm : object |
| 1607 | + The normalized function (or class) that is complementary to the |
| 1608 | + non-normalized function (or class) |
| 1609 | +
|
| 1610 | + non_norm : object |
| 1611 | + The non-normalized function (or class) that is complementary to the |
| 1612 | + z-normalized function (or class) |
| 1613 | +
|
| 1614 | + exclude : list |
| 1615 | + A list of parameters to exclude |
| 1616 | +
|
| 1617 | + Returns |
| 1618 | + ------- |
| 1619 | + is_same_params : bool |
| 1620 | + `True` if parameters from both `norm` and `non-norm` are the same. `False` |
| 1621 | + otherwise. |
| 1622 | + """ |
| 1623 | + norm_params = list(inspect.signature(norm).parameters.keys()) |
| 1624 | + non_norm_params = list(inspect.signature(non_norm).parameters.keys()) |
| 1625 | + |
| 1626 | + if exclude is not None: |
| 1627 | + for param in exclude: |
| 1628 | + norm_params.remove(param) |
| 1629 | + |
| 1630 | + is_same_params = set(norm_params) == set(non_norm_params) |
| 1631 | + if not is_same_params: |
| 1632 | + if exclude is not None: |
| 1633 | + logger.warning(f"Excluding `{exclude}` parameters, ") |
| 1634 | + logger.warning(f"`{norm}` and `{non_norm}` have different parameters.") |
| 1635 | + |
| 1636 | + return is_same_params |
| 1637 | + |
| 1638 | + |
| 1639 | +def non_normalized(non_norm): |
| 1640 | + """ |
| 1641 | + Decorator for swapping a z-normalized function (or class) for its complementary |
| 1642 | + non-normalized function (or class) as defined by `non_norm`. This requires that |
| 1643 | + the z-normalized function (or class) has a `normalize` parameter. |
| 1644 | +
|
| 1645 | + With the exception of `normalize` parameter, the `non_norm` function (or class) |
| 1646 | + must have the same siganture as the `norm` function (or class) signature in order |
| 1647 | + to be compatible. |
| 1648 | +
|
| 1649 | + Parameters |
| 1650 | + ---------- |
| 1651 | + non_norm : object |
| 1652 | + The non-normalized function (or class) that is complementary to the |
| 1653 | + z-normalized function (or class) |
| 1654 | +
|
| 1655 | + Returns |
| 1656 | + ------- |
| 1657 | + outer_wrapper : object |
| 1658 | + The desired z-normalized/non-normalized function (or class) |
| 1659 | + """ |
| 1660 | + |
| 1661 | + @functools.wraps(non_norm) |
| 1662 | + def outer_wrapper(norm): |
| 1663 | + @functools.wraps(norm) |
| 1664 | + def inner_wrapper(*args, **kwargs): |
| 1665 | + is_same_params = compare_parameters(norm, non_norm, exclude=["normalize"]) |
| 1666 | + |
| 1667 | + if not is_same_params or kwargs.get("normalize", True): |
| 1668 | + return norm(*args, **kwargs) |
| 1669 | + else: |
| 1670 | + kwargs = {k: v for k, v in kwargs.items() if k != "normalize"} |
| 1671 | + return non_norm(*args, **kwargs) |
| 1672 | + |
| 1673 | + return inner_wrapper |
| 1674 | + |
| 1675 | + return outer_wrapper |
0 commit comments