Skip to content

Commit e1773db

Browse files
author
Alexander Ororbia
committed
updates to art2a, cleanup of probes
1 parent be47ab0 commit e1773db

1 file changed

Lines changed: 127 additions & 0 deletions

File tree

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import jax
2+
import numpy as np
3+
from ngclearn.utils.analysis.probe import Probe
4+
from ngclearn.utils.model_utils import kwta
5+
from jax import jit, random, numpy as jnp, lax, nn
6+
from functools import partial as bind
7+
from ngclearn.utils.distribution_generator import DistributionGenerator
8+
9+
@bind(jax.jit, static_argnums=[2, 3])
10+
def _run_knn_probe(_embeddings, Wx, K, dist_fx=1):
11+
## Notes:
12+
### We do some 3D tensor math to handle a batch of predictions that need to be made
13+
### B = batch-size, D = embedding/input dim, C = number classes, N = number of memories
14+
_Wx = jnp.expand_dims(Wx, axis=0) ## 3D tensor format of KNN params (1 x N x D)
15+
embed_tensor = jnp.expand_dims(_embeddings, axis=1) ## 3D projection of input signals (B x 1 x D)
16+
D = embed_tensor - _Wx ## compute 3D batched delta tensor (B x N x D)
17+
## get batched (negative) distance measurements
18+
dist = jnp.linalg.norm(D, ord=2, axis=2, keepdims=True) ## (B x N x 1)
19+
if dist_fx == 2:
20+
dist = jnp.linalg.norm(D, ord=1, axis=2, keepdims=True) ## (B x N x 1)
21+
## else, default -> euclidean
22+
### Note: negative distance allows us to find minimal points w/ maximal functions
23+
dist = -jnp.squeeze(dist, axis=2) ## (B x N)
24+
## now get K winners per sample in batch
25+
values, indices = lax.top_k(dist, K)
26+
return values, indices
27+
28+
class KNNProbe(Probe):
29+
"""
30+
This implements a K-nearest neighbors (KNN) probe, which is useful for evaluating the quality of
31+
encodings/embeddings in light of some superivsory downstream data (e.g., label one-hot
32+
encodings or real-valued vector regression targets).
33+
34+
Args:
35+
dkey: init seed key
36+
37+
source_seq_length: length of input sequence (e.g., height x width of the image feature)
38+
39+
input_dim: input dimensionality of probe
40+
41+
out_dim: output dimensionality of probe
42+
43+
batch_size: size of batches to process per internal call to update (or process)
44+
45+
K: number of nearest neighbors to estimate output target
46+
47+
dist_function: what distance function should be used in calculating nearest neighbors (Default: euclidean)
48+
49+
"""
50+
def __init__(
51+
self,
52+
dkey,
53+
source_seq_length,
54+
input_dim,
55+
out_dim,
56+
batch_size=1,
57+
K=1, ## number of nearest neighbors (K) to find
58+
dist_function="euclidean",
59+
**kwargs
60+
):
61+
super().__init__(dkey, batch_size, **kwargs)
62+
self.dkey, *subkeys = random.split(self.dkey, 3)
63+
self.source_seq_length = source_seq_length
64+
self.input_dim = input_dim
65+
self.out_dim = out_dim
66+
self.K = K
67+
self.dist_function = dist_function
68+
self.dist_fx = 0
69+
if self.dist_function == "manhattan":
70+
self.dist_fx = 1
71+
72+
#flat_input_dim = input_dim * source_seq_length
73+
#W = jnp.zeros((flat_input_dim, out_dim))
74+
Wx = Wy = jnp.ones((1, 1)) ## Wy will be assumed to be one-hot encoded
75+
self.probe_params = (Wx, Wy)
76+
77+
def process(self, embeddings, dkey=None): ## TODO: JIT-i-fy this
78+
_embeddings = embeddings
79+
if len(_embeddings.shape) > 2:
80+
flat_dim = embeddings.shape[1] * embeddings.shape[2]
81+
_embeddings = jnp.reshape(_embeddings, (embeddings.shape[0], flat_dim))
82+
83+
Wx, Wy = self.probe_params ## pull out KNN parameters
84+
values, indices = _run_knn_probe(_embeddings, Wx, self.K, self.dist_fx)
85+
86+
## do K-neighbor voting scheme (find mode prediction)
87+
Y_counts = jnp.zeros((_embeddings.shape[0], Wy.shape[1]))
88+
for k in range(self.K):
89+
winner_k_indx = indices[:, k] ## batch of k-th winner of K winners
90+
Y_k = Wy[winner_k_indx, :] ## predicted Y's of k-th winner batch
91+
Y_counts = Y_counts + Y_k
92+
Y_pred = nn.one_hot(jnp.argmax(Y_counts, axis=1), num_classes=Wy.shape[1]) #, keepdims=True)
93+
return Y_pred ## (B, C)
94+
95+
def update(self, embeddings, labels, dkey=None):
96+
_embeddings = embeddings
97+
if len(_embeddings.shape) > 2:
98+
flat_dim = embeddings.shape[1] * embeddings.shape[2]
99+
_embeddings = jnp.reshape(_embeddings, (embeddings.shape[0], flat_dim))
100+
101+
## a K-NN's learning phase is just storing the data internally directly
102+
Wx = _embeddings
103+
Wy = labels
104+
self.probe_params = (Wx, Wy)
105+
106+
if __name__ == '__main__':
107+
seed = 42
108+
D = 7
109+
C = 5
110+
dkey = random.PRNGKey(seed)
111+
dkey, *subkeys = random.split(dkey, 3)
112+
knn = KNNProbe(
113+
subkeys[0], 1, input_dim=D, out_dim=C, K=1, dist_function="euclidean"
114+
)
115+
X = random.uniform(subkeys[1], shape=(10, D))
116+
Y = jnp.concat(
117+
[
118+
jnp.ones((2, C)) * jnp.array([[1., 0., 0., 0., 0.]]),
119+
jnp.ones((2, C)) * jnp.array([[0., 1., 0., 0., 0.]]),
120+
jnp.ones((2, C)) * jnp.array([[0., 0., 1., 0., 0.]]),
121+
jnp.ones((2, C)) * jnp.array([[0., 0., 0., 1., 0.]]),
122+
jnp.ones((2, C)) * jnp.array([[0., 0., 0., 0., 1.]])
123+
],
124+
axis=0
125+
)
126+
knn.update(X, Y) ## fit KNN to data
127+
print(knn.process(X)) ## should construct the (smeared) identity matrix, exactly same as Y

0 commit comments

Comments
 (0)