@@ -80,7 +80,7 @@ def __init__(
8080 active_func = [nn .Tanh (), nn .ReLU (), nn .LeakyReLU (), nn .ELU ()][activation_id ]
8181 init_method = [nn .init .xavier_uniform_ , nn .init .orthogonal_ ][use_orthogonal ]
8282 gain = nn .init .calculate_gain (
83- ["tanh" , "relu" , "leaky_relu" , "leaky_relu " ][activation_id ]
83+ ["tanh" , "relu" , "leaky_relu" , "selu " ][activation_id ]
8484 )
8585
8686 def init_ (m ):
@@ -194,7 +194,7 @@ def __init__(self, split_shape, d_model, use_orthogonal=True, activation_id=1):
194194 active_func = [nn .Tanh (), nn .ReLU (), nn .LeakyReLU (), nn .ELU ()][activation_id ]
195195 init_method = [nn .init .xavier_uniform_ , nn .init .orthogonal_ ][use_orthogonal ]
196196 gain = nn .init .calculate_gain (
197- ["tanh" , "relu" , "leaky_relu" , "leaky_relu " ][activation_id ]
197+ ["tanh" , "relu" , "leaky_relu" , "selu " ][activation_id ]
198198 )
199199
200200 def init_ (m ):
@@ -252,7 +252,7 @@ def __init__(self, split_shape, d_model, use_orthogonal=True, activation_id=1):
252252 active_func = [nn .Tanh (), nn .ReLU (), nn .LeakyReLU (), nn .ELU ()][activation_id ]
253253 init_method = [nn .init .xavier_uniform_ , nn .init .orthogonal_ ][use_orthogonal ]
254254 gain = nn .init .calculate_gain (
255- ["tanh" , "relu" , "leaky_relu" , "leaky_relu " ][activation_id ]
255+ ["tanh" , "relu" , "leaky_relu" , "selu " ][activation_id ]
256256 )
257257
258258 def init_ (m ):
0 commit comments