22import warnings
33from collections import defaultdict
44from datetime import datetime
5- from typing import Iterable , Literal , Optional
5+ from functools import partial
6+ from typing import Iterable , Literal , Optional , Union
67
78import igraph as ig
89import numpy as np
2122from sklearn .utils .validation import check_non_negative
2223
2324from turftopic .base import Encoder
25+ from turftopic .optimization import (
26+ decomposition_gaussian_bic ,
27+ optimize_n_components ,
28+ )
2429
2530NOT_MATCHING_ERROR = (
2631 "Document embedding dimensionality ({n_dims}) doesn't match term embedding dimensionality ({n_word_dims}). "
@@ -242,7 +247,7 @@ def batch_extract_keywords(
242247class KeywordNMF :
243248 def __init__ (
244249 self ,
245- n_components : int ,
250+ n_components : Union [ int , Literal [ "auto" ]] ,
246251 seed : Optional [int ] = None ,
247252 top_n : Optional [int ] = None ,
248253 ):
@@ -318,6 +323,15 @@ def vectorize(
318323
319324 def fit_transform (self , keywords : list [dict [str , float ]]) -> np .ndarray :
320325 X = self .vectorize (keywords , fitting = True )
326+ if self .n_components == "auto" :
327+ # Finding N components with BIC
328+ bic_fn = partial (
329+ decomposition_gaussian_bic ,
330+ decomp_class = NMF ,
331+ X = X ,
332+ )
333+ n_components = optimize_n_components (bic_fn , min_n = 1 , verbose = True )
334+ self .n_components = n_components
321335 check_non_negative (X , "NMF (input X)" )
322336 W , H = _initialize_nmf (X , self .n_components , random_state = self .seed )
323337 W , H , self .n_iter = NMF (
@@ -339,6 +353,10 @@ def transform(self, keywords: list[dict[str, float]]):
339353 return W .astype (X .dtype )
340354
341355 def partial_fit (self , keyword_batch : list [dict [str , float ]]):
356+ if self .n_components == "auto" :
357+ raise ValueError (
358+ "Cannot infer number of components with BIC when online fitting the model."
359+ )
342360 X = self .vectorize (keyword_batch , fitting = True )
343361 try :
344362 check_non_negative (X , "NMF (input X)" )
@@ -365,6 +383,15 @@ def fit_transform_dynamic(
365383 n_bins = len (time_bin_edges ) - 1
366384 document_term_matrix = self .vectorize (keywords , fitting = True )
367385 check_non_negative (document_term_matrix , "NMF (input X)" )
386+ if self .n_components == "auto" :
387+ # Finding N components with BIC
388+ bic_fn = partial (
389+ decomposition_gaussian_bic ,
390+ decomp_class = NMF ,
391+ X = X ,
392+ )
393+ n_components = optimize_n_components (bic_fn , verbose = True )
394+ self .n_components = n_components
368395 document_topic_matrix , H = _initialize_nmf (
369396 document_term_matrix ,
370397 self .n_components ,
0 commit comments