33import numpy as np
44from sklearn .datasets import make_hastie_10_2
55
6- from rehline import plqERM_Ridge_path_sol
6+ from rehline import CQR_Ridge_path_sol , plqERM_Ridge_path_sol
77
88
99def test_path_sol_warm_start_shapes ():
1010 """plqERM_Ridge_path_sol should return arrays with consistent shapes."""
1111 X , y = make_hastie_10_2 (random_state = 1 )
1212 loss = {"name" : "svm" }
1313 # Use a small number of C values so the test is fast
14- Cs = np .logspace (- 3 , 3 , 10 , base = 2 )
14+ Cs = np .logspace (- 3 , 3 , 7 , base = 2 )
1515
1616 (Cs_out , times , n_iters , loss_vals , l2_norms , coefs ) = plqERM_Ridge_path_sol (
1717 X ,
@@ -33,7 +33,7 @@ def test_path_sol_warm_start_shapes():
3333 assert len (times ) == n_path , f"times length should be { n_path } , got { len (times )} "
3434 assert len (n_iters ) == n_path , f"n_iters length should be { n_path } , got { len (n_iters )} "
3535 assert len (loss_vals ) == n_path , f"loss_vals length should be { n_path } , got { len (loss_vals )} "
36- assert coefs .shape == (n_path , n_features ), f"coefs shape should be ({ n_path } , { n_features } ), got { coefs .shape } "
36+ assert coefs .shape == (n_features , n_path ), f"coefs shape should be ({ n_features } , { n_path } ), got { coefs .shape } "
3737
3838 # All timing values should be non-negative
3939 assert np .all (np .array (times ) >= 0 ), "All timing values should be non-negative"
@@ -68,3 +68,108 @@ def test_path_sol_loss_range_with_larger_C():
6868 assert loss_vals [- 1 ] <= loss_vals [0 ] * 1.05 , (
6969 f"Loss at C=10 ({ loss_vals [- 1 ]:.2f} ) should be ≤ 105% of loss at C=0.01 ({ loss_vals [0 ]:.2f} )"
7070 )
71+
72+
73+ def test_path_sol_generates_default_Cs_when_not_provided ():
74+ """plqERM_Ridge_path_sol should generate a sorted path when Cs is omitted."""
75+ X , y = make_hastie_10_2 (random_state = 1 )
76+ loss = {"name" : "svm" }
77+
78+ Cs_out , n_iters , loss_vals , l2_norms , coefs = plqERM_Ridge_path_sol (
79+ X ,
80+ y ,
81+ loss = loss ,
82+ eps = 1e-2 ,
83+ n_Cs = 4 ,
84+ max_iter = 100000 ,
85+ tol = 1e-3 ,
86+ verbose = 0 ,
87+ warm_start = False ,
88+ constraint = None ,
89+ return_time = False ,
90+ )
91+
92+ assert len (Cs_out ) == 4
93+ assert np .all (np .diff (Cs_out ) >= 0 ), "Generated Cs should be sorted in ascending order"
94+ assert len (n_iters ) == 4
95+ assert len (loss_vals ) == 4
96+ assert len (l2_norms ) == 4
97+ assert coefs .shape == (X .shape [1 ], 4 )
98+
99+
100+ def test_cqr_path_sol_shapes_without_times ():
101+ """CQR_Ridge_path_sol should return consistently shaped outputs without timing."""
102+ np .random .seed (42 )
103+ X = np .random .randn (200 , 2 )
104+ y = X @ np .array ([1.0 , 2.0 ]) + np .random .randn (200 )
105+ quantiles = [0.1 , 0.5 , 0.9 ]
106+ Cs = np .array ([0.1 , 1.0 ])
107+
108+ Cs_out , models , coefs , intercepts = CQR_Ridge_path_sol (
109+ X ,
110+ y ,
111+ quantiles = quantiles ,
112+ Cs = Cs ,
113+ max_iter = 20000 ,
114+ tol = 1e-3 ,
115+ verbose = 0 ,
116+ warm_start = False ,
117+ return_time = False ,
118+ )
119+
120+ assert np .array_equal (Cs_out , Cs )
121+ assert len (models ) == len (Cs )
122+ assert coefs .shape == (len (Cs ), len (quantiles ), X .shape [1 ])
123+ assert intercepts .shape == (len (Cs ), len (quantiles ))
124+
125+
126+ def test_cqr_path_sol_generates_default_Cs_with_times ():
127+ """CQR_Ridge_path_sol should generate default Cs and return timing info."""
128+ np .random .seed (0 )
129+ X = np .random .randn (120 , 3 )
130+ y = X @ np .array ([1.0 , - 0.5 , 2.0 ]) + np .random .randn (120 )
131+ quantiles = [0.25 , 0.5 , 0.75 ]
132+
133+ Cs_out , models , coefs , intercepts , fit_times = CQR_Ridge_path_sol (
134+ X ,
135+ y ,
136+ quantiles = quantiles ,
137+ eps = 1e-3 ,
138+ n_Cs = 3 ,
139+ max_iter = 20000 ,
140+ tol = 1e-3 ,
141+ verbose = 0 ,
142+ warm_start = True ,
143+ return_time = True ,
144+ )
145+
146+ expected_Cs = np .power (10.0 , np .linspace (np .log10 (1e-3 ), np .log10 (10 ), 3 ))
147+
148+ assert np .allclose (Cs_out , expected_Cs )
149+ assert len (models ) == 3
150+ assert coefs .shape == (3 , len (quantiles ), X .shape [1 ])
151+ assert intercepts .shape == (3 , len (quantiles ))
152+ assert len (fit_times ) == 3
153+ assert np .all (np .array (fit_times ) >= 0 )
154+
155+
156+ def test_cqr_path_sol_verbose_reports_progress (capsys ):
157+ """CQR_Ridge_path_sol should print per-C progress when verbose is enabled."""
158+ np .random .seed (1 )
159+ X = np .random .randn (80 , 2 )
160+ y = X @ np .array ([1.5 , - 0.5 ]) + np .random .randn (80 )
161+
162+ CQR_Ridge_path_sol (
163+ X ,
164+ y ,
165+ quantiles = [0.2 , 0.8 ],
166+ Cs = np .array ([0.5 ]),
167+ max_iter = 20000 ,
168+ tol = 1e-3 ,
169+ verbose = 1 ,
170+ warm_start = False ,
171+ return_time = True ,
172+ )
173+
174+ captured = capsys .readouterr ()
175+ assert "[OK] C=" in captured .out
0 commit comments