Skip to content

Commit e8d0c61

Browse files
authored
Merge pull request #42 from DataboyUsen/main
refactor(_mf_class.py): improve code quality and fix potential issues
2 parents 510d6b9 + 5b4f691 commit e8d0c61

1 file changed

Lines changed: 67 additions & 36 deletions

File tree

rehline/_mf_class.py

Lines changed: 67 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import warnings
44

55
import numpy as np
6-
import pandas as pd
76
from sklearn.base import BaseEstimator
87
from sklearn.exceptions import ConvergenceWarning
98
from sklearn.utils.validation import _check_sample_weight
@@ -199,42 +198,22 @@ def __init__(
199198
tol_CD=1e-4,
200199
verbose=0,
201200
):
202-
# check input
203-
errors = []
204-
checks = [
205-
(0 < rho < 1, "rho must be between 0 and 1"),
206-
(C > 0, "C must be positive"),
207-
(tol_CD > 0, "tol_CD must be positive"),
208-
(tol > 0, "tol must be positive"),
209-
]
210-
for condition, error_msg in checks:
211-
if not condition:
212-
errors.append(error_msg)
213-
if errors:
214-
raise ValueError("; ".join(errors))
215-
216201
# parameter initialization
217-
## -----------------------------basic perameters-----------------------------
202+
## -----------------------------basic parameters-----------------------------
218203
self.n_users = n_users
219204
self.n_items = n_items
220205
self.loss = loss
221206
self.constraint_user = constraint_user if constraint_user is not None else []
222207
self.constraint_item = constraint_item if constraint_item is not None else []
223208
self.biased = biased
224-
## -----------------------------hyper perameters-----------------------------
209+
## -----------------------------hyper parameters-----------------------------
225210
self.rank = rank
226211
self.C = C
227212
self.rho = rho
228-
## --------------------------coefficient perameters--------------------------
213+
## -------------------------initialization parameters------------------------
229214
self.init_mean = init_mean
230215
self.init_sd = init_sd
231216
self.random_state = random_state
232-
if self.random_state:
233-
np.random.seed(random_state)
234-
self.P = np.random.normal(loc=init_mean, scale=init_sd, size=(n_users, rank))
235-
self.Q = np.random.normal(loc=init_mean, scale=init_sd, size=(n_items, rank))
236-
self.bu = np.zeros(n_users) if self.biased else None
237-
self.bi = np.zeros(n_items) if self.biased else None
238217
## ----------------------------fitting parameters----------------------------
239218
self.max_iter_CD = max_iter_CD
240219
self.tol_CD = tol_CD
@@ -266,17 +245,62 @@ def fit(self, X, y, sample_weight=None):
266245
An instance of the estimator.
267246
268247
"""
248+
# check input
249+
## parameter validation
250+
errors = []
251+
checks = [
252+
(0 < self.rho < 1, "rho must be between 0 and 1"),
253+
(self.C > 0, "C must be positive"),
254+
(self.tol_CD > 0, "tol_CD must be positive"),
255+
(self.tol > 0, "tol must be positive"),
256+
]
257+
for condition, error_msg in checks:
258+
if not condition:
259+
errors.append(error_msg)
260+
if errors:
261+
raise ValueError("; ".join(errors))
262+
263+
## data validation
264+
X = np.asarray(X)
265+
y = np.asarray(y)
266+
if X.ndim != 2 or X.shape[1] != 2:
267+
raise ValueError("X must have shape (n_ratings, 2)")
268+
if X.shape[0] != len(y):
269+
raise ValueError("X and y must have the same number of samples")
270+
user_ids = X[:, 0].astype(int)
271+
item_ids = X[:, 1].astype(int)
272+
if np.any(user_ids < 0) or np.any(user_ids >= self.n_users):
273+
raise ValueError("User IDs must be in [0, n_users)")
274+
if np.any(item_ids < 0) or np.any(item_ids >= self.n_items):
275+
raise ValueError("Item IDs must be in [0, n_items)")
276+
269277
# Preparation
270-
self.n_ratings = len(y)
271-
self.history = np.nan * np.zeros((self.max_iter_CD + 1, 2))
278+
## number of training observations
279+
self.n_ratings = len(y)
280+
## convergence trace
281+
self.history = np.full((self.max_iter_CD + 1, 2), np.nan)
282+
## sample weights
272283
self.sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)
273-
274-
X_df = pd.DataFrame(X, columns=["user", "item"])
275-
uidx_map = X_df.groupby("user").indices
276-
iidx_map = X_df.groupby("item").indices
277-
self.Iu = [uidx_map.get(u, np.array([], dtype=int)) for u in range(self.n_users)]
278-
self.Ui = [iidx_map.get(i, np.array([], dtype=int)) for i in range(self.n_items)]
279-
284+
## random number generator
285+
rng = np.random.default_rng(self.random_state)
286+
287+
## indices to locate interactions given a user or item id
288+
### user side: Iu[u] = row indices of interactions by user u
289+
sort_idx_users = np.argsort(X[:, 0], kind='stable')
290+
sorted_users = X[sort_idx_users, 0]
291+
counts = np.unique(sorted_users, return_counts=True)[1]
292+
self.Iu = [np.array([], dtype=int) for _ in range(self.n_users)]
293+
for u, idxs in zip(sorted_users[np.cumsum(counts) - counts], np.split(sort_idx_users, np.cumsum(counts)[:-1])):
294+
self.Iu[u] = idxs
295+
### item side: Ui[i] = row indices of interactions that involve item i
296+
sort_idx_items = np.argsort(X[:, 1], kind='stable')
297+
sorted_items = X[sort_idx_items, 1]
298+
counts = np.unique(sorted_items, return_counts=True)[1]
299+
self.Ui = [np.array([], dtype=int) for _ in range(self.n_items)]
300+
for i, idxs in zip(sorted_items[np.cumsum(counts) - counts], np.split(sort_idx_items, np.cumsum(counts)[:-1])):
301+
self.Ui[i] = idxs
302+
303+
## effective C when updating user/item blocks (to match rehline formulation: C * PLQ_loss + 0.5 * l_2)
280304
C_user = self.C * self.n_users / (self.rho) / 2
281305
C_item = self.C * self.n_items / (1 - self.rho) / 2
282306

@@ -289,6 +313,12 @@ def fit(self, X, y, sample_weight=None):
289313
)
290314
)
291315

316+
# Model Initialization
317+
self.P = rng.normal(loc=self.init_mean, scale=self.init_sd, size=(self.n_users, self.rank))
318+
self.Q = rng.normal(loc=self.init_mean, scale=self.init_sd, size=(self.n_items, self.rank))
319+
self.bu = np.zeros(self.n_users) if self.biased else None
320+
self.bi = np.zeros(self.n_items) if self.biased else None
321+
292322
# CD algorithm
293323
self.history[0] = self.obj(X, y)
294324
for iter_idx in range(self.max_iter_CD):
@@ -435,7 +465,7 @@ def fit(self, X, y, sample_weight=None):
435465
obj = f"{self.history[iter_idx + 1][1]:.6f}"
436466
print(f"{iter_idx + 1:<12} {mean_loss:<20} {obj:<20}")
437467

438-
if obj_diff < self.tol_CD:
468+
if abs(obj_diff) < self.tol_CD:
439469
break
440470

441471
return self
@@ -496,9 +526,10 @@ def obj(self, X, y):
496526
item_penalty = np.sum(self.Q**2) * (1 - self.rho) / self.n_items
497527
penalty = user_penalty + item_penalty
498528

499-
y_pred = self.decision_function(X)
500-
U, V, Tau, S, T = _make_loss_rehline_param(loss=self.loss, X=X, y=y)
529+
X_dummy = np.ones((len(y), 1)) # not used in loss computation, only shape matters for loss param construction
530+
U, V, Tau, S, T = _make_loss_rehline_param(loss=self.loss, X=X_dummy, y=y)
501531
loss = ReHLoss(U, V, S, T, Tau)
532+
y_pred = self.decision_function(X)
502533
loss_term = loss(y_pred)
503534

504535
return loss_term, self.C * loss_term + penalty

0 commit comments

Comments
 (0)