-
Notifications
You must be signed in to change notification settings - Fork 366
Expand file tree
/
Copy pathoverview.py
More file actions
291 lines (267 loc) · 10.9 KB
/
overview.py
File metadata and controls
291 lines (267 loc) · 10.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from collections.abc import Callable
from typing import Any, final
from ax.adapter.base import Adapter
from ax.analysis.analysis import Analysis, ErrorAnalysisCard
from ax.analysis.diagnostics import DiagnosticAnalysis
from ax.analysis.healthcheck.baseline_improvement import BaselineImprovementAnalysis
from ax.analysis.healthcheck.can_generate_candidates import (
CanGenerateCandidatesAnalysis,
)
from ax.analysis.healthcheck.complexity_rating import ComplexityRatingAnalysis
from ax.analysis.healthcheck.constraints_feasibility import (
ConstraintsFeasibilityAnalysis,
)
from ax.analysis.healthcheck.early_stopping_healthcheck import EarlyStoppingAnalysis
from ax.analysis.healthcheck.healthcheck_analysis import HealthcheckAnalysisCard
from ax.analysis.healthcheck.metric_fetching_errors import MetricFetchingErrorsAnalysis
from ax.analysis.healthcheck.predictable_metrics import PredictableMetricsAnalysis
from ax.analysis.healthcheck.search_space_analysis import SearchSpaceAnalysis
from ax.analysis.healthcheck.transfer_learning_analysis import TransferLearningAnalysis
from ax.analysis.insights import InsightsAnalysis
from ax.analysis.results import ResultsAnalysis
from ax.analysis.trials import AllTrialsAnalysis
from ax.analysis.utils import validate_experiment
from ax.core.analysis_card import AnalysisCardGroup
from ax.core.batch_trial import BatchTrial
from ax.core.experiment import Experiment
from ax.core.map_metric import MapMetric
from ax.core.trial_status import TrialStatus
from ax.exceptions.core import UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.service.orchestrator import OrchestratorOptions
from pyre_extensions import override
HEALTH_CHECK_CARDGROUP_TITLE = "Health Checks"
HEALTH_CHECK_CARDGROUP_SUBTITLE = (
"Comprehensive health checks designed to identify potential issues in the Ax "
"experiment. These checks cover areas such as metric fetching, search space "
"configuration, and candidate generation, with the aim of flagging areas where "
"user intervention may be necessary to ensure the experiment's robustness "
"and success."
)
OVERVIEW_CARDGROUP_TITLE = "Overview of the Entire Optimization Process "
OVERVIEW_CARDGROUP_SUBTITLE = (
"This analysis provides an overview of the entire optimization process. "
"It includes visualizations of the results obtained so far, insights into "
"the parameter and metric relationships learned by the Ax model, diagnostics "
"such as model fit, and health checks to assess the overall health of the "
"experiment."
)
@final
class OverviewAnalysis(Analysis):
"""
Top-level Analysis that provides an overview of the entire optimization process,
including results, insights, and diagnostics. OverviewAnalysis examines the
Experiment and GenerationStrategy's configuration and their respective current
states to heuristically determine which Analyses to compute under the hood.
AnalysisCards will be returned in the following groups:
* Overview
* Results
* Pairs of Modeled and Raw ArmEffectsPlots for objectives and
constraints
* Modeled ScatterPlots for objectives versus objectives and objectives
versus constraints
* ParallelCoordinatesPlot for objectives
* BanditRollout
* UtilityProgressionAnalysis
* ProgressionPlots
* BestArms
* Summary
* Insights
* Sensitivity Plots
* Slice Plots
* Contour Plots
* OutcomeConstraintsAnalysis
* MarginalEffectsPlot
* TopSurfacesAnalysis
* Diagnostic
* CrossValidationPlots
* Health Checks
* MetricFetchingErrorsAnalysis
* EarlyStoppingAnalysis
* CanGenerateCandidatesAnalysis
* ConstraintsFeasibilityAnalysis
* SearchSpaceAnalysis
* ComplexityRatingAnalysis
* PredictableMetricsAnalysis
* BaselineImprovementAnalysis
* Trial-Level Analyses
* Trial 0
* ArmEffectsPlot
...
"""
def __init__(
self,
can_generate: bool | None = None,
can_generate_reason: str | None = None,
can_generate_days_till_fail: int | None = None,
options: OrchestratorOptions | None = None,
tier_metadata: dict[str, Any] | None = None,
model_fit_threshold: float | None = None,
sqa_config: Any = None,
create_diff_paste_callable: Callable[[str, str, str], str] | None = None,
) -> None:
super().__init__()
self.can_generate = can_generate
self.can_generate_reason = can_generate_reason
self.can_generate_days_till_fail = can_generate_days_till_fail
self.options = options
self.tier_metadata = tier_metadata
self.model_fit_threshold = model_fit_threshold
self.sqa_config = sqa_config
self.create_diff_paste_callable = create_diff_paste_callable
@override
def validate_applicable_state(
self,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategy | None = None,
adapter: Adapter | None = None,
) -> str | None:
return validate_experiment(
experiment=experiment,
require_trials=False,
require_data=False,
)
@override
def compute(
self,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategy | None = None,
adapter: Adapter | None = None,
) -> AnalysisCardGroup:
# Compute the arm effects plots, scatter plots, etc.
results_group = ResultsAnalysis().compute_or_error_card(
experiment=experiment,
generation_strategy=generation_strategy,
adapter=adapter,
)
# Compute the sensitivity plots, slice plots, contour plots, etc.
insights_group = InsightsAnalysis().compute_or_error_card(
experiment=experiment,
generation_strategy=generation_strategy,
adapter=adapter,
)
# Compute the diagnostics section (cross validation plots)
diagnostics_group = DiagnosticAnalysis().compute_or_error_card(
experiment=experiment,
generation_strategy=generation_strategy,
adapter=adapter,
)
if experiment is None:
raise UserInputError(
"OverviewAnalysis requires a non-null experiment to compute candidate "
"trials. Please provide an experiment."
)
candidate_trials = experiment.extract_relevant_trials(
trial_statuses=[TrialStatus.CANDIDATE]
)
# Check if the experiment has data with a "step" column and MapMetrics
# (required for early stopping)
has_map_data = experiment.lookup_data().has_step_column
has_map_metrics = any(
isinstance(m, MapMetric) for m in experiment.metrics.values()
)
# Check if the experiment has BatchTrials
has_batch_trials = any(
isinstance(trial, BatchTrial) for trial in experiment.trials.values()
)
health_check_analyses = [
MetricFetchingErrorsAnalysis(),
(
EarlyStoppingAnalysis(
early_stopping_strategy=(
self.options.early_stopping_strategy if self.options else None
),
)
if has_map_data and has_map_metrics and not has_batch_trials
else None
),
CanGenerateCandidatesAnalysis(
can_generate_candidates=self.can_generate,
reason=self.can_generate_reason,
days_till_fail=self.can_generate_days_till_fail,
)
if self.can_generate is not None
and self.can_generate_reason is not None
and self.can_generate_days_till_fail is not None
else None,
(
ComplexityRatingAnalysis(
options=self.options,
tier_metadata=self.tier_metadata,
)
if self.options is not None
else None
)
if not has_batch_trials
else None,
ConstraintsFeasibilityAnalysis(),
(
PredictableMetricsAnalysis()
if self.model_fit_threshold is None
else PredictableMetricsAnalysis(
model_fit_threshold=self.model_fit_threshold
)
)
if not has_batch_trials
else None,
BaselineImprovementAnalysis() if not has_batch_trials else None,
TransferLearningAnalysis(
config=self.sqa_config,
create_diff_paste_callable=self.create_diff_paste_callable,
),
*[
SearchSpaceAnalysis(trial_index=trial.index)
for trial in candidate_trials
],
]
health_check_cards = [
analyis.compute_or_error_card(
experiment=experiment,
generation_strategy=generation_strategy,
adapter=adapter,
)
for analyis in health_check_analyses
if analyis is not None
]
user_facing_health_check_cards = [
card
for card in health_check_cards
if isinstance(card, HealthcheckAnalysisCard)
and card.is_user_facing()
or isinstance(card, ErrorAnalysisCard)
]
health_checks_group = (
AnalysisCardGroup(
name="HealthchecksAnalysis",
title=HEALTH_CHECK_CARDGROUP_TITLE,
subtitle=HEALTH_CHECK_CARDGROUP_SUBTITLE,
children=user_facing_health_check_cards,
)
if len(user_facing_health_check_cards) > 0
else None
)
trials_group = AllTrialsAnalysis().compute_or_error_card(
experiment=experiment,
generation_strategy=generation_strategy,
adapter=adapter,
)
return self._create_analysis_card_group(
title=OVERVIEW_CARDGROUP_TITLE,
subtitle=OVERVIEW_CARDGROUP_SUBTITLE,
children=[
group
for group in [
results_group,
insights_group,
diagnostics_group,
health_checks_group,
trials_group,
]
if group is not None
],
)