-
Notifications
You must be signed in to change notification settings - Fork 366
Expand file tree
/
Copy pathbest_arms.py
More file actions
214 lines (185 loc) · 8.08 KB
/
best_arms.py
File metadata and controls
214 lines (185 loc) · 8.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from collections.abc import Sequence
from typing import final
from ax.adapter.base import Adapter
from ax.analysis.analysis import Analysis
from ax.analysis.summary import Summary
from ax.analysis.utils import validate_experiment
from ax.core.analysis_card import AnalysisCard
from ax.core.batch_trial import BatchTrial
from ax.core.experiment import Experiment
from ax.core.trial_status import TrialStatus
from ax.exceptions.core import DataRequiredError
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.service.utils.best_point_utils import get_best_trial_indices
from pyre_extensions import none_throws, override
# Subtitle constants for reuse
_MODEL_PREDICTIONS_DESCRIPTION = (
"based on model predictions. The model predictions apply shrinkage for "
"noise and adjust for non-stationarity, making them more representative "
"of reproducible effects."
)
_RAW_OBSERVATIONS_DESCRIPTION = (
"based on raw observations. This reflects actual measured performance "
"during execution."
)
_PARETO_DESCRIPTION = (
"These trials represent optimal trade-offs between competing objectives. "
"Use this to understand the available trade-offs and select a trial that "
"best balances your optimization goals."
)
_SOO_DESCRIPTION = (
"This trial achieved the optimal objective value and represents the "
"recommended configuration for your optimization goal."
)
@final
class BestArms(Analysis):
"""
High-level summary of the best trial(s) in the Experiment with one row per arm.
Any values missing at compute time will be represented as None. Columns where
every value is None will be omitted by default.
For single-objective optimization (SOO), this analysis identifies the best trial
while for multi-objective optimization (MOO), this analysis identifies the Pareto
frontier trials.
The DataFrame computed will contain one row per best arm and the following columns:
- trial_index: The trial index of the arm
- arm_name: The name of the arm
- trial_status: The status of the trial (e.g. RUNNING, SUCCEEDED, FAILED)
- failure_reason: The reason for the failure, if applicable
- generation_node: The name of the ``GenerationNode`` that generated the arm
- **METADATA: Any metadata associated with the trial, as specified by the
Experiment's runner.run_metadata_report_keys field
- **METRIC_NAME: The observed mean of the metric specified, for each metric
- **PARAMETER_NAME: The parameter value for the arm, for each parameter
Args:
trial_statuses: If specified, only include trials with this status.
use_model_predictions: If True, use model predictions for best trial selection
instead of raw observations. This is useful in noisy settings where model
predictions can help filter out observation noise.
"""
def __init__(
self,
trial_statuses: Sequence[TrialStatus] | None = None,
use_model_predictions: bool = False,
) -> None:
self.trial_statuses: Sequence[TrialStatus] = (
[TrialStatus.COMPLETED] if trial_statuses is None else trial_statuses
)
self.use_model_predictions = use_model_predictions
@override
def validate_applicable_state(
self,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategy | None = None,
adapter: Adapter | None = None,
) -> str | None:
# Basic experiment validation
error = validate_experiment(
experiment=experiment,
require_trials=True,
require_data=True,
)
if error is not None:
return error
# Validate optimization config exists
if experiment is None or experiment.optimization_config is None:
return (
"`BestArms` analysis requires an `OptimizationConfig`. "
"Ensure the `Experiment` has an `optimization_config` set to compute "
"this analysis."
)
# Check for trials with required status
eligible_trial_indices = [
trial.index
for trial in experiment.trials.values()
if trial.status in self.trial_statuses
]
if not eligible_trial_indices:
status_names = [s.name for s in self.trial_statuses]
return f"No trials found with status in {status_names}."
optimization_config = experiment.optimization_config
# Validate GenerationStrategy is present when using model predictions or MOO
if self.use_model_predictions or optimization_config.is_moo_problem:
if generation_strategy is None:
return (
"`BestArms` analysis requires a `GenerationStrategy` input "
"when using model predictions or for multi-objective "
"optimization problems."
)
return None
@override
def compute(
self,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategy | None = None,
adapter: Adapter | None = None,
) -> AnalysisCard:
exp = none_throws(experiment)
optimization_config = none_throws(exp.optimization_config)
# Filter trials by status before finding best trials
eligible_trial_indices = [
trial.index
for trial in exp.trials.values()
if trial.status in self.trial_statuses
]
trial_indices = get_best_trial_indices(
experiment=exp,
optimization_config=optimization_config,
generation_strategy=generation_strategy,
trial_indices=eligible_trial_indices,
use_model_predictions=self.use_model_predictions,
)
if not trial_indices:
raise DataRequiredError(
"No best arm(s) could be identified. This could be due to "
"insufficient data or no trials meeting the optimization criteria."
)
# Use Summary analysis to compute the dataframe for the best trials to ensure
# consistency in formatting.
summary = Summary(trial_indices=trial_indices)
summary_card = summary.compute(
experiment=experiment,
generation_strategy=generation_strategy,
adapter=adapter,
)
is_moo = optimization_config.is_moo_problem
# Build descriptive subtitle based on optimization type and prediction method
prediction_method = (
_MODEL_PREDICTIONS_DESCRIPTION
if self.use_model_predictions
else _RAW_OBSERVATIONS_DESCRIPTION
)
if is_moo:
title_prefix = "Pareto Frontier Trials"
subtitle = (
f"Displays trials on the Pareto frontier {prediction_method} "
f"No trial is strictly better across all objectives. "
f"{_PARETO_DESCRIPTION}"
)
else:
title_prefix = "Best Trial"
subtitle = (
f"Displays the trial with the best objective value "
f"{prediction_method} {_SOO_DESCRIPTION}"
)
# Add trial status context
status_names = ", ".join(s.name for s in self.trial_statuses)
subtitle += f" Only considering {status_names} trials."
# Add relativization context
if "relativized" in summary_card.subtitle:
subtitle += " Metric values are shown relative to the status quo baseline."
has_batch_trials = any(
isinstance(trial, BatchTrial) for trial in exp.trials.values()
)
display_name = "BestArm" if has_batch_trials else "BestTrials"
card = self._create_analysis_card(
title=(f"{title_prefix} for Experiment"),
subtitle=subtitle,
df=summary_card.df,
)
card.name = display_name
return card