Skip to content

Commit 8fc28e2

Browse files
committed
Fixed #340 Added scraamp function
1 parent 692ce8c commit 8fc28e2

9 files changed

Lines changed: 1493 additions & 76 deletions

File tree

stumpy/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
from .floss import floss, fluss # noqa: F401
1414
from .ostinato import ostinato, ostinatoed # noqa: F401
1515
from .aamp_ostinato import aamp_ostinato, aamp_ostinatoed # noqa: F401
16-
from .scrump import scrump # noqa: F401
16+
from .scrump import scrump, prescrump # noqa: F401
17+
from .scraamp import scraamp, prescraamp # noqa: F401
1718
from .stumpi import stumpi # noqa: F401
1819
from .mpdist import mpdist, mpdisted # noqa: F401
1920
from .aampdist import aampdist, aampdisted # noqa: F401

stumpy/core.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1597,7 +1597,7 @@ def _get_partial_mp_func(mp_func, dask_client=None, device_id=None):
15971597
return partial_mp_func
15981598

15991599

1600-
def compare_parameters(norm, non_norm, exclude=None):
1600+
def compare_parameters(norm, non_norm, exclude=None, translate=None):
16011601
"""
16021602
Compare if the parameters in `norm` and `non_norm` are the same
16031603
@@ -1625,7 +1625,10 @@ def compare_parameters(norm, non_norm, exclude=None):
16251625

16261626
if exclude is not None:
16271627
for param in exclude:
1628-
norm_params.remove(param)
1628+
if param in norm_params:
1629+
norm_params.remove(param)
1630+
if param in non_norm_params:
1631+
non_norm_params.remove(param)
16291632

16301633
is_same_params = set(norm_params) == set(non_norm_params)
16311634
if not is_same_params:
@@ -1662,12 +1665,15 @@ def non_normalized(non_norm):
16621665
def outer_wrapper(norm):
16631666
@functools.wraps(norm)
16641667
def inner_wrapper(*args, **kwargs):
1665-
is_same_params = compare_parameters(norm, non_norm, exclude=["normalize"])
1668+
exclude = ["normalize", "pre_scrump", "pre_scraamp"]
1669+
is_same_params = compare_parameters(norm, non_norm, exclude=exclude)
16661670

16671671
if not is_same_params or kwargs.get("normalize", True):
16681672
return norm(*args, **kwargs)
16691673
else:
16701674
kwargs = {k: v for k, v in kwargs.items() if k != "normalize"}
1675+
if "pre_scrump" in kwargs.keys():
1676+
kwargs["pre_scraamp"] = kwargs.pop("pre_scrump")
16711677
return non_norm(*args, **kwargs)
16721678

16731679
return inner_wrapper

0 commit comments

Comments
 (0)