Skip to content

Commit d5874a3

Browse files
authored
Merge pull request #22 from SIPEC-Animal-Data-Analysis/devel
Devel to main after incorporating the new behaviour updates
2 parents 94f5486 + 8c78dca commit d5874a3

12 files changed

Lines changed: 568 additions & 203 deletions

File tree

Dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ RUN git clone https://github.com/SIPEC-Animal-Data-Analysis/SIPEC.git
4747

4848
WORKDIR /home/user/SIPEC
4949

50-
RUN git checkout main
50+
RUN git checkout devel
5151

5252
ENV VIRTUAL_ENV=/home/user/SIPEC/env
5353

@@ -65,7 +65,7 @@ RUN mkdir /home/user/data
6565

6666
WORKDIR /home/user/data/
6767

68-
RUN wget -O pretrained_networks.zip https://www.dropbox.com/s/38adfecf6741gm6/pretrained_networks.zip?dl=0 && unzip pretrained_networks.zip && rm pretrained_networks.zip
68+
RUN wget -O pretrained_networks.zip https://www.dropbox.com/s/38adfecf6741gm6/pretrained_networks.zip?dl=0 && unzip pretrained_networks.zip -x / -d pretrained_networks && rm pretrained_networks.zip
6969

7070
RUN wget -O mouse_segmentation_4plex_merged.zip https://www.dropbox.com/s/0c4m60zg5kx3nqq/mouse_segmentation_4plex_merged.zip?dl=0 && unzip mouse_segmentation_4plex_merged.zip && rm mouse_segmentation_4plex_merged.zip
7171

SwissKnife/behavior.py

Lines changed: 85 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,18 @@
11
# SIPEC
22
# MARKUS MARKS
33
# Behavioral Classification
4-
import sys
5-
import os
64

7-
# from scipy.misc import imresize
8-
from sklearn.externals._pilutil import imresize
95
from tqdm import tqdm
106
import pandas as pd
117
import random
128
from datetime import datetime
13-
14-
from SwissKnife.architectures import classification_small
15-
169
from argparse import ArgumentParser
17-
import tensorflow.keras.backend as K
1810
import numpy as np
1911

2012
from sklearn import metrics
2113
from scipy.stats import pearsonr
2214
from sklearn.model_selection import StratifiedKFold
15+
from sklearn.externals._pilutil import imresize
2316

2417
from SwissKnife.utils import (
2518
setGPU,
@@ -28,21 +21,21 @@
2821
load_vgg_labels,
2922
loadVideo,
3023
load_config,
31-
check_directory,
24+
check_directory, callbacks_learningRate_plateau,
3225
)
33-
from SwissKnife.dataloader import Dataloader
26+
from SwissKnife.dataloader import Dataloader, DataGenerator
3427
from SwissKnife.model import Model
28+
from SwissKnife.architectures import classification_small
3529

3630

3731
def train_behavior(
38-
dataloader,
39-
config,
40-
num_classes,
41-
encode_labels=True,
42-
class_weights=None,
43-
# results_sink=results_sink,
32+
dataloader,
33+
config,
34+
num_classes,
35+
encode_labels=True,
36+
class_weights=None,
37+
# results_sink=results_sink,
4438
):
45-
4639
print("data prepared!")
4740

4841
our_model = Model()
@@ -61,23 +54,32 @@ def train_behavior(
6154
our_model.set_lr_scheduler()
6255
else:
6356
# use standard training callback
64-
CB_es, CB_lr = get_callbacks()
57+
CB_es, CB_lr = callbacks_learningRate_plateau()
6558
our_model.add_callbacks([CB_es, CB_lr])
6659

6760
# add sklearn metrics for tracking in training
68-
my_metrics = Metrics(validation_data=(dataloader.x_test,dataloader.y_test))
61+
my_metrics = Metrics()
6962
my_metrics.setModel(our_model.recognition_model)
63+
my_metrics.validation_data = (dataloader.x_test, dataloader.y_test)
7064
our_model.add_callbacks([my_metrics])
7165

7266
if config["train_recognition_model"]:
67+
if dataloader.config["use_generator"]:
68+
dataloader.training_generator = DataGenerator(
69+
x_train=dataloader.x_train, y_train=dataloader.y_train, look_back=dataloader.config["look_back"], batch_size=32,
70+
type='recognition'
71+
)
7372
our_model.recognition_model_epochs = config["recognition_model_epochs"]
7473
our_model.recognition_model_batch_size = config["recognition_model_batch_size"]
75-
print(config["recognition_model_batch_size"])
76-
print(dataloader.y_test)
7774
our_model.train_recognition_network(dataloader=dataloader)
7875
print(config)
7976

8077
if config["train_sequential_model"]:
78+
if dataloader.config["use_generator"]:
79+
dataloader.training_generator = DataGenerator(
80+
x_train=dataloader.x_train, y_train=dataloader.y_train, look_back=dataloader.config["look_back"], batch_size=32,
81+
type='sequential'
82+
)
8183
# if False:
8284
if config["recognition_model_fix"]:
8385
our_model.fix_recognition_layers()
@@ -88,8 +90,8 @@ def train_behavior(
8890
input_shape=dataloader.get_input_shape(recurrent=True),
8991
num_classes=num_classes,
9092
)
91-
my_metrics = Metrics(validation_data=(dataloader.x_test_recurrent,dataloader.y_test_recurrent))
9293
my_metrics.setModel(our_model.sequential_model)
94+
my_metrics.validation_data = (dataloader.x_test_recurrent, dataloader.y_test_recurrent)
9395
our_model.add_callbacks([my_metrics])
9496
our_model.set_optimizer(
9597
config["sequential_model_optimizer"], lr=config["sequential_model_lr"],
@@ -106,17 +108,47 @@ def train_behavior(
106108

107109
print(config)
108110

109-
res = our_model.predict(dataloader.x_test, model="recognition")
110-
acc = metrics.balanced_accuracy_score(res, np.argmax(dataloader.y_test, axis=-1))
111-
f1 = metrics.f1_score(res, np.argmax(dataloader.y_test, axis=-1), average="macro")
112-
113-
corr = pearsonr(res, np.argmax(dataloader.y_test, axis=-1))[0]
114-
report = metrics.classification_report(res, np.argmax(dataloader.y_test, axis=-1),)
115-
return [acc, f1, corr], report
116-
111+
print('evaluating')
112+
res = []
113+
batches = len(dataloader.x_test)
114+
batches = int(batches/config["sequential_model_batch_size"])
115+
test_gt = []
116+
#TODO: fix -1 to really use all VAL data
117+
for idx in tqdm(range(batches-1)):
118+
if config["train_sequential_model"]:
119+
eval_batch = []
120+
for i in range(config["sequential_model_batch_size"]):
121+
new_idx = (idx*config["sequential_model_batch_size"]) + i + dataloader.look_back
122+
data = dataloader.x_test[new_idx - dataloader.look_back: new_idx + dataloader.look_back]
123+
eval_batch.append(data)
124+
test_gt.append(dataloader.y_test[new_idx])
125+
eval_batch = np.asarray(eval_batch)
126+
prediction = our_model.predict(eval_batch, model="sequential")
127+
else:
128+
eval_batch = []
129+
for i in range(config["recognition_model_batch_size"]):
130+
new_idx = (idx*config["recognition_model_batch_size"]) + i
131+
data = dataloader.x_test[new_idx]
132+
eval_batch.append(data)
133+
test_gt.append(dataloader.y_test[new_idx])
134+
eval_batch = np.asarray(eval_batch)
135+
prediction = our_model.predict(eval_batch, model="recognition")
136+
for idx, el in enumerate(prediction):
137+
res.append(el)
138+
139+
res = np.asarray(res)
140+
test_gt = np.asarray(test_gt)
141+
142+
acc = metrics.balanced_accuracy_score(res, np.argmax(test_gt, axis=-1))
143+
f1 = metrics.f1_score(res, np.argmax(test_gt, axis=-1), average="macro")
144+
#
145+
corr = pearsonr(res, np.argmax(test_gt, axis=-1))[0]
146+
report = metrics.classification_report(res, np.argmax(test_gt, axis=-1), )
147+
148+
print(report)
149+
return our_model, [acc, f1, corr], report
117150

118151
def train_primate(config, results_sink, shuffle):
119-
120152
basepath = "/media/nexus/storage5/swissknife_data/primate/behavior/"
121153

122154
vids = [
@@ -172,11 +204,11 @@ def train_primate(config, results_sink, shuffle):
172204
global groups
173205

174206
groups = (
175-
[0] * len(labels_idxs[0])
176-
+ [0] * len(labels_idxs[1])
177-
+ [3] * len(labels_idxs[2])
178-
+ [4] * len(labels_idxs[3])
179-
+ [4] * len(labels_idxs[4])
207+
[0] * len(labels_idxs[0])
208+
+ [0] * len(labels_idxs[1])
209+
+ [3] * len(labels_idxs[2])
210+
+ [4] * len(labels_idxs[3])
211+
+ [4] * len(labels_idxs[4])
180212
)
181213

182214
groups = groups
@@ -224,15 +256,15 @@ def train_primate(config, results_sink, shuffle):
224256
x_test = vid[tt_idx]
225257

226258
dataloader = Dataloader(
227-
x_train, y_train, x_test, y_test, look_back=config["look_back"]
259+
x_train, y_train, x_test, y_test, config=config
228260
)
229261

230262
# config_name = 'primate_' + str(1)
231263
#
232264
# config = load_config("../configs/behavior/primate/" + config_name)
233265
config["recognition_model_batch_size"] = 128
234266
config["backbone"] = "imagenet"
235-
config["encode_labels"]= True
267+
config["encode_labels"] = True
236268
print(config)
237269

238270
num_classes = config["num_classes"]
@@ -247,7 +279,7 @@ def train_primate(config, results_sink, shuffle):
247279

248280
if config["normalize_data"]:
249281
dataloader.normalize_data()
250-
if encode_labels:
282+
if config["encode_labels"]:
251283
dataloader.encode_labels()
252284
print("labels encoded")
253285

@@ -328,15 +360,9 @@ def main():
328360
shuffle = args.shuffle
329361
annotations = args.annotations
330362
video = args.video
331-
output_path = args.output_path
363+
results_sink = args.results_sink
332364

333365
setGPU(gpu_name)
334-
335-
output_path = "/home/user/results"
336-
337-
results_sink = (
338-
os.path.join(output_path, "primate/behavior-{}-{}/".format(network, datetime.now().strftime("%Y-%m-%d-%H_%M")))
339-
)
340366
check_directory(results_sink)
341367

342368
if annotations:
@@ -358,22 +384,19 @@ def main():
358384
# load cfg
359385
config = load_config("../configs/behavior/shared_config")
360386
beh_config = load_config(
361-
"../configs/behavior/primate/primate_final"
387+
"../configs/behavior/default"
362388
)
363389
config.update(beh_config)
364-
365390
print(config)
366391

367-
config["encode_labels"]= True
368392
num_classes = len(np.unique(annotation))
369-
370393
dataloader = Dataloader(
371394
x_train, y_train, x_test, y_test, config
372395
)
373-
374396
dataloader.prepare_data()
375397
train_behavior(dataloader=dataloader, num_classes=num_classes, config=config)
376-
elif operation("train_primate"):
398+
399+
elif operation == "train_primate":
377400
config_name = "primate_final"
378401
config = load_config("../configs/behavior/primate/" + config_name)
379402
train_primate(config=config, results_sink=results_sink, shuffle=shuffle)
@@ -405,7 +428,7 @@ def main():
405428
action="store",
406429
dest="config_name",
407430
type=str,
408-
default="behavior_config_baseline",
431+
default="default",
409432
help="behavioral config to use",
410433
)
411434
parser.add_argument(
@@ -416,6 +439,7 @@ def main():
416439
default="ours",
417440
help="which network used for training",
418441
)
442+
#TODO: check if folder and then load all files in folder, similar for vid files
419443
parser.add_argument(
420444
"--annotations",
421445
action="store",
@@ -433,20 +457,18 @@ def main():
433457
help="path to folder with annotated video",
434458
)
435459
parser.add_argument(
436-
"--shuffle", action="store", dest="shuffle", type=bool, default=False,
460+
"--results_sink",
461+
action="store",
462+
dest="results_sink",
463+
type=str,
464+
default='./results/behavior5/',
465+
help="path to results",
437466
)
438-
439467
parser.add_argument(
440-
"--output_path",
441-
action="store",
442-
dest="output_path",
443-
type=str,
444-
default=None,
445-
help="Path to the folder where the ouput should be written"
468+
"--shuffle", action="store", dest="shuffle", type=bool, default=False,
446469
)
447470

448-
449471
# example usage
450472
# python behavior.py --annotations "/media/nexus/storage5/swissknife_data/primate/behavior/20180124T113800-20180124T115800_0.csv" --video "/media/nexus/storage5/swissknife_data/primate/behavior/fullvids_20180124T113800-20180124T115800_%T1_0.mp4" --gpu 2
451473
if __name__ == "__main__":
452-
main()
474+
main()

0 commit comments

Comments
 (0)