-
Notifications
You must be signed in to change notification settings - Fork 366
Expand file tree
/
Copy pathresults.py
More file actions
392 lines (359 loc) · 14.9 KB
/
results.py
File metadata and controls
392 lines (359 loc) · 14.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import itertools
from collections.abc import Sequence
from typing import final
from ax.adapter.base import Adapter
from ax.analysis.analysis import Analysis
from ax.analysis.best_arms import BestArms
from ax.analysis.plotly.arm_effects import ArmEffectsPlot
from ax.analysis.plotly.bandit_rollout import BanditRollout
from ax.analysis.plotly.progression import (
PROGRESSION_CARDGROUP_SUBTITLE,
PROGRESSION_CARDGROUP_TITLE,
ProgressionPlot,
)
from ax.analysis.plotly.scatter import (
SCATTER_CARDGROUP_SUBTITLE,
SCATTER_CARDGROUP_TITLE,
ScatterPlot,
)
from ax.analysis.plotly.utility_progression import UtilityProgressionAnalysis
from ax.analysis.summary import Summary
from ax.analysis.utils import extract_relevant_adapter, validate_experiment
from ax.core.analysis_card import AnalysisCardGroup
from ax.core.arm import Arm
from ax.core.batch_trial import BatchTrial
from ax.core.data import MAP_KEY
from ax.core.experiment import Experiment
from ax.core.map_metric import MapMetric
from ax.core.outcome_constraint import ScalarizedOutcomeConstraint
from ax.core.trial_status import TrialStatus
from ax.core.utils import is_bandit_experiment
from ax.exceptions.core import UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy
from pyre_extensions import none_throws, override
RESULTS_CARDGROUP_TITLE = "Results Analysis"
RESULTS_CARDGROUP_SUBTITLE = (
"Result Analyses provide a high-level overview of the results of the optimization "
"process so far with respect to the metrics specified in experiment design."
)
ARM_EFFECTS_PAIR_CARDGROUP_TITLE = (
"Metric Effects: Predicted and observed effects for all arms in the experiment"
)
ARM_EFFECTS_PAIR_CARDGROUP_SUBTITLE = (
"These pair of plots visualize the metric effects for each arm, with the Ax "
"model predictions on the left and the raw observed data on the right. The "
"predicted effects apply shrinkage for noise and adjust for non-stationarity "
"in the data, so they are more representative of the reproducible effects that "
"will manifest in a long-term validation experiment. "
)
@final
class ResultsAnalysis(Analysis):
"""
An Analysis that provides a high-level overview of the results of the optimization
process so far, e.g. effects on all arms. It produces an analysis card group.
"""
@override
def validate_applicable_state(
self,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategy | None = None,
adapter: Adapter | None = None,
) -> str | None:
return validate_experiment(
experiment=experiment,
require_trials=True,
require_data=True,
)
@override
def compute(
self,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategy | None = None,
adapter: Adapter | None = None,
) -> AnalysisCardGroup:
experiment = none_throws(experiment)
# If the Experiment has an OptimizationConfig set, extract the objective and
# constraint names.
objective_names = []
constraint_names = []
if (optimization_config := experiment.optimization_config) is not None:
objective_names = optimization_config.objective.metric_names
for oc in optimization_config.outcome_constraints:
constraint_names.extend(oc.metric_names)
relevant_adapter = extract_relevant_adapter(
experiment=experiment,
generation_strategy=generation_strategy,
adapter=adapter,
)
# Check if there are BatchTrials present.
has_batch_trials = any(
isinstance(trial, BatchTrial) for trial in experiment.trials.values()
)
# Relativize the effects if the status quo is set and there are BatchTrials
# present.
relativize = experiment.status_quo is not None and has_batch_trials
# Compute both observed and modeled effects for each objective and constraint.
arm_effect_pair_group = (
ArmEffectsPair(
metric_names=[*objective_names, *constraint_names],
relativize=relativize,
).compute_or_error_card(
experiment=experiment,
generation_strategy=generation_strategy,
adapter=relevant_adapter,
)
if len(objective_names) > 0
else None
)
# If there are multiple objectives, compute scatter plots of each combination
# of two objectives. For MOO experiments, show the Pareto frontier line.
objective_scatter_group = (
AnalysisCardGroup(
name="Objective Scatter Plots",
title=SCATTER_CARDGROUP_TITLE + " (Objectives)",
subtitle=SCATTER_CARDGROUP_SUBTITLE,
children=[
ScatterPlot(
x_metric_name=x,
y_metric_name=y,
relativize=relativize,
show_pareto_frontier=True,
).compute_or_error_card(
experiment=experiment,
generation_strategy=generation_strategy,
adapter=relevant_adapter,
)
for x, y in itertools.combinations(objective_names, 2)
],
)
if len(objective_names) > 1
else None
)
# If there are objectives and constraints, compute scatter plots of each
# objective versus each constraint. Objectives are always plotted on the x-
# axis and constraints on the y-axis.
constraint_scatter_group = (
AnalysisCardGroup(
name="Constraint Scatter Plots",
title=SCATTER_CARDGROUP_TITLE + " (Constraints)",
subtitle=SCATTER_CARDGROUP_SUBTITLE,
children=[
ScatterPlot(
x_metric_name=objective_name,
y_metric_name=constraint_name,
relativize=relativize,
).compute_or_error_card(
experiment=experiment,
generation_strategy=generation_strategy,
adapter=relevant_adapter,
)
for objective_name in objective_names
for constraint_name in constraint_names
],
)
if len(objective_names) > 0 and len(constraint_names) > 0
else None
)
# Produce a parallel coordinates plot for each objective.
# TODO: mpolson mgarrard bring back parallel coordinates after fixing
# objective_parallel_coordinates_group = (
# AnalysisCardGroup(
# name="Objective Parallel Coordinates Plots",
# children=[
# ParallelCoordinatesPlot(
# metric_name=metric_name
# ).compute_or_error_card(
# experiment=experiment,
# generation_strategy=generation_strategy,
# adapter=adapter,
# )
# for metric_name in objective_names
# ],
# )
# if len(objective_names) > 0
# else None
# )
# Add BanditRollout for experiments with specific generation strategy
bandit_rollout_card = (
BanditRollout().compute_or_error_card(
experiment=experiment,
generation_strategy=generation_strategy,
adapter=adapter,
)
if generation_strategy
and is_bandit_experiment(generation_strategy_name=generation_strategy.name)
else None
)
# Compute best trials, skip for experiments with ScalarizedOutcomeConstraints
has_scalarized_outcome_constraints = optimization_config is not None and any(
isinstance(oc, ScalarizedOutcomeConstraint)
for oc in optimization_config.outcome_constraints
)
best_arms_card = (
BestArms().compute_or_error_card(
experiment=experiment,
generation_strategy=generation_strategy,
adapter=adapter,
)
if not has_scalarized_outcome_constraints
else None
)
# Add utility progression if there are objectives
# Skip for experiments with ScalarizedOutcomeConstraint as feasibility
# evaluation for scalarized outcome constraints is not yet implemented
# Skip for online experiments (those with BatchTrials)
utility_progression_card = (
UtilityProgressionAnalysis().compute_or_error_card(
experiment=experiment,
generation_strategy=generation_strategy,
adapter=adapter,
)
if len(objective_names) > 0
and not has_scalarized_outcome_constraints
and not has_batch_trials
else None
)
summary = Summary().compute_or_error_card(
experiment=experiment,
generation_strategy=generation_strategy,
adapter=adapter,
)
# Compute progression plots if there is curve data.
progression_group = None
data = experiment.lookup_data()
metrics = experiment.metrics.values()
map_metrics = [m for m in metrics if isinstance(m, MapMetric)]
if (
data.has_step_column
and data.full_df[MAP_KEY].notna().any()
and len(map_metrics) > 0
):
progression_cards = [
ProgressionPlot(
metric_name=m.name, by_wallclock_time=by_wallclock_time
).compute_or_error_card(
experiment=experiment,
generation_strategy=generation_strategy,
adapter=adapter,
)
for m in map_metrics
for by_wallclock_time in (False, True)
]
if progression_cards:
progression_group = AnalysisCardGroup(
name="ProgressionAnalysis",
title=PROGRESSION_CARDGROUP_TITLE,
subtitle=PROGRESSION_CARDGROUP_SUBTITLE,
children=progression_cards,
)
return self._create_analysis_card_group(
title=RESULTS_CARDGROUP_TITLE,
subtitle=RESULTS_CARDGROUP_SUBTITLE,
children=[
child
for child in (
arm_effect_pair_group,
objective_scatter_group,
constraint_scatter_group,
bandit_rollout_card,
utility_progression_card,
progression_group,
best_arms_card,
summary,
)
if child is not None
],
)
@final
class ArmEffectsPair(Analysis):
"""
Compute two ArmEffectsPlots in a single AnalysisCardGroup, one plotting model
predictions and one plotting raw observed data.
"""
def __init__(
self,
metric_names: Sequence[str] | None = None,
relativize: bool = False,
trial_index: int | None = None,
trial_statuses: Sequence[TrialStatus] | None = None,
additional_arms: Sequence[Arm] | None = None,
label: str | None = None,
) -> None:
"""
Args:
metric_names: The names of the metrics to include in the plot. If not
specified, all metrics in the experiment will be used.
relativize: Whether to relativize the effects of each arm against the status
quo arm. If multiple status quo arms are present, relativize each arm
against the status quo arm from the same trial.
trial_index: If present, only use arms from the trial with the given index.
additional_arms: If present, include these arms in the plot in addition to
the arms in the experiment. These arms will be marked as belonging to a
trial with index -1.
label: A label to use in the plot in place of the metric name.
"""
self.metric_names = metric_names
self.relativize = relativize
self.trial_index = trial_index
self.trial_statuses = trial_statuses
self.additional_arms = additional_arms
self.label = label
@override
def compute(
self,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategy | None = None,
adapter: Adapter | None = None,
) -> AnalysisCardGroup:
if experiment is None:
raise UserInputError("ArmEffectsPlot requires an Experiment.")
relevant_adapter = extract_relevant_adapter(
experiment=experiment,
generation_strategy=generation_strategy,
adapter=adapter,
)
pairs: list[AnalysisCardGroup] = []
for metric_name in self.metric_names or [*experiment.metrics.keys()]:
# TODO: Test for no effects and render a message instead of a flat line.
predicted_analysis = ArmEffectsPlot(
metric_name=metric_name,
use_model_predictions=True,
relativize=self.relativize,
trial_index=self.trial_index,
trial_statuses=self.trial_statuses,
additional_arms=self.additional_arms,
label=self.label,
)
raw_analysis = ArmEffectsPlot(
metric_name=metric_name,
use_model_predictions=False,
relativize=self.relativize,
trial_index=self.trial_index,
trial_statuses=self.trial_statuses,
additional_arms=self.additional_arms,
label=self.label,
)
pair = AnalysisCardGroup(
name=f"ArmEffects Pair {metric_name}",
title=f"Metric Effects Pair for {metric_name}",
subtitle=None,
children=[
predicted_analysis.compute_or_error_card(
experiment=experiment,
generation_strategy=generation_strategy,
adapter=relevant_adapter,
),
raw_analysis.compute_or_error_card(experiment=experiment),
],
)
pairs.append(pair)
return self._create_analysis_card_group(
title=ARM_EFFECTS_PAIR_CARDGROUP_TITLE,
subtitle=ARM_EFFECTS_PAIR_CARDGROUP_SUBTITLE,
children=pairs,
)