Skip to content

Commit 4a8e993

Browse files
aswinklforst
andauthored
Add support for classifier "scorers" (#1553)
Co-authored-by: Luca Forstner <luca.forstner@gmail.com>
1 parent 25d25d5 commit 4a8e993

13 files changed

Lines changed: 601 additions & 207 deletions

File tree

js/dev/server.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import {
88
EvalScorer,
99
EvaluatorDef,
1010
OneOrMoreScores,
11+
classifierName,
1112
scorerName,
1213
} from "../src/framework";
1314
import { errorHandler } from "./errorHandler";
@@ -117,9 +118,12 @@ export function runDevServer(
117118

118119
evalDefs[name] = {
119120
parameters,
120-
scores: evaluator.scores.map((score, idx) => ({
121+
scores: (evaluator.scores ?? []).map((score, idx) => ({
121122
name: scorerName(score, idx),
122123
})),
124+
classifiers: (evaluator.classifiers ?? []).map((classifier, idx) => ({
125+
name: classifierName(classifier, idx),
126+
})),
123127
};
124128
}
125129

@@ -209,7 +213,7 @@ export function runDevServer(
209213
{
210214
...evaluator,
211215
data: evalData.data,
212-
scores: evaluator.scores.concat(
216+
scores: (evaluator.scores ?? []).concat(
213217
scores?.map((score) =>
214218
makeScorer(
215219
state,

js/dev/types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ export type SerializedParametersContainer = z.infer<
111111
export const evaluatorDefinitionSchema = z.object({
112112
parameters: serializedParametersContainerSchema.optional(),
113113
scores: z.array(z.object({ name: z.string() })).optional(),
114+
classifiers: z.array(z.object({ name: z.string() })).optional(),
114115
});
115116
export type EvaluatorDefinition = z.infer<typeof evaluatorDefinitionSchema>;
116117

js/src/cli/functions/infer-source.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ export async function findCodeDefinition({
8686
fn =
8787
location.position.type === "task"
8888
? evaluator.task
89-
: evaluator.scores[location.position.index];
89+
: location.position.type === "scorer"
90+
? (evaluator.scores ?? [])[location.position.index]
91+
: (evaluator.classifiers ?? [])[location.position.index];
9092
}
9193
} else if (location.type === "function") {
9294
fn = outFileModule.functions[location.index].handler;

js/src/cli/functions/upload.test.ts

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,4 +84,37 @@ describe("buildBundledFunctionEntry", () => {
8484

8585
expect(entry.tags).toBeUndefined();
8686
});
87+
88+
test("preserves classifier experiment locations", async () => {
89+
const entry = await buildBundledFunctionEntry({
90+
spec: {
91+
project_id: "proj-123",
92+
name: "test-classifier",
93+
slug: "test-classifier",
94+
description: "Test classifier",
95+
location: {
96+
type: "experiment" as const,
97+
eval_name: "eval-1",
98+
position: {
99+
type: "classifier" as const,
100+
index: 0,
101+
},
102+
},
103+
function_type: "classifier" as const,
104+
},
105+
runtime_context: { runtime: "node", version: "22.0.0" },
106+
bundleId: "bundle-123",
107+
sourceMapContext: undefined,
108+
});
109+
110+
expect(entry.function_type).toBe("classifier");
111+
expect(entry.function_data.data.location).toEqual({
112+
type: "experiment",
113+
eval_name: "eval-1",
114+
position: {
115+
type: "classifier",
116+
index: 0,
117+
},
118+
});
119+
});
87120
});

js/src/cli/functions/upload.ts

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import {
44
type IfExistsType as IfExists,
55
} from "../../generated_types";
66
import type { BuildSuccess, EvaluatorState, FileHandle } from "../types";
7-
import { scorerName, warning } from "../../framework";
7+
import { classifierName, scorerName, warning } from "../../framework";
88
import {
99
_internalGetGlobalState,
1010
Experiment,
@@ -181,23 +181,42 @@ export async function uploadHandleBundles({
181181
function_type: "task",
182182
origin,
183183
},
184-
...evaluator.evaluator.scores.map((score, i): BundledFunctionSpec => {
185-
const name = scorerName(score, i);
186-
return {
187-
...baseInfo,
188-
// There is a very small chance that someone names a function with the same convention, but
189-
// let's assume it's low enough that it doesn't matter.
190-
...formatNameAndSlug(["eval", namePrefix, "scorer", name]),
191-
description: `Score ${name} for eval ${namePrefix}`,
192-
location: {
193-
type: "experiment",
194-
eval_name: evaluator.evaluator.evalName,
195-
position: { type: "scorer", index: i },
196-
},
197-
function_type: "scorer",
198-
origin,
199-
};
200-
}),
184+
...(evaluator.evaluator.scores ?? []).map(
185+
(score, i): BundledFunctionSpec => {
186+
const name = scorerName(score, i);
187+
return {
188+
...baseInfo,
189+
// There is a very small chance that someone names a function with the same convention, but
190+
// let's assume it's low enough that it doesn't matter.
191+
...formatNameAndSlug(["eval", namePrefix, "scorer", name]),
192+
description: `Score ${name} for eval ${namePrefix}`,
193+
location: {
194+
type: "experiment",
195+
eval_name: evaluator.evaluator.evalName,
196+
position: { type: "scorer", index: i },
197+
},
198+
function_type: "scorer",
199+
origin,
200+
};
201+
},
202+
),
203+
...(evaluator.evaluator.classifiers ?? []).map(
204+
(classifier, i): BundledFunctionSpec => {
205+
const name = classifierName(classifier, i);
206+
return {
207+
...baseInfo,
208+
...formatNameAndSlug(["eval", namePrefix, "classifier", name]),
209+
description: `Classifier ${name} for eval ${namePrefix}`,
210+
location: {
211+
type: "experiment",
212+
eval_name: evaluator.evaluator.evalName,
213+
position: { type: "classifier", index: i },
214+
},
215+
function_type: "classifier",
216+
origin,
217+
};
218+
},
219+
),
201220
];
202221

203222
bundleSpecs.push(...fileSpecs);
@@ -220,9 +239,14 @@ export async function uploadHandleBundles({
220239
serializeRemoteEvalParametersContainer(resolvedParameters),
221240
}
222241
: {}),
223-
scores: evaluator.evaluator.scores.map((score, i) => ({
242+
scores: (evaluator.evaluator.scores ?? []).map((score, i) => ({
224243
name: scorerName(score, i),
225244
})),
245+
classifiers: (evaluator.evaluator.classifiers ?? []).map(
246+
(classifier, i) => ({
247+
name: classifierName(classifier, i),
248+
}),
249+
),
226250
};
227251

228252
bundleSpecs.push({

js/src/exports.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ export type {
193193
EvalResult,
194194
EvalScorerArgs,
195195
EvalScorer,
196+
EvalClassifier,
196197
EvaluatorDef,
197198
EvaluatorFile,
198199
ReporterBody,

0 commit comments

Comments
 (0)