3333 Conv2DTranspose ,
3434 UpSampling2D ,
3535 Reshape ,
36- LeakyReLU
36+ LeakyReLU ,
3737)
3838from tensorflow .keras .models import Sequential
3939
@@ -56,7 +56,10 @@ def posenet_mouse(input_shape, num_classes):
5656 model
5757 """
5858 recognition_model = Xception (
59- include_top = False , input_shape = input_shape , pooling = "avg" , weights = "imagenet" ,
59+ include_top = False ,
60+ input_shape = input_shape ,
61+ pooling = "avg" ,
62+ weights = "imagenet" ,
6063 )
6164
6265 new_input = Input (
@@ -174,7 +177,10 @@ def posenet_primate(input_shape, num_classes): # recognition_model = DenseNet20
174177 num_classes:Number of classes for recognition or number of landmarks.
175178 """
176179 recognition_model = ResNet101 (
177- include_top = False , input_shape = input_shape , pooling = "avg" , weights = "imagenet" ,
180+ include_top = False ,
181+ input_shape = input_shape ,
182+ pooling = "avg" ,
183+ weights = "imagenet" ,
178184 )
179185
180186 new_input = Input (
@@ -532,9 +538,20 @@ def classification_small(input_shape, num_classes):
532538
533539def dlc_model_sturman (input_shape , num_classes ):
534540 """Model that implements behavioral classification based on Deeplabcut generated features as in Sturman et al.
535- Args:
536- input_shape:
537- num_classes:Number of behaviors to classify.
541+
542+ Reimplementation of the model used in the publication Sturman et al. that performs action recognition on top of pose estimation
543+
544+ Parameters
545+ ----------
546+ input_shape : keras compatible input shape (W,H,Channels)
547+ keras compatible input shape (features,)
548+ num_classes : int
549+ Number of behaviors to classify.
550+
551+ Returns
552+ -------
553+ keras.model
554+ Sturman et al. model
538555 """
539556 model = Sequential ()
540557
@@ -558,10 +575,21 @@ def dlc_model_sturman(input_shape, num_classes):
558575
559576
560577def dlc_model (input_shape , num_classes ):
561- """
562- Args:
563- input_shape:
564- num_classes:
578+ """Model for classification on top of pose estimation.
579+
580+ Classification model for behavior, operating on pose estimation. This model has more free parameters than Sturman et al.
581+
582+ Parameters
583+ ----------
584+ input_shape : keras compatible input shape (W,H,Channels)
585+ keras compatible input shape (features,)
586+ num_classes : int
587+ Number of behaviors to classify.
588+
589+ Returns
590+ -------
591+ keras.model
592+ behavior (from pose estimates) model
565593 """
566594 dropout = 0.3
567595
@@ -644,16 +672,19 @@ def recurrent_model_old(
644672
645673
646674def recurrent_model_tcn (
647- recognition_model , recurrent_input_shape , classes = 4 ,
675+ recognition_model ,
676+ recurrent_input_shape ,
677+ classes = 4 ,
648678):
649- """BehaviorNet architecture for behavioral classification based on temporal convolution architecture (TCN).
679+ """Recurrent architecture for classification of temporal sequences of images based on temporal convolution architecture (TCN).
680+ This architecture is used for BehaviorNet in SIPEC.
650681
651682 Parameters
652683 ----------
653684 recognition_model : keras.model
654685 Pretrained recognition model that extracts features for individual frames.
655- recurrent_input_shape : np.ndarray
656- Number of classes/landmarks .
686+ recurrent_input_shape : np.ndarray - (Time, Width, Height, Channels)
687+ Shape of the images over time .
657688 classes : int
658689 Number of behaviors to recognise.
659690
@@ -743,12 +774,24 @@ def recurrent_model_tcn(
743774def recurrent_model_lstm (
744775 recognition_model , recurrent_input_shape , classes = 4 , recurrent_dropout = None
745776):
746- """
747- Args:
748- recognition_model:
749- recurrent_input_shape:
750- classes:
751- recurrent_dropout:
777+ """Recurrent architecture for classification of temporal sequences of images based on LSTMs or GRUs.
778+ This architecture is used for IdNet in SIPEC.
779+
780+ Parameters
781+ ----------
782+ recognition_model : keras.model
783+ Pretrained recognition model that extracts features for individual frames.
784+ recurrent_input_shape : np.ndarray - (Time, Width, Height, Channels)
785+ Shape of the images over time.
786+ classes : int
787+ Number of behaviors to recognise.
788+ recurrent_dropout : float
789+ Recurrent dropout factor to use.
790+
791+ Returns
792+ -------
793+ keras.model
794+ IdNet
752795 """
753796 input_sequences = Input (shape = recurrent_input_shape )
754797 sequential_model_helper = TimeDistributed (recognition_model )(input_sequences )
@@ -795,14 +838,26 @@ def recurrent_model_lstm(
795838 return sequential_model
796839
797840
841+ # TODO: adaptiv size
798842def pretrained_recognition (model_name , input_shape , num_classes , fix_layers = True ):
799- # TODO: adaptiv size
800- """
801- Args:
802- model_name:
803- input_shape:
804- num_classes:
805- fix_layers:
843+ """This returns the model architecture for a model that operates on images and is pretrained with imagenet weights.
844+ This architecture is used for IdNet and BehaviorNet as backbone in SIPEC and is referred to as RecognitionNet.
845+
846+ Parameters
847+ ----------
848+ model_name : keras.model
849+ Name of the pretrained recognition model to use (names include: "xception, "resnet", "densenet")
850+ input_shape : np.ndarray - (Time, Width, Height, Channels)
851+ Shape of the images over time.
852+ num_classes : int
853+ Number of behaviors to recognise.
854+ fix_layers : bool
855+ Recurrent dropout factor to use.
856+
857+ Returns
858+ -------
859+ keras.model
860+ RecognitionNet
806861 """
807862 if model_name == "xception" :
808863 recognition_model = Xception (
@@ -912,11 +967,21 @@ def pretrained_recognition(model_name, input_shape, num_classes, fix_layers=True
912967
913968
914969def idtracker_ai (input_shape , classes ):
970+ """Implementation of the idtracker.ai identification module as described in the supplementary of Romero-Ferrero et al.
971+
972+ Parameters
973+ ----------
974+ input_shape : keras compatible input shape (W,H,Channels)
975+ keras compatible input shape (features,)
976+ num_classes : int
977+ Number of behaviors to classify..
978+
979+ Returns
980+ -------
981+ keras.model
982+ idtracker.ai identification module
915983 """
916- Args:
917- input_shape:
918- classes:
919- """
984+
920985 activation = "tanh"
921986 dropout = 0.2
922987 # conv model
@@ -934,7 +999,11 @@ def idtracker_ai(input_shape, classes):
934999 )
9351000 model .add (Activation ("relu" ))
9361001
937- model .add (MaxPooling2D (strides = (2 , 2 ),))
1002+ model .add (
1003+ MaxPooling2D (
1004+ strides = (2 , 2 ),
1005+ )
1006+ )
9381007
9391008 model .add (
9401009 Conv2D (
@@ -947,7 +1016,11 @@ def idtracker_ai(input_shape, classes):
9471016 )
9481017 model .add (Activation ("relu" ))
9491018
950- model .add (MaxPooling2D (strides = (2 , 2 ),))
1019+ model .add (
1020+ MaxPooling2D (
1021+ strides = (2 , 2 ),
1022+ )
1023+ )
9511024
9521025 model .add (
9531026 Conv2D (
0 commit comments