Skip to content

Commit 4c355a8

Browse files
author
Alexander Ororbia
committed
cleaned up graded/patched comps with inner batched_reset formulation
1 parent a9cf886 commit 4c355a8

9 files changed

Lines changed: 69 additions & 31 deletions

File tree

ngclearn/components/neurons/graded/bernoulliErrorCell.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,15 @@ def advance_state(self, dt): ## compute Bernoulli error cell output
105105
self.L.set(jnp.squeeze(L))
106106
self.mask.set(mask)
107107

108+
@compilable
109+
def reset(self): ## reset core components/statistics
110+
self.batched_reset(batch_size=self.batch_size) ## arg = batch_size data-member
108111

109-
# @transition(output_compartments=["dp", "dtarget", "target", "p", "modulator", "L", "mask"])
110112
@compilable
111-
def reset(self): ## reset core components/statistics
112-
_shape = (self.batch_size, self.shape[0])
113+
def batched_reset(self, batch_size):
114+
_shape = (batch_size, self.shape[0])
113115
if len(self.shape) > 1:
114-
_shape = (self.batch_size, self.shape[0], self.shape[1], self.shape[2])
116+
_shape = (batch_size, self.shape[0], self.shape[1], self.shape[2])
115117
restVals = jnp.zeros(_shape) ## "rest"/reset values
116118
dp = restVals
117119
dtarget = restVals
@@ -130,7 +132,6 @@ def reset(self): ## reset core components/statistics
130132
self.L.set(L)
131133
self.mask.set(mask)
132134

133-
134135
@classmethod
135136
def help(cls): ## component help function
136137
properties = {

ngclearn/components/neurons/graded/gaussianErrorCell.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def __init__(self, name, n_units, batch_size=1, sigma=1., shape=None, **kwargs):
4747
sigma_shape = jnp.array(sigma).shape
4848
self.sigma_shape = sigma_shape
4949
self.shape = shape
50+
self.batch_size = batch_size
5051
self.n_units = n_units
5152

5253
## Convolution shape setup
@@ -104,10 +105,12 @@ def advance_state(self, dt): ## compute Gaussian error cell output
104105
self.L.set(jnp.squeeze(L))
105106
self.mask.set(mask)
106107

107-
# @transition(output_compartments=["dmu", "dtarget", "dSigma", "target", "mu", "modulator", "L", "mask"])
108-
# @staticmethod
109108
@compilable
110-
def reset(self, batch_size): ## reset core components/statistics
109+
def reset(self): ## reset core components/statistics
110+
self.batched_reset(batch_size=self.batch_size) ## arg = batch_size data-member
111+
112+
@compilable
113+
def batched_reset(self, batch_size):
111114
_shape = (batch_size, self.shape[0])
112115
if len(self.shape) > 1:
113116
_shape = (batch_size, self.shape[0], self.shape[1], self.shape[2])

ngclearn/components/neurons/graded/laplacianErrorCell.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,16 +100,20 @@ def advance_state(self, dt): ## compute Laplacian error cell output
100100
self.mask.set(mask)
101101

102102
@compilable
103-
def reset(self): ## reset core components/statistics
104-
restVals = jnp.zeros((self.batch_size, self.n_units))
103+
def reset(self): ## reset core components/statistics
104+
self.batched_reset(batch_size=self.batch_size) ## arg = batch_size data-member
105+
106+
@compilable
107+
def batched_reset(self, batch_size):
108+
restVals = jnp.zeros((batch_size, self.n_units))
105109
dshift = restVals
106110
dtarget = restVals
107111
dScale = jnp.zeros(self.scale_shape)
108112
target = restVals
109113
shift = restVals
110114
modulator = shift + 1.
111115
L = 0.
112-
mask = jnp.ones((self.batch_size, self.n_units))
116+
mask = jnp.ones((batch_size, self.n_units))
113117

114118
self.dshift.set(dshift)
115119
self.dtarget.set(dtarget)

ngclearn/components/neurons/graded/leakyNoiseCell.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from jax import numpy as jnp, random, jit
2-
from ngcsimlib.logger import info
2+
#from ngcsimlib.logger import info
3+
from ngcsimlib import deprecate_args
34
from ngclearn import compilable #from ngcsimlib.parser import compilable
45
from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
56
from ngclearn.components.jaxComponent import JaxComponent
@@ -65,6 +66,7 @@ class LeakyNoiseCell(JaxComponent): ## Real-valued, leaky noise cell
6566
leak_scale: degree to which membrane leak should be scaled (Default: 1)
6667
"""
6768

69+
@deprecate_args(sigma_rec="sigma_pre")
6870
def __init__(
6971
self, name, n_units, tau_x, act_fx="relu", integration_type="euler", batch_size=1, sigma_pre=0.1,
7072
sigma_post=0.1, leak_scale=1., shape=None, **kwargs
@@ -127,13 +129,19 @@ def advance_state(self, t, dt):
127129
self.r_prime.set(r_prime)
128130

129131
@compilable
130-
def reset(self):
131-
_shape = (self.batch_size, self.shape[0])
132+
def reset(self): ## reset core components/statistics
133+
self.batched_reset(batch_size=self.batch_size) ## arg = batch_size data-member
134+
135+
@compilable
136+
def batched_reset(self, batch_size):
137+
_shape = (batch_size, self.shape[0])
132138
if len(self.shape) > 1:
133-
_shape = (self.batch_size, self.shape[0], self.shape[1], self.shape[2])
139+
_shape = (batch_size, self.shape[0], self.shape[1], self.shape[2])
134140
restVals = jnp.zeros(_shape)
135-
self.j_input.set(restVals)
136-
self.j_recurrent.set(restVals)
141+
if not self.j_input.targeted:
142+
self.j_input.set(restVals)
143+
if not self.j_recurrent.targeted:
144+
self.j_recurrent.set(restVals)
137145
self.x.set(restVals)
138146
self.r.set(restVals)
139147
self.r_prime.set(restVals)

ngclearn/components/neurons/graded/rateCell.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,11 @@ def advance_state(self, dt):
252252
self.zF.set(zF)
253253

254254
@compilable
255-
def reset(self, batch_size):
255+
def reset(self): ## reset core components/statistics
256+
self.batched_reset(batch_size=self.batch_size) ## arg = batch_size data-member
257+
258+
@compilable
259+
def batched_reset(self, batch_size):
256260
_shape = (batch_size, self.shape[0])
257261
if len(self.shape) > 1:
258262
_shape = (batch_size, self.shape[0], self.shape[1], self.shape[2])

ngclearn/components/synapses/patched/hebbianPatchedSynapse.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -281,20 +281,24 @@ def evolve(self):
281281
self.dBiases.set(dBiases)
282282

283283
@compilable
284-
def reset(self, batch_size):
284+
def reset(self): ## closed, no-batch argument reset
285+
## write reset command to call inner batched_reset command
286+
self.batched_reset(batch_size=self.batch_size) ## arg = batch_size data-member
287+
288+
@compilable
289+
def batched_reset(self, batch_size): ## open, batch argument reset
285290
preVals = jnp.zeros((batch_size, self.shape[0]))
286291
postVals = jnp.zeros((batch_size, self.shape[1]))
287292
# BUG: the self.inputs here does not have the targeted field
288293
# NOTE: Quick workaround is to check if targeted is in the input or not
289-
hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(preVals) # inputs
290-
self.outputs.set(postVals) # outputs
291-
self.post_in.set(postVals) # post_in
292-
self.pre_out.set(preVals) # pre_out
293-
self.pre.set(preVals) # pre
294-
self.post.set(postVals) # post
295-
self.dWeights.set(jnp.zeros(self.shape)) # dW
296-
self.dBiases.set(jnp.zeros(self.shape[1])) # db
297-
294+
hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(preVals) # inputs
295+
self.outputs.set(postVals) # outputs
296+
self.post_in.set(postVals) # post_in
297+
self.pre_out.set(preVals) # pre_out
298+
self.pre.set(preVals) # pre
299+
self.post.set(postVals) # post
300+
self.dWeights.set(jnp.zeros(self.shape)) # dW
301+
self.dBiases.set(jnp.zeros(self.shape[1])) # db
298302

299303
@classmethod
300304
def help(cls): ## component help function

ngclearn/components/synapses/patched/patchedSynapse.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,10 +167,15 @@ def advance_state(self):
167167
self.pre_out.set(pre_out)
168168

169169
@compilable
170-
def reset(self, batch_size):
170+
def reset(self): ## closed, no-batch argument reset
171+
## write reset command to call inner batched_reset command
172+
self.batched_reset(batch_size=self.batch_size) ## arg = batch_size data-member
173+
174+
@compilable
175+
def batched_reset(self, batch_size): ## open, batch argument reset
171176
preVals = jnp.zeros((batch_size, self.shape[0]))
172177
postVals = jnp.zeros((batch_size, self.shape[1]))
173-
178+
174179
# BUG: the self.inputs here does not have the targeted field
175180
# NOTE: Quick workaround is to check if targeted is in the input or not
176181
hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(preVals)

tests/components/neurons/graded/test_leakyNoiseCell.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,14 @@ def test_LeakyNoiseCell1():
1616
dt = 1. # ms
1717
with Context(name) as ctx:
1818
a = LeakyNoiseCell(
19-
name="a", n_units=1, tau_x=50., act_fx="identity", integration_type="euler", batch_size=1, sigma_rec=0.,
19+
name="a",
20+
n_units=1,
21+
tau_x=50.,
22+
act_fx="identity",
23+
integration_type="euler",
24+
batch_size=1,
25+
sigma_pre=0.,
26+
sigma_post=0.,
2027
leak_scale=0.
2128
)
2229
advance_process = (MethodProcess("advance_proc") >> a.advance_state)

tests/components/synapses/patched/test_patchedSynapse.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def test_patchedSynapse():
3636

3737
advance_process = (MethodProcess("advance_proc") >> a.advance_state)
3838
reset_process = (MethodProcess("reset_proc") >> a.reset)
39+
#batch_reset_process = (MethodProcess("batched_reset_proc") >> a.batched_reset)
3940

4041
def clamp_inputs(x):
4142
a.inputs.set(x)
@@ -46,6 +47,7 @@ def clamp_inputs(x):
4647
expected_outputs = (jnp.matmul(inputs_seq, weights) * resist_scale) + biases
4748
outputs_outs = []
4849
reset_process.run()
50+
#batch_reset_process.run(batch_size=batch_size)
4951
clamp_inputs(inputs_seq)
5052
advance_process.run(t=0., dt=dt)
5153
outputs_outs.append(a.outputs.get())

0 commit comments

Comments
 (0)