|
1 | 1 | from jax import numpy as jnp, random, jit |
2 | | -from ngcsimlib.logger import info |
| 2 | +#from ngcsimlib.logger import info |
| 3 | +from ngcsimlib import deprecate_args |
3 | 4 | from ngclearn import compilable #from ngcsimlib.parser import compilable |
4 | 5 | from ngclearn import Compartment #from ngcsimlib.compartment import Compartment |
5 | 6 | from ngclearn.components.jaxComponent import JaxComponent |
@@ -65,6 +66,7 @@ class LeakyNoiseCell(JaxComponent): ## Real-valued, leaky noise cell |
65 | 66 | leak_scale: degree to which membrane leak should be scaled (Default: 1) |
66 | 67 | """ |
67 | 68 |
|
| 69 | + @deprecate_args(sigma_rec="sigma_pre") |
68 | 70 | def __init__( |
69 | 71 | self, name, n_units, tau_x, act_fx="relu", integration_type="euler", batch_size=1, sigma_pre=0.1, |
70 | 72 | sigma_post=0.1, leak_scale=1., shape=None, **kwargs |
@@ -127,13 +129,19 @@ def advance_state(self, t, dt): |
127 | 129 | self.r_prime.set(r_prime) |
128 | 130 |
|
129 | 131 | @compilable |
130 | | - def reset(self): |
131 | | - _shape = (self.batch_size, self.shape[0]) |
| 132 | + def reset(self): ## reset core components/statistics |
| 133 | + self.batched_reset(batch_size=self.batch_size) ## arg = batch_size data-member |
| 134 | + |
| 135 | + @compilable |
| 136 | + def batched_reset(self, batch_size): |
| 137 | + _shape = (batch_size, self.shape[0]) |
132 | 138 | if len(self.shape) > 1: |
133 | | - _shape = (self.batch_size, self.shape[0], self.shape[1], self.shape[2]) |
| 139 | + _shape = (batch_size, self.shape[0], self.shape[1], self.shape[2]) |
134 | 140 | restVals = jnp.zeros(_shape) |
135 | | - self.j_input.set(restVals) |
136 | | - self.j_recurrent.set(restVals) |
| 141 | + if not self.j_input.targeted: |
| 142 | + self.j_input.set(restVals) |
| 143 | + if not self.j_recurrent.targeted: |
| 144 | + self.j_recurrent.set(restVals) |
137 | 145 | self.x.set(restVals) |
138 | 146 | self.r.set(restVals) |
139 | 147 | self.r_prime.set(restVals) |
|
0 commit comments