11import jax
22import numpy as np
3+ from ngcsimlib import deprecate_args
34from ngclearn .utils .analysis .probe import Probe
45from ngclearn .utils .model_utils import kwta
56from jax import jit , random , numpy as jnp , lax , nn
67from functools import partial as bind
78from ngclearn .utils .distribution_generator import DistributionGenerator
89
910@bind (jax .jit , static_argnums = [2 , 3 ])
10- def _run_knn_probe (_embeddings , Wx , K , dist_fx = 1 ):
11+ def _run_knn_probe (_embeddings , Wx , K , dist_order = 2 ):
1112 ## Notes:
1213 ### We do some 3D tensor math to handle a batch of predictions that need to be made
1314 ### B = batch-size, D = embedding/input dim, C = number classes, N = number of memories
1415 _Wx = jnp .expand_dims (Wx , axis = 0 ) ## 3D tensor format of KNN params (1 x N x D)
1516 embed_tensor = jnp .expand_dims (_embeddings , axis = 1 ) ## 3D projection of input signals (B x 1 x D)
1617 D = embed_tensor - _Wx ## compute 3D batched delta tensor (B x N x D)
1718 ## get batched (negative) distance measurements
18- dist = jnp .linalg .norm (D , ord = 2 , axis = 2 , keepdims = True ) ## (B x N x 1)
19- if dist_fx == 2 :
20- dist = jnp .linalg .norm (D , ord = 1 , axis = 2 , keepdims = True ) ## (B x N x 1)
19+ dist = jnp .linalg .norm (D , ord = dist_order , axis = 2 , keepdims = True ) ## (B x N x 1)
2120 ## else, default -> euclidean
2221 ### Note: negative distance allows us to find minimal points w/ maximal functions
2322 dist = - jnp .squeeze (dist , axis = 2 ) ## (B x N)
@@ -40,34 +39,63 @@ class KNNProbe(Probe):
4039
4140 out_dim: output dimensionality of probe
4241
42+ num_neighbors: number of nearest neighbors to perform estimate of output target with
43+
4344 batch_size: size of batches to process per internal call to update (or process)
4445
4546 K: number of nearest neighbors to estimate output target
4647
47- dist_function: what distance function should be used in calculating nearest neighbors (Default: euclidean)
48+ distance_function: tuple specifying distance function and its order for calculating nearest neighbors
49+ (Default: ("minkowski", 2)).
50+ usage guide:
51+ ("minkowski", 2) or ("euclidean", ?) => use L2 norm (Euclidean) distance;
52+ ("minkowski", 1) or ("manhattan", ?) => use L1 norm (taxi-cab/city-block) distance;
53+ ("minkowksi", jnp.inf) or ("chebyshev", ?) => use Chebyshev distance;
54+ ("minkowski", p > 2) => use a Minkowski distance of p-th order
55+
56+ predictor_type: Str what type of problem is this K-NN solving?
57+
58+ vote_style:
4859
4960 """
61+
62+ @deprecate_args (K = "num_neighbors" )
5063 def __init__ (
5164 self ,
5265 dkey ,
5366 source_seq_length ,
5467 input_dim ,
5568 out_dim ,
5669 batch_size = 1 ,
57- K = 1 , ## number of nearest neighbors (K) to find
58- dist_function = "euclidean" ,
70+ num_neighbors = 1 , ## number of nearest neighbors (K) to find
71+ distance_function = ("minkowski" , 2 ),
72+ predictor_type = "classifier" , ## "classifier"; "regressor"
73+ vote_style = "mode" , ## "mode", "mean"
5974 ** kwargs
6075 ):
6176 super ().__init__ (dkey , batch_size , ** kwargs )
6277 self .dkey , * subkeys = random .split (self .dkey , 3 )
6378 self .source_seq_length = source_seq_length
6479 self .input_dim = input_dim
6580 self .out_dim = out_dim
66- self .K = K
67- self .dist_function = dist_function
68- self .dist_fx = 0
69- if self .dist_function == "manhattan" :
70- self .dist_fx = 1
81+ self .K = num_neighbors
82+ self .vote_fx = 0 ## 0 -> mode prediction; 1 -> mean prediction
83+ if vote_style == "mean" :
84+ self .vote_fx = 1
85+ self .distance_function = distance_function
86+ dist_fun , dist_order = distance_function ## Default: ("minkowski", 2) -> Euclidean
87+ if "euclidean" in dist_fun .lower ():
88+ dist_order = 2
89+ elif "manhattan" in dist_fun .lower ():
90+ dist_order = 1
91+ elif "chebyshev" in dist_fun .lower ():
92+ dist_order = jnp .inf
93+ ## TODO: add in cosine-distance (and maybe Mahalanobis distance)
94+ self .dist_order = dist_order ## set distance order p
95+ self .predictor_type = predictor_type
96+ self .pred_fx = 0
97+ if "regressor" == predictor_type :
98+ self .pred_fx = 1
7199
72100 #flat_input_dim = input_dim * source_seq_length
73101 #W = jnp.zeros((flat_input_dim, out_dim))
@@ -81,15 +109,23 @@ def process(self, embeddings, dkey=None): ## TODO: JIT-i-fy this
81109 _embeddings = jnp .reshape (_embeddings , (embeddings .shape [0 ], flat_dim ))
82110
83111 Wx , Wy = self .probe_params ## pull out KNN parameters
84- values , indices = _run_knn_probe (_embeddings , Wx , self .K , self .dist_fx )
112+ values , indices = _run_knn_probe (_embeddings , Wx , self .K , self .dist_order )
85113
86- ## do K-neighbor voting scheme (find mode prediction)
114+ ## do K-neighbor voting scheme (find mode/frequency prediction)
87115 Y_counts = jnp .zeros ((_embeddings .shape [0 ], Wy .shape [1 ]))
88116 for k in range (self .K ):
89- winner_k_indx = indices [:, k ] ## batch of k-th winner of K winners
117+ winner_k_indx = indices [:, k ] ## batch of k-th set of K winners
90118 Y_k = Wy [winner_k_indx , :] ## predicted Y's of k-th winner batch
91119 Y_counts = Y_counts + Y_k
92- Y_pred = nn .one_hot (jnp .argmax (Y_counts , axis = 1 ), num_classes = Wy .shape [1 ]) #, keepdims=True)
120+ ## do post-processing to conform to problem-type being solved by this K-NN
121+ if self .pred_fx == 1 : ## (regressor, contus outputs)
122+ Y_pred = Y_counts * (1. / self .K )
123+ else : ## pred_fx == 0 (classifier, discrete outputs)
124+ Y_pred = Y_counts
125+ if self .vote_fx == 1 : ## calc mean prediction
126+ Y_pred = Y_counts * (1. / self .K )
127+ ## vote_fx == 0 (mode prediction)
128+ Y_pred = nn .one_hot (jnp .argmax (Y_pred , axis = 1 ), num_classes = Wy .shape [1 ]) # , keepdims=True)
93129 return Y_pred ## (B, C)
94130
95131 def update (self , embeddings , labels , dkey = None ):
@@ -124,4 +160,4 @@ def update(self, embeddings, labels, dkey=None):
124160 axis = 0
125161 )
126162 knn .update (X , Y ) ## fit KNN to data
127- print (knn .process (X )) ## should construct the (smeared) identity matrix, exactly same as Y
163+ print (knn .process (X )) ## should construct the (smeared) identity matrix, exactly same as Y
0 commit comments