@@ -458,7 +458,8 @@ def d_relu6(x):
458458 """
459459 # df/dx = 1 if 0<x<6 else 0
460460 # I_x = (z >= a_min) *@ (z <= b_max) //create an indicator function a = 0 b = 6
461- Ix1 = (x > 0. ).astype (jnp .float32 ) #tf.cast(tf.math.greater_equal(x, 0.0),dtype=tf.float32)
461+ #Ix1 = (x > 0.).astype(jnp.float32) #tf.cast(tf.math.greater_equal(x, 0.0),dtype=tf.float32)
462+ Ix1 = (x >= 0. ).astype (jnp .float32 )
462463 Ix2 = (x <= 6. ).astype (jnp .float32 ) #tf.cast(tf.math.less_equal(x, 6.0),dtype=tf.float32)
463464 Ix = Ix1 * Ix2
464465 return Ix
@@ -726,34 +727,37 @@ def d_clip(x, min_val, max_val):
726727 return jnp .where ((x < min_val ) | (x > max_val ), 0.0 , 1.0 )
727728
728729
729- @partial (jit , static_argnums = [2 ])
730- def lkwta (x , group_masks , nWTA = (20 ,)): ## local k-WTA
730+ @partial (jit , static_argnums = [2 , 3 ])
731+ def lkwta (x , group_masks , nWTA = (20 ,), clipval = - 1. ): ## local k-WTA
731732 out = 0.
732733 for g in range (len (group_masks )):
733734 m = group_masks [g ]
734- x_g = kwta (x , m , nWTA [g ])
735+ x_g = kwta (x , m , nWTA [g ], clipval )
735736 out = x_g + out
736737 return out
737738
738- @partial (jit , static_argnums = [2 ])
739- def d_lkwta (x , group_masks , nWTA = (20 ,)): ## d(lkwta(x))/dx
739+ @partial (jit , static_argnums = [2 , 3 ])
740+ def d_lkwta (x , group_masks , nWTA = (20 ,), clipval = - 1. ): ## d(lkwta(x))/dx
740741 out = 0.
741742 for g in range (len (group_masks )):
742743 m = group_masks [g ]
743- x_g = d_kwta (x , m , nWTA [g ])
744+ x_g = d_kwta (x , m , nWTA [g ], clipval )
744745 out = x_g + out
745746 return out
746747
747- @partial (jit , static_argnums = [2 ])
748- def kwta (x , m , nWTA = 20 ):
748+ @partial (jit , static_argnums = [2 , 3 ])
749+ def kwta (x , m , nWTA = 20 , clipval = - 1. ):
749750 _x = x * m + (1. - m ) * (jnp .amin (x ) - 1. )
750751 values , indices = lax .top_k (_x , nWTA ) # Note: we do not care to sort the indices
751752 kth = jnp .expand_dims (jnp .min (values ,axis = 1 ),axis = 1 ) # must do comparison per sample in potential mini-batch
752753 topK = jnp .greater_equal (_x , kth ).astype (jnp .float32 ) # cast booleans to floats
753- return topK * x
754+ topK = topK * x
755+ if clipval > 0. :
756+ topK = jnp .clip (topK , - clipval , clipval )
757+ return topK
754758
755- @partial (jit , static_argnums = [2 ])
756- def d_kwta (x , m , nWTA = 20 ): ## d(kwta(x))/dx
759+ @partial (jit , static_argnums = [2 , 3 ])
760+ def d_kwta (x , m , nWTA = 20 , clipval = - 1. ): ## d(kwta(x))/dx
757761 _x = x * m + (1. - m ) * (jnp .amin (x ) - 1. )
758762 values , indices = lax .top_k (_x , nWTA ) # Note: we do not care to sort the indices
759763 kth = jnp .expand_dims (jnp .min (values ,axis = 1 ),axis = 1 ) # must do comparison per sample in potential mini-batch
0 commit comments