1+ import jax
2+ import numpy as np
3+ from ngclearn .utils .analysis .probe import Probe
4+ from ngclearn .utils .model_utils import kwta
5+ from jax import jit , random , numpy as jnp , lax , nn
6+ from functools import partial as bind
7+ from ngclearn .utils .distribution_generator import DistributionGenerator
8+
9+ @bind (jax .jit , static_argnums = [2 , 3 ])
10+ def _run_knn_probe (_embeddings , Wx , K , dist_fx = 1 ):
11+ ## Notes:
12+ ### We do some 3D tensor math to handle a batch of predictions that need to be made
13+ ### B = batch-size, D = embedding/input dim, C = number classes, N = number of memories
14+ _Wx = jnp .expand_dims (Wx , axis = 0 ) ## 3D tensor format of KNN params (1 x N x D)
15+ embed_tensor = jnp .expand_dims (_embeddings , axis = 1 ) ## 3D projection of input signals (B x 1 x D)
16+ D = embed_tensor - _Wx ## compute 3D batched delta tensor (B x N x D)
17+ ## get batched (negative) distance measurements
18+ dist = jnp .linalg .norm (D , ord = 2 , axis = 2 , keepdims = True ) ## (B x N x 1)
19+ if dist_fx == 2 :
20+ dist = jnp .linalg .norm (D , ord = 1 , axis = 2 , keepdims = True ) ## (B x N x 1)
21+ ## else, default -> euclidean
22+ ### Note: negative distance allows us to find minimal points w/ maximal functions
23+ dist = - jnp .squeeze (dist , axis = 2 ) ## (B x N)
24+ ## now get K winners per sample in batch
25+ values , indices = lax .top_k (dist , K )
26+ return values , indices
27+
28+ class KNNProbe (Probe ):
29+ """
30+ This implements a K-nearest neighbors (KNN) probe, which is useful for evaluating the quality of
31+ encodings/embeddings in light of some superivsory downstream data (e.g., label one-hot
32+ encodings or real-valued vector regression targets).
33+
34+ Args:
35+ dkey: init seed key
36+
37+ source_seq_length: length of input sequence (e.g., height x width of the image feature)
38+
39+ input_dim: input dimensionality of probe
40+
41+ out_dim: output dimensionality of probe
42+
43+ batch_size: size of batches to process per internal call to update (or process)
44+
45+ K: number of nearest neighbors to estimate output target
46+
47+ dist_function: what distance function should be used in calculating nearest neighbors (Default: euclidean)
48+
49+ """
50+ def __init__ (
51+ self ,
52+ dkey ,
53+ source_seq_length ,
54+ input_dim ,
55+ out_dim ,
56+ batch_size = 1 ,
57+ K = 1 , ## number of nearest neighbors (K) to find
58+ dist_function = "euclidean" ,
59+ ** kwargs
60+ ):
61+ super ().__init__ (dkey , batch_size , ** kwargs )
62+ self .dkey , * subkeys = random .split (self .dkey , 3 )
63+ self .source_seq_length = source_seq_length
64+ self .input_dim = input_dim
65+ self .out_dim = out_dim
66+ self .K = K
67+ self .dist_function = dist_function
68+ self .dist_fx = 0
69+ if self .dist_function == "manhattan" :
70+ self .dist_fx = 1
71+
72+ #flat_input_dim = input_dim * source_seq_length
73+ #W = jnp.zeros((flat_input_dim, out_dim))
74+ Wx = Wy = jnp .ones ((1 , 1 )) ## Wy will be assumed to be one-hot encoded
75+ self .probe_params = (Wx , Wy )
76+
77+ def process (self , embeddings , dkey = None ): ## TODO: JIT-i-fy this
78+ _embeddings = embeddings
79+ if len (_embeddings .shape ) > 2 :
80+ flat_dim = embeddings .shape [1 ] * embeddings .shape [2 ]
81+ _embeddings = jnp .reshape (_embeddings , (embeddings .shape [0 ], flat_dim ))
82+
83+ Wx , Wy = self .probe_params ## pull out KNN parameters
84+ values , indices = _run_knn_probe (_embeddings , Wx , self .K , self .dist_fx )
85+
86+ ## do K-neighbor voting scheme (find mode prediction)
87+ Y_counts = jnp .zeros ((_embeddings .shape [0 ], Wy .shape [1 ]))
88+ for k in range (self .K ):
89+ winner_k_indx = indices [:, k ] ## batch of k-th winner of K winners
90+ Y_k = Wy [winner_k_indx , :] ## predicted Y's of k-th winner batch
91+ Y_counts = Y_counts + Y_k
92+ Y_pred = nn .one_hot (jnp .argmax (Y_counts , axis = 1 ), num_classes = Wy .shape [1 ]) #, keepdims=True)
93+ return Y_pred ## (B, C)
94+
95+ def update (self , embeddings , labels , dkey = None ):
96+ _embeddings = embeddings
97+ if len (_embeddings .shape ) > 2 :
98+ flat_dim = embeddings .shape [1 ] * embeddings .shape [2 ]
99+ _embeddings = jnp .reshape (_embeddings , (embeddings .shape [0 ], flat_dim ))
100+
101+ ## a K-NN's learning phase is just storing the data internally directly
102+ Wx = _embeddings
103+ Wy = labels
104+ self .probe_params = (Wx , Wy )
105+
106+ if __name__ == '__main__' :
107+ seed = 42
108+ D = 7
109+ C = 5
110+ dkey = random .PRNGKey (seed )
111+ dkey , * subkeys = random .split (dkey , 3 )
112+ knn = KNNProbe (
113+ subkeys [0 ], 1 , input_dim = D , out_dim = C , K = 1 , dist_function = "euclidean"
114+ )
115+ X = random .uniform (subkeys [1 ], shape = (10 , D ))
116+ Y = jnp .concat (
117+ [
118+ jnp .ones ((2 , C )) * jnp .array ([[1. , 0. , 0. , 0. , 0. ]]),
119+ jnp .ones ((2 , C )) * jnp .array ([[0. , 1. , 0. , 0. , 0. ]]),
120+ jnp .ones ((2 , C )) * jnp .array ([[0. , 0. , 1. , 0. , 0. ]]),
121+ jnp .ones ((2 , C )) * jnp .array ([[0. , 0. , 0. , 1. , 0. ]]),
122+ jnp .ones ((2 , C )) * jnp .array ([[0. , 0. , 0. , 0. , 1. ]])
123+ ],
124+ axis = 0
125+ )
126+ knn .update (X , Y ) ## fit KNN to data
127+ print (knn .process (X )) ## should construct the (smeared) identity matrix, exactly same as Y
0 commit comments