Skip to content

Commit eb89be5

Browse files
authored
Add retinal ganglion cell input encoder (#137)
* Add RetinalGanglionCell component with filtering methods Implement RetinalGanglionCell with Gaussian filtering and patch extraction. * Add RetinalGanglionCell import to input_encoders * Add RetinalGanglionCell to input encoders * Enhance filter functions in ganglionCell.py Refactor Gaussian filter creation and add Difference of Gaussian filter functionality.
1 parent 1ecadf8 commit eb89be5

3 files changed

Lines changed: 234 additions & 0 deletions

File tree

ngclearn/components/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
## point to input encoder component types
2828
from .input_encoders.bernoulliCell import BernoulliCell
2929
from .input_encoders.poissonCell import PoissonCell
30+
from .input_encoders.ganglionCell import RetinalGanglionCell
3031
from .input_encoders.latencyCell import LatencyCell
3132
from .input_encoders.phasorCell import PhasorCell
3233

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from .bernoulliCell import BernoulliCell
22
from .poissonCell import PoissonCell
33
from .latencyCell import LatencyCell
4+
from .ganglionCell import RetinalGanglionCell
45
from .phasorCell import PhasorCell
56

7+
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
from ngclearn.components.jaxComponent import JaxComponent
2+
from jax import numpy as jnp, random
3+
from ngclearn import compilable
4+
from ngclearn import Compartment
5+
import jax
6+
from typing import Union, Tuple
7+
8+
9+
10+
11+
def create_gaussian_filter(patch_shape, sigma):
12+
"""
13+
Create a 2D Gaussian kernel centered on patch_shape with given sigma.
14+
"""
15+
px, py = patch_shape
16+
17+
x_ = jnp.linspace(0, px - 1, px)
18+
y_ = jnp.linspace(0, py - 1, py)
19+
20+
x, y = jnp.meshgrid(x_, y_)
21+
22+
xc = px // 2
23+
yc = py // 2
24+
25+
filter = jnp.exp(-((x - xc) ** 2 + (y - yc) ** 2) / (2 * (sigma ** 2)))
26+
return filter / jnp.sum(filter)
27+
28+
29+
30+
31+
32+
def create_dog_filter(patch_shape, sigma, k=1.6, lmbda=1):
33+
g1 = create_gaussian_filter(patch_shape, sigma=sigma)
34+
g2 = create_gaussian_filter(patch_shape, sigma=sigma * k)
35+
36+
dog = g1 - lmbda * g2
37+
38+
return dog #- jnp.mean(dog)
39+
40+
41+
42+
43+
44+
def create_patches(obs, patch_shape, step_shape):
45+
"""
46+
Extract 2D patches from a batch of images using a sliding window.
47+
48+
Inputs:
49+
obs: Input array (B, ix, iy)
50+
patch_shape: Patch size (px, py)
51+
step_shape: Stride (sx, sy) -- use 0 for full-overlap
52+
53+
Output:
54+
Patches array (B, n_cells, px, py)
55+
"""
56+
57+
B, ix, iy = obs.shape
58+
px, py = patch_shape
59+
sx, sy = step_shape
60+
61+
if sx == 0:
62+
n_x = ix // px
63+
else:
64+
n_x = (ix - px) // sx + 1
65+
66+
if sy == 0:
67+
n_y = iy // py
68+
else:
69+
n_y = (iy - py) // sy + 1
70+
71+
patches = jnp.stack([
72+
obs[:,
73+
i * sx:i * sx + px, j * sy:j * sy + py
74+
] for i in range(n_x)
75+
for j in range(n_y)
76+
], axis=1)
77+
78+
return patches
79+
80+
81+
82+
83+
84+
class RetinalGanglionCell(JaxComponent):
85+
"""
86+
A groupd of retinal ganglion cell that senses the input
87+
stimuli and sends out the filtered signal to the brain.
88+
89+
| --- Cell Input Compartments: ---
90+
| inputs - input (takes in external signals)
91+
| --- Cell State Compartments: ---
92+
| filter - filter (function applied to input)
93+
| --- Cell Output Compartments: ---
94+
| outputs - output
95+
96+
Args:
97+
name: the string name of this cell
98+
99+
filter_type: string name of filter function (Default: identity)
100+
101+
:Note: supported filters include "gaussian", "difference_of_gaussian"
102+
103+
sigma: standard deviation of gaussian kernel
104+
105+
area_shape: receptive field area of ganglion cells in this module all together
106+
107+
n_cells: number of ganglion cells in this module
108+
109+
patch_shape: each ganglion cell receptive field area
110+
111+
step_shape: the non-overlapping area between each two ganglion cells
112+
113+
batch_size: batch size dimension of this cell (Default: 1)
114+
"""
115+
116+
def __init__(self, name: str,
117+
filter_type: str,
118+
area_shape: Tuple[int, int],
119+
n_cells: int,
120+
patch_shape: Tuple[int, int],
121+
step_shape: Tuple[int, int],
122+
batch_size: int = 1,
123+
sigma: float = 1.0,
124+
key: Union[jax.Array, None] = None,
125+
**kwargs):
126+
super().__init__(name=name, key=key)
127+
128+
129+
## Layer Size Setup
130+
self.filter_type = filter_type
131+
self.n_cells = n_cells
132+
self.sigma = sigma
133+
134+
self.batch_size = batch_size
135+
self.area_shape = area_shape
136+
self.patch_shape = patch_shape
137+
self.step_shape = step_shape
138+
139+
filter = jnp.ones(self.patch_shape)
140+
141+
if filter_type == 'gaussian':
142+
filter = create_gaussian_filter(patch_shape=self.patch_shape, sigma=self.sigma)
143+
elif filter_type == 'difference_of_gaussian':
144+
filter = create_dog_filter(patch_shape=self.patch_shape, sigma=sigma)
145+
146+
# ═════════════════ compartments initial values ════════════════════
147+
in_restVals = jnp.zeros((self.batch_size,
148+
*self.area_shape)) ## input: (B | ix | iy)
149+
150+
out_restVals = jnp.zeros((self.batch_size, ## output.shape: (B | n_cells * px * py)
151+
self.n_cells * self.patch_shape[0] * self.patch_shape[1]))
152+
153+
# ═══════════════════ set compartments ══════════════════════
154+
self.inputs = Compartment(in_restVals, display_name="Input Stimulus") # input compartment
155+
self.filter = Compartment(filter, display_name="Filter") # Filter compartment
156+
self.outputs = Compartment(out_restVals, display_name="Output Signal") # output compartment
157+
158+
@compilable
159+
def advance_state(self, t):
160+
inputs = self.inputs.get()
161+
filter = self.filter.get()
162+
px, py = self.patch_shape
163+
164+
# ═══════════════════ extract pathches for filters ══════════════════
165+
input_patches = create_patches(inputs, patch_shape=self.patch_shape, step_shape=self.step_shape)
166+
167+
# ═══════════════════ apply filter to all pathches ══════════════════
168+
filtered_input = input_patches * filter ## shape: (B | n_cells | px | py)
169+
170+
# ════════════ reshape all cells responses to a single input to brain ════════════
171+
filtered_input = filtered_input.reshape(-1, self.n_cells * (px * py)) ## shape: (B | n_cells * px * py)
172+
173+
# ═══════════════════ normalize filtered signals ══════════════════
174+
outputs = filtered_input - jnp.mean(filtered_input, axis=1, keepdims=True) ## shape: (B | n_cells * px * py)
175+
176+
self.outputs.set(outputs)
177+
178+
@compilable
179+
def reset(self):
180+
in_restVals = jnp.zeros((self.batch_size,
181+
*self.area_shape)) ## input: (B | ix | iy)
182+
183+
out_restVals = jnp.zeros((self.batch_size, ## output.shape: (B | n_cells * px * py)
184+
self.n_cells * self.patch_shape[0] * self.patch_shape[1]))
185+
186+
self.inputs.set(in_restVals)
187+
self.outputs.set(out_restVals)
188+
189+
@classmethod
190+
def help(cls): ## component help function
191+
properties = {
192+
"cell_type": "RetinalGanglionCell - filters the input stimuli, "
193+
}
194+
compartment_props = {
195+
"inputs":
196+
{"inputs": "Takes in external input signal values"},
197+
"states":
198+
{"filter": "Preprocessing function applies to input)"},
199+
"outputs":
200+
{"outputs": "Preprocessed signal values emitted at time t"},
201+
}
202+
hyperparams = {
203+
"filter_type": "Type of the filter for preprocessing the input",
204+
"sigma": "Standard deviation of gaussian kernel",
205+
"area_shape": "Effective receptive field area shape of ganglion cells in this module",
206+
"n_cells": "Number of Retinal Ganglion (center-surround) cells to model in this layer",
207+
"patch_shape": "Classical Receptive field area shape of individual ganglion cells in this module",
208+
"step_shape": "Extra-Classical Receptive field area shape each ganglion cell in this module",
209+
"batch_size": "Batch size dimension of this component"
210+
}
211+
info = {cls.__name__: properties,
212+
"compartments": compartment_props,
213+
"dynamics": "~ Gaussian(x)",
214+
"hyperparameters": hyperparams}
215+
return info
216+
217+
if __name__ == '__main__':
218+
from ngcsimlib.context import Context
219+
with Context("Bar") as bar:
220+
X = RetinalGanglionCell("RGC", filter_type="gaussian",
221+
sigma=2.3,
222+
area_shape=(16, 26),
223+
n_cells = 3,
224+
patch_shape=(16, 16),
225+
step_shape=(0, 5)
226+
)
227+
print(X)
228+
229+
230+
231+

0 commit comments

Comments
 (0)