Skip to content

Commit fc7fa59

Browse files
committed
2 parents 0d01f86 + 650a10d commit fc7fa59

5 files changed

Lines changed: 257 additions & 39 deletions

File tree

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ jobs:
3030
python-version: ${{ matrix.python-version }}
3131

3232
- name: Install package and test dependencies
33-
run: pip install ".[test]" pytest-cov
33+
run: pip install -e ".[test]" pytest-cov
3434

3535
- name: Run tests
36-
run: pytest -v --tb=short --cov=rehline --cov-report=xml --cov-report=term
36+
run: pytest -v --tb=short --cov-report=xml
3737

3838
- name: Upload coverage to Codecov
3939
if: success()

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ rehline.egg-info/
1414
*.log
1515
env/
1616

17+
# Coverage reports
18+
.coverage
19+
coverage.xml
20+
htmlcov/
21+
1722
# Files during package building
1823
eigen-3.4.0/
1924
eigen-5.0.1/

pyproject.toml

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,37 @@ py-modules = ["build"]
3434
build = "cp*"
3535

3636
[project.optional-dependencies]
37-
test = ["pytest >= 7.0", "pandas >= 1.5.0"]
37+
test = ["pytest >= 7.0", "pandas >= 1.5.0", "pytest-cov"]
3838

3939
[tool.pytest.ini_options]
4040
testpaths = ["tests"]
4141
python_files = ["test_*.py"]
4242
python_classes = ["Test*"]
4343
python_functions = ["test_*"]
44+
addopts = "--cov=rehline --cov-report=term-missing"
4445

4546
[build-system]
4647
requires = ["requests ~= 2.31.0", "pybind11 ~= 3.0.0", "setuptools >= 80.0.0", "wheel >= 0.42.0", "setuptools-scm >= 8.0"]
4748
build-backend = "setuptools.build_meta"
4849

50+
# --- Coverage ---
51+
[tool.coverage.run]
52+
source = ["rehline"]
53+
omit = ["rehline/_internal*"]
54+
relative_files = true
55+
56+
[tool.coverage.paths]
57+
source = [
58+
"rehline/",
59+
"*/site-packages/rehline/",
60+
]
61+
62+
[tool.coverage.report]
63+
exclude_lines = [
64+
"pragma: no cover",
65+
"if TYPE_CHECKING:",
66+
]
67+
4968
# --- Ruff (linter + formatter) ---
5069
[tool.ruff]
5170
target-version = "py310"

rehline/_mf_class.py

Lines changed: 67 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import warnings
44

55
import numpy as np
6-
import pandas as pd
76
from sklearn.base import BaseEstimator
87
from sklearn.exceptions import ConvergenceWarning
98
from sklearn.utils.validation import _check_sample_weight
@@ -199,42 +198,22 @@ def __init__(
199198
tol_CD=1e-4,
200199
verbose=0,
201200
):
202-
# check input
203-
errors = []
204-
checks = [
205-
(0 < rho < 1, "rho must be between 0 and 1"),
206-
(C > 0, "C must be positive"),
207-
(tol_CD > 0, "tol_CD must be positive"),
208-
(tol > 0, "tol must be positive"),
209-
]
210-
for condition, error_msg in checks:
211-
if not condition:
212-
errors.append(error_msg)
213-
if errors:
214-
raise ValueError("; ".join(errors))
215-
216201
# parameter initialization
217-
## -----------------------------basic perameters-----------------------------
202+
## -----------------------------basic parameters-----------------------------
218203
self.n_users = n_users
219204
self.n_items = n_items
220205
self.loss = loss
221206
self.constraint_user = constraint_user if constraint_user is not None else []
222207
self.constraint_item = constraint_item if constraint_item is not None else []
223208
self.biased = biased
224-
## -----------------------------hyper perameters-----------------------------
209+
## -----------------------------hyper parameters-----------------------------
225210
self.rank = rank
226211
self.C = C
227212
self.rho = rho
228-
## --------------------------coefficient perameters--------------------------
213+
## -------------------------initialization parameters------------------------
229214
self.init_mean = init_mean
230215
self.init_sd = init_sd
231216
self.random_state = random_state
232-
if self.random_state:
233-
np.random.seed(random_state)
234-
self.P = np.random.normal(loc=init_mean, scale=init_sd, size=(n_users, rank))
235-
self.Q = np.random.normal(loc=init_mean, scale=init_sd, size=(n_items, rank))
236-
self.bu = np.zeros(n_users) if self.biased else None
237-
self.bi = np.zeros(n_items) if self.biased else None
238217
## ----------------------------fitting parameters----------------------------
239218
self.max_iter_CD = max_iter_CD
240219
self.tol_CD = tol_CD
@@ -266,17 +245,62 @@ def fit(self, X, y, sample_weight=None):
266245
An instance of the estimator.
267246
268247
"""
248+
# check input
249+
## parameter validation
250+
errors = []
251+
checks = [
252+
(0 < self.rho < 1, "rho must be between 0 and 1"),
253+
(self.C > 0, "C must be positive"),
254+
(self.tol_CD > 0, "tol_CD must be positive"),
255+
(self.tol > 0, "tol must be positive"),
256+
]
257+
for condition, error_msg in checks:
258+
if not condition:
259+
errors.append(error_msg)
260+
if errors:
261+
raise ValueError("; ".join(errors))
262+
263+
## data validation
264+
X = np.asarray(X)
265+
y = np.asarray(y)
266+
if X.ndim != 2 or X.shape[1] != 2:
267+
raise ValueError("X must have shape (n_ratings, 2)")
268+
if X.shape[0] != len(y):
269+
raise ValueError("X and y must have the same number of samples")
270+
user_ids = X[:, 0].astype(int)
271+
item_ids = X[:, 1].astype(int)
272+
if np.any(user_ids < 0) or np.any(user_ids >= self.n_users):
273+
raise ValueError("User IDs must be in [0, n_users)")
274+
if np.any(item_ids < 0) or np.any(item_ids >= self.n_items):
275+
raise ValueError("Item IDs must be in [0, n_items)")
276+
269277
# Preparation
270-
self.n_ratings = len(y)
271-
self.history = np.nan * np.zeros((self.max_iter_CD + 1, 2))
278+
## number of training observations
279+
self.n_ratings = len(y)
280+
## convergence trace
281+
self.history = np.full((self.max_iter_CD + 1, 2), np.nan)
282+
## sample weights
272283
self.sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)
273-
274-
X_df = pd.DataFrame(X, columns=["user", "item"])
275-
uidx_map = X_df.groupby("user").indices
276-
iidx_map = X_df.groupby("item").indices
277-
self.Iu = [uidx_map.get(u, np.array([], dtype=int)) for u in range(self.n_users)]
278-
self.Ui = [iidx_map.get(i, np.array([], dtype=int)) for i in range(self.n_items)]
279-
284+
## random number generator
285+
rng = np.random.default_rng(self.random_state)
286+
287+
## indices to locate interactions given a user or item id
288+
### user side: Iu[u] = row indices of interactions by user u
289+
sort_idx_users = np.argsort(X[:, 0], kind='stable')
290+
sorted_users = X[sort_idx_users, 0]
291+
counts = np.unique(sorted_users, return_counts=True)[1]
292+
self.Iu = [np.array([], dtype=int) for _ in range(self.n_users)]
293+
for u, idxs in zip(sorted_users[np.cumsum(counts) - counts], np.split(sort_idx_users, np.cumsum(counts)[:-1])):
294+
self.Iu[u] = idxs
295+
### item side: Ui[i] = row indices of interactions that involve item i
296+
sort_idx_items = np.argsort(X[:, 1], kind='stable')
297+
sorted_items = X[sort_idx_items, 1]
298+
counts = np.unique(sorted_items, return_counts=True)[1]
299+
self.Ui = [np.array([], dtype=int) for _ in range(self.n_items)]
300+
for i, idxs in zip(sorted_items[np.cumsum(counts) - counts], np.split(sort_idx_items, np.cumsum(counts)[:-1])):
301+
self.Ui[i] = idxs
302+
303+
## effective C when updating user/item blocks (to match rehline formulation: C * PLQ_loss + 0.5 * l_2)
280304
C_user = self.C * self.n_users / (self.rho) / 2
281305
C_item = self.C * self.n_items / (1 - self.rho) / 2
282306

@@ -289,6 +313,12 @@ def fit(self, X, y, sample_weight=None):
289313
)
290314
)
291315

316+
# Model Initialization
317+
self.P = rng.normal(loc=self.init_mean, scale=self.init_sd, size=(self.n_users, self.rank))
318+
self.Q = rng.normal(loc=self.init_mean, scale=self.init_sd, size=(self.n_items, self.rank))
319+
self.bu = np.zeros(self.n_users) if self.biased else None
320+
self.bi = np.zeros(self.n_items) if self.biased else None
321+
292322
# CD algorithm
293323
self.history[0] = self.obj(X, y)
294324
for iter_idx in range(self.max_iter_CD):
@@ -435,7 +465,7 @@ def fit(self, X, y, sample_weight=None):
435465
obj = f"{self.history[iter_idx + 1][1]:.6f}"
436466
print(f"{iter_idx + 1:<12} {mean_loss:<20} {obj:<20}")
437467

438-
if obj_diff < self.tol_CD:
468+
if abs(obj_diff) < self.tol_CD:
439469
break
440470

441471
return self
@@ -496,9 +526,10 @@ def obj(self, X, y):
496526
item_penalty = np.sum(self.Q**2) * (1 - self.rho) / self.n_items
497527
penalty = user_penalty + item_penalty
498528

499-
y_pred = self.decision_function(X)
500-
U, V, Tau, S, T = _make_loss_rehline_param(loss=self.loss, X=X, y=y)
529+
X_dummy = np.ones((len(y), 1)) # not used in loss computation, only shape matters for loss param construction
530+
U, V, Tau, S, T = _make_loss_rehline_param(loss=self.loss, X=X_dummy, y=y)
501531
loss = ReHLoss(U, V, S, T, Tau)
532+
y_pred = self.decision_function(X)
502533
loss_term = loss(y_pred)
503534

504535
return loss_term, self.C * loss_term + penalty

tests/test_mf.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,169 @@ def test_mf_hinge_classification_fits(mf_data):
110110
assert accuracy > 0.5, f"Hinge-loss MF accuracy ({accuracy:.3f}) should be > 0.5"
111111

112112

113+
def test_mf_data_validation_errors():
114+
"""Test data validation raises appropriate errors."""
115+
# Test X with wrong shape (not 2 columns)
116+
with pytest.raises(ValueError, match="X must have shape"):
117+
model = plqMF_Ridge(n_users=10, n_items=10, loss={"name": "mae"})
118+
model.fit(np.array([[0, 0, 0]]), np.array([1.0]))
119+
120+
# Test X and y mismatch
121+
with pytest.raises(ValueError, match="X and y must have the same number"):
122+
model = plqMF_Ridge(n_users=10, n_items=10, loss={"name": "mae"})
123+
model.fit(np.array([[0, 0]]), np.array([1.0, 2.0]))
124+
125+
# Test invalid user ID (negative)
126+
with pytest.raises(ValueError, match="User IDs must be in"):
127+
model = plqMF_Ridge(n_users=10, n_items=10, loss={"name": "mae"})
128+
model.fit(np.array([[-1, 0]]), np.array([1.0]))
129+
130+
# Test invalid user ID (>= n_users)
131+
with pytest.raises(ValueError, match="User IDs must be in"):
132+
model = plqMF_Ridge(n_users=10, n_items=10, loss={"name": "mae"})
133+
model.fit(np.array([[10, 0]]), np.array([1.0]))
134+
135+
# Test invalid item ID (negative)
136+
with pytest.raises(ValueError, match="Item IDs must be in"):
137+
model = plqMF_Ridge(n_users=10, n_items=10, loss={"name": "mae"})
138+
model.fit(np.array([[0, -1]]), np.array([1.0]))
139+
140+
# Test invalid item ID (>= n_items)
141+
with pytest.raises(ValueError, match="Item IDs must be in"):
142+
model = plqMF_Ridge(n_users=10, n_items=10, loss={"name": "mae"})
143+
model.fit(np.array([[0, 10]]), np.array([1.0]))
144+
145+
146+
def test_mf_cold_start_users_items():
147+
"""Test cold start handling: users/items with no interactions."""
148+
# Create data where user 0 and item 0 have no interactions
149+
# n_users=3, n_items=3, but only users 1,2 and items 1,2 interact
150+
X = np.array([[1, 1], [1, 2], [2, 1], [2, 2]])
151+
y = np.array([3.0, 4.0, 2.0, 5.0])
152+
153+
model = plqMF_Ridge(
154+
n_users=3,
155+
n_items=3,
156+
loss={"name": "mae"},
157+
rank=2,
158+
C=0.1,
159+
max_iter=1000,
160+
tol=0.01,
161+
)
162+
model.fit(X, y)
163+
164+
# Cold start user (user 0) should have zero factors and bias
165+
assert np.allclose(model.P[0, :], 0.0)
166+
assert model.bu[0] == 0.0
167+
168+
# Cold start item (item 0) should have zero factors and bias
169+
assert np.allclose(model.Q[0, :], 0.0)
170+
assert model.bi[0] == 0.0
171+
172+
173+
def test_mf_biased_false():
174+
"""Test plqMF_Ridge with biased=False (no bias terms)."""
175+
n_users, n_items = 20, 30
176+
X = np.array([[0, 0], [0, 1], [1, 0], [1, 1], [2, 2], [3, 3]])
177+
y = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
178+
179+
model = plqMF_Ridge(
180+
n_users=n_users,
181+
n_items=n_items,
182+
loss={"name": "mae"},
183+
biased=False,
184+
rank=3,
185+
C=0.1,
186+
max_iter=1000,
187+
tol=0.01,
188+
)
189+
model.fit(X, y)
190+
191+
# bu and bi should be None when biased=False
192+
assert model.bu is None
193+
assert model.bi is None
194+
195+
# decision_function should work without biases
196+
scores = model.decision_function(X)
197+
assert scores.shape == (len(X),)
198+
199+
# obj should work without biases
200+
loss_term, obj_val = model.obj(X, y)
201+
assert np.isfinite(loss_term)
202+
assert np.isfinite(obj_val)
203+
204+
205+
def test_mf_verbose_output(capsys):
206+
"""Test verbose printing (lines 308, 464-466)."""
207+
X = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
208+
y = np.array([1.0, 2.0, 3.0, 4.0])
209+
210+
# Test verbose=1 (CD iteration progress)
211+
model = plqMF_Ridge(
212+
n_users=2,
213+
n_items=2,
214+
loss={"name": "mae"},
215+
rank=2,
216+
C=0.1,
217+
max_iter=500,
218+
tol=0.01,
219+
max_iter_CD=2,
220+
verbose=1,
221+
)
222+
model.fit(X, y)
223+
captured = capsys.readouterr()
224+
assert "Iteration" in captured.out
225+
assert "Average Loss" in captured.out
226+
227+
228+
def test_mf_convergence_warning():
229+
"""Test convergence warning when max_iter is too small."""
230+
from sklearn.exceptions import ConvergenceWarning
231+
232+
X = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
233+
y = np.array([1.0, 2.0, 3.0, 4.0])
234+
235+
model = plqMF_Ridge(
236+
n_users=2,
237+
n_items=2,
238+
loss={"name": "mae"},
239+
rank=2,
240+
C=0.1,
241+
max_iter=1, # Only 1 iteration to guarantee non-convergence
242+
tol=1e-10,
243+
max_iter_CD=1,
244+
)
245+
with pytest.warns(ConvergenceWarning, match="ReHLine failed to converge"):
246+
model.fit(X, y)
247+
248+
249+
def test_mf_param_validation_errors():
250+
"""Test parameter validation raises appropriate errors."""
251+
# Test invalid rho (must be between 0 and 1)
252+
with pytest.raises(ValueError, match="rho must be between 0 and 1"):
253+
model = plqMF_Ridge(n_users=10, n_items=10, loss={"name": "mae"}, rho=0.0)
254+
model.fit(np.array([[0, 0]]), np.array([1.0]))
255+
256+
with pytest.raises(ValueError, match="rho must be between 0 and 1"):
257+
model = plqMF_Ridge(n_users=10, n_items=10, loss={"name": "mae"}, rho=1.0)
258+
model.fit(np.array([[0, 0]]), np.array([1.0]))
259+
260+
# Test invalid C (must be positive)
261+
with pytest.raises(ValueError, match="C must be positive"):
262+
model = plqMF_Ridge(n_users=10, n_items=10, loss={"name": "mae"}, C=0.0)
263+
model.fit(np.array([[0, 0]]), np.array([1.0]))
264+
265+
# Test invalid tol_CD (must be positive)
266+
with pytest.raises(ValueError, match="tol_CD must be positive"):
267+
model = plqMF_Ridge(n_users=10, n_items=10, loss={"name": "mae"}, tol_CD=0.0)
268+
model.fit(np.array([[0, 0]]), np.array([1.0]))
269+
270+
# Test invalid tol (must be positive)
271+
with pytest.raises(ValueError, match="tol must be positive"):
272+
model = plqMF_Ridge(n_users=10, n_items=10, loss={"name": "mae"}, tol=0.0)
273+
model.fit(np.array([[0, 0]]), np.array([1.0]))
274+
275+
113276
def test_mf_nonneg_constraint(mf_data):
114277
"""plqMF_Ridge with non-negative constraints should produce non-negative factors."""
115278
d = mf_data

0 commit comments

Comments
 (0)