@@ -179,28 +179,23 @@ def _calc_bmu(self): ## obtain index of best-matching unit (BMU)
179179 return bmu_idx , delta
180180
181181 def _calc_neighborhood_weights (self ): ## neighborhood function
182- bmu = self .bmu .get () ## get best-matching unit
183- bmu = bmu [0 , 0 ]
182+ bmu = self .bmu .get ()[0 , 0 ] ## get best-matching unit
184183 coords = self .coords ## constant coordinate array
185- radius = self .radius .get ()
184+ radius = self .radius .get () ## get current neighborhood radius value
186185 coord_bmu = coords [bmu :bmu + 1 , :] ## TODO: might need to one-hot mask + sum
187- delta = coords - coord_bmu ## raw differences (delta)
186+ delta = coords - coord_bmu ## raw coordinate differences (delta)
188187
188+ ### neighborhood-weighting computation note:
189+ ### internally, calculation of neighborhood weighting depends on 1st calculating
190+ ### L2 distance in Cartesian coordinate-space, then applying the neighborhood
191+ ### over these coordinate distance values
189192 bmu_dist = jnp .linalg .norm (delta , axis = 1 , keepdims = True )
190- if self .neighbor_fx == 1 :
193+ if self .neighbor_fx == 1 : ## apply Mexican-hat kernel
191194 neighbor_weights = _ricker_marr_kernel (bmu_dist , sigma = radius )
192- else :
195+ else : ## apply Gaussian kernel
193196 neighbor_weights = _gaussian_kernel (bmu_dist , sigma = radius )
197+ ## TODO: add in triangular, bubble, & laplacian kernels
194198
195- '''
196- ## calc distance values
197- bmu_distance = jnp.sqrt(jnp.sum(jnp.square(delta), axis=1, keepdims=True))
198- ## apply kernel weighting function (below); e.g., Gaussian, Mexican-hat, triangular, etc.
199- if self.neighbor_fx == 1:
200- neighbor_weights = _ricker_marr_kernel(bmu_distance, sigma=radius)
201- else:
202- neighbor_weights = _gaussian_kernel(bmu_distance, sigma=radius)
203- '''
204199 return neighbor_weights .T ## transpose to (1 x n_units)
205200
206201 @compilable
0 commit comments