11# SIPEC
22# MARKUS MARKS
33# Behavioral Classification
4- import sys
5- import os
64
7- # from scipy.misc import imresize
8- from sklearn .externals ._pilutil import imresize
95from tqdm import tqdm
106import pandas as pd
117import random
128from datetime import datetime
13-
14- from SwissKnife .architectures import classification_small
15-
169from argparse import ArgumentParser
17- import tensorflow .keras .backend as K
1810import numpy as np
1911
2012from sklearn import metrics
2113from scipy .stats import pearsonr
2214from sklearn .model_selection import StratifiedKFold
15+ from sklearn .externals ._pilutil import imresize
2316
2417from SwissKnife .utils import (
2518 setGPU ,
2821 load_vgg_labels ,
2922 loadVideo ,
3023 load_config ,
31- check_directory ,
24+ check_directory , callbacks_learningRate_plateau ,
3225)
33- from SwissKnife .dataloader import Dataloader
26+ from SwissKnife .dataloader import Dataloader , DataGenerator
3427from SwissKnife .model import Model
28+ from SwissKnife .architectures import classification_small
3529
3630
3731def train_behavior (
38- dataloader ,
39- config ,
40- num_classes ,
41- encode_labels = True ,
42- class_weights = None ,
43- # results_sink=results_sink,
32+ dataloader ,
33+ config ,
34+ num_classes ,
35+ encode_labels = True ,
36+ class_weights = None ,
37+ # results_sink=results_sink,
4438):
45-
4639 print ("data prepared!" )
4740
4841 our_model = Model ()
@@ -61,23 +54,32 @@ def train_behavior(
6154 our_model .set_lr_scheduler ()
6255 else :
6356 # use standard training callback
64- CB_es , CB_lr = get_callbacks ()
57+ CB_es , CB_lr = callbacks_learningRate_plateau ()
6558 our_model .add_callbacks ([CB_es , CB_lr ])
6659
6760 # add sklearn metrics for tracking in training
68- my_metrics = Metrics (validation_data = ( dataloader . x_test , dataloader . y_test ) )
61+ my_metrics = Metrics ()
6962 my_metrics .setModel (our_model .recognition_model )
63+ my_metrics .validation_data = (dataloader .x_test , dataloader .y_test )
7064 our_model .add_callbacks ([my_metrics ])
7165
7266 if config ["train_recognition_model" ]:
67+ if dataloader .config ["use_generator" ]:
68+ dataloader .training_generator = DataGenerator (
69+ x_train = dataloader .x_train , y_train = dataloader .y_train , look_back = dataloader .config ["look_back" ], batch_size = 32 ,
70+ type = 'recognition'
71+ )
7372 our_model .recognition_model_epochs = config ["recognition_model_epochs" ]
7473 our_model .recognition_model_batch_size = config ["recognition_model_batch_size" ]
75- print (config ["recognition_model_batch_size" ])
76- print (dataloader .y_test )
7774 our_model .train_recognition_network (dataloader = dataloader )
7875 print (config )
7976
8077 if config ["train_sequential_model" ]:
78+ if dataloader .config ["use_generator" ]:
79+ dataloader .training_generator = DataGenerator (
80+ x_train = dataloader .x_train , y_train = dataloader .y_train , look_back = dataloader .config ["look_back" ], batch_size = 32 ,
81+ type = 'sequential'
82+ )
8183 # if False:
8284 if config ["recognition_model_fix" ]:
8385 our_model .fix_recognition_layers ()
@@ -88,8 +90,8 @@ def train_behavior(
8890 input_shape = dataloader .get_input_shape (recurrent = True ),
8991 num_classes = num_classes ,
9092 )
91- my_metrics = Metrics (validation_data = (dataloader .x_test_recurrent ,dataloader .y_test_recurrent ))
9293 my_metrics .setModel (our_model .sequential_model )
94+ my_metrics .validation_data = (dataloader .x_test_recurrent , dataloader .y_test_recurrent )
9395 our_model .add_callbacks ([my_metrics ])
9496 our_model .set_optimizer (
9597 config ["sequential_model_optimizer" ], lr = config ["sequential_model_lr" ],
@@ -106,17 +108,47 @@ def train_behavior(
106108
107109 print (config )
108110
109- res = our_model .predict (dataloader .x_test , model = "recognition" )
110- acc = metrics .balanced_accuracy_score (res , np .argmax (dataloader .y_test , axis = - 1 ))
111- f1 = metrics .f1_score (res , np .argmax (dataloader .y_test , axis = - 1 ), average = "macro" )
112-
113- corr = pearsonr (res , np .argmax (dataloader .y_test , axis = - 1 ))[0 ]
114- report = metrics .classification_report (res , np .argmax (dataloader .y_test , axis = - 1 ),)
115- return [acc , f1 , corr ], report
116-
111+ print ('evaluating' )
112+ res = []
113+ batches = len (dataloader .x_test )
114+ batches = int (batches / config ["sequential_model_batch_size" ])
115+ test_gt = []
116+ #TODO: fix -1 to really use all VAL data
117+ for idx in tqdm (range (batches - 1 )):
118+ if config ["train_sequential_model" ]:
119+ eval_batch = []
120+ for i in range (config ["sequential_model_batch_size" ]):
121+ new_idx = (idx * config ["sequential_model_batch_size" ]) + i + dataloader .look_back
122+ data = dataloader .x_test [new_idx - dataloader .look_back : new_idx + dataloader .look_back ]
123+ eval_batch .append (data )
124+ test_gt .append (dataloader .y_test [new_idx ])
125+ eval_batch = np .asarray (eval_batch )
126+ prediction = our_model .predict (eval_batch , model = "sequential" )
127+ else :
128+ eval_batch = []
129+ for i in range (config ["recognition_model_batch_size" ]):
130+ new_idx = (idx * config ["recognition_model_batch_size" ]) + i
131+ data = dataloader .x_test [new_idx ]
132+ eval_batch .append (data )
133+ test_gt .append (dataloader .y_test [new_idx ])
134+ eval_batch = np .asarray (eval_batch )
135+ prediction = our_model .predict (eval_batch , model = "recognition" )
136+ for idx , el in enumerate (prediction ):
137+ res .append (el )
138+
139+ res = np .asarray (res )
140+ test_gt = np .asarray (test_gt )
141+
142+ acc = metrics .balanced_accuracy_score (res , np .argmax (test_gt , axis = - 1 ))
143+ f1 = metrics .f1_score (res , np .argmax (test_gt , axis = - 1 ), average = "macro" )
144+ #
145+ corr = pearsonr (res , np .argmax (test_gt , axis = - 1 ))[0 ]
146+ report = metrics .classification_report (res , np .argmax (test_gt , axis = - 1 ), )
147+
148+ print (report )
149+ return our_model , [acc , f1 , corr ], report
117150
118151def train_primate (config , results_sink , shuffle ):
119-
120152 basepath = "/media/nexus/storage5/swissknife_data/primate/behavior/"
121153
122154 vids = [
@@ -172,11 +204,11 @@ def train_primate(config, results_sink, shuffle):
172204 global groups
173205
174206 groups = (
175- [0 ] * len (labels_idxs [0 ])
176- + [0 ] * len (labels_idxs [1 ])
177- + [3 ] * len (labels_idxs [2 ])
178- + [4 ] * len (labels_idxs [3 ])
179- + [4 ] * len (labels_idxs [4 ])
207+ [0 ] * len (labels_idxs [0 ])
208+ + [0 ] * len (labels_idxs [1 ])
209+ + [3 ] * len (labels_idxs [2 ])
210+ + [4 ] * len (labels_idxs [3 ])
211+ + [4 ] * len (labels_idxs [4 ])
180212 )
181213
182214 groups = groups
@@ -224,15 +256,15 @@ def train_primate(config, results_sink, shuffle):
224256 x_test = vid [tt_idx ]
225257
226258 dataloader = Dataloader (
227- x_train , y_train , x_test , y_test , look_back = config [ "look_back" ]
259+ x_train , y_train , x_test , y_test , config = config
228260 )
229261
230262 # config_name = 'primate_' + str(1)
231263 #
232264 # config = load_config("../configs/behavior/primate/" + config_name)
233265 config ["recognition_model_batch_size" ] = 128
234266 config ["backbone" ] = "imagenet"
235- config ["encode_labels" ]= True
267+ config ["encode_labels" ] = True
236268 print (config )
237269
238270 num_classes = config ["num_classes" ]
@@ -247,7 +279,7 @@ def train_primate(config, results_sink, shuffle):
247279
248280 if config ["normalize_data" ]:
249281 dataloader .normalize_data ()
250- if encode_labels :
282+ if config [ " encode_labels" ] :
251283 dataloader .encode_labels ()
252284 print ("labels encoded" )
253285
@@ -328,15 +360,9 @@ def main():
328360 shuffle = args .shuffle
329361 annotations = args .annotations
330362 video = args .video
331- output_path = args .output_path
363+ results_sink = args .results_sink
332364
333365 setGPU (gpu_name )
334-
335- output_path = "/home/user/results"
336-
337- results_sink = (
338- os .path .join (output_path , "primate/behavior-{}-{}/" .format (network , datetime .now ().strftime ("%Y-%m-%d-%H_%M" )))
339- )
340366 check_directory (results_sink )
341367
342368 if annotations :
@@ -358,22 +384,19 @@ def main():
358384 # load cfg
359385 config = load_config ("../configs/behavior/shared_config" )
360386 beh_config = load_config (
361- "../configs/behavior/primate/primate_final "
387+ "../configs/behavior/default "
362388 )
363389 config .update (beh_config )
364-
365390 print (config )
366391
367- config ["encode_labels" ]= True
368392 num_classes = len (np .unique (annotation ))
369-
370393 dataloader = Dataloader (
371394 x_train , y_train , x_test , y_test , config
372395 )
373-
374396 dataloader .prepare_data ()
375397 train_behavior (dataloader = dataloader , num_classes = num_classes , config = config )
376- elif operation ("train_primate" ):
398+
399+ elif operation == "train_primate" :
377400 config_name = "primate_final"
378401 config = load_config ("../configs/behavior/primate/" + config_name )
379402 train_primate (config = config , results_sink = results_sink , shuffle = shuffle )
@@ -405,7 +428,7 @@ def main():
405428 action = "store" ,
406429 dest = "config_name" ,
407430 type = str ,
408- default = "behavior_config_baseline " ,
431+ default = "default " ,
409432 help = "behavioral config to use" ,
410433)
411434parser .add_argument (
@@ -416,6 +439,7 @@ def main():
416439 default = "ours" ,
417440 help = "which network used for training" ,
418441)
442+ #TODO: check if folder and then load all files in folder, similar for vid files
419443parser .add_argument (
420444 "--annotations" ,
421445 action = "store" ,
@@ -433,20 +457,18 @@ def main():
433457 help = "path to folder with annotated video" ,
434458)
435459parser .add_argument (
436- "--shuffle" , action = "store" , dest = "shuffle" , type = bool , default = False ,
460+ "--results_sink" ,
461+ action = "store" ,
462+ dest = "results_sink" ,
463+ type = str ,
464+ default = './results/behavior5/' ,
465+ help = "path to results" ,
437466)
438-
439467parser .add_argument (
440- "--output_path" ,
441- action = "store" ,
442- dest = "output_path" ,
443- type = str ,
444- default = None ,
445- help = "Path to the folder where the ouput should be written"
468+ "--shuffle" , action = "store" , dest = "shuffle" , type = bool , default = False ,
446469)
447470
448-
449471# example usage
450472# python behavior.py --annotations "/media/nexus/storage5/swissknife_data/primate/behavior/20180124T113800-20180124T115800_0.csv" --video "/media/nexus/storage5/swissknife_data/primate/behavior/fullvids_20180124T113800-20180124T115800_%T1_0.mp4" --gpu 2
451473if __name__ == "__main__" :
452- main ()
474+ main ()
0 commit comments