Skip to content

Commit d8c875e

Browse files
committed
improve cov
1 parent f2ca291 commit d8c875e

6 files changed

Lines changed: 142 additions & 13 deletions

File tree

rehline/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,13 @@
1212
from ._internal import rehline_internal, rehline_result
1313
from ._loss import ReHLoss
1414
from ._mf_class import plqMF_Ridge
15-
from ._path_sol import plqERM_Ridge_path_sol
16-
from ._sklearn_mixin import plq_Ridge_Classifier, plq_Ridge_Regressor, plq_ElasticNet_Classifier, plq_ElasticNet_Regressor
15+
from ._path_sol import CQR_Ridge_path_sol, plqERM_Ridge_path_sol
16+
from ._sklearn_mixin import (
17+
plq_ElasticNet_Classifier,
18+
plq_ElasticNet_Regressor,
19+
plq_Ridge_Classifier,
20+
plq_Ridge_Regressor,
21+
)
1722

1823
__all__ = (
1924
"_BaseReHLine",
@@ -23,6 +28,7 @@
2328
"CQR_Ridge",
2429
"plqERM_ElasticNet",
2530
"plqMF_Ridge",
31+
"CQR_Ridge_path_sol",
2632
"plqERM_Ridge_path_sol",
2733
"plq_Ridge_Classifier",
2834
"plq_Ridge_Regressor",

tests/test_bugfixes.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def test_ratings_are_rounded_to_half(self):
135135
class TestPathSolVerbose:
136136
"""Verify that plqERM_Ridge_path_sol does not crash with verbose + no timing."""
137137

138-
def test_verbose_without_return_time(self):
138+
def test_verbose_without_return_time(self, capsys):
139139
"""verbose=1 + return_time=False must not raise NameError."""
140140
X, y = _make_classification_data(n=100, d=3)
141141
loss = {"name": "svm"}
@@ -158,8 +158,11 @@ def test_verbose_without_return_time(self):
158158
Cs_out, n_iters, loss_vals, l2_norms, coefs = result
159159
assert len(Cs_out) == 2
160160
assert len(n_iters) == 2
161+
captured = capsys.readouterr()
162+
assert "PLQ ERM Path Solution Results" in captured.out
163+
assert "Time (s)" not in captured.out
161164

162-
def test_verbose_with_return_time(self):
165+
def test_verbose_with_return_time(self, capsys):
163166
"""verbose=1 + return_time=True should still work."""
164167
X, y = _make_classification_data(n=100, d=3)
165168
loss = {"name": "svm"}
@@ -177,6 +180,9 @@ def test_verbose_with_return_time(self):
177180
)
178181

179182
assert len(result) == 6, f"Expected 6 return values, got {len(result)}"
183+
captured = capsys.readouterr()
184+
assert "PLQ ERM Path Solution Results" in captured.out
185+
assert "Total Time" in captured.out
180186

181187

182188
# ===========================================================================

tests/test_multiclass.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def test_decision_function_shapes():
183183

184184
# Binary
185185
y_bin = np.random.randint(0, 2, n_samples)
186-
clf = plq_Ridge_Classifier(loss={"name": "svm"}, C=1.0, tol=1e-5)
186+
clf = plq_Ridge_Classifier(loss={"name": "svm"}, C=1.0, tol=1e-5, max_iter=1_000_000)
187187
clf.fit(X, y_bin)
188188
assert clf.decision_function(X).shape == (n_samples,), "Binary decision_function should have shape (n_samples,)"
189189

@@ -206,6 +206,7 @@ def test_decision_function_shapes():
206206
C=1.0,
207207
multi_class="ovo",
208208
tol=1e-5,
209+
max_iter=1_000_000,
209210
)
210211
clf_ovo.fit(X, y_multi)
211212
assert clf_ovo.decision_function(X).shape == (n_samples, 6), (

tests/test_path_sol.py

Lines changed: 108 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
import numpy as np
44
from sklearn.datasets import make_hastie_10_2
55

6-
from rehline import plqERM_Ridge_path_sol
6+
from rehline import CQR_Ridge_path_sol, plqERM_Ridge_path_sol
77

88

99
def test_path_sol_warm_start_shapes():
1010
"""plqERM_Ridge_path_sol should return arrays with consistent shapes."""
1111
X, y = make_hastie_10_2(random_state=1)
1212
loss = {"name": "svm"}
1313
# Use a small number of C values so the test is fast
14-
Cs = np.logspace(-3, 3, 10, base=2)
14+
Cs = np.logspace(-3, 3, 7, base=2)
1515

1616
(Cs_out, times, n_iters, loss_vals, l2_norms, coefs) = plqERM_Ridge_path_sol(
1717
X,
@@ -33,7 +33,7 @@ def test_path_sol_warm_start_shapes():
3333
assert len(times) == n_path, f"times length should be {n_path}, got {len(times)}"
3434
assert len(n_iters) == n_path, f"n_iters length should be {n_path}, got {len(n_iters)}"
3535
assert len(loss_vals) == n_path, f"loss_vals length should be {n_path}, got {len(loss_vals)}"
36-
assert coefs.shape == (n_path, n_features), f"coefs shape should be ({n_path}, {n_features}), got {coefs.shape}"
36+
assert coefs.shape == (n_features, n_path), f"coefs shape should be ({n_features}, {n_path}), got {coefs.shape}"
3737

3838
# All timing values should be non-negative
3939
assert np.all(np.array(times) >= 0), "All timing values should be non-negative"
@@ -68,3 +68,108 @@ def test_path_sol_loss_range_with_larger_C():
6868
assert loss_vals[-1] <= loss_vals[0] * 1.05, (
6969
f"Loss at C=10 ({loss_vals[-1]:.2f}) should be ≤ 105% of loss at C=0.01 ({loss_vals[0]:.2f})"
7070
)
71+
72+
73+
def test_path_sol_generates_default_Cs_when_not_provided():
74+
"""plqERM_Ridge_path_sol should generate a sorted path when Cs is omitted."""
75+
X, y = make_hastie_10_2(random_state=1)
76+
loss = {"name": "svm"}
77+
78+
Cs_out, n_iters, loss_vals, l2_norms, coefs = plqERM_Ridge_path_sol(
79+
X,
80+
y,
81+
loss=loss,
82+
eps=1e-2,
83+
n_Cs=4,
84+
max_iter=100000,
85+
tol=1e-3,
86+
verbose=0,
87+
warm_start=False,
88+
constraint=None,
89+
return_time=False,
90+
)
91+
92+
assert len(Cs_out) == 4
93+
assert np.all(np.diff(Cs_out) >= 0), "Generated Cs should be sorted in ascending order"
94+
assert len(n_iters) == 4
95+
assert len(loss_vals) == 4
96+
assert len(l2_norms) == 4
97+
assert coefs.shape == (X.shape[1], 4)
98+
99+
100+
def test_cqr_path_sol_shapes_without_times():
101+
"""CQR_Ridge_path_sol should return consistently shaped outputs without timing."""
102+
np.random.seed(42)
103+
X = np.random.randn(200, 2)
104+
y = X @ np.array([1.0, 2.0]) + np.random.randn(200)
105+
quantiles = [0.1, 0.5, 0.9]
106+
Cs = np.array([0.1, 1.0])
107+
108+
Cs_out, models, coefs, intercepts = CQR_Ridge_path_sol(
109+
X,
110+
y,
111+
quantiles=quantiles,
112+
Cs=Cs,
113+
max_iter=20000,
114+
tol=1e-3,
115+
verbose=0,
116+
warm_start=False,
117+
return_time=False,
118+
)
119+
120+
assert np.array_equal(Cs_out, Cs)
121+
assert len(models) == len(Cs)
122+
assert coefs.shape == (len(Cs), len(quantiles), X.shape[1])
123+
assert intercepts.shape == (len(Cs), len(quantiles))
124+
125+
126+
def test_cqr_path_sol_generates_default_Cs_with_times():
127+
"""CQR_Ridge_path_sol should generate default Cs and return timing info."""
128+
np.random.seed(0)
129+
X = np.random.randn(120, 3)
130+
y = X @ np.array([1.0, -0.5, 2.0]) + np.random.randn(120)
131+
quantiles = [0.25, 0.5, 0.75]
132+
133+
Cs_out, models, coefs, intercepts, fit_times = CQR_Ridge_path_sol(
134+
X,
135+
y,
136+
quantiles=quantiles,
137+
eps=1e-3,
138+
n_Cs=3,
139+
max_iter=20000,
140+
tol=1e-3,
141+
verbose=0,
142+
warm_start=True,
143+
return_time=True,
144+
)
145+
146+
expected_Cs = np.power(10.0, np.linspace(np.log10(1e-3), np.log10(10), 3))
147+
148+
assert np.allclose(Cs_out, expected_Cs)
149+
assert len(models) == 3
150+
assert coefs.shape == (3, len(quantiles), X.shape[1])
151+
assert intercepts.shape == (3, len(quantiles))
152+
assert len(fit_times) == 3
153+
assert np.all(np.array(fit_times) >= 0)
154+
155+
156+
def test_cqr_path_sol_verbose_reports_progress(capsys):
157+
"""CQR_Ridge_path_sol should print per-C progress when verbose is enabled."""
158+
np.random.seed(1)
159+
X = np.random.randn(80, 2)
160+
y = X @ np.array([1.5, -0.5]) + np.random.randn(80)
161+
162+
CQR_Ridge_path_sol(
163+
X,
164+
y,
165+
quantiles=[0.2, 0.8],
166+
Cs=np.array([0.5]),
167+
max_iter=20000,
168+
tol=1e-3,
169+
verbose=1,
170+
warm_start=False,
171+
return_time=True,
172+
)
173+
174+
captured = capsys.readouterr()
175+
assert "[OK] C=" in captured.out

tests/test_sklearn_mixin.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_classifier_pipeline_fits_and_predicts():
3737
pipe = Pipeline(
3838
[
3939
("scaler", StandardScaler()),
40-
("clf", plq_Ridge_Classifier(loss={"name": "svm"}, C=1.0)),
40+
("clf", plq_Ridge_Classifier(loss={"name": "svm"}, C=1.0, tol=1e-3, max_iter=1_000_000)),
4141
]
4242
)
4343
pipe.fit(X, y)
@@ -50,7 +50,16 @@ def test_classifier_pipeline_fits_and_predicts():
5050

5151
def test_classifier_cross_val_score():
5252
"""cross_val_score on plq_Ridge_Classifier pipeline should return reasonable scores."""
53-
X, y = _clf_dataset()
53+
X, y = make_classification(
54+
n_samples=500,
55+
n_features=10,
56+
n_informative=5,
57+
n_redundant=2,
58+
n_classes=2,
59+
class_sep=1.5,
60+
flip_y=0.0,
61+
random_state=42,
62+
)
5463
pipe = Pipeline(
5564
[
5665
("scaler", StandardScaler()),
@@ -72,6 +81,7 @@ def test_classifier_with_intercept_scaling():
7281
C=1.0,
7382
fit_intercept=True,
7483
intercept_scaling=1.0,
84+
max_iter=1_000_000,
7585
)
7686
clf.fit(X_tr, y_tr)
7787
preds = clf.predict(X_te)
@@ -86,6 +96,7 @@ def test_classifier_with_nonneg_constraint():
8696
loss={"name": "svm"},
8797
C=1.0,
8898
constraint=[{"name": "nonnegative"}],
99+
max_iter=1_000_000,
89100
)
90101
clf.fit(X, y)
91102
# Allow 1e-2 numerical slack — the solver may not satisfy the constraint

tests/test_svr.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_plqERM_Ridge_svr_matches_sklearn():
3333
reg_skl.fit(X, y)
3434
coef_skl = reg_skl.coef_.flatten()
3535

36-
reg_reh = plqERM_Ridge(loss={"name": "svr", "epsilon": epsilon}, C=C)
36+
reg_reh = plqERM_Ridge(loss={"name": "svr", "epsilon": epsilon}, C=C, tol=1e-4, max_iter=100000)
3737
reg_reh.fit(X=X, y=y)
3838
coef_reh = reg_reh.coef_.flatten()
3939

@@ -53,7 +53,7 @@ def test_ReHLine_manual_svr_params_match_builtin():
5353
n = X.shape[0]
5454

5555
# Built-in loss
56-
reg_builtin = plqERM_Ridge(loss={"name": "svr", "epsilon": epsilon}, C=C)
56+
reg_builtin = plqERM_Ridge(loss={"name": "svr", "epsilon": epsilon}, C=C, tol=1e-6, max_iter=1_000_000)
5757
reg_builtin.fit(X=X, y=y)
5858
coef_builtin = reg_builtin.coef_.flatten()
5959

@@ -65,7 +65,7 @@ def test_ReHLine_manual_svr_params_match_builtin():
6565
V[1] = C * (y - epsilon)
6666

6767
# When U/V are pre-scaled by C, use C=1.0 to avoid double-counting
68-
reg_manual = ReHLine(C=1.0)
68+
reg_manual = ReHLine(C=1.0, tol=1e-6, max_iter=1_000_000)
6969
reg_manual._U, reg_manual._V = U, V
7070
reg_manual.fit(X=X)
7171
coef_manual = reg_manual.coef_.flatten()

0 commit comments

Comments
 (0)