Skip to content

Commit 5097184

Browse files
author
Alexander Ororbia
committed
added in knn-probe for utils.analysis
1 parent e1773db commit 5097184

1 file changed

Lines changed: 53 additions & 17 deletions

File tree

ngclearn/utils/analysis/knn_probe.py

Lines changed: 53 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
11
import jax
22
import numpy as np
3+
from ngcsimlib import deprecate_args
34
from ngclearn.utils.analysis.probe import Probe
45
from ngclearn.utils.model_utils import kwta
56
from jax import jit, random, numpy as jnp, lax, nn
67
from functools import partial as bind
78
from ngclearn.utils.distribution_generator import DistributionGenerator
89

910
@bind(jax.jit, static_argnums=[2, 3])
10-
def _run_knn_probe(_embeddings, Wx, K, dist_fx=1):
11+
def _run_knn_probe(_embeddings, Wx, K, dist_order=2):
1112
## Notes:
1213
### We do some 3D tensor math to handle a batch of predictions that need to be made
1314
### B = batch-size, D = embedding/input dim, C = number classes, N = number of memories
1415
_Wx = jnp.expand_dims(Wx, axis=0) ## 3D tensor format of KNN params (1 x N x D)
1516
embed_tensor = jnp.expand_dims(_embeddings, axis=1) ## 3D projection of input signals (B x 1 x D)
1617
D = embed_tensor - _Wx ## compute 3D batched delta tensor (B x N x D)
1718
## get batched (negative) distance measurements
18-
dist = jnp.linalg.norm(D, ord=2, axis=2, keepdims=True) ## (B x N x 1)
19-
if dist_fx == 2:
20-
dist = jnp.linalg.norm(D, ord=1, axis=2, keepdims=True) ## (B x N x 1)
19+
dist = jnp.linalg.norm(D, ord=dist_order, axis=2, keepdims=True) ## (B x N x 1)
2120
## else, default -> euclidean
2221
### Note: negative distance allows us to find minimal points w/ maximal functions
2322
dist = -jnp.squeeze(dist, axis=2) ## (B x N)
@@ -40,34 +39,63 @@ class KNNProbe(Probe):
4039
4140
out_dim: output dimensionality of probe
4241
42+
num_neighbors: number of nearest neighbors to perform estimate of output target with
43+
4344
batch_size: size of batches to process per internal call to update (or process)
4445
4546
K: number of nearest neighbors to estimate output target
4647
47-
dist_function: what distance function should be used in calculating nearest neighbors (Default: euclidean)
48+
distance_function: tuple specifying distance function and its order for calculating nearest neighbors
49+
(Default: ("minkowski", 2)).
50+
usage guide:
51+
("minkowski", 2) or ("euclidean", ?) => use L2 norm (Euclidean) distance;
52+
("minkowski", 1) or ("manhattan", ?) => use L1 norm (taxi-cab/city-block) distance;
53+
("minkowksi", jnp.inf) or ("chebyshev", ?) => use Chebyshev distance;
54+
("minkowski", p > 2) => use a Minkowski distance of p-th order
55+
56+
predictor_type: Str what type of problem is this K-NN solving?
57+
58+
vote_style:
4859
4960
"""
61+
62+
@deprecate_args(K="num_neighbors")
5063
def __init__(
5164
self,
5265
dkey,
5366
source_seq_length,
5467
input_dim,
5568
out_dim,
5669
batch_size=1,
57-
K=1, ## number of nearest neighbors (K) to find
58-
dist_function="euclidean",
70+
num_neighbors=1, ## number of nearest neighbors (K) to find
71+
distance_function=("minkowski", 2),
72+
predictor_type="classifier", ## "classifier"; "regressor"
73+
vote_style="mode", ## "mode", "mean"
5974
**kwargs
6075
):
6176
super().__init__(dkey, batch_size, **kwargs)
6277
self.dkey, *subkeys = random.split(self.dkey, 3)
6378
self.source_seq_length = source_seq_length
6479
self.input_dim = input_dim
6580
self.out_dim = out_dim
66-
self.K = K
67-
self.dist_function = dist_function
68-
self.dist_fx = 0
69-
if self.dist_function == "manhattan":
70-
self.dist_fx = 1
81+
self.K = num_neighbors
82+
self.vote_fx = 0 ## 0 -> mode prediction; 1 -> mean prediction
83+
if vote_style == "mean":
84+
self.vote_fx = 1
85+
self.distance_function = distance_function
86+
dist_fun, dist_order = distance_function ## Default: ("minkowski", 2) -> Euclidean
87+
if "euclidean" in dist_fun.lower():
88+
dist_order = 2
89+
elif "manhattan" in dist_fun.lower():
90+
dist_order = 1
91+
elif "chebyshev" in dist_fun.lower():
92+
dist_order = jnp.inf
93+
## TODO: add in cosine-distance (and maybe Mahalanobis distance)
94+
self.dist_order = dist_order ## set distance order p
95+
self.predictor_type = predictor_type
96+
self.pred_fx = 0
97+
if "regressor" == predictor_type:
98+
self.pred_fx = 1
7199

72100
#flat_input_dim = input_dim * source_seq_length
73101
#W = jnp.zeros((flat_input_dim, out_dim))
@@ -81,15 +109,23 @@ def process(self, embeddings, dkey=None): ## TODO: JIT-i-fy this
81109
_embeddings = jnp.reshape(_embeddings, (embeddings.shape[0], flat_dim))
82110

83111
Wx, Wy = self.probe_params ## pull out KNN parameters
84-
values, indices = _run_knn_probe(_embeddings, Wx, self.K, self.dist_fx)
112+
values, indices = _run_knn_probe(_embeddings, Wx, self.K, self.dist_order)
85113

86-
## do K-neighbor voting scheme (find mode prediction)
114+
## do K-neighbor voting scheme (find mode/frequency prediction)
87115
Y_counts = jnp.zeros((_embeddings.shape[0], Wy.shape[1]))
88116
for k in range(self.K):
89-
winner_k_indx = indices[:, k] ## batch of k-th winner of K winners
117+
winner_k_indx = indices[:, k] ## batch of k-th set of K winners
90118
Y_k = Wy[winner_k_indx, :] ## predicted Y's of k-th winner batch
91119
Y_counts = Y_counts + Y_k
92-
Y_pred = nn.one_hot(jnp.argmax(Y_counts, axis=1), num_classes=Wy.shape[1]) #, keepdims=True)
120+
## do post-processing to conform to problem-type being solved by this K-NN
121+
if self.pred_fx == 1: ## (regressor, contus outputs)
122+
Y_pred = Y_counts * (1. / self.K)
123+
else: ## pred_fx == 0 (classifier, discrete outputs)
124+
Y_pred = Y_counts
125+
if self.vote_fx == 1: ## calc mean prediction
126+
Y_pred = Y_counts * (1. / self.K)
127+
## vote_fx == 0 (mode prediction)
128+
Y_pred = nn.one_hot(jnp.argmax(Y_pred, axis=1), num_classes=Wy.shape[1]) # , keepdims=True)
93129
return Y_pred ## (B, C)
94130

95131
def update(self, embeddings, labels, dkey=None):
@@ -124,4 +160,4 @@ def update(self, embeddings, labels, dkey=None):
124160
axis=0
125161
)
126162
knn.update(X, Y) ## fit KNN to data
127-
print(knn.process(X)) ## should construct the (smeared) identity matrix, exactly same as Y
163+
print(knn.process(X)) ## should construct the (smeared) identity matrix, exactly same as Y

0 commit comments

Comments
 (0)