-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathTrainSession.py
More file actions
68 lines (53 loc) · 1.64 KB
/
TrainSession.py
File metadata and controls
68 lines (53 loc) · 1.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import tensorflow as tf
from BotTrainer import get_data, create_network, train_network
import os
import sys
from pickle import Pickler
weights_file = 'model.ckpt'
weights_dir = 'tools/weights'
pickle_dir = 'pickle_files'
pickle_file_loss = 'loss.p'
def train_session(session_mode):
# Fetch data
print('Fetching data...')
batches = get_data(session_mode)
# Define Session
sess = tf.InteractiveSession()
# Create Network
print('Creating network...')
q_s, s, keep_prob = create_network()
# Train Network
print('Training...')
if len(batches) > 1:
# Train
last_loss = train_network(q_s, s, sess, batches, keep_prob, session_mode)
# Save the weights
if not os.path.exists(weights_dir):
os.makedirs(weights_dir)
print('directory created')
saver = tf.train.Saver()
path = os.path.join(weights_dir, weights_file)
saver.save(sess, path)
print('weights saved')
# save loss value
loss_path = os.path.join(pickle_dir, pickle_file_loss)
with open(loss_path, 'ab+') as f:
Pickler(f).dump(last_loss)
def main():
# get command line arguments
if len(sys.argv) > 1:
session_mode = sys.argv[1]
else:
# default is 'training'
session_mode = 'training'
# decide session mode
if session_mode == 'observing':
print('Session mode: ' + session_mode)
elif session_mode == 'debug':
print('Session mode: ' + session_mode)
else:
print('Session mode: training')
# start training
train_session(session_mode='training')
if __name__ == '__main__':
main()