Skip to content

Commit 1575d5e

Browse files
committed
Fix multipart binary composed-schema matching
1 parent aef7c45 commit 1575d5e

5 files changed

Lines changed: 248 additions & 49 deletions

File tree

openapi_core/deserializing/media_types/deserializers.py

Lines changed: 87 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from dataclasses import dataclass
12
from typing import TYPE_CHECKING
23
from typing import Any
4+
from typing import Iterator
35
from typing import Mapping
46
from typing import Optional
57
from xml.etree.ElementTree import ParseError
@@ -23,6 +25,7 @@
2325
from openapi_core.schema.protocols import SuportsGetAll
2426
from openapi_core.schema.protocols import SuportsGetList
2527
from openapi_core.schema.schemas import get_properties
28+
from openapi_core.validation.schemas.exceptions import ValidateError
2629
from openapi_core.validation.schemas.validators import SchemaValidator
2730

2831
if TYPE_CHECKING:
@@ -63,6 +66,12 @@ def get_deserializer_callable(
6366
return self.media_type_deserializers[mimetype]
6467

6568

69+
@dataclass(frozen=True)
70+
class FormMediaSchemaMatch:
71+
schema: SchemaPath
72+
decoded_candidate: Mapping[str, Any]
73+
74+
6675
class MediaTypeDeserializer:
6776
def __init__(
6877
self,
@@ -97,7 +106,7 @@ def deserialize(self, value: bytes) -> Any:
97106
):
98107
return deserialized
99108

100-
# decode multipart request bodies if schema provided
109+
# Decode form-media bodies only when a schema is available.
101110
if self.schema is not None:
102111
return self.decode(deserialized)
103112

@@ -126,43 +135,30 @@ def evolve(
126135
schema=schema,
127136
schema_validator=schema_validator,
128137
schema_caster=schema_caster,
138+
encoding=self.encoding,
139+
**self.parameters,
129140
)
130141

131142
def decode(
132143
self, location: Mapping[str, Any], schema_only: bool = False
133144
) -> Mapping[str, Any]:
134-
# schema is required for multipart
145+
# Form-media decoding always needs a schema to resolve properties.
135146
assert self.schema is not None
136147
properties: dict[str, Any] = {}
137148

138-
# For urlencoded/multipart, use caster for oneOf/anyOf detection if validator available
149+
# For form media, select composed branches from decoded candidates.
139150
if self.schema_validator is not None:
140-
one_of_schema = self.schema_validator.get_one_of_schema(
141-
location, caster=self.schema_caster
142-
)
143-
if one_of_schema is not None:
144-
one_of_properties = self.evolve(one_of_schema).decode(
145-
location, schema_only=True
146-
)
147-
properties.update(one_of_properties)
151+
one_of_match = self.get_form_media_one_of_match(location)
152+
if one_of_match is not None:
153+
properties.update(one_of_match.decoded_candidate)
148154

149-
any_of_schemas = self.schema_validator.iter_any_of_schemas(
150-
location, caster=self.schema_caster
151-
)
152-
for any_of_schema in any_of_schemas:
153-
any_of_properties = self.evolve(any_of_schema).decode(
154-
location, schema_only=True
155-
)
156-
properties.update(any_of_properties)
155+
any_of_matches = self.iter_form_media_any_of_matches(location)
156+
for any_of_match in any_of_matches:
157+
properties.update(any_of_match.decoded_candidate)
157158

158-
all_of_schemas = self.schema_validator.iter_all_of_schemas(
159-
location
160-
)
161-
for all_of_schema in all_of_schemas:
162-
all_of_properties = self.evolve(all_of_schema).decode(
163-
location, schema_only=True
164-
)
165-
properties.update(all_of_properties)
159+
all_of_matches = self.iter_form_media_all_of_matches(location)
160+
for all_of_match in all_of_matches:
161+
properties.update(all_of_match.decoded_candidate)
166162

167163
for prop_name, prop_schema in get_properties(self.schema).items():
168164
try:
@@ -179,6 +175,70 @@ def decode(
179175

180176
return properties
181177

178+
def get_form_media_one_of_match(
179+
self,
180+
location: Mapping[str, Any],
181+
) -> Optional[FormMediaSchemaMatch]:
182+
if self.schema is None or "oneOf" not in self.schema:
183+
return None
184+
185+
for subschema in self.schema / "oneOf":
186+
match = self.get_form_media_schema_match(subschema, location)
187+
if match is not None:
188+
return match
189+
190+
return None
191+
192+
def iter_form_media_any_of_matches(
193+
self,
194+
location: Mapping[str, Any],
195+
) -> list[FormMediaSchemaMatch]:
196+
if self.schema is None or "anyOf" not in self.schema:
197+
return []
198+
199+
return list(self.iter_form_media_schema_matches("anyOf", location))
200+
201+
def iter_form_media_all_of_matches(
202+
self,
203+
location: Mapping[str, Any],
204+
) -> list[FormMediaSchemaMatch]:
205+
if self.schema is None or "allOf" not in self.schema:
206+
return []
207+
208+
return list(self.iter_form_media_schema_matches("allOf", location))
209+
210+
def iter_form_media_schema_matches(
211+
self,
212+
keyword: str,
213+
location: Mapping[str, Any],
214+
) -> Iterator[FormMediaSchemaMatch]:
215+
assert self.schema is not None
216+
217+
for subschema in self.schema / keyword:
218+
if keyword == "allOf" and "type" not in subschema:
219+
continue
220+
match = self.get_form_media_schema_match(subschema, location)
221+
if match is not None:
222+
yield match
223+
224+
def get_form_media_schema_match(
225+
self,
226+
subschema: SchemaPath,
227+
location: Mapping[str, Any],
228+
) -> Optional[FormMediaSchemaMatch]:
229+
assert self.schema_validator is not None
230+
231+
deserializer = self.evolve(subschema)
232+
decoded_candidate = deserializer.decode(location, schema_only=True)
233+
validator = self.schema_validator.evolve(subschema)
234+
235+
try:
236+
validator.validate(decoded_candidate)
237+
except ValidateError:
238+
return None
239+
240+
return FormMediaSchemaMatch(subschema, decoded_candidate)
241+
182242
def decode_property(
183243
self,
184244
prop_name: str,

openapi_core/unmarshalling/unmarshallers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def _get_param_or_header_and_schema(
117117
def _get_content_and_schema(
118118
self, raw: Any, content: SchemaPath, mimetype: Optional[str] = None
119119
) -> Tuple[Any, Optional[SchemaPath]]:
120-
casted, schema = super()._get_content_and_schema(
120+
casted, schema = self._get_content_schema_value_and_schema(
121121
raw, content, mimetype
122122
)
123123
if schema is None:

openapi_core/validation/schemas/validators.py

Lines changed: 159 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
log = logging.getLogger(__name__)
2121

2222

23+
_MISSING = object()
24+
25+
2326
class SchemaValidator:
2427
def __init__(
2528
self,
@@ -33,7 +36,8 @@ def __contains__(self, schema_format: str) -> bool:
3336
return schema_format in self.validator.format_checker.checkers
3437

3538
def validate(self, value: Any) -> None:
36-
errors_iter = self.validator.iter_errors(value)
39+
validation_value = self.get_binary_validation_value(value)
40+
errors_iter = self.validator.iter_errors(validation_value)
3741
errors = tuple(errors_iter)
3842
if errors:
3943
schema_type = (self.schema / "type").read_str_or_list("any")
@@ -93,6 +97,8 @@ def get_primitive_type(self, value: Any) -> Optional[str]:
9397
schema_types = sorted(self.validator.TYPE_CHECKER._type_checkers)
9498
assert isinstance(schema_types, list)
9599
for schema_type in schema_types:
100+
if self.accepts_binary_string_value(schema_type, value):
101+
return schema_type
96102
result = self.type_validator(value, type_override=schema_type)
97103
if not result:
98104
continue
@@ -104,6 +110,158 @@ def get_primitive_type(self, value: Any) -> Optional[str]:
104110
# OpenAPI 3.0: None is not a primitive type so None value will not find any type
105111
return None
106112

113+
def accepts_binary_string_value(
114+
self, schema_type: Optional[str], value: Any
115+
) -> bool:
116+
if schema_type != "string" or not isinstance(value, bytes):
117+
return False
118+
119+
schema_format = (self.schema / "format").read_str(None)
120+
return schema_format in ("binary", "byte")
121+
122+
def get_binary_validation_value(self, value: Any) -> Any:
123+
# OpenAPI binary and byte string values are represented as bytes,
124+
# but jsonschema validates string schemas against text values.
125+
if self.accepts_binary_string_value(
126+
(self.schema / "type").read_str(None), value
127+
):
128+
return self.decode_binary_string_value(value)
129+
130+
normalized = value
131+
132+
for keyword in ["oneOf", "anyOf", "allOf"]:
133+
if keyword not in self.schema:
134+
continue
135+
for subschema in self.schema / keyword:
136+
branch_value = self.evolve(
137+
subschema
138+
).get_binary_validation_value(value)
139+
normalized = self.merge_binary_validation_value(
140+
normalized, branch_value
141+
)
142+
143+
if isinstance(normalized, dict):
144+
return self.get_binary_validation_mapping_value(normalized)
145+
146+
if isinstance(normalized, list) and "items" in self.schema:
147+
return self.get_binary_validation_array_value(normalized)
148+
149+
return normalized
150+
151+
def decode_binary_string_value(self, value: bytes) -> str:
152+
try:
153+
return value.decode("utf-8")
154+
except UnicodeDecodeError:
155+
return value.decode("ASCII", errors="surrogateescape")
156+
157+
def get_binary_validation_mapping_value(self, value: Any) -> Any:
158+
normalized = value
159+
160+
if "properties" in self.schema:
161+
for prop_name, prop_schema in (self.schema / "properties").items():
162+
if prop_name not in value:
163+
continue
164+
prop_value = self.evolve(
165+
prop_schema
166+
).get_binary_validation_value(value[prop_name])
167+
if prop_value is value[prop_name]:
168+
continue
169+
if normalized is value:
170+
normalized = dict(value)
171+
normalized[prop_name] = prop_value
172+
173+
additional_properties = self.schema.get("additionalProperties", True)
174+
if additional_properties in (True, False):
175+
return normalized
176+
177+
property_names = set()
178+
if "properties" in self.schema:
179+
property_names = set((self.schema / "properties").keys())
180+
additional_validator = self.evolve(
181+
self.schema / "additionalProperties"
182+
)
183+
for prop_name, prop_value in value.items():
184+
if prop_name in property_names:
185+
continue
186+
normalized_prop_value = (
187+
additional_validator.get_binary_validation_value(prop_value)
188+
)
189+
if normalized_prop_value is prop_value:
190+
continue
191+
if normalized is value:
192+
normalized = dict(value)
193+
normalized[prop_name] = normalized_prop_value
194+
195+
return normalized
196+
197+
def get_binary_validation_array_value(self, value: Any) -> Any:
198+
item_validator = self.evolve(self.schema / "items")
199+
normalized = None
200+
201+
for idx, item in enumerate(value):
202+
normalized_item = item_validator.get_binary_validation_value(item)
203+
if normalized_item is item:
204+
continue
205+
if normalized is None:
206+
normalized = list(value)
207+
normalized[idx] = normalized_item
208+
209+
if normalized is None:
210+
return value
211+
212+
return normalized
213+
214+
def merge_binary_validation_value(
215+
self, value: Any, normalized_value: Any
216+
) -> Any:
217+
if normalized_value is value:
218+
return value
219+
220+
if isinstance(value, dict) and isinstance(normalized_value, dict):
221+
merged_dict = value
222+
for key, normalized_item in normalized_value.items():
223+
item = value.get(key, _MISSING)
224+
if item is _MISSING:
225+
if merged_dict is value:
226+
merged_dict = dict(value)
227+
merged_dict[key] = normalized_item
228+
continue
229+
230+
merged_item = self.merge_binary_validation_value(
231+
item, normalized_item
232+
)
233+
if merged_item is item:
234+
continue
235+
if merged_dict is value:
236+
merged_dict = dict(value)
237+
merged_dict[key] = merged_item
238+
239+
return merged_dict
240+
241+
if isinstance(value, list) and isinstance(normalized_value, list):
242+
if len(value) != len(normalized_value):
243+
return normalized_value
244+
245+
merged_list = None
246+
for idx, (item, normalized_item) in enumerate(
247+
zip(value, normalized_value)
248+
):
249+
merged_item = self.merge_binary_validation_value(
250+
item, normalized_item
251+
)
252+
if merged_item is item:
253+
continue
254+
if merged_list is None:
255+
merged_list = list(value)
256+
merged_list[idx] = merged_item
257+
258+
if merged_list is None:
259+
return value
260+
261+
return merged_list
262+
263+
return normalized_value
264+
107265
def iter_valid_schemas(self, value: Any) -> Iterator[SchemaPath]:
108266
yield self.schema
109267

tests/integration/unmarshalling/test_request_unmarshaller.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import json
22
from base64 import b64encode
3-
from email.generator import _make_boundary
43

54
import pytest
65

@@ -469,16 +468,10 @@ def test_request_body_with_object_default(self):
469468
assert result.errors == []
470469
assert result.body == {"tags": []}
471470

472-
@pytest.mark.xfail(
473-
reason=(
474-
"multipart composed-schema branch selection is not binary-aware"
475-
),
476-
strict=True,
477-
)
478471
def test_request_body_multipart_oneof_binary_field(self):
479472
from openapi_core import OpenAPI
480473

481-
boundary = _make_boundary()
474+
boundary = "testboundary"
482475
spec = OpenAPI.from_dict(
483476
{
484477
"openapi": "3.1.0",

0 commit comments

Comments
 (0)