|
| 1 | +from ngclearn.components.jaxComponent import JaxComponent |
| 2 | +from jax import numpy as jnp, random |
| 3 | +from ngclearn import compilable |
| 4 | +from ngclearn import Compartment |
| 5 | +import jax |
| 6 | +from typing import Union, Tuple |
| 7 | + |
| 8 | + |
| 9 | + |
| 10 | + |
| 11 | +def create_gaussian_filter(patch_shape, sigma): |
| 12 | + """ |
| 13 | + Create a 2D Gaussian kernel centered on patch_shape with given sigma. |
| 14 | + """ |
| 15 | + px, py = patch_shape |
| 16 | + |
| 17 | + x_ = jnp.linspace(0, px - 1, px) |
| 18 | + y_ = jnp.linspace(0, py - 1, py) |
| 19 | + |
| 20 | + x, y = jnp.meshgrid(x_, y_) |
| 21 | + |
| 22 | + xc = px // 2 |
| 23 | + yc = py // 2 |
| 24 | + |
| 25 | + filter = jnp.exp(-((x - xc) ** 2 + (y - yc) ** 2) / (2 * (sigma ** 2))) |
| 26 | + return filter / jnp.sum(filter) |
| 27 | + |
| 28 | + |
| 29 | + |
| 30 | + |
| 31 | + |
| 32 | +def create_dog_filter(patch_shape, sigma, k=1.6, lmbda=1): |
| 33 | + g1 = create_gaussian_filter(patch_shape, sigma=sigma) |
| 34 | + g2 = create_gaussian_filter(patch_shape, sigma=sigma * k) |
| 35 | + |
| 36 | + dog = g1 - lmbda * g2 |
| 37 | + |
| 38 | + return dog #- jnp.mean(dog) |
| 39 | + |
| 40 | + |
| 41 | + |
| 42 | + |
| 43 | + |
| 44 | +def create_patches(obs, patch_shape, step_shape): |
| 45 | + """ |
| 46 | + Extract 2D patches from a batch of images using a sliding window. |
| 47 | +
|
| 48 | + Inputs: |
| 49 | + obs: Input array (B, ix, iy) |
| 50 | + patch_shape: Patch size (px, py) |
| 51 | + step_shape: Stride (sx, sy) -- use 0 for full-overlap |
| 52 | +
|
| 53 | + Output: |
| 54 | + Patches array (B, n_cells, px, py) |
| 55 | + """ |
| 56 | + |
| 57 | + B, ix, iy = obs.shape |
| 58 | + px, py = patch_shape |
| 59 | + sx, sy = step_shape |
| 60 | + |
| 61 | + if sx == 0: |
| 62 | + n_x = ix // px |
| 63 | + else: |
| 64 | + n_x = (ix - px) // sx + 1 |
| 65 | + |
| 66 | + if sy == 0: |
| 67 | + n_y = iy // py |
| 68 | + else: |
| 69 | + n_y = (iy - py) // sy + 1 |
| 70 | + |
| 71 | + patches = jnp.stack([ |
| 72 | + obs[:, |
| 73 | + i * sx:i * sx + px, j * sy:j * sy + py |
| 74 | + ] for i in range(n_x) |
| 75 | + for j in range(n_y) |
| 76 | + ], axis=1) |
| 77 | + |
| 78 | + return patches |
| 79 | + |
| 80 | + |
| 81 | + |
| 82 | + |
| 83 | + |
| 84 | +class RetinalGanglionCell(JaxComponent): |
| 85 | + """ |
| 86 | + A groupd of retinal ganglion cell that senses the input |
| 87 | + stimuli and sends out the filtered signal to the brain. |
| 88 | +
|
| 89 | + | --- Cell Input Compartments: --- |
| 90 | + | inputs - input (takes in external signals) |
| 91 | + | --- Cell State Compartments: --- |
| 92 | + | filter - filter (function applied to input) |
| 93 | + | --- Cell Output Compartments: --- |
| 94 | + | outputs - output |
| 95 | +
|
| 96 | + Args: |
| 97 | + name: the string name of this cell |
| 98 | +
|
| 99 | + filter_type: string name of filter function (Default: identity) |
| 100 | +
|
| 101 | + :Note: supported filters include "gaussian", "difference_of_gaussian" |
| 102 | +
|
| 103 | + sigma: standard deviation of gaussian kernel |
| 104 | +
|
| 105 | + area_shape: receptive field area of ganglion cells in this module all together |
| 106 | +
|
| 107 | + n_cells: number of ganglion cells in this module |
| 108 | +
|
| 109 | + patch_shape: each ganglion cell receptive field area |
| 110 | +
|
| 111 | + step_shape: the non-overlapping area between each two ganglion cells |
| 112 | +
|
| 113 | + batch_size: batch size dimension of this cell (Default: 1) |
| 114 | + """ |
| 115 | + |
| 116 | + def __init__(self, name: str, |
| 117 | + filter_type: str, |
| 118 | + area_shape: Tuple[int, int], |
| 119 | + n_cells: int, |
| 120 | + patch_shape: Tuple[int, int], |
| 121 | + step_shape: Tuple[int, int], |
| 122 | + batch_size: int = 1, |
| 123 | + sigma: float = 1.0, |
| 124 | + key: Union[jax.Array, None] = None, |
| 125 | + **kwargs): |
| 126 | + super().__init__(name=name, key=key) |
| 127 | + |
| 128 | + |
| 129 | + ## Layer Size Setup |
| 130 | + self.filter_type = filter_type |
| 131 | + self.n_cells = n_cells |
| 132 | + self.sigma = sigma |
| 133 | + |
| 134 | + self.batch_size = batch_size |
| 135 | + self.area_shape = area_shape |
| 136 | + self.patch_shape = patch_shape |
| 137 | + self.step_shape = step_shape |
| 138 | + |
| 139 | + filter = jnp.ones(self.patch_shape) |
| 140 | + |
| 141 | + if filter_type == 'gaussian': |
| 142 | + filter = create_gaussian_filter(patch_shape=self.patch_shape, sigma=self.sigma) |
| 143 | + elif filter_type == 'difference_of_gaussian': |
| 144 | + filter = create_dog_filter(patch_shape=self.patch_shape, sigma=sigma) |
| 145 | + |
| 146 | + # ═════════════════ compartments initial values ════════════════════ |
| 147 | + in_restVals = jnp.zeros((self.batch_size, |
| 148 | + *self.area_shape)) ## input: (B | ix | iy) |
| 149 | + |
| 150 | + out_restVals = jnp.zeros((self.batch_size, ## output.shape: (B | n_cells * px * py) |
| 151 | + self.n_cells * self.patch_shape[0] * self.patch_shape[1])) |
| 152 | + |
| 153 | + # ═══════════════════ set compartments ══════════════════════ |
| 154 | + self.inputs = Compartment(in_restVals, display_name="Input Stimulus") # input compartment |
| 155 | + self.filter = Compartment(filter, display_name="Filter") # Filter compartment |
| 156 | + self.outputs = Compartment(out_restVals, display_name="Output Signal") # output compartment |
| 157 | + |
| 158 | + @compilable |
| 159 | + def advance_state(self, t): |
| 160 | + inputs = self.inputs.get() |
| 161 | + filter = self.filter.get() |
| 162 | + px, py = self.patch_shape |
| 163 | + |
| 164 | + # ═══════════════════ extract pathches for filters ══════════════════ |
| 165 | + input_patches = create_patches(inputs, patch_shape=self.patch_shape, step_shape=self.step_shape) |
| 166 | + |
| 167 | + # ═══════════════════ apply filter to all pathches ══════════════════ |
| 168 | + filtered_input = input_patches * filter ## shape: (B | n_cells | px | py) |
| 169 | + |
| 170 | + # ════════════ reshape all cells responses to a single input to brain ════════════ |
| 171 | + filtered_input = filtered_input.reshape(-1, self.n_cells * (px * py)) ## shape: (B | n_cells * px * py) |
| 172 | + |
| 173 | + # ═══════════════════ normalize filtered signals ══════════════════ |
| 174 | + outputs = filtered_input - jnp.mean(filtered_input, axis=1, keepdims=True) ## shape: (B | n_cells * px * py) |
| 175 | + |
| 176 | + self.outputs.set(outputs) |
| 177 | + |
| 178 | + @compilable |
| 179 | + def reset(self): |
| 180 | + in_restVals = jnp.zeros((self.batch_size, |
| 181 | + *self.area_shape)) ## input: (B | ix | iy) |
| 182 | + |
| 183 | + out_restVals = jnp.zeros((self.batch_size, ## output.shape: (B | n_cells * px * py) |
| 184 | + self.n_cells * self.patch_shape[0] * self.patch_shape[1])) |
| 185 | + |
| 186 | + self.inputs.set(in_restVals) |
| 187 | + self.outputs.set(out_restVals) |
| 188 | + |
| 189 | + @classmethod |
| 190 | + def help(cls): ## component help function |
| 191 | + properties = { |
| 192 | + "cell_type": "RetinalGanglionCell - filters the input stimuli, " |
| 193 | + } |
| 194 | + compartment_props = { |
| 195 | + "inputs": |
| 196 | + {"inputs": "Takes in external input signal values"}, |
| 197 | + "states": |
| 198 | + {"filter": "Preprocessing function applies to input)"}, |
| 199 | + "outputs": |
| 200 | + {"outputs": "Preprocessed signal values emitted at time t"}, |
| 201 | + } |
| 202 | + hyperparams = { |
| 203 | + "filter_type": "Type of the filter for preprocessing the input", |
| 204 | + "sigma": "Standard deviation of gaussian kernel", |
| 205 | + "area_shape": "Effective receptive field area shape of ganglion cells in this module", |
| 206 | + "n_cells": "Number of Retinal Ganglion (center-surround) cells to model in this layer", |
| 207 | + "patch_shape": "Classical Receptive field area shape of individual ganglion cells in this module", |
| 208 | + "step_shape": "Extra-Classical Receptive field area shape each ganglion cell in this module", |
| 209 | + "batch_size": "Batch size dimension of this component" |
| 210 | + } |
| 211 | + info = {cls.__name__: properties, |
| 212 | + "compartments": compartment_props, |
| 213 | + "dynamics": "~ Gaussian(x)", |
| 214 | + "hyperparameters": hyperparams} |
| 215 | + return info |
| 216 | + |
| 217 | +if __name__ == '__main__': |
| 218 | + from ngcsimlib.context import Context |
| 219 | + with Context("Bar") as bar: |
| 220 | + X = RetinalGanglionCell("RGC", filter_type="gaussian", |
| 221 | + sigma=2.3, |
| 222 | + area_shape=(16, 26), |
| 223 | + n_cells = 3, |
| 224 | + patch_shape=(16, 16), |
| 225 | + step_shape=(0, 5) |
| 226 | + ) |
| 227 | + print(X) |
| 228 | + |
| 229 | + |
| 230 | + |
| 231 | + |
0 commit comments