Skip to content

Commit 2ac03fc

Browse files
authored
Merge pull request #16 from SIPEC-Animal-Data-Analysis/docs
extending documentation #1
2 parents f7c2dc4 + 87df0ea commit 2ac03fc

4 files changed

Lines changed: 172 additions & 64 deletions

File tree

SwissKnife/architectures.py

Lines changed: 103 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,30 @@ def posenet(
5454
features=256,
5555
bias=False,
5656
):
57-
"""Mouse pose estimation architecture.
58-
Extended description of function.
57+
"""Model that implements SIPEC:PoseNet architecture.
58+
59+
This model uses an EfficientNet backbone and deconvolves generated features into landmarks in imagespace.
60+
It operates on single images and can be used in conjuntion with SIPEC:SegNet to perform top-down pose estimation.
61+
5962
Parameters
6063
----------
61-
arg1 : np.ndarray
62-
Input shape for mouse pose estimation network.
63-
arg2 : int
64-
Number of classes/landmarks.
64+
input_shape : keras compatible input shape (W,H,Channels)
65+
keras compatible input shape (features,)
66+
num_classes : int
67+
Number of joints/landmarks to detect.
68+
backbone : str
69+
Backbone/feature detector to use, default is EfficientNet5. Choose smaller/bigger backbone depending on GPU memory.
70+
gaussian_noise : float
71+
Kernel size of gaussian noise layers to use.
72+
features : int
73+
Number of feature maps to generate at each level.
74+
bias : bool
75+
Use bias for deconvolutional layers.
76+
6577
Returns
6678
-------
6779
keras.model
68-
model
80+
SIPEC:PoseNet
6981
"""
7082
if backbone == "efficientnetb5":
7183
recognition_model = EfficientNetB5(
@@ -334,9 +346,20 @@ def classification_small(input_shape, num_classes):
334346

335347
def dlc_model_sturman(input_shape, num_classes):
336348
"""Model that implements behavioral classification based on Deeplabcut generated features as in Sturman et al.
337-
Args:
338-
input_shape:
339-
num_classes:Number of behaviors to classify.
349+
350+
Reimplementation of the model used in the publication Sturman et al. that performs action recognition on top of pose estimation
351+
352+
Parameters
353+
----------
354+
input_shape : keras compatible input shape (W,H,Channels)
355+
keras compatible input shape (features,)
356+
num_classes : int
357+
Number of behaviors to classify.
358+
359+
Returns
360+
-------
361+
keras.model
362+
Sturman et al. model
340363
"""
341364
model = Sequential()
342365

@@ -360,10 +383,21 @@ def dlc_model_sturman(input_shape, num_classes):
360383

361384

362385
def dlc_model(input_shape, num_classes):
363-
"""
364-
Args:
365-
input_shape:
366-
num_classes:
386+
"""Model for classification on top of pose estimation.
387+
388+
Classification model for behavior, operating on pose estimation. This model has more free parameters than Sturman et al.
389+
390+
Parameters
391+
----------
392+
input_shape : keras compatible input shape (W,H,Channels)
393+
keras compatible input shape (features,)
394+
num_classes : int
395+
Number of behaviors to classify.
396+
397+
Returns
398+
-------
399+
keras.model
400+
behavior (from pose estimates) model
367401
"""
368402
dropout = 0.3
369403

@@ -450,14 +484,15 @@ def recurrent_model_tcn(
450484
recurrent_input_shape,
451485
classes=4,
452486
):
453-
"""BehaviorNet architecture for behavioral classification based on temporal convolution architecture (TCN).
487+
"""Recurrent architecture for classification of temporal sequences of images based on temporal convolution architecture (TCN).
488+
This architecture is used for BehaviorNet in SIPEC.
454489
455490
Parameters
456491
----------
457492
recognition_model : keras.model
458493
Pretrained recognition model that extracts features for individual frames.
459-
recurrent_input_shape : np.ndarray
460-
Number of classes/landmarks.
494+
recurrent_input_shape : np.ndarray - (Time, Width, Height, Channels)
495+
Shape of the images over time.
461496
classes : int
462497
Number of behaviors to recognise.
463498
@@ -547,12 +582,24 @@ def recurrent_model_tcn(
547582
def recurrent_model_lstm(
548583
recognition_model, recurrent_input_shape, classes=4, recurrent_dropout=None
549584
):
550-
"""
551-
Args:
552-
recognition_model:
553-
recurrent_input_shape:
554-
classes:
555-
recurrent_dropout:
585+
"""Recurrent architecture for classification of temporal sequences of images based on LSTMs or GRUs.
586+
This architecture is used for IdNet in SIPEC.
587+
588+
Parameters
589+
----------
590+
recognition_model : keras.model
591+
Pretrained recognition model that extracts features for individual frames.
592+
recurrent_input_shape : np.ndarray - (Time, Width, Height, Channels)
593+
Shape of the images over time.
594+
classes : int
595+
Number of behaviors to recognise.
596+
recurrent_dropout : float
597+
Recurrent dropout factor to use.
598+
599+
Returns
600+
-------
601+
keras.model
602+
IdNet
556603
"""
557604
input_sequences = Input(shape=recurrent_input_shape)
558605
sequential_model_helper = TimeDistributed(recognition_model)(input_sequences)
@@ -599,14 +646,26 @@ def recurrent_model_lstm(
599646
return sequential_model
600647

601648

649+
# TODO: adaptiv size
602650
def pretrained_recognition(model_name, input_shape, num_classes, fix_layers=True):
603-
# TODO: adaptiv size
604-
"""
605-
Args:
606-
model_name:
607-
input_shape:
608-
num_classes:
609-
fix_layers:
651+
"""This returns the model architecture for a model that operates on images and is pretrained with imagenet weights.
652+
This architecture is used for IdNet and BehaviorNet as backbone in SIPEC and is referred to as RecognitionNet.
653+
654+
Parameters
655+
----------
656+
model_name : keras.model
657+
Name of the pretrained recognition model to use (names include: "xception, "resnet", "densenet")
658+
input_shape : np.ndarray - (Time, Width, Height, Channels)
659+
Shape of the images over time.
660+
num_classes : int
661+
Number of behaviors to recognise.
662+
fix_layers : bool
663+
Recurrent dropout factor to use.
664+
665+
Returns
666+
-------
667+
keras.model
668+
RecognitionNet
610669
"""
611670
if model_name == "xception":
612671
recognition_model = Xception(
@@ -716,11 +775,21 @@ def pretrained_recognition(model_name, input_shape, num_classes, fix_layers=True
716775

717776

718777
def idtracker_ai(input_shape, classes):
778+
"""Implementation of the idtracker.ai identification module as described in the supplementary of Romero-Ferrero et al.
779+
780+
Parameters
781+
----------
782+
input_shape : keras compatible input shape (W,H,Channels)
783+
keras compatible input shape (features,)
784+
num_classes : int
785+
Number of behaviors to classify..
786+
787+
Returns
788+
-------
789+
keras.model
790+
idtracker.ai identification module
719791
"""
720-
Args:
721-
input_shape:
722-
classes:
723-
"""
792+
724793
activation = "tanh"
725794
dropout = 0.2
726795
# conv model

SwissKnife/dataloader.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,21 @@
1111

1212

1313
def create_dataset(dataset, look_back=5, oneD=False):
14-
# """Create a recurrent dataset from array.
15-
# Args:
16-
# dataset: Numpy/List of dataset.
17-
# look_back: Number of future/past timepoints to add to current timepoint.
18-
# oneD: Boolean flag whether data is one dimensional or not.
19-
# """
20-
"""Summary line.
21-
22-
Extended description of function.
14+
"""Create a recurrent dataset from array.
2315
2416
Parameters
2517
----------
26-
arg1 : int
27-
Description of arg1
28-
arg2 : str
29-
Description of arg2
18+
dataset : np.ndarray
19+
numpy array of dataset to make recurrent
20+
look_back : int
21+
Number of timesteps to look into the past and future.
22+
oneD : bool
23+
Boolean that indicates if the current dataset is one dimensional.
3024
3125
Returns
3226
-------
33-
bool
34-
dataset
27+
np.ndarray
28+
recurrent dataset
3529
"""
3630
dataX = []
3731
print("creating recurrency")
@@ -196,7 +190,6 @@ def create_recurrent_data(self, oneD=False, recurrent_labels=True):
196190
if recurrent_labels:
197191
self.create_recurrent_labels()
198192

199-
200193
def create_recurrent_data_dlc(self, recurrent_labels=True):
201194

202195
self.dlc_train_recurrent = create_dataset(self.dlc_train, self.look_back)
@@ -209,7 +202,6 @@ def create_recurrent_data_dlc(self, recurrent_labels=True):
209202
if recurrent_labels:
210203
self.create_recurrent_labels()
211204

212-
213205
# TODO: redo all like this, i.e. gettters instead of changing data
214206
def expand_dims(self):
215207
self.x_train = np.expand_dims(self.x_train, axis=-1)
@@ -250,7 +242,7 @@ def decimate_labels(self, percentage, balanced=False):
250242
raise NotImplementedError
251243
if self.x_train_recurrent is not None:
252244
num_labels = int(len(self.x_train_recurrent) * percentage)
253-
indices = np.arange(0, len(self.x_train_recurrent)-1)
245+
indices = np.arange(0, len(self.x_train_recurrent) - 1)
254246
random_idxs = np.random.choice(indices, size=num_labels, replace=False)
255247
self.x_train = self.x_train[random_idxs]
256248
self.y_train = self.y_train[random_idxs]
@@ -392,7 +384,6 @@ def downscale_frames(self, factor=0.5):
392384
im_re.append(imresize(el, factor))
393385
self.x_test = np.asarray(im_re)
394386

395-
396387
def prepare_data(self, downscale=0.5, remove_behaviors=[], flatten=False):
397388
print("preparing data")
398389
self.change_dtype()

SwissKnife/full_inference.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,33 @@ def full_inference(
4040
mold_dimension=1024,
4141
max_ids=4,
4242
):
43+
"""Performs full inference on a given video using available SIPEC modules.
44+
45+
Parameters
46+
----------
47+
videodata : np.ndarray
48+
numpy array of read-in videodata.
49+
results_sink : str
50+
Path to where data will be saved.
51+
networks : dict
52+
Dictionary containing SIPEC modules to be used for full inference ("SegNet", "PoseNet", "BehaveNet", IdNet")
53+
mask_matching : bool
54+
Use greedy-mask-matching
55+
id_matching : bool
56+
Correct/smooth SIPEC:IdNet identity using identities based on temporal tracking (greedy-mask-matching)
57+
mask_size : int
58+
Mask size used for the cutout of animals.
59+
lookback : int
60+
Number of timesteps to look back into the past for id_matching.
61+
max_ids : int
62+
Number of maximum ids / maximum number of animals in any FOV.
63+
64+
65+
Returns
66+
-------
67+
list
68+
Outputs of all the provided SIPEC modules for each video frame.
69+
"""
4370
maskmatcher = MaskMatcher()
4471
maskmatcher.max_ids = max_ids
4572
classes = id_classes

SwissKnife/segmentation.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
# SEGMENTATION PART
1010
# This code is optimized from the Mask RCNN (Waleed Abdulla, (c) 2017 Matterport, Inc.) repository
1111

12-
#TODO: Look at the warnings and resolve them
12+
# TODO: Look at the warnings and resolve them
1313
import warnings
14+
1415
warnings.filterwarnings("ignore")
1516

1617
import sys
@@ -546,6 +547,7 @@ def evaluate_network(model_path, species, filter_masks=False, cv_folds=0):
546547

547548

548549
# TODO: change cv folds to None default
550+
# TODO: make default species
549551
def train_on_data_once(
550552
model_path,
551553
cv_folds=0,
@@ -557,17 +559,36 @@ def train_on_data_once(
557559
perform_evaluation=True,
558560
debug=0,
559561
):
562+
"""Performs training for the segmentation moduel of SIPEC (SIPEC:SegNet).
560563
561-
"""
562-
Args:
563-
model_path:
564-
cv_folds:
565-
frames_path:
566-
annotations_path:
567-
species:
568-
fold:
569-
fraction:
570-
debug:
564+
Parameters
565+
----------
566+
model_path : str
567+
Path to model, can be either where a new model should be stored or a path to an existing model to be retrained.
568+
cv_folds : int
569+
Number of cross_validation folds, use 0 for a normal train/test split.
570+
frames_path : str
571+
Path to the frames used for training.
572+
annotations_path : str
573+
Path to the annotations used for training.
574+
species : str
575+
Species to perform segmentation on (can be any species, but "mouse" or "primate" have more specialised parameters). If your species is neither "mouse" nor "primate", use "default".
576+
fold : int
577+
If cv_folds > 1, fold is the number of fold to be tested on.
578+
fraction : float
579+
Factor by which to decimate the training data points.
580+
perform_evaluation : bool
581+
Perform subsequent evaluation of the model
582+
debug : bool
583+
Debug verbosity.
584+
585+
586+
Returns
587+
-------
588+
model
589+
SIPEC:SegNet model
590+
mean_ap
591+
Mean average precision score achieved by this model
571592
"""
572593
dataset_train, dataset_val = get_segmentation_data(
573594
frames_path=frames_path,

0 commit comments

Comments
 (0)