33import warnings
44
55import numpy as np
6- import pandas as pd
76from sklearn .base import BaseEstimator
87from sklearn .exceptions import ConvergenceWarning
98from sklearn .utils .validation import _check_sample_weight
@@ -199,42 +198,22 @@ def __init__(
199198 tol_CD = 1e-4 ,
200199 verbose = 0 ,
201200 ):
202- # check input
203- errors = []
204- checks = [
205- (0 < rho < 1 , "rho must be between 0 and 1" ),
206- (C > 0 , "C must be positive" ),
207- (tol_CD > 0 , "tol_CD must be positive" ),
208- (tol > 0 , "tol must be positive" ),
209- ]
210- for condition , error_msg in checks :
211- if not condition :
212- errors .append (error_msg )
213- if errors :
214- raise ValueError ("; " .join (errors ))
215-
216201 # parameter initialization
217- ## -----------------------------basic perameters -----------------------------
202+ ## -----------------------------basic parameters -----------------------------
218203 self .n_users = n_users
219204 self .n_items = n_items
220205 self .loss = loss
221206 self .constraint_user = constraint_user if constraint_user is not None else []
222207 self .constraint_item = constraint_item if constraint_item is not None else []
223208 self .biased = biased
224- ## -----------------------------hyper perameters -----------------------------
209+ ## -----------------------------hyper parameters -----------------------------
225210 self .rank = rank
226211 self .C = C
227212 self .rho = rho
228- ## --------------------------coefficient perameters-- ------------------------
213+ ## -------------------------initialization parameters ------------------------
229214 self .init_mean = init_mean
230215 self .init_sd = init_sd
231216 self .random_state = random_state
232- if self .random_state :
233- np .random .seed (random_state )
234- self .P = np .random .normal (loc = init_mean , scale = init_sd , size = (n_users , rank ))
235- self .Q = np .random .normal (loc = init_mean , scale = init_sd , size = (n_items , rank ))
236- self .bu = np .zeros (n_users ) if self .biased else None
237- self .bi = np .zeros (n_items ) if self .biased else None
238217 ## ----------------------------fitting parameters----------------------------
239218 self .max_iter_CD = max_iter_CD
240219 self .tol_CD = tol_CD
@@ -266,17 +245,62 @@ def fit(self, X, y, sample_weight=None):
266245 An instance of the estimator.
267246
268247 """
248+ # check input
249+ ## parameter validation
250+ errors = []
251+ checks = [
252+ (0 < self .rho < 1 , "rho must be between 0 and 1" ),
253+ (self .C > 0 , "C must be positive" ),
254+ (self .tol_CD > 0 , "tol_CD must be positive" ),
255+ (self .tol > 0 , "tol must be positive" ),
256+ ]
257+ for condition , error_msg in checks :
258+ if not condition :
259+ errors .append (error_msg )
260+ if errors :
261+ raise ValueError ("; " .join (errors ))
262+
263+ ## data validation
264+ X = np .asarray (X )
265+ y = np .asarray (y )
266+ if X .ndim != 2 or X .shape [1 ] != 2 :
267+ raise ValueError ("X must have shape (n_ratings, 2)" )
268+ if X .shape [0 ] != len (y ):
269+ raise ValueError ("X and y must have the same number of samples" )
270+ user_ids = X [:, 0 ].astype (int )
271+ item_ids = X [:, 1 ].astype (int )
272+ if np .any (user_ids < 0 ) or np .any (user_ids >= self .n_users ):
273+ raise ValueError ("User IDs must be in [0, n_users)" )
274+ if np .any (item_ids < 0 ) or np .any (item_ids >= self .n_items ):
275+ raise ValueError ("Item IDs must be in [0, n_items)" )
276+
269277 # Preparation
270- self .n_ratings = len (y )
271- self .history = np .nan * np .zeros ((self .max_iter_CD + 1 , 2 ))
278+ ## number of training observations
279+ self .n_ratings = len (y )
280+ ## convergence trace
281+ self .history = np .full ((self .max_iter_CD + 1 , 2 ), np .nan )
282+ ## sample weights
272283 self .sample_weight = _check_sample_weight (sample_weight , X , dtype = X .dtype )
273-
274- X_df = pd .DataFrame (X , columns = ["user" , "item" ])
275- uidx_map = X_df .groupby ("user" ).indices
276- iidx_map = X_df .groupby ("item" ).indices
277- self .Iu = [uidx_map .get (u , np .array ([], dtype = int )) for u in range (self .n_users )]
278- self .Ui = [iidx_map .get (i , np .array ([], dtype = int )) for i in range (self .n_items )]
279-
284+ ## random number generator
285+ rng = np .random .default_rng (self .random_state )
286+
287+ ## indices to locate interactions given a user or item id
288+ ### user side: Iu[u] = row indices of interactions by user u
289+ sort_idx_users = np .argsort (X [:, 0 ], kind = 'stable' )
290+ sorted_users = X [sort_idx_users , 0 ]
291+ counts = np .unique (sorted_users , return_counts = True )[1 ]
292+ self .Iu = [np .array ([], dtype = int ) for _ in range (self .n_users )]
293+ for u , idxs in zip (sorted_users [np .cumsum (counts ) - counts ], np .split (sort_idx_users , np .cumsum (counts )[:- 1 ])):
294+ self .Iu [u ] = idxs
295+ ### item side: Ui[i] = row indices of interactions that involve item i
296+ sort_idx_items = np .argsort (X [:, 1 ], kind = 'stable' )
297+ sorted_items = X [sort_idx_items , 1 ]
298+ counts = np .unique (sorted_items , return_counts = True )[1 ]
299+ self .Ui = [np .array ([], dtype = int ) for _ in range (self .n_items )]
300+ for i , idxs in zip (sorted_items [np .cumsum (counts ) - counts ], np .split (sort_idx_items , np .cumsum (counts )[:- 1 ])):
301+ self .Ui [i ] = idxs
302+
303+ ## effective C when updating user/item blocks (to match rehline formulation: C * PLQ_loss + 0.5 * l_2)
280304 C_user = self .C * self .n_users / (self .rho ) / 2
281305 C_item = self .C * self .n_items / (1 - self .rho ) / 2
282306
@@ -289,6 +313,12 @@ def fit(self, X, y, sample_weight=None):
289313 )
290314 )
291315
316+ # Model Initialization
317+ self .P = rng .normal (loc = self .init_mean , scale = self .init_sd , size = (self .n_users , self .rank ))
318+ self .Q = rng .normal (loc = self .init_mean , scale = self .init_sd , size = (self .n_items , self .rank ))
319+ self .bu = np .zeros (self .n_users ) if self .biased else None
320+ self .bi = np .zeros (self .n_items ) if self .biased else None
321+
292322 # CD algorithm
293323 self .history [0 ] = self .obj (X , y )
294324 for iter_idx in range (self .max_iter_CD ):
@@ -435,7 +465,7 @@ def fit(self, X, y, sample_weight=None):
435465 obj = f"{ self .history [iter_idx + 1 ][1 ]:.6f} "
436466 print (f"{ iter_idx + 1 :<12} { mean_loss :<20} { obj :<20} " )
437467
438- if obj_diff < self .tol_CD :
468+ if abs ( obj_diff ) < self .tol_CD :
439469 break
440470
441471 return self
@@ -496,9 +526,10 @@ def obj(self, X, y):
496526 item_penalty = np .sum (self .Q ** 2 ) * (1 - self .rho ) / self .n_items
497527 penalty = user_penalty + item_penalty
498528
499- y_pred = self . decision_function ( X )
500- U , V , Tau , S , T = _make_loss_rehline_param (loss = self .loss , X = X , y = y )
529+ X_dummy = np . ones (( len ( y ), 1 )) # not used in loss computation, only shape matters for loss param construction
530+ U , V , Tau , S , T = _make_loss_rehline_param (loss = self .loss , X = X_dummy , y = y )
501531 loss = ReHLoss (U , V , S , T , Tau )
532+ y_pred = self .decision_function (X )
502533 loss_term = loss (y_pred )
503534
504535 return loss_term , self .C * loss_term + penalty
0 commit comments