Skip to content

Commit b95c56f

Browse files
committed
Fix issue with params short format and add regression test
1 parent 6ca1ecb commit b95c56f

4 files changed

Lines changed: 62 additions & 2 deletions

File tree

cli/polyaxon/_flow/operations/operation.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
StrictStr,
88
field_validator,
99
model_validator,
10-
validation_after,
10+
validation_after, validation_before,
1111
)
1212
from clipped.config.patch_strategy import PatchStrategy
1313
from clipped.config.schema import skip_partial, to_partial
@@ -16,7 +16,7 @@
1616
from polyaxon._flow.component.component import V1Component
1717
from polyaxon._flow.hooks import V1Hook
1818
from polyaxon._flow.operations.base import BaseOp
19-
from polyaxon._flow.params import V1Param
19+
from polyaxon._flow.params import V1Param, normalize_param_value
2020
from polyaxon._flow.references import V1DagRef, V1HubRef, V1PathRef, V1UrlRef
2121
from polyaxon._flow.run.patch import validate_run_patch
2222
from polyaxon._flow.templates import TemplateMixinConfig, V1Template
@@ -544,6 +544,13 @@ class V1Operation(BaseOp, TemplateMixinConfig):
544544
run_patch: Optional[Dict] = Field(alias="runPatch", default=None)
545545
template: Optional[V1Template] = None
546546

547+
@field_validator("params", **validation_before)
548+
@classmethod
549+
def validate_params(cls, params):
550+
if not params:
551+
return params
552+
return {k: normalize_param_value(v) for k, v in params.items()}
553+
547554
@model_validator(**validation_after)
548555
@skip_partial
549556
def validate_reference(cls, values):

cli/polyaxon/_flow/params/params.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ def is_short_form_param(value: Any) -> bool:
5656
Returns:
5757
True if the value is in short-form, False if it's a full-form V1Param dict
5858
"""
59+
if isinstance(value, V1Param):
60+
return False
61+
5962
if not isinstance(value, Mapping):
6063
# Non-dict values are always short-form
6164
return True
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
---
2+
version: 1.1
3+
kind: operation
4+
params:
5+
loss: MeanSquaredError
6+
flag: true
7+
component:
8+
tags: [foo, bar]
9+
inputs:
10+
- name: loss
11+
type: str
12+
- name: flag
13+
type: bool
14+
isFlag: true
15+
run:
16+
kind: job
17+
container:
18+
image: my_image
19+
command: ["/bin/sh", "-c"]
20+
args: video_prediction_train --loss={{loss}} {{ flag }}

cli/tests/test_polyaxonfile/test_polyaxonfile.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,35 @@ def test_passing_params_overrides_polyaxonfiles(self):
227227
== "video_prediction_train --loss=some-loss --flag"
228228
)
229229

230+
def test_passing_short_params_overrides_polyaxonfiles(self):
231+
# Test that short form params (without the "value" wrapper) are parsed correctly
232+
op_config = OperationSpecification.read(
233+
[
234+
os.path.abspath(
235+
"tests/fixtures/typing/required_inputs_with_short_params.yml"
236+
),
237+
]
238+
)
239+
240+
# Verify params are parsed correctly from short form
241+
assert op_config.params["loss"].value == "MeanSquaredError"
242+
assert op_config.params["flag"].value is True
243+
244+
# Compile the operation and apply params
245+
run_config = OperationSpecification.compile_operation(op_config)
246+
# Apply the params from the operation to the compiled operation
247+
run_config.apply_params(params=op_config.params)
248+
run_config = CompiledOperationSpecification.apply_operation_contexts(run_config)
249+
run_config = CompiledOperationSpecification.apply_runtime_contexts(run_config)
250+
assert run_config.version == 1.1
251+
assert run_config.tags == ["foo", "bar"]
252+
assert run_config.run.container.image == "my_image"
253+
assert run_config.run.container.command == ["/bin/sh", "-c"]
254+
assert (
255+
run_config.run.container.args
256+
== "video_prediction_train --loss=MeanSquaredError --flag"
257+
)
258+
230259
def test_passing_wrong_params_raises(self):
231260
with self.assertRaises(PolyaxonfileError):
232261
check_polyaxonfile(
@@ -830,3 +859,4 @@ def test_specification_with_context_requirement(self):
830859
run_config, contexts=contexts
831860
)
832861
assert run_config.run.to_dict() == expected_run
862+

0 commit comments

Comments
 (0)