Skip to content

Commit c65e591

Browse files
committed
fix: various mistakes related to structs and obj hashing
1 parent 44a5689 commit c65e591

4 files changed

Lines changed: 140 additions & 9 deletions

File tree

libdestruct/common/obj.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,10 @@ def _compare_value(self: obj, other: object) -> tuple[object, object] | None:
145145
return self_val, other
146146
return None
147147

148+
# Restore identity hashing — Python blanks __hash__ when __eq__ is defined.
149+
# Equality is value-based but hash is identity, so {a, b} won't dedupe equal values.
150+
__hash__ = object.__hash__
151+
148152
def __eq__(self: obj, other: object) -> bool:
149153
"""Return whether the object is equal to the given value."""
150154
pair = self._compare_value(other)

libdestruct/common/struct/struct_impl.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -183,14 +183,18 @@ def _resolve_field(
183183
Either resolved_inflater or bitfield_field will be non-None (not both).
184184
explicit_offset is set when an OffsetAttribute is present.
185185
"""
186-
# Unwrap Annotated[type, metadata...] — extract the real type and any metadata
187186
annotated_offset = None
188187
if get_origin(annotation) is Annotated:
189188
ann_args = get_args(annotation)
190189
annotation = ann_args[0]
191-
for meta in ann_args[1:]:
192-
if isinstance(meta, OffsetAttribute):
193-
annotated_offset = meta.offset
190+
ann_offsets = [m.offset for m in ann_args[1:] if isinstance(m, OffsetAttribute)]
191+
if len(ann_offsets) > 1:
192+
raise ValueError(
193+
f"Field {name!r} has multiple OffsetAttribute entries in its Annotated metadata; "
194+
f"only one is allowed.",
195+
)
196+
if ann_offsets:
197+
annotated_offset = ann_offsets[0]
194198

195199
if name not in reference.__dict__:
196200
return inflater.inflater_for(annotation, owner=owner), None, annotated_offset
@@ -202,6 +206,13 @@ def _resolve_field(
202206
if sum(isinstance(attr, Field) for attr in attrs) > 1:
203207
raise ValueError("Only one Field is allowed per attribute.")
204208

209+
attr_offsets = sum(isinstance(a, OffsetAttribute) for a in attrs)
210+
if attr_offsets + (1 if annotated_offset is not None else 0) > 1:
211+
raise ValueError(
212+
f"Field {name!r} has multiple OffsetAttribute entries (across Annotated metadata "
213+
f"and attribute tuple); only one is allowed.",
214+
)
215+
205216
resolved_type = None
206217
bitfield_field = None
207218
explicit_offset = annotated_offset
@@ -231,6 +242,7 @@ def compute_own_size(cls: type[struct_impl], reference_type: type) -> None:
231242
bf_tracker = BitfieldTracker()
232243
aligned = getattr(reference_type, "_aligned_", False)
233244
seen_vla = False
245+
seen_names: set[str] = set()
234246

235247
for name, annotation, reference in iterate_annotation_chain(reference_type, terminate_at=struct):
236248
if name == "_aligned_":
@@ -245,13 +257,24 @@ def compute_own_size(cls: type[struct_impl], reference_type: type) -> None:
245257
# Detect VLA from default value or subscript annotation
246258
default = getattr(reference, name, None) if hasattr(reference, name) else None
247259
is_vla = isinstance(default, VLAField)
248-
if not is_vla and isinstance(annotation, GenericAlias):
260+
count_field_name: str | None = None
261+
if is_vla:
262+
count_field_name = default.count_field
263+
elif isinstance(annotation, GenericAlias):
249264
args = annotation.__args__
250265
if len(args) == 2 and isinstance(args[1], str):
251266
is_vla = True
267+
count_field_name = args[1]
252268
if is_vla:
269+
if count_field_name is not None and count_field_name not in seen_names:
270+
raise ValueError(
271+
f"VLA field {name!r} references undefined count field {count_field_name!r}. "
272+
f"The count field must be declared before the VLA in the same struct.",
273+
)
253274
seen_vla = True
254275

276+
seen_names.add(name)
277+
255278
resolved_type, bitfield_field, explicit_offset = struct_impl._resolve_field(
256279
name, annotation, reference, cls._inflater, owner=(None, cls),
257280
)
@@ -351,13 +374,13 @@ def freeze(self: struct_impl) -> None:
351374
super().freeze()
352375

353376
def reset(self: struct_impl) -> None:
354-
"""Reset each member to its frozen value."""
377+
"""Restore the struct's memory region to the bytes captured at freeze time."""
355378
if not object.__getattribute__(self, "_frozen"):
356379
raise RuntimeError("Cannot reset a struct that has not been frozen.")
357380

358-
members = object.__getattribute__(self, "_members")
359-
for member in members.values():
360-
member.reset()
381+
resolver = object.__getattribute__(self, "resolver")
382+
frozen_bytes = object.__getattribute__(self, "_frozen_struct_bytes")
383+
resolver.modify(len(frozen_bytes), 0, frozen_bytes)
361384

362385
def to_str(self: struct_impl, indent: int = 0) -> str:
363386
"""Return a string representation of the struct."""
@@ -384,6 +407,8 @@ def __repr__(self: struct_impl) -> str:
384407
}}
385408
}}"""
386409

410+
__hash__ = object.__hash__
411+
387412
def __eq__(self: struct_impl, value: object) -> bool:
388413
"""Return whether the struct is equal to the given value."""
389414
if not isinstance(value, struct_impl):

test/scripts/struct_test.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,5 +704,91 @@ class s_t(struct):
704704
self.assertEqual(s.to_bytes(), memory)
705705

706706

707+
class StructResetWithCompositesTest(unittest.TestCase):
708+
"""struct.reset() must work for any composite member shape, including arrays."""
709+
710+
def test_reset_struct_with_array(self):
711+
from libdestruct.backing.memory_resolver import MemoryResolver
712+
713+
class S(struct):
714+
arr: array[c_int, 3]
715+
716+
memory = bytearray((1).to_bytes(4, "little") + (2).to_bytes(4, "little") + (3).to_bytes(4, "little"))
717+
s = S(MemoryResolver(memory, 0))
718+
s.freeze()
719+
memory[0:4] = (99).to_bytes(4, "little")
720+
memory[4:8] = (98).to_bytes(4, "little")
721+
s.reset()
722+
self.assertEqual(int.from_bytes(memory[0:4], "little"), 1)
723+
self.assertEqual(int.from_bytes(memory[4:8], "little"), 2)
724+
725+
def test_reset_struct_with_nested_array_of_struct(self):
726+
from libdestruct.backing.memory_resolver import MemoryResolver
727+
728+
class Inner(struct):
729+
x: c_int
730+
731+
class Outer(struct):
732+
items: array[Inner, 2]
733+
734+
memory = bytearray((10).to_bytes(4, "little") + (20).to_bytes(4, "little"))
735+
s = Outer(MemoryResolver(memory, 0))
736+
s.freeze()
737+
memory[0:4] = (999).to_bytes(4, "little")
738+
s.reset()
739+
self.assertEqual(int.from_bytes(memory[0:4], "little"), 10)
740+
741+
742+
class AnnotatedMultipleOffsetTest(unittest.TestCase):
743+
"""Multiple OffsetAttribute on a single field must raise instead of silently using one."""
744+
745+
def test_two_offsets_in_annotated_raises(self):
746+
from libdestruct.common.attributes.offset_attribute import OffsetAttribute
747+
from libdestruct.backing.memory_resolver import MemoryResolver
748+
749+
class S(struct):
750+
pad: c_int
751+
field: Annotated[c_int, OffsetAttribute(4), OffsetAttribute(8)]
752+
753+
with self.assertRaises(ValueError):
754+
S(MemoryResolver(bytearray(16), 0))
755+
756+
def test_offset_in_annotated_and_attribute_tuple_raises(self):
757+
from libdestruct.common.attributes.offset_attribute import OffsetAttribute
758+
from libdestruct.backing.memory_resolver import MemoryResolver
759+
760+
class S(struct):
761+
pad: c_int
762+
field: Annotated[c_int, OffsetAttribute(4)] = OffsetAttribute(8)
763+
764+
with self.assertRaises(ValueError):
765+
S(MemoryResolver(bytearray(16), 0))
766+
767+
768+
class VLAUndefinedCountFieldTest(unittest.TestCase):
769+
"""VLA referencing a count field that doesn't exist must fail at struct definition (inflation)."""
770+
771+
def test_undefined_count_field_subscript_form(self):
772+
from libdestruct.backing.memory_resolver import MemoryResolver
773+
774+
class BadVLA(struct):
775+
n: c_int
776+
data: array[c_int, "nonexistent_field"]
777+
778+
with self.assertRaises(ValueError):
779+
BadVLA(MemoryResolver(bytearray(16), 0))
780+
781+
def test_undefined_count_field_descriptor_form(self):
782+
from libdestruct.backing.memory_resolver import MemoryResolver
783+
from libdestruct.common.array.vla_of import vla_of
784+
785+
class BadVLA(struct):
786+
n: c_int
787+
data: array = vla_of(c_int, "nonexistent_field")
788+
789+
with self.assertRaises(ValueError):
790+
BadVLA(MemoryResolver(bytearray(16), 0))
791+
792+
707793
if __name__ == "__main__":
708794
unittest.main()

test/scripts/types_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -970,6 +970,22 @@ def test_unwrap_bytes_returns_live_data(self):
970970
self.assertEqual(p.unwrap(4), b"BBBB")
971971

972972

973+
class ObjHashableTest(unittest.TestCase):
974+
"""obj subclasses must be hashable (Python sets __hash__ = None when only __eq__ is defined)."""
975+
976+
def test_c_int_in_set(self):
977+
x = c_int.from_bytes((1).to_bytes(4, "little"))
978+
self.assertIn(x, {x})
979+
980+
def test_struct_in_dict(self):
981+
class S(struct):
982+
x: c_int
983+
984+
s = S.from_bytes((42).to_bytes(4, "little"))
985+
d = {s: "value"}
986+
self.assertEqual(d[s], "value")
987+
988+
973989
class PtrArithmeticSubclassTest(unittest.TestCase):
974990
"""Pointer arithmetic must preserve subclass identity (e.g. for narrower pointer widths)."""
975991

0 commit comments

Comments
 (0)