Skip to content

Commit f7ce618

Browse files
committed
extending documentation #1
1 parent 9ee38c4 commit f7ce618

4 files changed

Lines changed: 179 additions & 64 deletions

File tree

SwissKnife/architectures.py

Lines changed: 106 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
Conv2DTranspose,
3434
UpSampling2D,
3535
Reshape,
36-
LeakyReLU
36+
LeakyReLU,
3737
)
3838
from tensorflow.keras.models import Sequential
3939

@@ -56,7 +56,10 @@ def posenet_mouse(input_shape, num_classes):
5656
model
5757
"""
5858
recognition_model = Xception(
59-
include_top=False, input_shape=input_shape, pooling="avg", weights="imagenet",
59+
include_top=False,
60+
input_shape=input_shape,
61+
pooling="avg",
62+
weights="imagenet",
6063
)
6164

6265
new_input = Input(
@@ -174,7 +177,10 @@ def posenet_primate(input_shape, num_classes): # recognition_model = DenseNet20
174177
num_classes:Number of classes for recognition or number of landmarks.
175178
"""
176179
recognition_model = ResNet101(
177-
include_top=False, input_shape=input_shape, pooling="avg", weights="imagenet",
180+
include_top=False,
181+
input_shape=input_shape,
182+
pooling="avg",
183+
weights="imagenet",
178184
)
179185

180186
new_input = Input(
@@ -532,9 +538,20 @@ def classification_small(input_shape, num_classes):
532538

533539
def dlc_model_sturman(input_shape, num_classes):
534540
"""Model that implements behavioral classification based on Deeplabcut generated features as in Sturman et al.
535-
Args:
536-
input_shape:
537-
num_classes:Number of behaviors to classify.
541+
542+
Reimplementation of the model used in the publication Sturman et al. that performs action recognition on top of pose estimation
543+
544+
Parameters
545+
----------
546+
input_shape : keras compatible input shape (W,H,Channels)
547+
keras compatible input shape (features,)
548+
num_classes : int
549+
Number of behaviors to classify.
550+
551+
Returns
552+
-------
553+
keras.model
554+
Sturman et al. model
538555
"""
539556
model = Sequential()
540557

@@ -558,10 +575,21 @@ def dlc_model_sturman(input_shape, num_classes):
558575

559576

560577
def dlc_model(input_shape, num_classes):
561-
"""
562-
Args:
563-
input_shape:
564-
num_classes:
578+
"""Model for classification on top of pose estimation.
579+
580+
Classification model for behavior, operating on pose estimation. This model has more free parameters than Sturman et al.
581+
582+
Parameters
583+
----------
584+
input_shape : keras compatible input shape (W,H,Channels)
585+
keras compatible input shape (features,)
586+
num_classes : int
587+
Number of behaviors to classify.
588+
589+
Returns
590+
-------
591+
keras.model
592+
behavior (from pose estimates) model
565593
"""
566594
dropout = 0.3
567595

@@ -644,16 +672,19 @@ def recurrent_model_old(
644672

645673

646674
def recurrent_model_tcn(
647-
recognition_model, recurrent_input_shape, classes=4,
675+
recognition_model,
676+
recurrent_input_shape,
677+
classes=4,
648678
):
649-
"""BehaviorNet architecture for behavioral classification based on temporal convolution architecture (TCN).
679+
"""Recurrent architecture for classification of temporal sequences of images based on temporal convolution architecture (TCN).
680+
This architecture is used for BehaviorNet in SIPEC.
650681
651682
Parameters
652683
----------
653684
recognition_model : keras.model
654685
Pretrained recognition model that extracts features for individual frames.
655-
recurrent_input_shape : np.ndarray
656-
Number of classes/landmarks.
686+
recurrent_input_shape : np.ndarray - (Time, Width, Height, Channels)
687+
Shape of the images over time.
657688
classes : int
658689
Number of behaviors to recognise.
659690
@@ -743,12 +774,24 @@ def recurrent_model_tcn(
743774
def recurrent_model_lstm(
744775
recognition_model, recurrent_input_shape, classes=4, recurrent_dropout=None
745776
):
746-
"""
747-
Args:
748-
recognition_model:
749-
recurrent_input_shape:
750-
classes:
751-
recurrent_dropout:
777+
"""Recurrent architecture for classification of temporal sequences of images based on LSTMs or GRUs.
778+
This architecture is used for IdNet in SIPEC.
779+
780+
Parameters
781+
----------
782+
recognition_model : keras.model
783+
Pretrained recognition model that extracts features for individual frames.
784+
recurrent_input_shape : np.ndarray - (Time, Width, Height, Channels)
785+
Shape of the images over time.
786+
classes : int
787+
Number of behaviors to recognise.
788+
recurrent_dropout : float
789+
Recurrent dropout factor to use.
790+
791+
Returns
792+
-------
793+
keras.model
794+
IdNet
752795
"""
753796
input_sequences = Input(shape=recurrent_input_shape)
754797
sequential_model_helper = TimeDistributed(recognition_model)(input_sequences)
@@ -795,14 +838,26 @@ def recurrent_model_lstm(
795838
return sequential_model
796839

797840

841+
# TODO: adaptiv size
798842
def pretrained_recognition(model_name, input_shape, num_classes, fix_layers=True):
799-
# TODO: adaptiv size
800-
"""
801-
Args:
802-
model_name:
803-
input_shape:
804-
num_classes:
805-
fix_layers:
843+
"""This returns the model architecture for a model that operates on images and is pretrained with imagenet weights.
844+
This architecture is used for IdNet and BehaviorNet as backbone in SIPEC and is referred to as RecognitionNet.
845+
846+
Parameters
847+
----------
848+
model_name : keras.model
849+
Name of the pretrained recognition model to use (names include: "xception, "resnet", "densenet")
850+
input_shape : np.ndarray - (Time, Width, Height, Channels)
851+
Shape of the images over time.
852+
num_classes : int
853+
Number of behaviors to recognise.
854+
fix_layers : bool
855+
Recurrent dropout factor to use.
856+
857+
Returns
858+
-------
859+
keras.model
860+
RecognitionNet
806861
"""
807862
if model_name == "xception":
808863
recognition_model = Xception(
@@ -912,11 +967,21 @@ def pretrained_recognition(model_name, input_shape, num_classes, fix_layers=True
912967

913968

914969
def idtracker_ai(input_shape, classes):
970+
"""Implementation of the idtracker.ai identification module as described in the supplementary of Romero-Ferrero et al.
971+
972+
Parameters
973+
----------
974+
input_shape : keras compatible input shape (W,H,Channels)
975+
keras compatible input shape (features,)
976+
num_classes : int
977+
Number of behaviors to classify..
978+
979+
Returns
980+
-------
981+
keras.model
982+
idtracker.ai identification module
915983
"""
916-
Args:
917-
input_shape:
918-
classes:
919-
"""
984+
920985
activation = "tanh"
921986
dropout = 0.2
922987
# conv model
@@ -934,7 +999,11 @@ def idtracker_ai(input_shape, classes):
934999
)
9351000
model.add(Activation("relu"))
9361001

937-
model.add(MaxPooling2D(strides=(2, 2),))
1002+
model.add(
1003+
MaxPooling2D(
1004+
strides=(2, 2),
1005+
)
1006+
)
9381007

9391008
model.add(
9401009
Conv2D(
@@ -947,7 +1016,11 @@ def idtracker_ai(input_shape, classes):
9471016
)
9481017
model.add(Activation("relu"))
9491018

950-
model.add(MaxPooling2D(strides=(2, 2),))
1019+
model.add(
1020+
MaxPooling2D(
1021+
strides=(2, 2),
1022+
)
1023+
)
9511024

9521025
model.add(
9531026
Conv2D(

SwissKnife/dataloader.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,21 @@
1111

1212

1313
def create_dataset(dataset, look_back=5, oneD=False):
14-
# """Create a recurrent dataset from array.
15-
# Args:
16-
# dataset: Numpy/List of dataset.
17-
# look_back: Number of future/past timepoints to add to current timepoint.
18-
# oneD: Boolean flag whether data is one dimensional or not.
19-
# """
20-
"""Summary line.
21-
22-
Extended description of function.
14+
"""Create a recurrent dataset from array.
2315
2416
Parameters
2517
----------
26-
arg1 : int
27-
Description of arg1
28-
arg2 : str
29-
Description of arg2
18+
dataset : np.ndarray
19+
numpy array of dataset to make recurrent
20+
look_back : int
21+
Number of timesteps to look into the past and future.
22+
oneD : bool
23+
Boolean that indicates if the current dataset is one dimensional.
3024
3125
Returns
3226
-------
33-
bool
34-
dataset
27+
np.ndarray
28+
recurrent dataset
3529
"""
3630
dataX = []
3731
print("creating recurrency")
@@ -196,7 +190,6 @@ def create_recurrent_data(self, oneD=False, recurrent_labels=True):
196190
if recurrent_labels:
197191
self.create_recurrent_labels()
198192

199-
200193
def create_recurrent_data_dlc(self, recurrent_labels=True):
201194

202195
self.dlc_train_recurrent = create_dataset(self.dlc_train, self.look_back)
@@ -209,7 +202,6 @@ def create_recurrent_data_dlc(self, recurrent_labels=True):
209202
if recurrent_labels:
210203
self.create_recurrent_labels()
211204

212-
213205
# TODO: redo all like this, i.e. gettters instead of changing data
214206
def expand_dims(self):
215207
self.x_train = np.expand_dims(self.x_train, axis=-1)
@@ -250,7 +242,7 @@ def decimate_labels(self, percentage, balanced=False):
250242
raise NotImplementedError
251243
if self.x_train_recurrent is not None:
252244
num_labels = int(len(self.x_train_recurrent) * percentage)
253-
indices = np.arange(0, len(self.x_train_recurrent)-1)
245+
indices = np.arange(0, len(self.x_train_recurrent) - 1)
254246
random_idxs = np.random.choice(indices, size=num_labels, replace=False)
255247
self.x_train = self.x_train[random_idxs]
256248
self.y_train = self.y_train[random_idxs]
@@ -392,7 +384,6 @@ def downscale_frames(self, factor=0.5):
392384
im_re.append(imresize(el, factor))
393385
self.x_test = np.asarray(im_re)
394386

395-
396387
def prepare_data(self, downscale=0.5, remove_behaviors=[], flatten=False):
397388
print("preparing data")
398389
self.change_dtype()

SwissKnife/full_inference.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,33 @@ def full_inference(
3737
lookback=100,
3838
max_ids=4,
3939
):
40+
"""Performs full inference on a given video using available SIPEC modules.
41+
42+
Parameters
43+
----------
44+
videodata : np.ndarray
45+
numpy array of read-in videodata.
46+
results_sink : str
47+
Path to where data will be saved.
48+
networks : dict
49+
Dictionary containing SIPEC modules to be used for full inference ("SegNet", "PoseNet", "BehaveNet", IdNet")
50+
mask_matching : bool
51+
Use greedy-mask-matching
52+
id_matching : bool
53+
Correct/smooth SIPEC:IdNet identity using identities based on temporal tracking (greedy-mask-matching)
54+
mask_size : int
55+
Mask size used for the cutout of animals.
56+
lookback : int
57+
Number of timesteps to look back into the past for id_matching.
58+
max_ids : int
59+
Number of maximum ids / maximum number of animals in any FOV.
60+
61+
62+
Returns
63+
-------
64+
list
65+
Outputs of all the provided SIPEC modules for each video frame.
66+
"""
4067
maskmatcher = MaskMatcher()
4168
maskmatcher.max_ids = max_ids
4269
classes = id_classes

0 commit comments

Comments
 (0)