Skip to content

Commit 1569e31

Browse files
author
Alexander Ororbia
committed
minor cleanup
1 parent 64eb27d commit 1569e31

4 files changed

Lines changed: 20 additions & 16 deletions

File tree

ngclearn/utils/diffeq/ode_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def _leapfrog(carry, dfq, dt, params):
127127
return new_carry, (new_carry, carry)
128128

129129
@partial(jit, static_argnums=(3, 4))
130-
def leapfrog(t_curr, q_curr, p_curr, dfq, L, step_size, params):
130+
def leapfrog(t_curr, q_curr, p_curr, dfq, L, step_size, params): ## leapfrog estimator step
131131
t = t_curr + 0.
132132
q = q_curr + 0.
133133
p = p_curr + 0.

ngclearn/utils/model_utils.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,8 @@ def d_relu6(x):
458458
"""
459459
# df/dx = 1 if 0<x<6 else 0
460460
# I_x = (z >= a_min) *@ (z <= b_max) //create an indicator function a = 0 b = 6
461-
Ix1 = (x > 0.).astype(jnp.float32) #tf.cast(tf.math.greater_equal(x, 0.0),dtype=tf.float32)
461+
#Ix1 = (x > 0.).astype(jnp.float32) #tf.cast(tf.math.greater_equal(x, 0.0),dtype=tf.float32)
462+
Ix1 = (x >= 0.).astype(jnp.float32)
462463
Ix2 = (x <= 6.).astype(jnp.float32) #tf.cast(tf.math.less_equal(x, 6.0),dtype=tf.float32)
463464
Ix = Ix1 * Ix2
464465
return Ix
@@ -726,34 +727,37 @@ def d_clip(x, min_val, max_val):
726727
return jnp.where((x < min_val) | (x > max_val), 0.0, 1.0)
727728

728729

729-
@partial(jit, static_argnums=[2])
730-
def lkwta(x, group_masks, nWTA=(20,)): ## local k-WTA
730+
@partial(jit, static_argnums=[2, 3])
731+
def lkwta(x, group_masks, nWTA=(20,), clipval=-1.): ## local k-WTA
731732
out = 0.
732733
for g in range(len(group_masks)):
733734
m = group_masks[g]
734-
x_g = kwta(x, m, nWTA[g])
735+
x_g = kwta(x, m, nWTA[g], clipval)
735736
out = x_g + out
736737
return out
737738

738-
@partial(jit, static_argnums=[2])
739-
def d_lkwta(x, group_masks, nWTA=(20,)): ## d(lkwta(x))/dx
739+
@partial(jit, static_argnums=[2, 3])
740+
def d_lkwta(x, group_masks, nWTA=(20,), clipval=-1.): ## d(lkwta(x))/dx
740741
out = 0.
741742
for g in range(len(group_masks)):
742743
m = group_masks[g]
743-
x_g = d_kwta(x, m, nWTA[g])
744+
x_g = d_kwta(x, m, nWTA[g], clipval)
744745
out = x_g + out
745746
return out
746747

747-
@partial(jit, static_argnums=[2])
748-
def kwta(x, m, nWTA=20):
748+
@partial(jit, static_argnums=[2, 3])
749+
def kwta(x, m, nWTA=20, clipval=-1.):
749750
_x = x * m + (1. - m) * (jnp.amin(x) - 1.)
750751
values, indices = lax.top_k(_x, nWTA) # Note: we do not care to sort the indices
751752
kth = jnp.expand_dims(jnp.min(values,axis=1),axis=1) # must do comparison per sample in potential mini-batch
752753
topK = jnp.greater_equal(_x, kth).astype(jnp.float32) # cast booleans to floats
753-
return topK * x
754+
topK = topK * x
755+
if clipval > 0.:
756+
topK = jnp.clip(topK, -clipval, clipval)
757+
return topK
754758

755-
@partial(jit, static_argnums=[2])
756-
def d_kwta(x, m, nWTA=20): ## d(kwta(x))/dx
759+
@partial(jit, static_argnums=[2, 3])
760+
def d_kwta(x, m, nWTA=20, clipval=-1.): ## d(kwta(x))/dx
757761
_x = x * m + (1. - m) * (jnp.amin(x) - 1.)
758762
values, indices = lax.top_k(_x, nWTA) # Note: we do not care to sort the indices
759763
kth = jnp.expand_dims(jnp.min(values,axis=1),axis=1) # must do comparison per sample in potential mini-batch

ngclearn/utils/patch_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def generate_patch_set(x_batch, patch_size=(8, 8), max_patches=50, center=True,
148148
for s in range(_x_batch.shape[0]):
149149
xs = _x_batch[s, :]
150150
xs = xs.reshape(px, py)
151-
patches = extract_patches_2d(xs, patch_size, max_patches=max_patches, random_state=seed, extraction_step=step_size)#, random_state=69)
151+
patches = extract_patches_2d(xs, patch_size, max_patches=max_patches, random_state=seed) #, extraction_step=step_size)#, random_state=69)
152152
patches = np.reshape(patches, (len(patches), -1)) # flatten each patch in set
153153
if s > 0:
154154
p_batch = np.concatenate((p_batch,patches),axis=0)

tests/components/synapses/hebbian/test_BCMSynapse.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ def test_BCMSynapse1():
1919
a = BCMSynapse(
2020
name="a", shape=(1,1), tau_w=40., tau_theta=20., key=subkeys[0]
2121
)
22-
2322
evolve_process = (MethodProcess("evolve_process")
2423
>> a.evolve)
2524

@@ -42,4 +41,5 @@ def test_BCMSynapse1():
4241
# print(truth)
4342
assert_array_equal(a.dWeights.get(), truth)
4443

45-
test_BCMSynapse1()
44+
#test_BCMSynapse1()
45+

0 commit comments

Comments
 (0)