Skip to content

Commit 1362b11

Browse files
author
Alexander Ororbia
committed
integrated prototype for vector-quantize memory model/synapse
1 parent d6b1ecf commit 1362b11

1 file changed

Lines changed: 213 additions & 0 deletions

File tree

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
from jax import random, numpy as jnp, jit
2+
from ngclearn import compilable #from ngcsimlib.parser import compilable
3+
from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
4+
from ngclearn.utils.model_utils import softmax, bkwta #, chebyshev_norm
5+
6+
from ngclearn.components.synapses.denseSynapse import DenseSynapse
7+
8+
def _gaussian_kernel(dist, sigma): ## Gaussian weighting function
9+
density = jnp.exp(-jnp.power(dist, 2) / (2 * (sigma ** 2))) # n_units x 1
10+
return density
11+
12+
class VectorQuantizeSynapse(DenseSynapse): # Vector quantization (VQ) synaptic cable
13+
"""
14+
A synaptic cable that emulates a vector quantization memory model (the base case of this
15+
model is referred to as "learning vector quantization"; LVQ).
16+
17+
| --- Synapse Compartments: ---
18+
| inputs - input (takes in external signals)
19+
| outputs - output signals (transformation induced by synapses)
20+
| weights - current value matrix of synaptic efficacies
21+
| bmu - current best-matching unit (BMU) mask, based on current inputs
22+
| i_tick - current internal tick / marker (gets incremented by 1 for each call to `evolve`)
23+
| eta - current learning rate value
24+
| key - JAX PRNG key
25+
| --- Synaptic Plasticity Compartments: ---
26+
| inputs - pre-synaptic signal/value to drive 1st term of VQ update (x)
27+
| outputs - post-synaptic signal/value to drive 2nd term of VQ update (y)
28+
| dWeights - current delta matrix containing changes to be applied to synapses
29+
30+
| References:
31+
| Kohonen, Teuvo. "The self-organizing map." Proceedings of the IEEE 78.9 (2002): 1464-1480.
32+
33+
Args:
34+
name: the string name of this cell
35+
36+
shape: tuple specifying shape of this synaptic cable (usually a 2-tuple with number of
37+
inputs by number of outputs)
38+
39+
eta: (initial) learning rate / step-size for this VQ model (initial condition value for `eta`)
40+
41+
distance_function: tuple specifying distance function and its order for computing best-matching units (BMUs)
42+
(Default: ("minkowski", 2)).
43+
usage guide:
44+
("minkowski", 2) or ("euclidean", ?) => use L2 norm (Euclidean) distance;
45+
("minkowski", 1) or ("manhattan", ?) => use L1 norm (taxi-cab/city-block) distance;
46+
("minkowksi", jnp.inf) or ("chebyshev", ?) => use Chebyshev distance;
47+
("minkowski", p > 2) => use a Minkowski distance of p-th order
48+
49+
weight_init: a kernel to drive initialization of this synaptic cable's values;
50+
typically a tuple with 1st element as a string calling the name of
51+
initialization to use
52+
53+
resist_scale: a fixed scaling factor to apply to synaptic transform
54+
(Default: 1.), i.e., yields: out = ((W * Rscale) * in)
55+
56+
p_conn: probability of a connection existing (default: 1.); setting
57+
this to < 1. will result in a sparser synaptic structure
58+
"""
59+
60+
def __init__(
61+
self,
62+
name,
63+
shape, ## determines codebook size
64+
eta=0.3, ## learning rate
65+
eta_decrement=0.00001, ## learning rate linear decrease (per update)
66+
syn_decay=0., ## weight decay term
67+
w_bound=0.,
68+
distance_function=("minkowski", 2),
69+
initial_patterns=None, ## possible class-based prototypes to init by
70+
weight_init=None,
71+
resist_scale=1.,
72+
p_conn=1.,
73+
batch_size=1,
74+
**kwargs
75+
):
76+
super().__init__(
77+
name, shape, weight_init, None, resist_scale, p_conn, batch_size=batch_size, **kwargs
78+
)
79+
80+
### Synapse and VQ hyper-parameters
81+
self.K = 1 ## number of winners for a bmu
82+
dist_fun, dist_order = distance_function ## Default: ("minkowski", 2) -> Euclidean
83+
if "euclidean" in dist_fun.lower():
84+
dist_order = 2
85+
elif "manhattan" in dist_fun.lower():
86+
dist_order = 1
87+
elif "chebyshev" in dist_fun.lower():
88+
dist_order = jnp.inf
89+
## TODO: add in cosine-distance (and maybe Mahalanobis distance)
90+
self.dist_order = dist_order ## set distance order p
91+
92+
self.shape = shape ## shape of synaptic efficacy matrix
93+
self.initial_eta = eta
94+
self.eta_decr = eta_decrement #0.001
95+
self.syn_decay = syn_decay
96+
self.w_bound = w_bound ## soft synaptic value bound (on magnitude)
97+
98+
## VQ Compartment setup
99+
self.eta = Compartment(jnp.zeros((1, 1)) + self.initial_eta)
100+
self.i_tick = Compartment(jnp.zeros((1, 1)))
101+
self.bmu = Compartment(jnp.zeros((1, 1)))
102+
#self.delta = Compartment(self.weights.get() * 0)
103+
self.dWeights = Compartment(self.weights.get() * 0)
104+
105+
@compilable
106+
def advance_state(self): ## forward-inference step of VQ
107+
x_in = self.inputs.get()
108+
W = self.weights.get().T ## get (transposed) memory matrix
109+
110+
### We do some 3D tensor math to handle a batch of predictions that need to be made
111+
### B = batch-size, D = embedding/input dim, C = number classes, N = number of memories
112+
_W = jnp.expand_dims(W, axis=0) ## 3D tensor format of memory (1 x N x D)
113+
_x_in = jnp.expand_dims(x_in, axis=1) ## 3D projection of input signals (B x 1 x D)
114+
D = _x_in - _W ## compute 3D batched delta tensor (B x N x D)
115+
116+
## now apply distance function measuremnt over 3D tensor of deltas
117+
## get batched (negative) distance measurements
118+
dist = jnp.linalg.norm(D, ord=self.dist_order, axis=2, keepdims=True) ## (B x N x 1)
119+
dist = -jnp.squeeze(dist, axis=2) ## (B x N) (negative distance to find minimal vals)
120+
121+
## now get K winners per sample in batch
122+
#values, indices = lax.top_k(dist, K)
123+
bmu_mask = bkwta(dist, self.K)
124+
self.outputs.set(bmu_mask)
125+
126+
@compilable
127+
def evolve(self, t, dt): ## competitive Hebbian update step of VQ
128+
W = self.weights.get()
129+
x_in = self.inputs.get()
130+
z_out = self.outputs.get()
131+
tmp_key, *subkeys = random.split(self.key.get(), 3)
132+
self.key.set(tmp_key)
133+
## synaptic update noise
134+
eps = random.normal(subkeys[0], W.shape) ## TODO: is this same size as tensor? or scalar?
135+
136+
## do the competitive Hebbian update
137+
dW = jnp.matmul(x_in.T, z_out) ## (N X D)
138+
#print("dW ", jnp.linalg.norm(dW))
139+
## TODO: compute sign of dW given label match (-1 if no match, +1 if match)
140+
self.dWeights.set(dW)
141+
#print("W(t) ", jnp.linalg.norm(W))
142+
dW = dW * self.eta.get() - (W * self.syn_decay) ## inject weight decay
143+
#dW = dW + jnp.sqrt(2. * self.eta.get()) * eps ## inject Langevin noise
144+
zeta = 0.2 #0.35 #1. ## Langevin dampening factor
145+
dW = dW + eps * (2. * self.eta.get()) * zeta ## noise term (prevents going to zero in theory)
146+
if self.w_bound > 0.:
147+
## enforce a soft value bound
148+
dW = dW * (self.w_bound - jnp.abs(W))
149+
## else, do not apply soft-bounding
150+
W = W + dW * self.eta.get()
151+
self.weights.set(W)
152+
#print("W(t+1) ", jnp.linalg.norm(W))
153+
#exit()
154+
155+
## update learning rate alpha
156+
#a = self.eta.get()
157+
#a = a + (-a) * (1./self.tau_eta)
158+
eta_tp1 = jnp.maximum(1e-5, self.eta.get() - self.eta_decr)
159+
self.eta.set(eta_tp1)
160+
161+
self.i_tick.set(self.i_tick.get() + 1)
162+
163+
@compilable
164+
def reset(self):
165+
preVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0]))
166+
postVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1]))
167+
168+
if not self.inputs.targeted:
169+
self.inputs.set(preVals)
170+
self.outputs.set(postVals)
171+
self.dWeights.set(jnp.zeros(self.shape.get()))
172+
#self.delta.set(jnp.zeros(self.shape.get()))
173+
self.bmu.set(self.bmu.get() * 0)
174+
#self.neighbor_weights.set(jnp.zeros((1, self.shape.get()[1])))
175+
176+
@classmethod
177+
def help(cls): ## component help function
178+
properties = {
179+
"synapse_type": "VectorQuantizeSynapse - performs an adaptable synaptic transformation of inputs to produce output "
180+
"signals; synapses are adjusted via competitive Hebbian learning in accordance with a "
181+
"vector quantization model"
182+
}
183+
compartment_props = {
184+
"input_compartments":
185+
{"inputs": "Takes in external input signal values",
186+
"key": "JAX PRNG key"},
187+
"parameter_compartments":
188+
{"weights": "Synapse efficacy/strength parameter values"},
189+
"output_compartments":
190+
{"outputs": "Output of synaptic transformation",
191+
"bmu": "Best-matching unit (BMU) mask"},
192+
}
193+
hyperparams = {
194+
"shape": "Shape of synaptic weight value matrix; number inputs x number outputs",
195+
"batch_size": "Batch size dimension of this component",
196+
"weight_init": "Initialization conditions for synaptic weight (W) values",
197+
"resist_scale": "Resistance level scaling factor (applied to output of transformation)",
198+
"p_conn": "Probability of a connection existing (otherwise, it is masked to zero)",
199+
"eta": "Global learning rate",
200+
"distance_function": "Distance function tuple specifying how to compute BMUs"
201+
}
202+
info = {cls.__name__: properties,
203+
"compartments": compartment_props,
204+
"dynamics": "outputs = [bmu_mask] ;"
205+
"dW = VQ competitive Hebbian update",
206+
"hyperparameters": hyperparams}
207+
return info
208+
209+
# if __name__ == '__main__':
210+
# from ngcsimlib.context import Context
211+
# with Context("Bar") as bar:
212+
# Wab = VectorQuantizeSynapse("Wab", (2, 3), 4, 4, 1.)
213+
# print(Wab)

0 commit comments

Comments
 (0)