Skip to content

Commit 55cc374

Browse files
committed
add multi-class classification function with tutorial examples
1 parent 8a4d932 commit 55cc374

3 files changed

Lines changed: 405 additions & 25 deletions

File tree

doc/source/example.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ Example Gallery
1717
examples/RankRegression.ipynb
1818
examples/Path_solution.ipynb
1919
examples/Warm_start.ipynb
20-
examples/Sklearn_Mixin.ipynb
20+
examples/Sklearn_Mixin.ipynb
21+
examples/Multiclass_Classification.ipynb
2122
examples/NMF.ipynb
2223

2324
List of Examples
@@ -35,5 +36,6 @@ List of Examples
3536
examples/RankRegression.ipynb
3637
examples/Path_solution.ipynb
3738
examples/Warm_start.ipynb
38-
examples/Sklearn_Mixin.ipynb
39+
examples/Sklearn_Mixin.ipynb
40+
examples/Multiclass_Classification.ipynb
3941
examples/NMF.ipynb
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
{
2+
"nbformat": 4,
3+
"nbformat_minor": 0,
4+
"metadata": {
5+
"colab": {
6+
"provenance": []
7+
},
8+
"kernelspec": {
9+
"name": "python3",
10+
"display_name": "Python 3"
11+
},
12+
"language_info": {
13+
"name": "python"
14+
}
15+
},
16+
"cells": [
17+
{
18+
"cell_type": "markdown",
19+
"source": [
20+
"# Multi-class Classification Support\n",
21+
"\n",
22+
"[![Slides](https://img.shields.io/badge/🦌-ReHLine-blueviolet)](https://rehline-python.readthedocs.io/en/latest/)\n",
23+
"\n",
24+
"\n",
25+
"`PLQ_Ridge_Classifier` extends binary PLQ-ERM to multi-class problems via the `multi_class`\n",
26+
"parameter, supporting two standard decomposition strategies.\n",
27+
"\n",
28+
"**One-vs-Rest (OvR)** fits $K$ binary classifiers, one per class, each trained on the\n",
29+
"full dataset with relabelled targets ($+1$ for the class, $-1$ for all others).\n",
30+
"Prediction selects the class with the highest decision score:\n",
31+
"\n",
32+
"$$\\widehat{y} = \\arg\\max_{k \\in \\{1, \\dots, K\\}} f_k(\\mathbf{x})$$\n",
33+
"\n",
34+
"**One-vs-One (OvO)** fits $\\binom{K}{2}$ binary classifiers, one per class pair $(i, j)$,\n",
35+
"each trained only on samples belonging to those two classes.\n",
36+
"Prediction uses majority voting across all pairwise classifiers:\n",
37+
"\n",
38+
"$$\\widehat{y} = \\arg\\max_{k \\in \\{1, \\dots, K\\}} \\sum_{j \\neq k} \\mathbf{1}\\bigl[f_{kj}(\\mathbf{x}) > 0\\bigr]$$\n",
39+
"\n",
40+
"In both cases, each binary sub-problem is solved by a standard `PLQ_Ridge_Classifier`\n",
41+
"with the same loss, regularizer, and solver settings passed to the parent estimator."
42+
],
43+
"metadata": {
44+
"id": "qEniIUh2yiGb"
45+
}
46+
},
47+
{
48+
"cell_type": "code",
49+
"source": [
50+
"import numpy as np\n",
51+
"import pandas as pd\n",
52+
"from sklearn.datasets import make_classification\n",
53+
"from sklearn.model_selection import train_test_split\n",
54+
"from sklearn.metrics import accuracy_score\n",
55+
"from rehline import plq_Ridge_Classifier\n",
56+
"\n",
57+
"\n",
58+
"# generate data\n",
59+
"X_mc, y_mc = make_classification(\n",
60+
" n_samples=10000, n_features=20, n_informative=10,\n",
61+
" n_classes=4, n_clusters_per_class=1, random_state=42\n",
62+
")\n",
63+
"X_train, X_test, y_train, y_test = train_test_split(X_mc, y_mc, test_size=0.2, random_state=42)"
64+
],
65+
"metadata": {
66+
"id": "W5JaV_p2Ztua"
67+
},
68+
"execution_count": 2,
69+
"outputs": []
70+
},
71+
{
72+
"cell_type": "markdown",
73+
"source": [
74+
"### One-vs-Rest (OvR)\n",
75+
"\n",
76+
"\n",
77+
"In OvR, we train $K$ binary classifiers, one per class. Classifier $k$ learns to distinguish\n",
78+
"class $k$ from all other classes. The final prediction selects the class whose classifier\n",
79+
"reports the highest decision score:\n",
80+
"\n",
81+
"$$\\widehat{y} \\;=\\; \\arg\\max_{k \\in \\{1, \\dots, K\\}} f_k(\\mathbf{x})$$\n",
82+
"\n",
83+
"where $f_k(\\mathbf{x})$ is the signed distance from $\\mathbf{x}$ to the decision boundary of classifier $k$."
84+
],
85+
"metadata": {
86+
"id": "P_kBHXWG_qxa"
87+
}
88+
},
89+
{
90+
"cell_type": "code",
91+
"source": [
92+
"# predict using ovr method\n",
93+
"plq_ovr = plq_Ridge_Classifier(\n",
94+
" loss={'name': 'svm'}, C=1.0,\n",
95+
" fit_intercept=True, max_iter=50000, multi_class='ovr',\n",
96+
")\n",
97+
"plq_ovr.fit(X_train, y_train)\n",
98+
"\n",
99+
"y_pred = plq_ovr.predict(X_test)\n",
100+
"print(f\"plq OvR accuracy: {accuracy_score(y_test, y_pred):.4f}\")"
101+
],
102+
"metadata": {
103+
"colab": {
104+
"base_uri": "https://localhost:8080/"
105+
},
106+
"id": "eNgPMJdctWyi",
107+
"outputId": "b654fab7-2aff-4fa0-cf56-1b4ef6e82313"
108+
},
109+
"execution_count": 3,
110+
"outputs": [
111+
{
112+
"output_type": "stream",
113+
"name": "stdout",
114+
"text": [
115+
"plq OvR accuracy: 0.7770\n"
116+
]
117+
}
118+
]
119+
},
120+
{
121+
"cell_type": "markdown",
122+
"source": [
123+
"### One-vs-One (OvO)\n",
124+
"\n",
125+
"In OvO, we train $\\binom{K}{2}$ binary classifiers, one for each pair of classes $(i, j)$.\n",
126+
"Each classifier $f_{ij}$ votes for either class $i$ or class $j$. The final prediction\n",
127+
"is the class that receives the most votes:\n",
128+
"\n",
129+
"$$\\widehat{y} = \\arg\\max_{k \\in \\{1, \\dots, K\\}} \\sum_{j \\neq k} \\mathbf{1}\\bigl[f_{kj}(\\mathbf{x}) > 0\\bigr]$$\n",
130+
"\n",
131+
"where $\\mathbf{1}[\\cdot]$ is the indicator function, and $f_{kj}(\\mathbf{x}) > 0$ means\n",
132+
"classifier $(k, j)$ votes for class $k$."
133+
],
134+
"metadata": {
135+
"id": "bnl41NChAlUr"
136+
}
137+
},
138+
{
139+
"cell_type": "code",
140+
"source": [
141+
"# predict using ovo method\n",
142+
"plq_ovo = plq_Ridge_Classifier(\n",
143+
" loss={'name': 'svm'}, C=1.0,\n",
144+
" fit_intercept=True, max_iter=50000, multi_class='ovo',\n",
145+
")\n",
146+
"plq_ovo.fit(X_train, y_train)\n",
147+
"y_pred = plq_ovo.predict(X_test)\n",
148+
"\n",
149+
"print(f\"plq OvO accuracy: {accuracy_score(y_test, y_pred):.4f}\")"
150+
],
151+
"metadata": {
152+
"colab": {
153+
"base_uri": "https://localhost:8080/"
154+
},
155+
"id": "s9nalNTit9Jl",
156+
"outputId": "134c9f27-444b-41ee-e606-9985dbbdb9c9"
157+
},
158+
"execution_count": 4,
159+
"outputs": [
160+
{
161+
"output_type": "stream",
162+
"name": "stdout",
163+
"text": [
164+
"plq OvO accuracy: 0.8035\n"
165+
]
166+
}
167+
]
168+
}
169+
]
170+
}

0 commit comments

Comments
 (0)