@@ -54,18 +54,30 @@ def posenet(
5454 features = 256 ,
5555 bias = False ,
5656):
57- """Mouse pose estimation architecture.
58- Extended description of function.
57+ """Model that implements SIPEC:PoseNet architecture.
58+
59+ This model uses an EfficientNet backbone and deconvolves generated features into landmarks in imagespace.
60+ It operates on single images and can be used in conjuntion with SIPEC:SegNet to perform top-down pose estimation.
61+
5962 Parameters
6063 ----------
61- arg1 : np.ndarray
62- Input shape for mouse pose estimation network.
63- arg2 : int
64- Number of classes/landmarks.
64+ input_shape : keras compatible input shape (W,H,Channels)
65+ keras compatible input shape (features,)
66+ num_classes : int
67+ Number of joints/landmarks to detect.
68+ backbone : str
69+ Backbone/feature detector to use, default is EfficientNet5. Choose smaller/bigger backbone depending on GPU memory.
70+ gaussian_noise : float
71+ Kernel size of gaussian noise layers to use.
72+ features : int
73+ Number of feature maps to generate at each level.
74+ bias : bool
75+ Use bias for deconvolutional layers.
76+
6577 Returns
6678 -------
6779 keras.model
68- model
80+ SIPEC:PoseNet
6981 """
7082 if backbone == "efficientnetb5" :
7183 recognition_model = EfficientNetB5 (
@@ -334,9 +346,20 @@ def classification_small(input_shape, num_classes):
334346
335347def dlc_model_sturman (input_shape , num_classes ):
336348 """Model that implements behavioral classification based on Deeplabcut generated features as in Sturman et al.
337- Args:
338- input_shape:
339- num_classes:Number of behaviors to classify.
349+
350+ Reimplementation of the model used in the publication Sturman et al. that performs action recognition on top of pose estimation
351+
352+ Parameters
353+ ----------
354+ input_shape : keras compatible input shape (W,H,Channels)
355+ keras compatible input shape (features,)
356+ num_classes : int
357+ Number of behaviors to classify.
358+
359+ Returns
360+ -------
361+ keras.model
362+ Sturman et al. model
340363 """
341364 model = Sequential ()
342365
@@ -360,10 +383,21 @@ def dlc_model_sturman(input_shape, num_classes):
360383
361384
362385def dlc_model (input_shape , num_classes ):
363- """
364- Args:
365- input_shape:
366- num_classes:
386+ """Model for classification on top of pose estimation.
387+
388+ Classification model for behavior, operating on pose estimation. This model has more free parameters than Sturman et al.
389+
390+ Parameters
391+ ----------
392+ input_shape : keras compatible input shape (W,H,Channels)
393+ keras compatible input shape (features,)
394+ num_classes : int
395+ Number of behaviors to classify.
396+
397+ Returns
398+ -------
399+ keras.model
400+ behavior (from pose estimates) model
367401 """
368402 dropout = 0.3
369403
@@ -450,14 +484,15 @@ def recurrent_model_tcn(
450484 recurrent_input_shape ,
451485 classes = 4 ,
452486):
453- """BehaviorNet architecture for behavioral classification based on temporal convolution architecture (TCN).
487+ """Recurrent architecture for classification of temporal sequences of images based on temporal convolution architecture (TCN).
488+ This architecture is used for BehaviorNet in SIPEC.
454489
455490 Parameters
456491 ----------
457492 recognition_model : keras.model
458493 Pretrained recognition model that extracts features for individual frames.
459- recurrent_input_shape : np.ndarray
460- Number of classes/landmarks .
494+ recurrent_input_shape : np.ndarray - (Time, Width, Height, Channels)
495+ Shape of the images over time .
461496 classes : int
462497 Number of behaviors to recognise.
463498
@@ -547,12 +582,24 @@ def recurrent_model_tcn(
547582def recurrent_model_lstm (
548583 recognition_model , recurrent_input_shape , classes = 4 , recurrent_dropout = None
549584):
550- """
551- Args:
552- recognition_model:
553- recurrent_input_shape:
554- classes:
555- recurrent_dropout:
585+ """Recurrent architecture for classification of temporal sequences of images based on LSTMs or GRUs.
586+ This architecture is used for IdNet in SIPEC.
587+
588+ Parameters
589+ ----------
590+ recognition_model : keras.model
591+ Pretrained recognition model that extracts features for individual frames.
592+ recurrent_input_shape : np.ndarray - (Time, Width, Height, Channels)
593+ Shape of the images over time.
594+ classes : int
595+ Number of behaviors to recognise.
596+ recurrent_dropout : float
597+ Recurrent dropout factor to use.
598+
599+ Returns
600+ -------
601+ keras.model
602+ IdNet
556603 """
557604 input_sequences = Input (shape = recurrent_input_shape )
558605 sequential_model_helper = TimeDistributed (recognition_model )(input_sequences )
@@ -599,14 +646,26 @@ def recurrent_model_lstm(
599646 return sequential_model
600647
601648
649+ # TODO: adaptiv size
602650def pretrained_recognition (model_name , input_shape , num_classes , fix_layers = True ):
603- # TODO: adaptiv size
604- """
605- Args:
606- model_name:
607- input_shape:
608- num_classes:
609- fix_layers:
651+ """This returns the model architecture for a model that operates on images and is pretrained with imagenet weights.
652+ This architecture is used for IdNet and BehaviorNet as backbone in SIPEC and is referred to as RecognitionNet.
653+
654+ Parameters
655+ ----------
656+ model_name : keras.model
657+ Name of the pretrained recognition model to use (names include: "xception, "resnet", "densenet")
658+ input_shape : np.ndarray - (Time, Width, Height, Channels)
659+ Shape of the images over time.
660+ num_classes : int
661+ Number of behaviors to recognise.
662+ fix_layers : bool
663+ Recurrent dropout factor to use.
664+
665+ Returns
666+ -------
667+ keras.model
668+ RecognitionNet
610669 """
611670 if model_name == "xception" :
612671 recognition_model = Xception (
@@ -716,11 +775,21 @@ def pretrained_recognition(model_name, input_shape, num_classes, fix_layers=True
716775
717776
718777def idtracker_ai (input_shape , classes ):
778+ """Implementation of the idtracker.ai identification module as described in the supplementary of Romero-Ferrero et al.
779+
780+ Parameters
781+ ----------
782+ input_shape : keras compatible input shape (W,H,Channels)
783+ keras compatible input shape (features,)
784+ num_classes : int
785+ Number of behaviors to classify..
786+
787+ Returns
788+ -------
789+ keras.model
790+ idtracker.ai identification module
719791 """
720- Args:
721- input_shape:
722- classes:
723- """
792+
724793 activation = "tanh"
725794 dropout = 0.2
726795 # conv model
0 commit comments