|
| 1 | +from jax import random, numpy as jnp, jit |
| 2 | +from ngclearn import compilable #from ngcsimlib.parser import compilable |
| 3 | +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment |
| 4 | +from ngclearn.utils.model_utils import softmax |
| 5 | + |
| 6 | +from ngclearn.components.synapses.denseSynapse import DenseSynapse |
| 7 | + |
| 8 | +def _gaussian_kernel(dist, sigma): ## Gaussian neighborhood function |
| 9 | + density = jnp.exp(-jnp.power(dist, 2) / (2 * (sigma ** 2))) # n_units x 1 |
| 10 | + #return jnp.prod(density, axis=1, keepdims=True) ## calc likelihood |
| 11 | + return density |
| 12 | + |
| 13 | +def _ricker_marr_kernel(dist, sigma): ## mexican hat neighborhood function |
| 14 | + # p = jnp.square(dist) |
| 15 | + # d = sigma * sigma * 2. |
| 16 | + #density = jnp.exp(-p/d) * (1. - 2./d * p) |
| 17 | + # p = jnp.square(dist/ sigma) |
| 18 | + # density = (1. - p) * jnp.exp(p * -0.5) |
| 19 | + # #return jnp.prod(density, axis=1, keepdims=True) ## calc likelihood |
| 20 | + # return density |
| 21 | + gauss_density = _gaussian_kernel(dist, sigma) |
| 22 | + density = gauss_density * (1. - (jnp.power(dist, 2) / (sigma ** 2))) |
| 23 | + return density |
| 24 | + |
| 25 | +def _euclidean_dist(a, b): ## Euclidean (L2) distance |
| 26 | + delta = a - b |
| 27 | + d = jnp.linalg.norm(delta, axis=0, keepdims=True) |
| 28 | + return d, delta |
| 29 | + |
| 30 | +def _manhattan_dist(a, b): ## Manhattan (L1) distance |
| 31 | + delta = a - b |
| 32 | + d = jnp.linalg.norm(delta, ord=1, axis=0, keepdims=True) |
| 33 | + return d, delta |
| 34 | + |
| 35 | +def _cosine_dist(a, b): ## Cosine-similarity distance |
| 36 | + delta = a - b |
| 37 | + d = 1. - (jnp.matmul(a.T, b) / (jnp.linalg.norm(a, axis=0) * jnp.linalg.norm(b, axis=0))) |
| 38 | + return d, delta |
| 39 | + |
| 40 | +class SOMSynapse(DenseSynapse): # Self-organizing map (SOM) synaptic cable |
| 41 | + """ |
| 42 | + A synaptic cable that emulates a self-organizing map (or Kohonen map) that is adapted via |
| 43 | + competitive Hebbian learning. Many of this synapses internal compartments house dynamically-updated |
| 44 | + values for learning elements such as the SOM's neighborhood radius and learning rate. |
| 45 | +
|
| 46 | + Mathematically, a synaptic update performed according to SOM theory is: |
| 47 | + | Delta W_{ij} = (x.T - W) * n(BMU) * eta |
| 48 | + | where n(BMU) is a neighborhood weighting function centered around (topological) coordinates of BMU |
| 49 | + | where x is vector of pre-synaptic inputs, W is SOM's synaptic matrix, and BMU is best-matching unit for x |
| 50 | +
|
| 51 | + | --- Synapse Compartments: --- |
| 52 | + | inputs - input (takes in external signals) |
| 53 | + | outputs - output signals (transformation induced by synapses) |
| 54 | + | weights - current value matrix of synaptic efficacies |
| 55 | + | bmu - current best-matching unit (BMU), based on current inputs |
| 56 | + | delta - current differences between inputs and each weight vector of this SOM's synaptic matrix |
| 57 | + | i_tick - current internal tick / marker (gets incremented by 1 for each call to `evolve`) |
| 58 | + | eta - current learning rate value |
| 59 | + | radius - current radius value to control neighborhood function |
| 60 | + | key - JAX PRNG key |
| 61 | + | --- Synaptic Plasticity Compartments: --- |
| 62 | + | inputs - pre-synaptic signal/value to drive 1st term of SOM update (x) |
| 63 | + | outputs - post-synaptic signal/value to drive 2nd term of SOM update (y) |
| 64 | + | neighbor_weights - topology weighting applied to synaptic adjustments |
| 65 | + | dWeights - current delta matrix containing changes to be applied to synapses |
| 66 | +
|
| 67 | + | References: |
| 68 | + | Kohonen, Teuvo. "The self-organizing map." Proceedings of the IEEE 78.9 (2002): 1464-1480. |
| 69 | +
|
| 70 | + Args: |
| 71 | + name: the string name of this cell |
| 72 | +
|
| 73 | + n_inputs: number of input units to this SOM |
| 74 | +
|
| 75 | + n_units_x: number of output units along length of rectangular topology of this SOM |
| 76 | +
|
| 77 | + n_units_y: number of output units along width of rectangular topology of this SOM |
| 78 | +
|
| 79 | + eta: (initial) learning rate / step-size for this SOM (initial condition value for `eta`) |
| 80 | +
|
| 81 | + distance_function: string specifying distance function to use for finding best-matching units (BMUs) |
| 82 | + (Default: "euclidean"). |
| 83 | + usage guide: |
| 84 | + "euclidean" = use L2 / Euclidean distance |
| 85 | + "manhattan" = use L1 / Manhattan / taxi-cab distance |
| 86 | + "cosine" = use cosine-similarity distance |
| 87 | +
|
| 88 | + neighbor_function: string specifying neighborhood function to compute approximate topology weighting across |
| 89 | + units in topology (based on BMU) (Default: "gaussian"). |
| 90 | + usage guide: |
| 91 | + "gaussian" = use Gaussian kernel |
| 92 | + "ricker" = use Mexican-hat / Ricker-Marr kernel |
| 93 | +
|
| 94 | + weight_init: a kernel to drive initialization of this synaptic cable's values; |
| 95 | + typically a tuple with 1st element as a string calling the name of |
| 96 | + initialization to use |
| 97 | +
|
| 98 | + resist_scale: a fixed scaling factor to apply to synaptic transform |
| 99 | + (Default: 1.), i.e., yields: out = ((W * Rscale) * in) |
| 100 | +
|
| 101 | + p_conn: probability of a connection existing (default: 1.); setting |
| 102 | + this to < 1. will result in a sparser synaptic structure |
| 103 | + """ |
| 104 | + |
| 105 | + def __init__( |
| 106 | + self, |
| 107 | + name, |
| 108 | + n_inputs, |
| 109 | + n_units_x, ## num units along width of SOM rectangular topology |
| 110 | + n_units_y, ## num units along length of SOM rectangular topology |
| 111 | + eta=0.5, ## learning rate |
| 112 | + distance_function="euclidean", |
| 113 | + neighbor_function="gaussian", |
| 114 | + weight_init=None, |
| 115 | + resist_scale=1., |
| 116 | + p_conn=1., |
| 117 | + batch_size=1, |
| 118 | + **kwargs |
| 119 | + ): |
| 120 | + shape = (n_inputs, n_units_x * n_units_y) |
| 121 | + super().__init__(name, shape, weight_init, None, resist_scale, p_conn, batch_size=batch_size, **kwargs) |
| 122 | + |
| 123 | + ### build (rectangular) topology coordinates |
| 124 | + coords = [] |
| 125 | + for i in range(n_units_x): |
| 126 | + x = jnp.ones((n_units_x, 1)) * i |
| 127 | + y = jnp.expand_dims(jnp.arange(start=0, stop=n_units_y), axis=1) |
| 128 | + xy = jnp.concat((x, y), axis=1) |
| 129 | + coords.append(xy) |
| 130 | + self.coords = jnp.concat(coords, axis=0) |
| 131 | + |
| 132 | + ### Synapse and SOM hyper-parameters |
| 133 | + #self.radius = radius |
| 134 | + self.distance_function = distance_function |
| 135 | + self.dist_fx = 0 ## default := 0 (euclidean) |
| 136 | + if "manhattan" in distance_function: |
| 137 | + self.dist_fx = 1 |
| 138 | + elif "cosine" in distance_function: |
| 139 | + self.dist_fx = 2 |
| 140 | + self.neighbor_function = neighbor_function |
| 141 | + self.neighbor_fx = 0 ## default := 0 (Gaussian) |
| 142 | + if "ricker" in neighbor_function: |
| 143 | + self.neighbor_fx = 1 ## Mexican-hat function |
| 144 | + self.shape = shape ## shape of synaptic efficacy matrix |
| 145 | + |
| 146 | + ## exponential decay -> dz/dt = -kz has sol'n: z0 exp(-k t) |
| 147 | + # self.iterations = 50000 |
| 148 | + # self.initial_eta = eta ## alpha (in SOM-lingo) #0.5 |
| 149 | + # self.initial_radius = jnp.maximum(n_units_x, n_units_y) / 2 #n_units_x / 2 |
| 150 | + # self.C = self.iterations / jnp.log(self.initial_radius) |
| 151 | + |
| 152 | + ## exponential decay -> dz/dt = -kz has sol'n: z0 exp(-k t) |
| 153 | + self.initial_eta = eta ## alpha (in SOM-lingo) #0.5 |
| 154 | + self.initial_radius = jnp.maximum(n_units_x, n_units_y) / 2 |
| 155 | + self.tau_eta = 50000 |
| 156 | + self.tau_radius = self.tau_eta / jnp.log(self.initial_radius) ## C |
| 157 | + |
| 158 | + ## SOM Compartment setup |
| 159 | + self.radius = Compartment(jnp.zeros((1, 1)) + self.initial_radius) |
| 160 | + self.eta = Compartment(jnp.zeros((1, 1)) + self.initial_eta) |
| 161 | + self.i_tick = Compartment(jnp.zeros((1, 1))) |
| 162 | + self.bmu = Compartment(jnp.zeros((1, 1))) |
| 163 | + self.delta = Compartment(self.weights.get() * 0) |
| 164 | + self.neighbor_weights = Compartment(jnp.zeros((1, shape[1]))) |
| 165 | + self.dWeights = Compartment(self.weights.get() * 0) |
| 166 | + |
| 167 | + def _calc_bmu(self): ## obtain index of best-matching unit (BMU) |
| 168 | + x = self.inputs.get() |
| 169 | + W = self.weights.get() |
| 170 | + # W * I - x * I ? |
| 171 | + if self.dist_fx == 1: ## L1 distance |
| 172 | + d, delta = _manhattan_dist(x.T, W) |
| 173 | + elif self.dist_fx == 2: ## cosine distance |
| 174 | + d, delta = _cosine_dist(x.T, W) |
| 175 | + else: ## L2 distance |
| 176 | + d, delta = _euclidean_dist(x.T, W) |
| 177 | + bmu = jnp.argmin(d, axis=1, keepdims=True) |
| 178 | + bmu_idx = bmu #bmu[0, 0] |
| 179 | + return bmu_idx, delta |
| 180 | + |
| 181 | + def _calc_neighborhood_weights(self): ## neighborhood function |
| 182 | + bmu = self.bmu.get() ## get best-matching unit |
| 183 | + bmu = bmu[0, 0] |
| 184 | + coords = self.coords ## constant coordinate array |
| 185 | + radius = self.radius.get() |
| 186 | + coord_bmu = coords[bmu:bmu + 1, :] ## TODO: might need to one-hot mask + sum |
| 187 | + delta = coords - coord_bmu ## raw differences (delta) |
| 188 | + |
| 189 | + bmu_dist = jnp.linalg.norm(delta, axis=1, keepdims=True) |
| 190 | + if self.neighbor_fx == 1: |
| 191 | + neighbor_weights = _ricker_marr_kernel(bmu_dist, sigma=radius) |
| 192 | + else: |
| 193 | + neighbor_weights = _gaussian_kernel(bmu_dist, sigma=radius) |
| 194 | + |
| 195 | + ''' |
| 196 | + ## calc distance values |
| 197 | + bmu_distance = jnp.sqrt(jnp.sum(jnp.square(delta), axis=1, keepdims=True)) |
| 198 | + ## apply kernel weighting function (below); e.g., Gaussian, Mexican-hat, triangular, etc. |
| 199 | + if self.neighbor_fx == 1: |
| 200 | + neighbor_weights = _ricker_marr_kernel(bmu_distance, sigma=radius) |
| 201 | + else: |
| 202 | + neighbor_weights = _gaussian_kernel(bmu_distance, sigma=radius) |
| 203 | + ''' |
| 204 | + return neighbor_weights.T ## transpose to (1 x n_units) |
| 205 | + |
| 206 | + @compilable |
| 207 | + def advance_state(self): ## forward-inference step of SOM |
| 208 | + bmu_idx, delta = self._calc_bmu() |
| 209 | + self.bmu.set(bmu_idx) ## store BMU |
| 210 | + self.delta.set(delta) ## store delta/differences |
| 211 | + neighbor_weights = self._calc_neighborhood_weights() |
| 212 | + self.neighbor_weights.set(neighbor_weights) ## store neighborhood weightings |
| 213 | + |
| 214 | + ## compute an approximate weighted activity output for input pattern |
| 215 | + #activity = jnp.sum(self.weights * self.resist_scale * neighbor_weights, axis=1, keepdims=True) |
| 216 | + ### obtain weighted competitive activations (via softmax probs) |
| 217 | + activity = softmax(neighbor_weights * self.resist_scale) |
| 218 | + self.outputs.set(activity) |
| 219 | + |
| 220 | + @compilable |
| 221 | + def evolve(self, t, dt): ## competitive Hebbian update step of SOM |
| 222 | + #bmu = self.bmu.get() ## best-matching unit |
| 223 | + delta = self.delta.get() ## deltas/differences between input & all SOM templates |
| 224 | + neighbor_weights = self.neighbor_weights.get() ## get neighborhood weight values |
| 225 | + |
| 226 | + ## exponential decay -> dz/dt = -kz has sol'n: z0 exp(-k t) |
| 227 | + #t = self.i_tick.get() |
| 228 | + ## update radius |
| 229 | + r = self.radius.get() |
| 230 | + r = r + (-r) * (1./self.tau_radius) |
| 231 | + self.radius.set(r) |
| 232 | + ## update learning rate alpha |
| 233 | + a = self.eta.get() |
| 234 | + a = a + (-a) * (1./self.tau_eta) |
| 235 | + self.eta.set(a) |
| 236 | + # self.radius.set(self.initial_radius * jnp.exp(-self.i_tick.get() / self.C)) ## update radius |
| 237 | + # self.eta.set(self.initial_eta * jnp.exp(-self.i_tick.get() / self.iterations)) ## update learning rate alpha |
| 238 | + |
| 239 | + dWeights = delta * neighbor_weights * self.eta.get() ## calculate change-in-synapses |
| 240 | + self.dWeights.set(dWeights) |
| 241 | + _W = self.weights.get() + dWeights ## update via competitive Hebbian rule |
| 242 | + self.weights.set(_W) |
| 243 | + |
| 244 | + self.i_tick.set(self.i_tick.get() + 1) |
| 245 | + |
| 246 | + @compilable |
| 247 | + def reset(self): |
| 248 | + preVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0])) |
| 249 | + postVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1])) |
| 250 | + |
| 251 | + if not self.inputs.targeted: |
| 252 | + self.inputs.set(preVals) |
| 253 | + self.outputs.set(postVals) |
| 254 | + self.dWeights.set(jnp.zeros(self.shape.get())) |
| 255 | + self.delta.set(jnp.zeros(self.shape.get())) |
| 256 | + self.bmu.set(jnp.zeros((1, 1))) |
| 257 | + self.neighbor_weights.set(jnp.zeros((1, self.shape.get()[1]))) |
| 258 | + |
| 259 | + @classmethod |
| 260 | + def help(cls): ## component help function |
| 261 | + properties = { |
| 262 | + "synapse_type": "SOMSynapse - performs an adaptable synaptic transformation of inputs to produce output " |
| 263 | + "signals; synapses are adjusted via competitive Hebbian learning in accordance with a " |
| 264 | + "Kohonen map" |
| 265 | + } |
| 266 | + compartment_props = { |
| 267 | + "input_compartments": |
| 268 | + {"inputs": "Takes in external input signal values", |
| 269 | + "key": "JAX PRNG key"}, |
| 270 | + "parameter_compartments": |
| 271 | + {"weights": "Synapse efficacy/strength parameter values"}, |
| 272 | + "output_compartments": |
| 273 | + {"outputs": "Output of synaptic transformation", |
| 274 | + "bmu": "Best-matching unit (BMU)"}, |
| 275 | + } |
| 276 | + hyperparams = { |
| 277 | + "shape": "Shape of synaptic weight value matrix; number inputs x number outputs", |
| 278 | + "batch_size": "Batch size dimension of this component", |
| 279 | + "weight_init": "Initialization conditions for synaptic weight (W) values", |
| 280 | + "resist_scale": "Resistance level scaling factor (applied to output of transformation)", |
| 281 | + "p_conn": "Probability of a connection existing (otherwise, it is masked to zero)", |
| 282 | + "eta": "Global learning rate", |
| 283 | + "radius": "Radius parameter to control influence of neighborhood function", |
| 284 | + "distance_function": "Distance function used to compute BMU" |
| 285 | + } |
| 286 | + info = {cls.__name__: properties, |
| 287 | + "compartments": compartment_props, |
| 288 | + "dynamics": "outputs = [W * alpha(bmu)] ;" |
| 289 | + "dW = SOM competitive Hebbian update", |
| 290 | + "hyperparameters": hyperparams} |
| 291 | + return info |
| 292 | + |
| 293 | +# if __name__ == '__main__': |
| 294 | +# from ngcsimlib.context import Context |
| 295 | +# with Context("Bar") as bar: |
| 296 | +# Wab = SOMSynapse("Wab", (2, 3), 4, 4, 1.) |
| 297 | +# print(Wab) |
0 commit comments