@@ -234,10 +234,13 @@ def forward(self, x, self_idx=-1):
234234 K = self .split_shape [i ][0 ]
235235 L = self .split_shape [i ][1 ]
236236 for j in range (K ):
237- torch .cat ((x [i ][:, (L * j ) : (L * j + L )], self_x ), dim = - 1 )
238- exec ("x1.append(self.fc_{}(temp))" .format (i ))
239- x [self_idx ]
240- exec ("x1.append(self.fc_{}(temp))" .format (N - 1 ))
237+ # torch.cat((x[i][:, (L * j) : (L * j + L)], self_x), dim=-1)
238+ # exec("x1.append(self.fc_{}(temp))".format(i))
239+ temp = torch .cat ((x [i ][:, (L * j ) : (L * j + L )], self_x ), dim = - 1 )
240+ x1 .append (getattr (self , "fc_" + str (i ))(temp ))
241+ x1 .append (getattr (self , "fc_" + str (N - 1 ))(self_x ))
242+ # x[self_idx]
243+ # exec("x1.append(self.fc_{}(temp))".format(N - 1))
241244
242245 out = torch .stack (x1 , 1 )
243246
@@ -278,8 +281,10 @@ def forward(self, x, self_idx=None):
278281 K = self .split_shape [i ][0 ]
279282 L = self .split_shape [i ][1 ]
280283 for j in range (K ):
281- x [i ][:, (L * j ) : (L * j + L )]
282- exec ("x1.append(self.fc_{}(temp))" .format (i ))
284+ # x[i][:, (L * j) : (L * j + L)]
285+ # exec("x1.append(self.fc_{}(temp))".format(i))
286+ temp = x [i ][:, (L * j ) : (L * j + L )]
287+ x1 .append (getattr (self , "fc_" + str (i ))(temp ))
283288
284289 out = torch .stack (x1 , 1 )
285290
0 commit comments