@@ -16,19 +16,27 @@ class VectorQuantizeSynapse(DenseSynapse): # Vector quantization (VQ) synaptic c
1616
1717 | --- Synapse Compartments: ---
1818 | inputs - input (takes in external signals)
19+ | labels - label input (optional)
1920 | outputs - output signals (transformation induced by synapses)
2021 | weights - current value matrix of synaptic efficacies
21- | bmu - current best-matching unit (BMU) mask, based on current inputs
22+ | label_weights - current value of matrix label efficacies (if this VQ is supervised)
2223 | i_tick - current internal tick / marker (gets incremented by 1 for each call to `evolve`)
2324 | eta - current learning rate value
2425 | key - JAX PRNG key
2526 | --- Synaptic Plasticity Compartments: ---
2627 | inputs - pre-synaptic signal/value to drive 1st term of VQ update (x)
2728 | outputs - post-synaptic signal/value to drive 2nd term of VQ update (y)
29+ | labels - (optional) pre-synaptic signal to drive 1st term of VQ update to label matrix
2830 | dWeights - current delta matrix containing changes to be applied to synapses
2931
3032 | References:
33+ | Somervuo, Panu, and Teuvo Kohonen. "Self-organizing maps and learning vector quantization for
34+ | feature sequences." Neural Processing Letters 10.2 (1999): 151-159.
35+ |
3136 | Kohonen, Teuvo. "The self-organizing map." Proceedings of the IEEE 78.9 (2002): 1464-1480.
37+ |
38+ | Ororbia, Alexander G. "Continual competitive memory: A neural system for online task-free
39+ | lifelong learning." arXiv preprint arXiv:2106.13300 (2021).
3240
3341 Args:
3442 name: the string name of this cell
@@ -38,6 +46,13 @@ class VectorQuantizeSynapse(DenseSynapse): # Vector quantization (VQ) synaptic c
3846
3947 eta: (initial) learning rate / step-size for this VQ model (initial condition value for `eta`)
4048
49+ eta_decrement: a constant value to linearly decrease `eta` by per synaptic update
50+ (Default: 0, which disables this)
51+
52+ syn_decay: a synaptic weight (L2) decay to apply to synapses per update
53+
54+ w_bound: upper soft bound to enforce over synapses post-update (Default: 0, which disables this scaling term)
55+
4156 distance_function: tuple specifying distance function and its order for computing best-matching units (BMUs)
4257 (Default: ("minkowski", 2)).
4358 usage guide:
@@ -46,6 +61,15 @@ class VectorQuantizeSynapse(DenseSynapse): # Vector quantization (VQ) synaptic c
4661 ("minkowksi", jnp.inf) or ("chebyshev", ?) => use Chebyshev distance;
4762 ("minkowski", p > 2) => use a Minkowski distance of p-th order
4863
64+ label_dim: dimensionality of label neurons (corresponding to each memory/prototype); note this is
65+ inactive/unused if <= 0 (Default: 0)
66+
67+ initial_patterns: a tuple containing data vectors (and labels) to initialize code-book by;
68+ note if `label_dim` <= 0, then only first element of tuple will be used (Default: None)
69+
70+ langevin_noise_scale: scale factor to control degree to which Langevin sampling noise is
71+ applied to a given synaptic weight update (Default: 0, which disables this)
72+
4973 weight_init: a kernel to drive initialization of this synaptic cable's values;
5074 typically a tuple with 1st element as a string calling the name of
5175 initialization to use
@@ -66,7 +90,9 @@ def __init__(
6690 syn_decay = 0. , ## weight decay term
6791 w_bound = 0. ,
6892 distance_function = ("minkowski" , 2 ),
93+ label_dim = 0 , ## if > 0, then this becomes supervised LVQ(1)
6994 initial_patterns = None , ## possible class-based prototypes to init by
95+ lanvegin_noise_scale = 0. , ## scale of Langevin noise to apply to updates
7096 weight_init = None ,
7197 resist_scale = 1. ,
7298 p_conn = 1. ,
@@ -78,38 +104,66 @@ def __init__(
78104 )
79105
80106 ### Synapse and VQ hyper-parameters
81- self .K = 1 ## number of winners for a bmu
107+ self .label_dim = label_dim
108+ self .K = 1 ## number of winners (for a bmu)
82109 dist_fun , dist_order = distance_function ## Default: ("minkowski", 2) -> Euclidean
83110 if "euclidean" in dist_fun .lower ():
84111 dist_order = 2
85112 elif "manhattan" in dist_fun .lower ():
86113 dist_order = 1
87114 elif "chebyshev" in dist_fun .lower ():
88115 dist_order = jnp .inf
89- ## TODO: add in cosine-distance (and maybe Mahalanobis distance)
90- self .dist_order = dist_order ## set distance order p
116+ self .dist_order = dist_order ## set distance order p
91117
92118 self .shape = shape ## shape of synaptic efficacy matrix
93119 self .initial_eta = eta
94120 self .eta_decr = eta_decrement #0.001
95121 self .syn_decay = syn_decay
96122 self .w_bound = w_bound ## soft synaptic value bound (on magnitude)
123+ self .zeta = langevin_noise_scale #0.2 #0.35 #1. ## Langevin dampening factor
97124
98125 ## VQ Compartment setup
99- self .eta = Compartment (jnp .zeros ((1 , 1 )) + self .initial_eta )
126+ label_syn_init = labels_init = jnp .zeros ((1 , 1 ))
127+ if self .label_dim > 0 :
128+ label_syn_init = jnp .zeros ((label_dim , self .shape [1 ]))
129+ labels_init = jnp .zeros ((self .batch_size , self .label_dim ))
130+ self .labels = Compartment (labels_init , display_name = "Label Units" )
131+ self .pred_labels = Compartment (labels_init , display_name = "Predicted Label Values" )
132+ self .label_weights = Compartment (label_syn_init , display_name = "Label Synapses / Memory" )
133+ self .eta = Compartment (jnp .zeros ((1 , 1 )) + self .initial_eta , display_name = "Dynamic step size" )
100134 self .i_tick = Compartment (jnp .zeros ((1 , 1 )))
101- self .bmu = Compartment (jnp .zeros ((1 , 1 )))
102- #self.delta = Compartment(self.weights.get() * 0)
135+ #self.bmu = Compartment(jnp.zeros((1, 1)), display_name="Best matching unit mask")
103136 self .dWeights = Compartment (self .weights .get () * 0 )
104137
138+ if initial_patterns is not None : ## preload memory synaptic matrix
139+ initX , initY = initial_patterns
140+ W = self .weights .get ()
141+ D , H = W .shape
142+ tmp_key , * subkeys = random .split (self .key .get (), 3 )
143+ if initX .shape [1 ] < H : ## randomly portions of memory with stored patterns/templates
144+ ptrs = random .permutation (subkeys [0 ], H )
145+ W = jnp .concat ([initX , W [:, 0 :(H - initX .shape [1 ])]], axis = 1 )
146+ W = W [:, ptrs ] ## shuffle memories
147+ self .weights .set (W )
148+ if self .label_dim > 0 :
149+ Wy = self .label_weights .get ()
150+ Wy = jnp .concat ([initY , Wy [:, 0 :(H - initX .shape [1 ])]], axis = 1 )
151+ Wy = Wy [:, ptrs ] ## shuffle memories
152+ self .label_weights .set (Wy )
153+ else : ## memory is exactly the set of stored patterns/templates
154+ self .weights .set (initX )
155+ if self .label_dim > 0 :
156+ self .label_weights .set (initY )
105157 @compilable
106158 def advance_state (self ): ## forward-inference step of VQ
107159 x_in = self .inputs .get ()
160+ x_in = x_in / jnp .linalg .norm (x_in , axis = 1 , keepdims = True )
161+ self .inputs .set (x_in )
108162 W = self .weights .get ().T ## get (transposed) memory matrix
109163
110164 ### We do some 3D tensor math to handle a batch of predictions that need to be made
111165 ### B = batch-size, D = embedding/input dim, C = number classes, N = number of memories
112- _W = jnp .expand_dims (W , axis = 0 ) ## 3D tensor format of memory (1 x N x D )
166+ _W = jnp .expand_dims (W , axis = 0 ) ## 3D tensor format of memory (1 x N x C )
113167 _x_in = jnp .expand_dims (x_in , axis = 1 ) ## 3D projection of input signals (B x 1 x D)
114168 D = _x_in - _W ## compute 3D batched delta tensor (B x N x D)
115169
@@ -120,45 +174,62 @@ def advance_state(self): ## forward-inference step of VQ
120174
121175 ## now get K winners per sample in batch
122176 #values, indices = lax.top_k(dist, K)
123- bmu_mask = bkwta (dist , self .K )
177+ bmu_mask = bkwta (dist , nWTA = self .K )
124178 self .outputs .set (bmu_mask )
179+ if self .label_dim > 0 : ## store a label prediction (if applicable)
180+ pred_labels = jnp .matmul (bmu_mask , self .label_weights .get ().T )
181+ self .pred_labels .set (pred_labels )
125182
126183 @compilable
127184 def evolve (self , t , dt ): ## competitive Hebbian update step of VQ
128185 W = self .weights .get ()
129186 x_in = self .inputs .get ()
130187 z_out = self .outputs .get ()
131- tmp_key , * subkeys = random .split (self .key .get (), 3 )
132- self .key .set (tmp_key )
133- ## synaptic update noise
134- eps = random .normal (subkeys [0 ], W .shape ) ## TODO: is this same size as tensor? or scalar?
135-
188+
136189 ## do the competitive Hebbian update
137- dW = jnp .matmul (x_in .T , z_out ) ## (N X D)
138- #print("dW ", jnp.linalg.norm(dW))
139- ## TODO: compute sign of dW given label match (-1 if no match, +1 if match)
140- self .dWeights .set (dW )
141- #print("W(t) ", jnp.linalg.norm(W))
142- dW = dW * self .eta .get () - (W * self .syn_decay ) ## inject weight decay
143- #dW = dW + jnp.sqrt(2. * self.eta.get()) * eps ## inject Langevin noise
144- zeta = 0.2 #0.35 #1. ## Langevin dampening factor
145- dW = dW + eps * (2. * self .eta .get ()) * zeta ## noise term (prevents going to zero in theory)
146- if self .w_bound > 0. :
147- ## enforce a soft value bound
190+ signed_x_in = x_in
191+ if self .label_dim > 0 : ## first, compute the sign of the update if labels are available
192+ ## LVQ(1) => compute sign of dW given label match (-1 if no match, +1 if match)
193+ y_in = self .labels .get ()
194+ YW = self .label_weights .get ()#.T
195+ y_mem = jnp .matmul (z_out , YW .T ) ## decode to get label memories
196+ ## each row of `y_exists`: 1 if lab mem stored, 0 otherwise
197+ y_exists = (jnp .sum (y_mem , axis = 1 , keepdims = True ) > 0. ) * 1.
198+ ## TODO: update YW with labels for initial cond?
199+
200+ y_in_l = jnp .argmax (y_in , axis = 1 , keepdims = True ) ## get lab indices
201+ y_mem_l = jnp .argmax (y_mem , axis = 1 , keepdims = True ) ## get mem lab indices
202+ ## each of `dy`: -1 => incorrect (push away), +1 => correct (push towards)
203+ dy = (y_in_l == y_mem_l ) * 2. - 1.
204+ dy = dy * y_exists + (1. - y_exists ) ## +1 for each "empty" memory
205+ signed_x_in = x_in * dy ## sign (-1, +1) each update
206+ ## else, sign of all updates is +1
207+
208+ ## second, given the above sign, compute the Hebbian adjustment and optional terms
209+ dW = jnp .matmul (signed_x_in .T , z_out ) ## calc competitive Hebbian update (N x D)
210+ dW = dW - (W * self .syn_decay ) ## inject weight decay
211+
212+ if self .zeta > 0. : ## synaptic update noise
213+ tmp_key , * subkeys = random .split (self .key .get (), 3 )
214+ self .key .set (tmp_key )
215+ eps = random .normal (subkeys [0 ], W .shape )
216+ dW = dW + jnp .sqrt (2. * self .eta .get ()) * eps ## inject Langevin noise
217+ dW = dW + eps * (2. * self .eta .get ()) * self .zeta ## noise term (prevents going to zero in theory)
218+
219+ if self .w_bound > 0. : ## enforce a soft value bound
148220 dW = dW * (self .w_bound - jnp .abs (W ))
149- ## else, do not apply soft-bounding
221+ # ## else, do not apply soft-bounding
222+ self .dWeights .set (dW )
223+
224+ ## third, apply the synaptic update to memory matrix W
150225 W = W + dW * self .eta .get ()
151226 self .weights .set (W )
152- #print("W(t+1) ", jnp.linalg.norm(W))
153- #exit()
154227
155- ## update learning rate alpha
156- #a = self.eta.get()
157- #a = a + (-a) * (1./self.tau_eta)
228+ ## update learning rate (eta)
158229 eta_tp1 = jnp .maximum (1e-5 , self .eta .get () - self .eta_decr )
159230 self .eta .set (eta_tp1 )
160231
161- self .i_tick .set (self .i_tick .get () + 1 )
232+ self .i_tick .set (self .i_tick .get () + 1 ) ## advance internal "tick"
162233
163234 @compilable
164235 def reset (self ):
@@ -168,10 +239,10 @@ def reset(self):
168239 if not self .inputs .targeted :
169240 self .inputs .set (preVals )
170241 self .outputs .set (postVals )
242+ self .labels .set (self .labels .get () * 0 )
243+ self .pred_labels .set (self .pred_labels .get () * 0 )
171244 self .dWeights .set (jnp .zeros (self .shape .get ()))
172- #self.delta.set(jnp.zeros(self.shape.get()))
173- self .bmu .set (self .bmu .get () * 0 )
174- #self.neighbor_weights.set(jnp.zeros((1, self.shape.get()[1])))
245+ #self.bmu.set(self.bmu.get() * 0)
175246
176247 @classmethod
177248 def help (cls ): ## component help function
@@ -183,20 +254,24 @@ def help(cls): ## component help function
183254 compartment_props = {
184255 "input_compartments" :
185256 {"inputs" : "Takes in external input signal values" ,
257+ "labels" : "Takes in (optional) label signal values" ,
186258 "key" : "JAX PRNG key" },
187259 "parameter_compartments" :
188- {"weights" : "Synapse efficacy/strength parameter values" },
260+ {"weights" : "Synapse efficacy/strength parameter values" ,
261+ "label_weights" : "Label efficacy parameter values (if this VQ is supervised)" },
189262 "output_compartments" :
190- {"outputs" : "Output of synaptic transformation" ,
191- "bmu " : "Best-matching unit (BMU) mask " },
263+ {"outputs" : "Output of synaptic transformation" ,
264+ "pred_labels " : "Predicted labels (if this VQ is supervised) " },
192265 }
193266 hyperparams = {
194267 "shape" : "Shape of synaptic weight value matrix; number inputs x number outputs" ,
268+ "label_dim" : "Dimensionality of labels (if this VQ is supervised)" ,
195269 "batch_size" : "Batch size dimension of this component" ,
196270 "weight_init" : "Initialization conditions for synaptic weight (W) values" ,
197271 "resist_scale" : "Resistance level scaling factor (applied to output of transformation)" ,
198272 "p_conn" : "Probability of a connection existing (otherwise, it is masked to zero)" ,
199273 "eta" : "Global learning rate" ,
274+ "eta_decrement" : "Constant to decrement `eta` by per update/call to `evolve()`" ,
200275 "distance_function" : "Distance function tuple specifying how to compute BMUs"
201276 }
202277 info = {cls .__name__ : properties ,
0 commit comments