Skip to content

Commit d21fc65

Browse files
author
Alexander Ororbia
committed
cleaned up vq-synapse
1 parent 5097184 commit d21fc65

1 file changed

Lines changed: 113 additions & 38 deletions

File tree

ngclearn/components/synapses/competitive/vectorQuantizeSynapse.py

Lines changed: 113 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,27 @@ class VectorQuantizeSynapse(DenseSynapse): # Vector quantization (VQ) synaptic c
1616
1717
| --- Synapse Compartments: ---
1818
| inputs - input (takes in external signals)
19+
| labels - label input (optional)
1920
| outputs - output signals (transformation induced by synapses)
2021
| weights - current value matrix of synaptic efficacies
21-
| bmu - current best-matching unit (BMU) mask, based on current inputs
22+
| label_weights - current value of matrix label efficacies (if this VQ is supervised)
2223
| i_tick - current internal tick / marker (gets incremented by 1 for each call to `evolve`)
2324
| eta - current learning rate value
2425
| key - JAX PRNG key
2526
| --- Synaptic Plasticity Compartments: ---
2627
| inputs - pre-synaptic signal/value to drive 1st term of VQ update (x)
2728
| outputs - post-synaptic signal/value to drive 2nd term of VQ update (y)
29+
| labels - (optional) pre-synaptic signal to drive 1st term of VQ update to label matrix
2830
| dWeights - current delta matrix containing changes to be applied to synapses
2931
3032
| References:
33+
| Somervuo, Panu, and Teuvo Kohonen. "Self-organizing maps and learning vector quantization for
34+
| feature sequences." Neural Processing Letters 10.2 (1999): 151-159.
35+
|
3136
| Kohonen, Teuvo. "The self-organizing map." Proceedings of the IEEE 78.9 (2002): 1464-1480.
37+
|
38+
| Ororbia, Alexander G. "Continual competitive memory: A neural system for online task-free
39+
| lifelong learning." arXiv preprint arXiv:2106.13300 (2021).
3240
3341
Args:
3442
name: the string name of this cell
@@ -38,6 +46,13 @@ class VectorQuantizeSynapse(DenseSynapse): # Vector quantization (VQ) synaptic c
3846
3947
eta: (initial) learning rate / step-size for this VQ model (initial condition value for `eta`)
4048
49+
eta_decrement: a constant value to linearly decrease `eta` by per synaptic update
50+
(Default: 0, which disables this)
51+
52+
syn_decay: a synaptic weight (L2) decay to apply to synapses per update
53+
54+
w_bound: upper soft bound to enforce over synapses post-update (Default: 0, which disables this scaling term)
55+
4156
distance_function: tuple specifying distance function and its order for computing best-matching units (BMUs)
4257
(Default: ("minkowski", 2)).
4358
usage guide:
@@ -46,6 +61,15 @@ class VectorQuantizeSynapse(DenseSynapse): # Vector quantization (VQ) synaptic c
4661
("minkowksi", jnp.inf) or ("chebyshev", ?) => use Chebyshev distance;
4762
("minkowski", p > 2) => use a Minkowski distance of p-th order
4863
64+
label_dim: dimensionality of label neurons (corresponding to each memory/prototype); note this is
65+
inactive/unused if <= 0 (Default: 0)
66+
67+
initial_patterns: a tuple containing data vectors (and labels) to initialize code-book by;
68+
note if `label_dim` <= 0, then only first element of tuple will be used (Default: None)
69+
70+
langevin_noise_scale: scale factor to control degree to which Langevin sampling noise is
71+
applied to a given synaptic weight update (Default: 0, which disables this)
72+
4973
weight_init: a kernel to drive initialization of this synaptic cable's values;
5074
typically a tuple with 1st element as a string calling the name of
5175
initialization to use
@@ -66,7 +90,9 @@ def __init__(
6690
syn_decay=0., ## weight decay term
6791
w_bound=0.,
6892
distance_function=("minkowski", 2),
93+
label_dim=0, ## if > 0, then this becomes supervised LVQ(1)
6994
initial_patterns=None, ## possible class-based prototypes to init by
95+
lanvegin_noise_scale=0., ## scale of Langevin noise to apply to updates
7096
weight_init=None,
7197
resist_scale=1.,
7298
p_conn=1.,
@@ -78,38 +104,66 @@ def __init__(
78104
)
79105

80106
### Synapse and VQ hyper-parameters
81-
self.K = 1 ## number of winners for a bmu
107+
self.label_dim = label_dim
108+
self.K = 1 ## number of winners (for a bmu)
82109
dist_fun, dist_order = distance_function ## Default: ("minkowski", 2) -> Euclidean
83110
if "euclidean" in dist_fun.lower():
84111
dist_order = 2
85112
elif "manhattan" in dist_fun.lower():
86113
dist_order = 1
87114
elif "chebyshev" in dist_fun.lower():
88115
dist_order = jnp.inf
89-
## TODO: add in cosine-distance (and maybe Mahalanobis distance)
90-
self.dist_order = dist_order ## set distance order p
116+
self.dist_order = dist_order ## set distance order p
91117

92118
self.shape = shape ## shape of synaptic efficacy matrix
93119
self.initial_eta = eta
94120
self.eta_decr = eta_decrement #0.001
95121
self.syn_decay = syn_decay
96122
self.w_bound = w_bound ## soft synaptic value bound (on magnitude)
123+
self.zeta = langevin_noise_scale #0.2 #0.35 #1. ## Langevin dampening factor
97124

98125
## VQ Compartment setup
99-
self.eta = Compartment(jnp.zeros((1, 1)) + self.initial_eta)
126+
label_syn_init = labels_init = jnp.zeros((1, 1))
127+
if self.label_dim > 0:
128+
label_syn_init = jnp.zeros((label_dim, self.shape[1]))
129+
labels_init = jnp.zeros((self.batch_size, self.label_dim))
130+
self.labels = Compartment(labels_init, display_name="Label Units")
131+
self.pred_labels = Compartment(labels_init, display_name="Predicted Label Values")
132+
self.label_weights = Compartment(label_syn_init, display_name="Label Synapses / Memory")
133+
self.eta = Compartment(jnp.zeros((1, 1)) + self.initial_eta, display_name="Dynamic step size")
100134
self.i_tick = Compartment(jnp.zeros((1, 1)))
101-
self.bmu = Compartment(jnp.zeros((1, 1)))
102-
#self.delta = Compartment(self.weights.get() * 0)
135+
#self.bmu = Compartment(jnp.zeros((1, 1)), display_name="Best matching unit mask")
103136
self.dWeights = Compartment(self.weights.get() * 0)
104137

138+
if initial_patterns is not None: ## preload memory synaptic matrix
139+
initX, initY = initial_patterns
140+
W = self.weights.get()
141+
D, H = W.shape
142+
tmp_key, *subkeys = random.split(self.key.get(), 3)
143+
if initX.shape[1] < H: ## randomly portions of memory with stored patterns/templates
144+
ptrs = random.permutation(subkeys[0], H)
145+
W = jnp.concat([initX, W[:, 0:(H - initX.shape[1])]], axis=1)
146+
W = W[:, ptrs] ## shuffle memories
147+
self.weights.set(W)
148+
if self.label_dim > 0:
149+
Wy = self.label_weights.get()
150+
Wy = jnp.concat([initY, Wy[:, 0:(H - initX.shape[1])]], axis=1)
151+
Wy = Wy[:, ptrs] ## shuffle memories
152+
self.label_weights.set(Wy)
153+
else: ## memory is exactly the set of stored patterns/templates
154+
self.weights.set(initX)
155+
if self.label_dim > 0:
156+
self.label_weights.set(initY)
105157
@compilable
106158
def advance_state(self): ## forward-inference step of VQ
107159
x_in = self.inputs.get()
160+
x_in = x_in / jnp.linalg.norm(x_in, axis=1, keepdims=True)
161+
self.inputs.set(x_in)
108162
W = self.weights.get().T ## get (transposed) memory matrix
109163

110164
### We do some 3D tensor math to handle a batch of predictions that need to be made
111165
### B = batch-size, D = embedding/input dim, C = number classes, N = number of memories
112-
_W = jnp.expand_dims(W, axis=0) ## 3D tensor format of memory (1 x N x D)
166+
_W = jnp.expand_dims(W, axis=0) ## 3D tensor format of memory (1 x N x C)
113167
_x_in = jnp.expand_dims(x_in, axis=1) ## 3D projection of input signals (B x 1 x D)
114168
D = _x_in - _W ## compute 3D batched delta tensor (B x N x D)
115169

@@ -120,45 +174,62 @@ def advance_state(self): ## forward-inference step of VQ
120174

121175
## now get K winners per sample in batch
122176
#values, indices = lax.top_k(dist, K)
123-
bmu_mask = bkwta(dist, self.K)
177+
bmu_mask = bkwta(dist, nWTA=self.K)
124178
self.outputs.set(bmu_mask)
179+
if self.label_dim > 0: ## store a label prediction (if applicable)
180+
pred_labels = jnp.matmul(bmu_mask, self.label_weights.get().T)
181+
self.pred_labels.set(pred_labels)
125182

126183
@compilable
127184
def evolve(self, t, dt): ## competitive Hebbian update step of VQ
128185
W = self.weights.get()
129186
x_in = self.inputs.get()
130187
z_out = self.outputs.get()
131-
tmp_key, *subkeys = random.split(self.key.get(), 3)
132-
self.key.set(tmp_key)
133-
## synaptic update noise
134-
eps = random.normal(subkeys[0], W.shape) ## TODO: is this same size as tensor? or scalar?
135-
188+
136189
## do the competitive Hebbian update
137-
dW = jnp.matmul(x_in.T, z_out) ## (N X D)
138-
#print("dW ", jnp.linalg.norm(dW))
139-
## TODO: compute sign of dW given label match (-1 if no match, +1 if match)
140-
self.dWeights.set(dW)
141-
#print("W(t) ", jnp.linalg.norm(W))
142-
dW = dW * self.eta.get() - (W * self.syn_decay) ## inject weight decay
143-
#dW = dW + jnp.sqrt(2. * self.eta.get()) * eps ## inject Langevin noise
144-
zeta = 0.2 #0.35 #1. ## Langevin dampening factor
145-
dW = dW + eps * (2. * self.eta.get()) * zeta ## noise term (prevents going to zero in theory)
146-
if self.w_bound > 0.:
147-
## enforce a soft value bound
190+
signed_x_in = x_in
191+
if self.label_dim > 0: ## first, compute the sign of the update if labels are available
192+
## LVQ(1) => compute sign of dW given label match (-1 if no match, +1 if match)
193+
y_in = self.labels.get()
194+
YW = self.label_weights.get()#.T
195+
y_mem = jnp.matmul(z_out, YW.T) ## decode to get label memories
196+
## each row of `y_exists`: 1 if lab mem stored, 0 otherwise
197+
y_exists = (jnp.sum(y_mem, axis=1, keepdims=True) > 0.) * 1.
198+
## TODO: update YW with labels for initial cond?
199+
200+
y_in_l = jnp.argmax(y_in, axis=1, keepdims=True) ## get lab indices
201+
y_mem_l = jnp.argmax(y_mem, axis=1, keepdims=True) ## get mem lab indices
202+
## each of `dy`: -1 => incorrect (push away), +1 => correct (push towards)
203+
dy = (y_in_l == y_mem_l) * 2. - 1.
204+
dy = dy * y_exists + (1. - y_exists) ## +1 for each "empty" memory
205+
signed_x_in = x_in * dy ## sign (-1, +1) each update
206+
## else, sign of all updates is +1
207+
208+
## second, given the above sign, compute the Hebbian adjustment and optional terms
209+
dW = jnp.matmul(signed_x_in.T, z_out) ## calc competitive Hebbian update (N x D)
210+
dW = dW - (W * self.syn_decay) ## inject weight decay
211+
212+
if self.zeta > 0.: ## synaptic update noise
213+
tmp_key, *subkeys = random.split(self.key.get(), 3)
214+
self.key.set(tmp_key)
215+
eps = random.normal(subkeys[0], W.shape)
216+
dW = dW + jnp.sqrt(2. * self.eta.get()) * eps ## inject Langevin noise
217+
dW = dW + eps * (2. * self.eta.get()) * self.zeta ## noise term (prevents going to zero in theory)
218+
219+
if self.w_bound > 0.: ## enforce a soft value bound
148220
dW = dW * (self.w_bound - jnp.abs(W))
149-
## else, do not apply soft-bounding
221+
# ## else, do not apply soft-bounding
222+
self.dWeights.set(dW)
223+
224+
## third, apply the synaptic update to memory matrix W
150225
W = W + dW * self.eta.get()
151226
self.weights.set(W)
152-
#print("W(t+1) ", jnp.linalg.norm(W))
153-
#exit()
154227

155-
## update learning rate alpha
156-
#a = self.eta.get()
157-
#a = a + (-a) * (1./self.tau_eta)
228+
## update learning rate (eta)
158229
eta_tp1 = jnp.maximum(1e-5, self.eta.get() - self.eta_decr)
159230
self.eta.set(eta_tp1)
160231

161-
self.i_tick.set(self.i_tick.get() + 1)
232+
self.i_tick.set(self.i_tick.get() + 1) ## advance internal "tick"
162233

163234
@compilable
164235
def reset(self):
@@ -168,10 +239,10 @@ def reset(self):
168239
if not self.inputs.targeted:
169240
self.inputs.set(preVals)
170241
self.outputs.set(postVals)
242+
self.labels.set(self.labels.get() * 0)
243+
self.pred_labels.set(self.pred_labels.get() * 0)
171244
self.dWeights.set(jnp.zeros(self.shape.get()))
172-
#self.delta.set(jnp.zeros(self.shape.get()))
173-
self.bmu.set(self.bmu.get() * 0)
174-
#self.neighbor_weights.set(jnp.zeros((1, self.shape.get()[1])))
245+
#self.bmu.set(self.bmu.get() * 0)
175246

176247
@classmethod
177248
def help(cls): ## component help function
@@ -183,20 +254,24 @@ def help(cls): ## component help function
183254
compartment_props = {
184255
"input_compartments":
185256
{"inputs": "Takes in external input signal values",
257+
"labels": "Takes in (optional) label signal values",
186258
"key": "JAX PRNG key"},
187259
"parameter_compartments":
188-
{"weights": "Synapse efficacy/strength parameter values"},
260+
{"weights": "Synapse efficacy/strength parameter values",
261+
"label_weights": "Label efficacy parameter values (if this VQ is supervised)"},
189262
"output_compartments":
190-
{"outputs": "Output of synaptic transformation",
191-
"bmu": "Best-matching unit (BMU) mask"},
263+
{"outputs": "Output of synaptic transformation",
264+
"pred_labels": "Predicted labels (if this VQ is supervised)"},
192265
}
193266
hyperparams = {
194267
"shape": "Shape of synaptic weight value matrix; number inputs x number outputs",
268+
"label_dim": "Dimensionality of labels (if this VQ is supervised)",
195269
"batch_size": "Batch size dimension of this component",
196270
"weight_init": "Initialization conditions for synaptic weight (W) values",
197271
"resist_scale": "Resistance level scaling factor (applied to output of transformation)",
198272
"p_conn": "Probability of a connection existing (otherwise, it is masked to zero)",
199273
"eta": "Global learning rate",
274+
"eta_decrement": "Constant to decrement `eta` by per update/call to `evolve()`",
200275
"distance_function": "Distance function tuple specifying how to compute BMUs"
201276
}
202277
info = {cls.__name__: properties,

0 commit comments

Comments
 (0)