Skip to content

Commit 79d2aab

Browse files
author
Alexander Ororbia
committed
integrated working som-synapse into competitive sub-package for synapses
1 parent 4f434f1 commit 79d2aab

6 files changed

Lines changed: 305 additions & 2 deletions

File tree

AUTHORS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ Contributors
1515
Maxbeth2 (Ohas)
1616
pagrawal-psu
1717
pulinagrawal
18+
antonvice

ngclearn/components/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from .synapses.hebbian.expSTDPSynapse import ExpSTDPSynapse
4040
from .synapses.hebbian.eventSTDPSynapse import EventSTDPSynapse
4141
from .synapses.hebbian.BCMSynapse import BCMSynapse
42+
from .synapses.competitive.SOMSynapse import SOMSynapse
4243
from .synapses.STPDenseSynapse import STPDenseSynapse
4344
from .synapses.exponentialSynapse import ExponentialSynapse
4445
from .synapses.doubleExpSynapse import DoubleExpSynapse

ngclearn/components/synapses/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
from .alphaSynapse import AlphaSynapse
99

1010
## dense synaptic components
11-
# from .hebbian.hebbianSynapse import HebbianSynapse
11+
from .hebbian.hebbianSynapse import HebbianSynapse ##
1212
from .hebbian.traceSTDPSynapse import TraceSTDPSynapse
1313
from .hebbian.expSTDPSynapse import ExpSTDPSynapse
1414
from .hebbian.eventSTDPSynapse import EventSTDPSynapse
1515
from .hebbian.BCMSynapse import BCMSynapse
1616
from .mpsSynapse import MPSSynapse
17+
### dense competitive synaptic components/elements
18+
from .competitive.SOMSynapse import SOMSynapse
1719

1820
## conv/deconv synaptic components
1921
from .convolution.convSynapse import ConvSynapse
Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
from jax import random, numpy as jnp, jit
2+
from ngclearn import compilable #from ngcsimlib.parser import compilable
3+
from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
4+
from ngclearn.utils.model_utils import softmax
5+
6+
from ngclearn.components.synapses.denseSynapse import DenseSynapse
7+
8+
def _gaussian_kernel(dist, sigma): ## Gaussian neighborhood function
9+
density = jnp.exp(-jnp.power(dist, 2) / (2 * (sigma ** 2))) # n_units x 1
10+
#return jnp.prod(density, axis=1, keepdims=True) ## calc likelihood
11+
return density
12+
13+
def _ricker_marr_kernel(dist, sigma): ## mexican hat neighborhood function
14+
# p = jnp.square(dist)
15+
# d = sigma * sigma * 2.
16+
#density = jnp.exp(-p/d) * (1. - 2./d * p)
17+
# p = jnp.square(dist/ sigma)
18+
# density = (1. - p) * jnp.exp(p * -0.5)
19+
# #return jnp.prod(density, axis=1, keepdims=True) ## calc likelihood
20+
# return density
21+
gauss_density = _gaussian_kernel(dist, sigma)
22+
density = gauss_density * (1. - (jnp.power(dist, 2) / (sigma ** 2)))
23+
return density
24+
25+
def _euclidean_dist(a, b): ## Euclidean (L2) distance
26+
delta = a - b
27+
d = jnp.linalg.norm(delta, axis=0, keepdims=True)
28+
return d, delta
29+
30+
def _manhattan_dist(a, b): ## Manhattan (L1) distance
31+
delta = a - b
32+
d = jnp.linalg.norm(delta, ord=1, axis=0, keepdims=True)
33+
return d, delta
34+
35+
def _cosine_dist(a, b): ## Cosine-similarity distance
36+
delta = a - b
37+
d = 1. - (jnp.matmul(a.T, b) / (jnp.linalg.norm(a, axis=0) * jnp.linalg.norm(b, axis=0)))
38+
return d, delta
39+
40+
class SOMSynapse(DenseSynapse): # Self-organizing map (SOM) synaptic cable
41+
"""
42+
A synaptic cable that emulates a self-organizing map (or Kohonen map) that is adapted via
43+
competitive Hebbian learning. Many of this synapses internal compartments house dynamically-updated
44+
values for learning elements such as the SOM's neighborhood radius and learning rate.
45+
46+
Mathematically, a synaptic update performed according to SOM theory is:
47+
| Delta W_{ij} = (x.T - W) * n(BMU) * eta
48+
| where n(BMU) is a neighborhood weighting function centered around (topological) coordinates of BMU
49+
| where x is vector of pre-synaptic inputs, W is SOM's synaptic matrix, and BMU is best-matching unit for x
50+
51+
| --- Synapse Compartments: ---
52+
| inputs - input (takes in external signals)
53+
| outputs - output signals (transformation induced by synapses)
54+
| weights - current value matrix of synaptic efficacies
55+
| bmu - current best-matching unit (BMU), based on current inputs
56+
| delta - current differences between inputs and each weight vector of this SOM's synaptic matrix
57+
| i_tick - current internal tick / marker (gets incremented by 1 for each call to `evolve`)
58+
| eta - current learning rate value
59+
| radius - current radius value to control neighborhood function
60+
| key - JAX PRNG key
61+
| --- Synaptic Plasticity Compartments: ---
62+
| inputs - pre-synaptic signal/value to drive 1st term of SOM update (x)
63+
| outputs - post-synaptic signal/value to drive 2nd term of SOM update (y)
64+
| neighbor_weights - topology weighting applied to synaptic adjustments
65+
| dWeights - current delta matrix containing changes to be applied to synapses
66+
67+
| References:
68+
| Kohonen, Teuvo. "The self-organizing map." Proceedings of the IEEE 78.9 (2002): 1464-1480.
69+
70+
Args:
71+
name: the string name of this cell
72+
73+
n_inputs: number of input units to this SOM
74+
75+
n_units_x: number of output units along length of rectangular topology of this SOM
76+
77+
n_units_y: number of output units along width of rectangular topology of this SOM
78+
79+
eta: (initial) learning rate / step-size for this SOM (initial condition value for `eta`)
80+
81+
distance_function: string specifying distance function to use for finding best-matching units (BMUs)
82+
(Default: "euclidean").
83+
usage guide:
84+
"euclidean" = use L2 / Euclidean distance
85+
"manhattan" = use L1 / Manhattan / taxi-cab distance
86+
"cosine" = use cosine-similarity distance
87+
88+
neighbor_function: string specifying neighborhood function to compute approximate topology weighting across
89+
units in topology (based on BMU) (Default: "gaussian").
90+
usage guide:
91+
"gaussian" = use Gaussian kernel
92+
"ricker" = use Mexican-hat / Ricker-Marr kernel
93+
94+
weight_init: a kernel to drive initialization of this synaptic cable's values;
95+
typically a tuple with 1st element as a string calling the name of
96+
initialization to use
97+
98+
resist_scale: a fixed scaling factor to apply to synaptic transform
99+
(Default: 1.), i.e., yields: out = ((W * Rscale) * in)
100+
101+
p_conn: probability of a connection existing (default: 1.); setting
102+
this to < 1. will result in a sparser synaptic structure
103+
"""
104+
105+
def __init__(
106+
self,
107+
name,
108+
n_inputs,
109+
n_units_x, ## num units along width of SOM rectangular topology
110+
n_units_y, ## num units along length of SOM rectangular topology
111+
eta=0.5, ## learning rate
112+
distance_function="euclidean",
113+
neighbor_function="gaussian",
114+
weight_init=None,
115+
resist_scale=1.,
116+
p_conn=1.,
117+
batch_size=1,
118+
**kwargs
119+
):
120+
shape = (n_inputs, n_units_x * n_units_y)
121+
super().__init__(name, shape, weight_init, None, resist_scale, p_conn, batch_size=batch_size, **kwargs)
122+
123+
### build (rectangular) topology coordinates
124+
coords = []
125+
for i in range(n_units_x):
126+
x = jnp.ones((n_units_x, 1)) * i
127+
y = jnp.expand_dims(jnp.arange(start=0, stop=n_units_y), axis=1)
128+
xy = jnp.concat((x, y), axis=1)
129+
coords.append(xy)
130+
self.coords = jnp.concat(coords, axis=0)
131+
132+
### Synapse and SOM hyper-parameters
133+
#self.radius = radius
134+
self.distance_function = distance_function
135+
self.dist_fx = 0 ## default := 0 (euclidean)
136+
if "manhattan" in distance_function:
137+
self.dist_fx = 1
138+
elif "cosine" in distance_function:
139+
self.dist_fx = 2
140+
self.neighbor_function = neighbor_function
141+
self.neighbor_fx = 0 ## default := 0 (Gaussian)
142+
if "ricker" in neighbor_function:
143+
self.neighbor_fx = 1 ## Mexican-hat function
144+
self.shape = shape ## shape of synaptic efficacy matrix
145+
146+
## exponential decay -> dz/dt = -kz has sol'n: z0 exp(-k t)
147+
# self.iterations = 50000
148+
# self.initial_eta = eta ## alpha (in SOM-lingo) #0.5
149+
# self.initial_radius = jnp.maximum(n_units_x, n_units_y) / 2 #n_units_x / 2
150+
# self.C = self.iterations / jnp.log(self.initial_radius)
151+
152+
## exponential decay -> dz/dt = -kz has sol'n: z0 exp(-k t)
153+
self.initial_eta = eta ## alpha (in SOM-lingo) #0.5
154+
self.initial_radius = jnp.maximum(n_units_x, n_units_y) / 2
155+
self.tau_eta = 50000
156+
self.tau_radius = self.tau_eta / jnp.log(self.initial_radius) ## C
157+
158+
## SOM Compartment setup
159+
self.radius = Compartment(jnp.zeros((1, 1)) + self.initial_radius)
160+
self.eta = Compartment(jnp.zeros((1, 1)) + self.initial_eta)
161+
self.i_tick = Compartment(jnp.zeros((1, 1)))
162+
self.bmu = Compartment(jnp.zeros((1, 1)))
163+
self.delta = Compartment(self.weights.get() * 0)
164+
self.neighbor_weights = Compartment(jnp.zeros((1, shape[1])))
165+
self.dWeights = Compartment(self.weights.get() * 0)
166+
167+
def _calc_bmu(self): ## obtain index of best-matching unit (BMU)
168+
x = self.inputs.get()
169+
W = self.weights.get()
170+
# W * I - x * I ?
171+
if self.dist_fx == 1: ## L1 distance
172+
d, delta = _manhattan_dist(x.T, W)
173+
elif self.dist_fx == 2: ## cosine distance
174+
d, delta = _cosine_dist(x.T, W)
175+
else: ## L2 distance
176+
d, delta = _euclidean_dist(x.T, W)
177+
bmu = jnp.argmin(d, axis=1, keepdims=True)
178+
bmu_idx = bmu #bmu[0, 0]
179+
return bmu_idx, delta
180+
181+
def _calc_neighborhood_weights(self): ## neighborhood function
182+
bmu = self.bmu.get() ## get best-matching unit
183+
bmu = bmu[0, 0]
184+
coords = self.coords ## constant coordinate array
185+
radius = self.radius.get()
186+
coord_bmu = coords[bmu:bmu + 1, :] ## TODO: might need to one-hot mask + sum
187+
delta = coords - coord_bmu ## raw differences (delta)
188+
189+
bmu_dist = jnp.linalg.norm(delta, axis=1, keepdims=True)
190+
if self.neighbor_fx == 1:
191+
neighbor_weights = _ricker_marr_kernel(bmu_dist, sigma=radius)
192+
else:
193+
neighbor_weights = _gaussian_kernel(bmu_dist, sigma=radius)
194+
195+
'''
196+
## calc distance values
197+
bmu_distance = jnp.sqrt(jnp.sum(jnp.square(delta), axis=1, keepdims=True))
198+
## apply kernel weighting function (below); e.g., Gaussian, Mexican-hat, triangular, etc.
199+
if self.neighbor_fx == 1:
200+
neighbor_weights = _ricker_marr_kernel(bmu_distance, sigma=radius)
201+
else:
202+
neighbor_weights = _gaussian_kernel(bmu_distance, sigma=radius)
203+
'''
204+
return neighbor_weights.T ## transpose to (1 x n_units)
205+
206+
@compilable
207+
def advance_state(self): ## forward-inference step of SOM
208+
bmu_idx, delta = self._calc_bmu()
209+
self.bmu.set(bmu_idx) ## store BMU
210+
self.delta.set(delta) ## store delta/differences
211+
neighbor_weights = self._calc_neighborhood_weights()
212+
self.neighbor_weights.set(neighbor_weights) ## store neighborhood weightings
213+
214+
## compute an approximate weighted activity output for input pattern
215+
#activity = jnp.sum(self.weights * self.resist_scale * neighbor_weights, axis=1, keepdims=True)
216+
### obtain weighted competitive activations (via softmax probs)
217+
activity = softmax(neighbor_weights * self.resist_scale)
218+
self.outputs.set(activity)
219+
220+
@compilable
221+
def evolve(self, t, dt): ## competitive Hebbian update step of SOM
222+
#bmu = self.bmu.get() ## best-matching unit
223+
delta = self.delta.get() ## deltas/differences between input & all SOM templates
224+
neighbor_weights = self.neighbor_weights.get() ## get neighborhood weight values
225+
226+
## exponential decay -> dz/dt = -kz has sol'n: z0 exp(-k t)
227+
#t = self.i_tick.get()
228+
## update radius
229+
r = self.radius.get()
230+
r = r + (-r) * (1./self.tau_radius)
231+
self.radius.set(r)
232+
## update learning rate alpha
233+
a = self.eta.get()
234+
a = a + (-a) * (1./self.tau_eta)
235+
self.eta.set(a)
236+
# self.radius.set(self.initial_radius * jnp.exp(-self.i_tick.get() / self.C)) ## update radius
237+
# self.eta.set(self.initial_eta * jnp.exp(-self.i_tick.get() / self.iterations)) ## update learning rate alpha
238+
239+
dWeights = delta * neighbor_weights * self.eta.get() ## calculate change-in-synapses
240+
self.dWeights.set(dWeights)
241+
_W = self.weights.get() + dWeights ## update via competitive Hebbian rule
242+
self.weights.set(_W)
243+
244+
self.i_tick.set(self.i_tick.get() + 1)
245+
246+
@compilable
247+
def reset(self):
248+
preVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0]))
249+
postVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1]))
250+
251+
if not self.inputs.targeted:
252+
self.inputs.set(preVals)
253+
self.outputs.set(postVals)
254+
self.dWeights.set(jnp.zeros(self.shape.get()))
255+
self.delta.set(jnp.zeros(self.shape.get()))
256+
self.bmu.set(jnp.zeros((1, 1)))
257+
self.neighbor_weights.set(jnp.zeros((1, self.shape.get()[1])))
258+
259+
@classmethod
260+
def help(cls): ## component help function
261+
properties = {
262+
"synapse_type": "SOMSynapse - performs an adaptable synaptic transformation of inputs to produce output "
263+
"signals; synapses are adjusted via competitive Hebbian learning in accordance with a "
264+
"Kohonen map"
265+
}
266+
compartment_props = {
267+
"input_compartments":
268+
{"inputs": "Takes in external input signal values",
269+
"key": "JAX PRNG key"},
270+
"parameter_compartments":
271+
{"weights": "Synapse efficacy/strength parameter values"},
272+
"output_compartments":
273+
{"outputs": "Output of synaptic transformation",
274+
"bmu": "Best-matching unit (BMU)"},
275+
}
276+
hyperparams = {
277+
"shape": "Shape of synaptic weight value matrix; number inputs x number outputs",
278+
"batch_size": "Batch size dimension of this component",
279+
"weight_init": "Initialization conditions for synaptic weight (W) values",
280+
"resist_scale": "Resistance level scaling factor (applied to output of transformation)",
281+
"p_conn": "Probability of a connection existing (otherwise, it is masked to zero)",
282+
"eta": "Global learning rate",
283+
"radius": "Radius parameter to control influence of neighborhood function",
284+
"distance_function": "Distance function used to compute BMU"
285+
}
286+
info = {cls.__name__: properties,
287+
"compartments": compartment_props,
288+
"dynamics": "outputs = [W * alpha(bmu)] ;"
289+
"dW = SOM competitive Hebbian update",
290+
"hyperparameters": hyperparams}
291+
return info
292+
293+
# if __name__ == '__main__':
294+
# from ngcsimlib.context import Context
295+
# with Context("Bar") as bar:
296+
# Wab = SOMSynapse("Wab", (2, 3), 4, 4, 1.)
297+
# print(Wab)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .SOMSynapse import SOMSynapse
2+

ngclearn/components/synapses/hebbian/BCMSynapse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def evolve(self, t, dt): #t, dt, tau_w, tau_theta, w_bound, w_decay, pre, post,
103103
dtheta = jnp.mean(jnp.square(self.post.get()), axis=0, keepdims=True) ## batch avg
104104
theta = self.theta.get() + (-self.theta.get() + dtheta) * dt / self.tau_theta
105105

106-
#self.weights.set(weights)
106+
self.weights.set(_W) ## TODO: this should update?
107107
self.theta.set(theta)
108108
self.dWeights.set(dWeights)
109109
self.post_term.set(post_term)

0 commit comments

Comments
 (0)