Skip to content

Commit 56565b4

Browse files
authored
Merge pull request #32 from Leona-LYT/main
add multi-class classification function with tutorial examples
2 parents 8a4d932 + f63b8b7 commit 56565b4

3 files changed

Lines changed: 515 additions & 25 deletions

File tree

doc/source/example.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ Example Gallery
1717
examples/RankRegression.ipynb
1818
examples/Path_solution.ipynb
1919
examples/Warm_start.ipynb
20-
examples/Sklearn_Mixin.ipynb
20+
examples/Sklearn_Mixin.ipynb
21+
examples/Multiclass_Classification.ipynb
2122
examples/NMF.ipynb
2223

2324
List of Examples
@@ -35,5 +36,6 @@ List of Examples
3536
examples/RankRegression.ipynb
3637
examples/Path_solution.ipynb
3738
examples/Warm_start.ipynb
38-
examples/Sklearn_Mixin.ipynb
39+
examples/Sklearn_Mixin.ipynb
40+
examples/Multiclass_Classification.ipynb
3941
examples/NMF.ipynb

doc/source/examples/Multiclass_Classification.ipynb

Lines changed: 280 additions & 0 deletions
Large diffs are not rendered by default.

rehline/_sklearn_mixin.py

Lines changed: 231 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import numpy as np
2-
from sklearn.base import ClassifierMixin, RegressorMixin
2+
from itertools import combinations
3+
from sklearn.base import ClassifierMixin, RegressorMixin, clone
34
from sklearn.preprocessing import LabelEncoder
45
from sklearn.utils._tags import ClassifierTags, RegressorTags
56
from sklearn.utils.class_weight import compute_class_weight
67
from sklearn.utils.multiclass import check_classification_targets
78
from sklearn.utils.validation import check_array, check_is_fitted, check_X_y
9+
from joblib import Parallel, delayed
10+
811

912
from ._class import plqERM_Ridge
1013

@@ -21,6 +24,7 @@ class plq_Ridge_Classifier(plqERM_Ridge, ClassifierMixin):
2124
- Supports optional intercept fitting (via an augmented constant feature).
2225
- Provides standard methods ``fit``, ``predict``, and ``decision_function``.
2326
- Integrates with scikit-learn ecosystem (e.g., GridSearchCV, Pipeline).
27+
- Supports multiclass classification via OvR or OvO method.
2428
2529
Parameters
2630
----------
@@ -81,21 +85,38 @@ class plq_Ridge_Classifier(plqERM_Ridge, ClassifierMixin):
8185
- 'balanced' uses n_samples / (n_classes * n_j).
8286
- dict maps label -> weight in the ORIGINAL label space.
8387
88+
multi_class : str or list, default=[]
89+
Method for multiclass classification. Options:
90+
- 'ovo': One-vs-One, trains K*(K-1)/2 binary classifiers.
91+
- 'ovr': One-vs-Rest, trains K binary classifiers.
92+
- [ ] or ignored when only 2 classes are present.
93+
94+
n_jobs : int or None, default=None
95+
Number of parallel jobs for multiclass fitting.
96+
None means 1 (serial). -1 means use all available CPUs.
97+
Passed directly to joblib.Parallel.
98+
99+
84100
Attributes
85101
----------
86-
``coef_`` : ndarray of shape (n_features,)
87-
Coefficients excluding the intercept.
102+
``coef_ ``: ndarray of shape (n_features,) for binary, (n_estimators, n_features) for multiclass
103+
Coefficients of all fitted classifiers, excluding the intercept.
88104
89-
``intercept_`` : float
90-
Intercept term. 0.0 if ``fit_intercept=False``.
105+
``intercept_ ``: float for binary, ndarray of shape (n_estimators,) for multiclass
106+
Intercept term(s). 0.0 if ``fit_intercept=False``.
91107
92-
classes_ : ndarray of shape (2,)
108+
classes_ : ndarray of shape (n_classes,)
93109
Unique class labels in the original label space.
110+
111+
estimators_ : list, only present for multiclass
112+
For OvR: list of (coef, intercept) tuples, length K.
113+
For OvO: list of (coef, intercept, cls_i, cls_j) tuples, length K*(K-1)/2.
94114
95115
_label_encoder : LabelEncoder
96116
Encodes original labels into {0,1} for internal training.
97117
"""
98118

119+
99120
def __init__(
100121
self,
101122
loss,
@@ -117,6 +138,8 @@ def __init__(
117138
fit_intercept=True,
118139
intercept_scaling=1.0,
119140
class_weight=None,
141+
multi_class=[],
142+
n_jobs=None,
120143
):
121144
self.loss = loss
122145
self.constraint = constraint
@@ -148,6 +171,68 @@ def __init__(
148171

149172
self._label_encoder = None
150173
self.classes_ = None
174+
self.multi_class = multi_class
175+
self.n_jobs = n_jobs
176+
177+
@staticmethod
178+
def _fit_subproblem(estimator, X_aug, y_pm, sample_weight, fit_intercept):
179+
"""
180+
Train a plqERM_Ridge instance on a single multiclass subproblem.
181+
182+
Directly constructs plqERM_Ridge from estimator's hyperparameters,
183+
bypassing plq_Ridge_Classifier.fit() preprocessing (LabelEncoder,
184+
intercept augmentation) since X_aug and y_pm are already preprocessed.
185+
186+
Parameters
187+
----------
188+
estimator : plq_Ridge_Classifier
189+
Source estimator from which hyperparameters are extracted.
190+
Only used to read parameters, never fitted directly.
191+
192+
X_aug : ndarray of shape (n_samples, n_features[+1])
193+
Feature matrix, possibly already augmented with intercept column.
194+
Passed directly to plqERM_Ridge.fit() without further preprocessing.
195+
196+
y_pm : ndarray of shape (n_samples,)
197+
Binary labels already encoded in {-1, +1}.
198+
Passed directly to plqERM_Ridge.fit() without further preprocessing.
199+
200+
sample_weight : ndarray of shape (n_samples,) or None
201+
Per-sample weights.
202+
203+
fit_intercept : bool
204+
Whether to extract the last coefficient as intercept.
205+
Should match estimator.fit_intercept.
206+
207+
Returns
208+
-------
209+
``coef_``: ndarray of shape (n_features,)
210+
Fitted coefficients excluding the intercept column.
211+
212+
``intercept``: float
213+
Fitted intercept. 0.0 if fit_intercept is False.
214+
"""
215+
216+
clf = plqERM_Ridge(
217+
loss=estimator.loss,
218+
constraint=estimator.constraint,
219+
C=estimator.C,
220+
max_iter=estimator.max_iter,
221+
tol=estimator.tol,
222+
shrink=estimator.shrink,
223+
warm_start=estimator.warm_start,
224+
verbose=estimator.verbose,
225+
trace_freq=estimator.trace_freq,
226+
)
227+
clf.fit(X_aug, y_pm, sample_weight=sample_weight)
228+
if fit_intercept:
229+
coef = clf.coef_[:-1].copy()
230+
intercept = float(clf.coef_[-1])
231+
else:
232+
coef = clf.coef_.copy()
233+
intercept = 0.0
234+
return coef, intercept
235+
151236

152237
def fit(self, X, y, sample_weight=None):
153238
"""
@@ -183,9 +268,9 @@ def fit(self, X, y, sample_weight=None):
183268

184269
# Establish classes_ on ORIGINAL labels
185270
self.classes_ = np.unique(y)
186-
if self.classes_.size != 2:
271+
if self.classes_.size < 2:
187272
raise ValueError(
188-
f"plqERMClassifier currently supports only binary classification, "
273+
f"plqERMClassifier requires at least 2 classes, "
189274
f"but received {self.classes_.size} classes: {self.classes_}."
190275
)
191276

@@ -205,49 +290,135 @@ def fit(self, X, y, sample_weight=None):
205290
# Encode -> {0,1} -> {-1,+1}
206291
le = LabelEncoder().fit(self.classes_)
207292
self._label_encoder = le
208-
y01 = le.transform(y)
209-
y_pm = 2 * y01 - 1
210293

211294
# Add constant column for intercept
212295
X_aug = X
213296
if self.fit_intercept:
214297
col = np.full((X.shape[0], 1), self.intercept_scaling, dtype=X.dtype)
215298
X_aug = np.hstack([X, col])
216-
217-
super().fit(X_aug, y_pm, sample_weight=sample_weight)
218-
219-
# Split intercept
220-
if self.fit_intercept:
221-
self.intercept_ = float(self.coef_[-1])
222-
self.coef_ = self.coef_[:-1].copy()
299+
300+
if self.classes_.size == 2:
301+
y01 = le.transform(y)
302+
y_pm = 2 * y01 - 1
303+
304+
super().fit(X_aug, y_pm, sample_weight=sample_weight)
305+
306+
# Split intercept
307+
if self.fit_intercept:
308+
self.intercept_ = float(self.coef_[-1])
309+
self.coef_ = self.coef_[:-1].copy()
310+
else:
311+
self.intercept_ = 0.0
312+
223313
else:
224-
self.intercept_ = 0.0
314+
# Multiclass classification
315+
if self.multi_class not in ('ovr', 'ovo'):
316+
raise ValueError(
317+
f"multi_class must be 'ovr' or 'ovo' for multiclass problems, "
318+
f"got '{self.multi_class}'."
319+
)
320+
self._fit_multiclass(X_aug, y, sample_weight)
225321

226322
return self
323+
324+
325+
def _fit_multiclass(self, X_aug, y, sample_weight=None):
326+
"""
327+
Fit multiple binary classifiers for multiclass classification.
328+
329+
For OvR, trains K binary classifiers (one per class vs. all others).
330+
For OvO, trains K*(K-1)/2 binary classifiers (one per pair of classes).
331+
332+
Each binary subproblem is fully independent and dispatched in parallel
333+
via joblib.Parallel. Results are collected and stacked into coef_ and
334+
intercept_ matrices.
335+
336+
Parameters
337+
----------
338+
X_aug : ndarray of shape (n_samples, n_features[+1])
339+
Feature matrix, possibly augmented with intercept column.
340+
341+
y : ndarray of shape (n_samples,)
342+
Original (non-encoded) target labels.
343+
344+
sample_weight : ndarray of shape (n_samples,) or None
345+
Per-sample weights.
346+
"""
347+
if self.multi_class == 'ovr':
348+
# Build one task per class: positive=cls, negative=all others
349+
tasks = [
350+
(X_aug, np.where(y == cls, 1, -1).astype(np.float64), sample_weight)
351+
for cls in self.classes_
352+
]
353+
class_pairs = None
354+
355+
elif self.multi_class == 'ovo':
356+
# Build one task per pair of classes
357+
tasks = []
358+
class_pairs = []
359+
for cls_i, cls_j in combinations(self.classes_, 2):
360+
mask = np.isin(y, [cls_i, cls_j])
361+
y_pm = np.where(y[mask] == cls_i, 1, -1).astype(np.float64)
362+
sw_sub = sample_weight[mask] if sample_weight is not None else None
363+
tasks.append((X_aug[mask], y_pm, sw_sub))
364+
class_pairs.append((cls_i, cls_j))
365+
366+
# Dispatch all binary subproblems in parallel
367+
results = Parallel(n_jobs=self.n_jobs, prefer="threads")(
368+
delayed(self._fit_subproblem)(self, X_sub, y_pm, sw, self.fit_intercept)
369+
for X_sub, y_pm, sw in tasks
370+
)
371+
372+
# Collect results into estimators_
373+
if self.multi_class == 'ovr':
374+
self.estimators_ = [
375+
(coef, intercept) for coef, intercept in results
376+
]
377+
elif self.multi_class == 'ovo':
378+
self.estimators_ = [
379+
(coef, intercept, cls_i, cls_j) for (coef, intercept), (cls_i, cls_j) in zip(results, class_pairs)
380+
]
381+
382+
# Stack into matrices for efficient decision_function computation
383+
# OvR: coef_ shape (K, n_features), intercept_ shape (K,)
384+
# OvO: coef_ shape (K*(K-1)/2, n_features), intercept_ shape (K*(K-1)/2,)
385+
self.coef_ = np.array([e[0] for e in self.estimators_])
386+
self.intercept_ = np.array([e[1] for e in self.estimators_])
387+
227388

228389
def decision_function(self, X):
229390
"""
230391
Compute the decision function for samples in X.
231392
393+
For binary classification, returns a 1D array of scores.
394+
For OvR multiclass, returns a 2D array of shape (n_samples, K).
395+
For OvO multiclass, returns a 2D array of shape (n_samples, K*(K-1)/2).
396+
397+
Using coef_.T works uniformly for both binary (n_features,) and
398+
multiclass (n_estimators, n_features) shapes.
399+
232400
Parameters
233401
----------
234-
X : array-like of shape (n_samples, n_features)
402+
X : array-like of shape (n_samples, n_features)
235403
Input samples.
236404
237405
Returns
238406
-------
239-
ndarray of shape (n_samples,)
407+
ndarray of shape (n_samples,) or (n_samples, n_estimators)
240408
Continuous scores for each sample.
241409
"""
242410
check_is_fitted(
243411
self, attributes=["coef_", "intercept_", "_label_encoder", "classes_"]
244412
)
245413
X = check_array(X, accept_sparse=False, dtype=np.float64, order="C")
246-
return X @ self.coef_ + self.intercept_
414+
return X @ self.coef_.T + self.intercept_
247415

248416
def predict(self, X):
249417
"""
250418
Predict class labels for samples in X.
419+
For binary classification, thresholds the decision score at 0.
420+
For OvR, takes the argmax across K classifiers.
421+
For OvO, uses majority voting across K*(K-1)/2 classifiers.
251422
252423
Parameters
253424
----------
@@ -260,8 +431,45 @@ def predict(self, X):
260431
Predicted class labels in the original label space.
261432
"""
262433
scores = self.decision_function(X)
263-
pred01 = (scores >= 0).astype(int)
264-
return self._label_encoder.inverse_transform(pred01)
434+
435+
if self.classes_.size == 2:
436+
pred01 = (scores >= 0).astype(int)
437+
return self._label_encoder.inverse_transform(pred01)
438+
439+
elif self.multi_class == 'ovr':
440+
# OvR: class with highest decision score wins
441+
idx = np.argmax(scores, axis=1)
442+
return self.classes_[idx]
443+
444+
elif self.multi_class == 'ovo':
445+
# OvO: votes + normalized confidences to break ties
446+
# Note: score > 0 favors cls_i (first class in pair),
447+
n_samples = X.shape[0]
448+
n_classes = len(self.classes_)
449+
votes = np.zeros((n_samples, n_classes), dtype=np.float64)
450+
sum_of_confidences = np.zeros((n_samples, n_classes), dtype=np.float64)
451+
452+
for k, (_, _, cls_i, cls_j) in enumerate(self.estimators_):
453+
i = np.where(self.classes_ == cls_i)[0][0]
454+
j = np.where(self.classes_ == cls_j)[0][0]
455+
456+
# discrete vote: score > 0 favors cls_i, score <= 0 favors cls_j
457+
pred = (scores[:, k] > 0).astype(int)
458+
votes[:, i] += pred
459+
votes[:, j] += 1 - pred
460+
461+
# continuous confidence: score > 0 means cls_i is more confident
462+
sum_of_confidences[:, i] += scores[:, k]
463+
sum_of_confidences[:, j] -= scores[:, k]
464+
465+
# Monotonically transform to (-1/3, 1/3) to break ties without
466+
# overriding any decision made by a difference of >= 1 vote
467+
transformed_confidences = sum_of_confidences / (
468+
3 * (np.abs(sum_of_confidences) + 1)
469+
)
470+
471+
return self.classes_[np.argmax(votes + transformed_confidences, axis=1)]
472+
265473

266474
def __sklearn_tags__(self):
267475
"""

0 commit comments

Comments
 (0)