Skip to content

Commit efad9e0

Browse files
committed
- initial commit
1 parent 119ab9d commit efad9e0

1,412 files changed

Lines changed: 22877 additions & 0 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

AE_dyna_results.py

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
import os
2+
import pickle
3+
import matplotlib.pyplot as plt
4+
import numpy as np
5+
6+
label = "ME-TRPO"
7+
label = "AE-DYNA"
8+
if label == "ME-TRPO":
9+
# ME-TRPO results
10+
project_directory = 'Data_Experiments/2020_10_06_ME_TRPO_stable@FERMI/run2/'
11+
else:
12+
# AE-Dyna results
13+
project_directory = 'Data_Experiments/2020_11_05_AE_Dyna@FERMI/-nr_steps_25-cr_lr-n_ep_13-m_bs_100-sim_steps_3000-m_iter_35-ensnr_3-init_200/'
14+
15+
def read_rewards(rewards):
16+
iterations_all = []
17+
final_rews_all = []
18+
mean_rews_all = []
19+
stds = []
20+
21+
22+
23+
iterations = []
24+
final_rews = []
25+
mean_rews = []
26+
for i in range(len(rewards)):
27+
if len(rewards[i]) > 0:
28+
final_rews.append(rewards[i][len(rewards[i]) - 1])
29+
iterations.append(len(rewards[i]))
30+
try:
31+
mean_rews.append(np.sum(rewards[i][1:]))
32+
except:
33+
mean_rews.append([])
34+
stds.append(np.std(rewards[i][1:]))
35+
36+
# iterations = np.mean(np.array(iterations_all), axis=0)
37+
# final_rews = np.mean(np.array(final_rews_all), axis=0)
38+
# mean_rews = np.mean(np.array(mean_rews_all), axis=0)
39+
40+
return np.array(iterations), np.array(final_rews), np.array(mean_rews), np.array(stds)
41+
42+
43+
44+
def plot_results(data, label='Verification', **kwargs):
45+
'''plotting'''
46+
rewards = data['rews']
47+
# iterations = []
48+
# finals = []
49+
# means = []
50+
# stds = []
51+
#
52+
# for i in range(len(rewards)):
53+
# if (len(rewards[i]) > 1):
54+
# finals.append(rewards[i][-1])
55+
# means.append(np.mean(rewards[i][1:]))
56+
# stds.append(np.std(rewards[i][1:]))
57+
# iterations.append(len(rewards[i]))
58+
#
59+
# x = range(len(iterations))
60+
# iterations = np.array(iterations)
61+
# finals = np.array(finals)
62+
# means = np.array(means)
63+
# stds = np.array(stds)
64+
65+
iterations, finals, means, stds = read_rewards(rewards)
66+
plot_suffix = label # , Fermi time: {env.TOTAL_COUNTER / 600:.1f} h'
67+
68+
fig, axs = plt.subplots(2, 1, sharex=True)
69+
70+
ax = axs[0]
71+
x = range(len(iterations))
72+
ax.plot(x, iterations)
73+
ax.set_ylabel('Iterations (1)')
74+
ax.set_title(plot_suffix)
75+
# fig.suptitle(label, fontsize=12)
76+
if 'data_number' in kwargs:
77+
ax1 = plt.twinx(ax)
78+
color = 'lime'
79+
ax1.set_ylabel('Mean reward', color=color) # we already handled the x-label with ax1
80+
ax1.tick_params(axis='y', labelcolor=color)
81+
ax1.plot(x, kwargs.get('data_number'), color=color)
82+
83+
ax = axs[1]
84+
color = 'blue'
85+
ax.set_ylabel('Final reward', color=color) # we already handled the x-label with ax1
86+
ax.tick_params(axis='y', labelcolor=color)
87+
ax.plot(x, finals, color=color)
88+
89+
ax.set_title('Final reward per episode') # + plot_suffix)
90+
ax.set_xlabel('Episodes (1)')
91+
92+
ax1 = plt.twinx(ax)
93+
color = 'lime'
94+
ax1.set_ylabel('Mean reward', color=color) # we already handled the x-label with ax1
95+
ax1.tick_params(axis='y', labelcolor=color)
96+
ax1.fill_between(x, means - stds, means + stds,
97+
alpha=0.5, edgecolor=color, facecolor='#FF9848')
98+
ax1.plot(x, means, color=color)
99+
fig.align_labels()
100+
# ax.set_ylim(ax1.get_ylim())
101+
if 'save_name' in kwargs:
102+
plt.savefig(kwargs.get('save_name') + '.pdf')
103+
plt.savefig(kwargs.get('save_name') + '.png')
104+
plt.show()
105+
106+
def plot_observables(data, label='Experiment', **kwargs):
107+
"""plot observables during the test"""
108+
109+
sim_rewards_all = np.array(data.get('sim_rewards_all'))
110+
step_counts_all = np.array(data.get('step_counts_all'))
111+
batch_rews_all = np.array(data.get('batch_rews_all'))
112+
tests_all = np.array(data.get('tests_all'))
113+
length_all = object['entropy_all']
114+
115+
fig, axs = plt.subplots(2, 1, sharex=True)
116+
x = np.arange(len(batch_rews_all[0]))
117+
ax = axs[0]
118+
ax.step(x, batch_rews_all[0])
119+
ax.fill_between(x, batch_rews_all[0] - batch_rews_all[1], batch_rews_all[0] + batch_rews_all[1],
120+
alpha=0.5)
121+
ax.set_ylabel('rews per batch')
122+
123+
ax.set_title(label)
124+
125+
ax2 = ax.twinx()
126+
127+
color = 'lime'
128+
ax2.set_ylabel('data points', color=color) # we already handled the x-label with ax1
129+
ax2.tick_params(axis='y', labelcolor=color)
130+
ax2.step(x, step_counts_all, color=color)
131+
132+
ax = axs[1]
133+
ax.plot(sim_rewards_all[0], ls=':')
134+
ax.fill_between(x, sim_rewards_all[0] - sim_rewards_all[1], sim_rewards_all[0] + sim_rewards_all[1],
135+
alpha=0.5)
136+
try:
137+
ax.plot(tests_all[0])
138+
ax.fill_between(x, tests_all[0] - tests_all[1], tests_all[0] + tests_all[1],
139+
alpha=0.5)
140+
ax.axhline(y=np.max(tests_all[0]), c='orange')
141+
except:
142+
pass
143+
ax.set_ylabel('rewards tests')
144+
# plt.tw
145+
ax.grid(True)
146+
if length_all:
147+
ax2 = ax.twinx()
148+
color = 'lime'
149+
ax2.set_ylabel(r'- log(std($p_\pi$))', color=color) # we already handled the x-label with ax1
150+
ax2.tick_params(axis='y', labelcolor=color)
151+
ax2.plot(length_all, color=color)
152+
fig.align_labels()
153+
154+
if 'save_name' in kwargs:
155+
plt.savefig(kwargs.get('save_name') + '.pdf')
156+
plt.savefig(kwargs.get('save_name') + '.png')
157+
plt.show()
158+
159+
# plot verification
160+
161+
filenames = []
162+
for file in os.listdir(project_directory):
163+
if 'final' in file:
164+
filenames.append(file)
165+
166+
filenames.sort()
167+
168+
filename = filenames[-1]
169+
print(filename)
170+
171+
filehandler = open(project_directory + filename, 'rb')
172+
object = pickle.load(filehandler)
173+
save_name = 'Figures/' + label+'_verification'
174+
plot_results(object,label=label, save_name=save_name)
175+
176+
# plot observables
177+
178+
filenames = []
179+
for file in os.listdir(project_directory):
180+
if 'training_observables' in file:
181+
filenames.append(file)
182+
183+
filenames.sort()
184+
185+
filename = filenames[-1]
186+
print(filename)
187+
188+
filehandler = open(project_directory + filename, 'rb')
189+
object = pickle.load(filehandler)
190+
save_name = 'Figures/' + label+'_observables'
191+
plot_observables(object, label=label, save_name=save_name)
192+

0 commit comments

Comments
 (0)