Skip to content

Commit dbd1029

Browse files
author
Alexander Ororbia
committed
added working hopfield-syn/modern-hopfield-syn
1 parent 1569e31 commit dbd1029

3 files changed

Lines changed: 234 additions & 2 deletions

File tree

ngclearn/components/synapses/competitive/SOMSynapse.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,9 @@ def __init__(
113113
**kwargs
114114
):
115115
shape = (n_inputs, n_units_x * n_units_y)
116-
super().__init__(name, shape, weight_init, None, resist_scale, p_conn, batch_size=batch_size, **kwargs)
116+
super().__init__(
117+
name, shape, weight_init, None, resist_scale, p_conn, batch_size=batch_size, **kwargs
118+
)
117119

118120
### build (rectangular) topology coordinates
119121
coords = []
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from .SOMSynapse import SOMSynapse
2-
2+
from .hopfieldSynapse import HopfieldSynapse
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
from jax import random, numpy as jnp, jit, vmap
2+
from ngclearn import compilable #from ngcsimlib.parser import compilable
3+
from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
4+
from ngclearn.utils.model_utils import softmax, bkwta
5+
6+
from ngclearn.components.synapses.denseSynapse import DenseSynapse
7+
8+
class HopfieldSynapse(DenseSynapse): # (Modern) Hopfield synaptic cable
9+
"""
10+
A synaptic cable that emulates a modern Hopfield network (MHN). Note that this model has been generalized a bit,
11+
a.l.a. NAC-Lab style, and comes equipped with two non-standard local plasticity update rules to alter the
12+
underlying memory matrix W from scratch (or to fine-tune an existing preloaded one); note that a mixed
13+
MHN can be created (one where initial patterns are stored but portions / elements of the
14+
memory matrix are further adapted in accordance to a local adjustment rule). This model currently only implements
15+
the exponential coupling/energy function.
16+
17+
| --- Synapse Compartments: ---
18+
| inputs - input probe (takes in external signals)
19+
| outputs - output signals (retrieved memory / updated probe)
20+
| weights - current value matrix of synaptic efficacies
21+
| similarities - current raw similarity scores computed (pre-softmax)
22+
| memory_weights - current similarity scores computed (post-softmax)
23+
| i_tick - current internal tick / marker (gets incremented by 1 for each call to `evolve`)
24+
| energy - current energy functional reading (given current clamped input probe)
25+
| key - JAX PRNG key
26+
| --- Synaptic Plasticity Compartments: ---
27+
| dWeights - current delta matrix containing changes to be applied to synapses
28+
29+
| References:
30+
| Movellan, Javier R. "Contrastive Hebbian learning in the continuous Hopfield model." Connectionist models.
31+
| Morgan Kaufmann, 1991. 10-17.
32+
|
33+
| Krotov, Dmitry, and John Hopfield. "Large associative memory problem in neurobiology and machine learning."
34+
| arXiv preprint arXiv:2008.06996 (2020).
35+
|
36+
| Hintzman, Douglas L. "MINERVA 2: A simulation model of human memory." Behavior Research Methods, Instruments,
37+
| & Computers 16.2 (1984): 96-101.
38+
39+
Args:
40+
name: the string name of this cell
41+
42+
shape: tuple specifying shape of this synaptic cable (usually a 2-tuple
43+
with number of inputs by number of outputs)
44+
45+
eta: (initial) learning rate / step-size for this SOM (initial condition value for `eta`)
46+
47+
reg_lambda: weight decay coefficient applied to Hebbian update
48+
49+
beta: (inverse) temperature to control sharpness of memory similarity calculation
50+
51+
initial_patterns: seed patterns to store within memory matrix (Default: None)
52+
53+
update_rule: local plasticity rule to use to adjust/update memory matrix (Default: "delta";
54+
Currently, two rules are encoded that work - a custom delta rule (prescribed error rule) and
55+
a custom contrastive Hebbian rule (Movellan/NAC-Lab-style)
56+
57+
weight_init: a kernel to drive initialization of this synaptic cable's values;
58+
typically a tuple with 1st element as a string calling the name of
59+
initialization to use
60+
61+
resist_scale: a fixed scaling factor to apply to synaptic transform
62+
(Default: 1.), i.e., yields: out = ((W * Rscale) * in)
63+
64+
p_conn: probability of a connection existing (default: 1.); setting
65+
this to < 1. will result in a sparser synaptic structure
66+
"""
67+
68+
def __init__(
69+
self,
70+
name,
71+
shape,
72+
eta,
73+
reg_lambda=0.,
74+
beta=8.,
75+
initial_patterns=None,
76+
update_rule = "delta", ## memory plasticity rule
77+
weight_init=None,
78+
resist_scale=1.,
79+
p_conn=1.,
80+
batch_size=1,
81+
**kwargs
82+
):
83+
super().__init__(
84+
name, shape, weight_init, None, resist_scale, p_conn, batch_size=batch_size, **kwargs
85+
)
86+
87+
### Synapse and Hopfield hyper-parameters
88+
self.eta = eta
89+
self.reg_lambda = reg_lambda #0.0001 ## regularization co-efficient
90+
self.l1_lambda = 0. #0.0001 ## coefficient for L1 decay
91+
self.beta = beta
92+
if initial_patterns is not None: ## preload memory synaptic matrix
93+
W = self.weights.get()
94+
D, H = W.shape
95+
tmp_key, *subkeys = random.split(self.key.get(), 3)
96+
if initial_patterns.shape[1] < H: ## randomly portions of memory with stored patterns/templates
97+
ptrs = random.permutation(subkeys[0], H)
98+
W = jnp.concat([initial_patterns, W[:, 0:(H - initial_patterns.shape[1])]], axis=1)
99+
W = W[:, ptrs] ## shuffle memories
100+
self.weights.set(W)
101+
else: ## memory is exactly the set of stored patterns/templates
102+
self.weights.set(initial_patterns)
103+
self.rule_fx = 2 ## Default: delta-rule
104+
if update_rule == "contrastive":
105+
self.rule_fx = 1
106+
107+
## Hopfield Compartment setup
108+
inputVals = jnp.zeros((self.batch_size, shape[0]))
109+
simVals = jnp.zeros((self.batch_size, shape[0]))
110+
self.inputs = Compartment(inputVals) ## input shape = output shape
111+
self.outputs = Compartment(inputVals) ## output shape = input shape
112+
self.similarities = Compartment(simVals) ## "hidden layer"
113+
self.memory_weights = Compartment(simVals)
114+
115+
self.energy = Compartment(jnp.zeros((1, 1)), display_name="Energy")
116+
self.i_tick = Compartment(jnp.zeros((1, 1)))
117+
self.dWeights = Compartment(self.weights.get() * 0)
118+
119+
@compilable
120+
def advance_state(self): ## forward-inference step of SOM
121+
WX = self.weights.get()
122+
probe_t = self.inputs.get()
123+
124+
## TODO: what about power/quadratic functions instead? (integrate Minerva power coupling)
125+
sims = jnp.matmul(probe_t, WX) ## similarities (w/ xn as probe)
126+
sims_max = jnp.max(sims, axis=1, keepdims=True)
127+
sims = sims - sims_max
128+
self.similarities.set(sims) ## similarities = "hidden layer"
129+
memory_weights = softmax(sims * self.beta)
130+
self.memory_weights.set(memory_weights)
131+
z = memory_weights
132+
probe_tp1 = jnp.matmul(z, WX.T) ## calc probe update
133+
self.outputs.set(probe_tp1)
134+
135+
## Calculate (modern) Hopfield energy functional
136+
N = WX.shape[1] ## how many neural memories are there
137+
max_sim_value = jnp.max(self.beta * sims, axis=1, keepdims=True)
138+
lse = max_sim_value + jnp.log(jnp.sum(jnp.exp(self.beta * sims - max_sim_value), axis=1, keepdims=True))
139+
term1 = -(1. / self.beta) * lse
140+
term2 = 0.5 * jnp.expand_dims(jnp.diag(jnp.matmul(probe_t, probe_tp1.T)),axis=1)
141+
term3 = (1. / self.beta) * jnp.log(N) + 0.5 * jnp.max(jnp.linalg.norm(WX, ord=2, axis=1) ** 2) ## C
142+
Ex = jnp.mean(term1 + term2 + term3, axis=0, keepdims=True) #* (1. / probe_t.shape[0]) ## calc batch avg energy
143+
self.energy.set(Ex)
144+
145+
self.i_tick.set(self.i_tick.get() + 1.) ## march internal tick forward
146+
147+
@compilable
148+
def evolve(self, t, dt): ## plasticity rule for changing this Hopfield network's memory matrix
149+
x = self.inputs.get()
150+
x_hat = self.outputs.get()
151+
s = self.memory_weights.get()
152+
W = self.weights.get()
153+
beta = self.beta
154+
155+
## TODO: make updates noisy? (perturbative)
156+
## TODO: also, make a perturbation-based update synapse?
157+
if self.rule_fx == 1: ## contrastive (Movellan) Hebbian style plasticity
158+
## TODO: add a loop to iterative over negative term several times
159+
## we propagate the updated probe (negative) through memory to get a negative weighted state
160+
sims_hat = jnp.matmul(x_hat, W)
161+
s_hat = softmax(sims_hat - jnp.max(sims_hat, axis=1, keepdims=True) * beta) #s_hat = bkwta(s_hat, nWTA=1)
162+
## positive Hebbian prod of probe+pos-state against negative Hebbian prod of updated-probe+neg-state
163+
term1 = (x.T @ s)
164+
term2 = -(x_hat.T @ s_hat)
165+
dW = term1 + term2
166+
#elif self.rule_fx == XX: ## deriv of energy w.r.t. memory W rule
167+
# dW = x.T @ -s
168+
else: ## delta-rule (prescribed error rule) is the default
169+
dW = (x - x_hat).T @ s ## (deriv of MSE w.r.t. x_hat/updated probe)
170+
Ns = x.shape[0] ## get batch size
171+
dW = dW * (1./ Ns) ## we average batch updates
172+
173+
## TODO: add a term that checks if we need to append to memory W
174+
W = W + dW * self.eta - W * self.reg_lambda - jnp.sign(W) * self.l1_lambda ## actually adjust synaptic efficacies
175+
176+
self.dWeights.set(dW)
177+
self.weights.set(W)
178+
179+
@compilable
180+
def reset(self):
181+
inputVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0]))
182+
outputVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1]))
183+
184+
if not self.inputs.targeted:
185+
self.inputs.set(inputVals)
186+
self.outputs.set(inputVals)
187+
self.similarities.set(outputVals)
188+
self.memory_weights.set(outputVals)
189+
self.energy.set(self.energy.get() * 0)
190+
self.dWeights.set(jnp.zeros(self.shape.get()))
191+
192+
@classmethod
193+
def help(cls): ## component help function
194+
properties = {
195+
"synapse_type": "HopfieldSynapse - performs an adaptable synaptic transformation of inputs to produce output "
196+
"signals; synapses are adjusted via Hebbian learning in accordance with a Hopfield network"
197+
}
198+
compartment_props = {
199+
"input_compartments":
200+
{"inputs": "Takes in external input signal values",
201+
"key": "JAX PRNG key"},
202+
"parameter_compartments":
203+
{"weights": "Synapse efficacy/strength parameter values"},
204+
"output_compartments":
205+
{"outputs": "Output of synaptic transformation"}
206+
}
207+
hyperparams = {
208+
"shape": "Shape of synaptic weight value matrix; number inputs x number outputs",
209+
"batch_size": "Batch size dimension of this component",
210+
"weight_init": "Initialization conditions for synaptic weight (W) values",
211+
"resist_scale": "Resistance level scaling factor (applied to output of transformation)",
212+
"p_conn": "Probability of a connection existing (otherwise, it is masked to zero)",
213+
"eta": "Global learning rate (to control update to memory matrix)",
214+
"beta": "Inverse temperature (controls softmax sharpness",
215+
"reg_lambda": "Weight decay coefficient to apply to local memory matrix updates",
216+
"update_rule": "What type of rule to use to update memory matrix (locally)",
217+
"initial_patterns": "Matrix containing a series of concatenated vectors to store into memory explicitly",
218+
}
219+
info = {cls.__name__: properties,
220+
"compartments": compartment_props,
221+
"dynamics": "outputs = Hopfield memory retrieval ;"
222+
"dW = Hopfield Hebbian update",
223+
"hyperparameters": hyperparams}
224+
return info
225+
226+
# if __name__ == '__main__':
227+
# from ngcsimlib.context import Context
228+
# with Context("Bar") as bar:
229+
# Wab = HopfieldSynapse("Wab", (2, 3), 4, 4, 1.)
230+
# print(Wab)

0 commit comments

Comments
 (0)