@@ -1597,7 +1597,7 @@ def _get_partial_mp_func(mp_func, dask_client=None, device_id=None):
15971597 return partial_mp_func
15981598
15991599
1600- def compare_parameters (norm , non_norm , exclude = None ):
1600+ def compare_parameters (norm , non_norm , exclude = None , translate = None ):
16011601 """
16021602 Compare if the parameters in `norm` and `non_norm` are the same
16031603
@@ -1625,7 +1625,10 @@ def compare_parameters(norm, non_norm, exclude=None):
16251625
16261626 if exclude is not None :
16271627 for param in exclude :
1628- norm_params .remove (param )
1628+ if param in norm_params :
1629+ norm_params .remove (param )
1630+ if param in non_norm_params :
1631+ non_norm_params .remove (param )
16291632
16301633 is_same_params = set (norm_params ) == set (non_norm_params )
16311634 if not is_same_params :
@@ -1662,12 +1665,15 @@ def non_normalized(non_norm):
16621665 def outer_wrapper (norm ):
16631666 @functools .wraps (norm )
16641667 def inner_wrapper (* args , ** kwargs ):
1665- is_same_params = compare_parameters (norm , non_norm , exclude = ["normalize" ])
1668+ exclude = ["normalize" , "pre_scrump" , "pre_scraamp" ]
1669+ is_same_params = compare_parameters (norm , non_norm , exclude = exclude )
16661670
16671671 if not is_same_params or kwargs .get ("normalize" , True ):
16681672 return norm (* args , ** kwargs )
16691673 else :
16701674 kwargs = {k : v for k , v in kwargs .items () if k != "normalize" }
1675+ if "pre_scrump" in kwargs .keys ():
1676+ kwargs ["pre_scraamp" ] = kwargs .pop ("pre_scrump" )
16711677 return non_norm (* args , ** kwargs )
16721678
16731679 return inner_wrapper
0 commit comments