Skip to content

Commit d4613fe

Browse files
committed
Revert "MAINT: Remove deprecated attributes for BasicStatistics (#2897)"
This reverts commit c760965.
1 parent a398c4a commit d4613fe

7 files changed

Lines changed: 123 additions & 79 deletions

File tree

onedal/basic_statistics/basic_statistics.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,34 +35,34 @@ class BasicStatistics:
3535
3636
Attributes
3737
----------
38-
min_ : ndarray of shape (n_features,)
38+
min : ndarray of shape (n_features,)
3939
Minimum of each feature over all samples.
4040
41-
max_ : ndarray of shape (n_features,)
41+
max : ndarray of shape (n_features,)
4242
Maximum of each feature over all samples.
4343
44-
sum_ : ndarray of shape (n_features,)
44+
sum : ndarray of shape (n_features,)
4545
Sum of each feature over all samples.
4646
47-
mean_ : ndarray of shape (n_features,)
47+
mean : ndarray of shape (n_features,)
4848
Mean of each feature over all samples.
4949
50-
variance_ : ndarray of shape (n_features,)
50+
variance : ndarray of shape (n_features,)
5151
Variance of each feature over all samples.
5252
53-
variation_ : ndarray of shape (n_features,)
53+
variation : ndarray of shape (n_features,)
5454
Variation of each feature over all samples.
5555
56-
sum_squares_ : ndarray of shape (n_features,)
56+
sum_squares : ndarray of shape (n_features,)
5757
Sum of squares for each feature over all samples.
5858
59-
standard_deviation_ : ndarray of shape (n_features,)
59+
standard_deviation : ndarray of shape (n_features,)
6060
Standard deviation of each feature over all samples.
6161
62-
sum_squares_centered_ : ndarray of shape (n_features,)
62+
sum_squares_centered : ndarray of shape (n_features,)
6363
Centered sum of squares for each feature over all samples.
6464
65-
second_order_raw_moment_ : ndarray of shape (n_features,)
65+
second_order_raw_moment : ndarray of shape (n_features,)
6666
Second order moment of each feature over all samples.
6767
6868
Notes

onedal/basic_statistics/incremental_basic_statistics.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,34 +35,34 @@ class IncrementalBasicStatistics(BasicStatistics):
3535
3636
Attributes
3737
----------
38-
min_ : ndarray of shape (n_features,)
38+
min : ndarray of shape (n_features,)
3939
Minimum of each feature over all samples.
4040
41-
max_ : ndarray of shape (n_features,)
41+
max : ndarray of shape (n_features,)
4242
Maximum of each feature over all samples.
4343
44-
sum_ : ndarray of shape (n_features,)
44+
sum : ndarray of shape (n_features,)
4545
Sum of each feature over all samples.
4646
47-
mean_ : ndarray of shape (n_features,)
47+
mean : ndarray of shape (n_features,)
4848
Mean of each feature over all samples.
4949
50-
variance_ : ndarray of shape (n_features,)
50+
variance : ndarray of shape (n_features,)
5151
Variance of each feature over all samples.
5252
53-
variation_ : ndarray of shape (n_features,)
53+
variation : ndarray of shape (n_features,)
5454
Variation of each feature over all samples.
5555
56-
sum_squares_ : ndarray of shape (n_features,)
56+
sum_squares : ndarray of shape (n_features,)
5757
Sum of squares for each feature over all samples.
5858
59-
standard_deviation_ : ndarray of shape (n_features,)
59+
standard_deviation : ndarray of shape (n_features,)
6060
Standard deviation of each feature over all samples.
6161
62-
sum_squares_centered_ : ndarray of shape (n_features,)
62+
sum_squares_centered : ndarray of shape (n_features,)
6363
Centered sum of squares for each feature over all samples.
6464
65-
second_order_raw_moment_ : ndarray of shape (n_features,)
65+
second_order_raw_moment : ndarray of shape (n_features,)
6666
Second order moment of each feature over all samples.
6767
6868
Notes

sklearnex/_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,16 +187,16 @@ def reset_hyperparameters(self, op):
187187
return decorator
188188

189189

190-
def _add_inc_serialization_note(class_docstrings: str, plural: bool = False) -> str:
190+
def _add_inc_serialization_note(class_docstrings: str) -> str:
191191
"""Adds a small note note about serialization for extension estimators that are incremental.
192192
The class docstrings should leave a placeholder '%incremental_serialization_note%' inside
193193
their docstrings, which will be replaced by this note.
194194
"""
195195
# In python versions >=3.13, leading whitespace in docstrings defined through
196196
# static strings (but **not through other ways**) is automatically removed
197197
# from the final docstrings, while in earlier versions is kept.
198-
inc_serialization_note = f"""Note{'s' if plural else ''}
199-
----{'-' if plural else ''}
198+
inc_serialization_note = """Note
199+
----
200200
Serializing instances of this class will trigger a forced finalization of calculations
201201
when the inputs are in a sycl queue or when using GPUs. Since (internal method)
202202
finalize_fit can't be dispatched without directly provided queue and the dispatching

sklearnex/basic_statistics/basic_statistics.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# limitations under the License.
1515
# ==============================================================================
1616

17+
import warnings
18+
1719
from sklearn.base import BaseEstimator
1820

1921
from daal4py.sklearn._n_jobs_support import control_n_jobs
@@ -75,6 +77,9 @@ class BasicStatistics(oneDALEstimator, BaseEstimator):
7577
-----
7678
Attribute exists only if corresponding result option has been provided.
7779
80+
Names of attributes without the trailing underscore are
81+
supported currently but deprecated in 2025.1 and will be removed in 2026.0
82+
7883
Some results can exhibit small variations due to
7984
floating point error accumulation and multithreading.
8085
@@ -124,6 +129,24 @@ def _save_attributes(self):
124129
option += "_"
125130
setattr(self, option, getattr(self._onedal_estimator, option))
126131

132+
def __getattr__(self, attr):
133+
is_deprecated_attr = (
134+
attr in self._onedal_estimator.options
135+
if "_onedal_estimator" in self.__dict__
136+
else False
137+
)
138+
if is_deprecated_attr:
139+
warnings.warn(
140+
"Result attributes without a trailing underscore were deprecated in version 2025.1 and will be removed in 2026.0"
141+
)
142+
attr += "_"
143+
if attr in self.__dict__:
144+
return self.__dict__[attr]
145+
146+
raise AttributeError(
147+
f"'{self.__class__.__name__}' object has no attribute '{attr}'"
148+
)
149+
127150
def _onedal_cpu_supported(self, method_name, *data):
128151
patching_status = PatchingConditionsChain(
129152
f"sklearnex.basic_statistics.{self.__class__.__name__}.{method_name}"

sklearnex/basic_statistics/incremental_basic_statistics.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from sklearn.utils._param_validation import Interval, StrOptions
3535

3636
import numbers
37+
import warnings
3738

3839

3940
@enable_array_api
@@ -100,12 +101,17 @@ class IncrementalBasicStatistics(oneDALEstimator, BaseEstimator):
100101
n_features_in_ : int
101102
Number of features seen during :meth:`fit` or :meth:`partial_fit`.
102103
103-
%incremental_serialization_note%
104-
104+
Notes
105+
-----
105106
Attribute exists only if corresponding result option has been provided.
106107
108+
Names of attributes without the trailing underscore are supported
109+
currently but deprecated in 2025.1 and will be removed in 2026.0.
110+
107111
Sparse data formats are not supported. Input dtype must be ``float32`` or ``float64``.
108112
113+
%incremental_serialization_note%
114+
109115
Examples
110116
--------
111117
>>> import numpy as np
@@ -125,7 +131,7 @@ class IncrementalBasicStatistics(oneDALEstimator, BaseEstimator):
125131
np.array([3., 4.])
126132
"""
127133

128-
__doc__ = _add_inc_serialization_note(__doc__, plural=True)
134+
__doc__ = _add_inc_serialization_note(__doc__)
129135

130136
_onedal_incremental_basic_statistics = staticmethod(onedal_IncrementalBasicStatistics)
131137

@@ -246,8 +252,18 @@ def __getattr__(self, attr):
246252
if is_statistic_attr:
247253
if self._need_to_finalize:
248254
self._onedal_finalize_fit()
255+
if sattr == attr:
256+
warnings.warn(
257+
"Result attributes without a trailing underscore were deprecated in version 2025.1 and will be removed in 2026.0"
258+
)
259+
attr += "_"
249260
return getattr(self._onedal_estimator, attr)
250-
return self.__getattribute__(attr)
261+
if attr in self.__dict__:
262+
return self.__dict__[attr]
263+
264+
raise AttributeError(
265+
f"'{self.__class__.__name__}' object has no attribute '{attr}'"
266+
)
251267

252268
def partial_fit(self, X, sample_weight=None, check_input=True):
253269
"""Incremental fit with X. All of X is processed as a single batch.

sklearnex/basic_statistics/tests/test_basic_statistics.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -387,17 +387,19 @@ def test_1d_input_on_random_data(
387387
assert_allclose(gtr, res, atol=tol)
388388

389389

390-
@pytest.mark.parametrize("underscore_first", [False, True])
391-
def test_results_have_underscores(underscore_first):
392-
X = np.arange(10).reshape((-1, 1))
393-
bs = BasicStatistics().fit(X)
394-
395-
# Note: these are generated dynamically. Need to
396-
# test them in different order to ensure calling
397-
# one doesn't set the other and then change results.
398-
if underscore_first:
399-
assert hasattr(bs, "mean_")
400-
assert not hasattr(bs, "mean")
401-
else:
402-
assert not hasattr(bs, "mean")
403-
assert hasattr(bs, "mean_")
390+
def test_warning():
391+
basicstat = BasicStatistics("all")
392+
data = np.array([0, 1])
393+
394+
basicstat.fit(data)
395+
for i in basicstat._onedal_estimator.get_all_result_options():
396+
with pytest.warns(
397+
UserWarning,
398+
match="Result attributes without a trailing underscore were deprecated in version 2025.1 and will be removed in 2026.0",
399+
) as warn_record:
400+
getattr(basicstat, i)
401+
402+
if daal_check_version((2026, "P", 0)):
403+
assert len(warn_record) == 0, i
404+
else:
405+
assert len(warn_record) == 1, i

0 commit comments

Comments
 (0)