Skip to content

Commit 9d2805e

Browse files
sbohezcopybara-github
authored andcommitted
Add size.memory to MJCF schema.
PiperOrigin-RevId: 509904430 Change-Id: I1265c3e1d57613146ae4382e49f069e475295abf
1 parent c467f60 commit 9d2805e

3 files changed

Lines changed: 44 additions & 1 deletion

File tree

dm_control/mjcf/element.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,40 @@
3434

3535
_raw_property = property # pylint: disable=invalid-name
3636

37+
_UNITS = ('K', 'M', 'G', 'T', 'P', 'E')
3738

38-
_CONFLICT_BEHAVIOR_FUNC = {'min': min, 'max': max}
39+
40+
def _to_bytes(value_str):
41+
"""Converts a `str` value representing a size in bytes to `int`.
42+
43+
Args:
44+
value_str: `str` value to be converted.
45+
46+
Returns:
47+
`int` corresponding size in bytes.
48+
49+
Raises:
50+
ValueError: if the `str` value passed has an unsupported unit.
51+
"""
52+
if value_str.isdigit():
53+
value_int = int(value_str)
54+
else:
55+
value_int = int(value_str[:-1])
56+
unit = value_str[-1].upper()
57+
if unit not in _UNITS:
58+
raise ValueError(
59+
f'unit of `size.memory` should be one of [{", ".join(_UNITS)}], got'
60+
f' {unit}')
61+
power = 10 * (_UNITS.index(unit) + 1)
62+
value_int *= 2**power
63+
return value_int
64+
65+
66+
def _max_bytes(first, second):
67+
return str(max(_to_bytes(first), _to_bytes(second)))
68+
69+
70+
_CONFLICT_BEHAVIOR_FUNC = {'min': min, 'max': max, 'max_bytes': _max_bytes}
3971

4072

4173
def property(method): # pylint: disable=redefined-builtin

dm_control/mjcf/element_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,6 +1055,16 @@ def testMaxConflictingValues(self):
10551055
self.assertEqual(model_1.size.nconmax, 345)
10561056
self.assertEqual(model_1.size.njmax, 456)
10571057

1058+
def testMaxBytesConflictingValues(self):
1059+
model_1 = mjcf.RootElement()
1060+
model_1.size.memory = '10000'
1061+
1062+
model_2 = mjcf.RootElement()
1063+
model_2.size.memory = '1M'
1064+
1065+
model_1.attach(model_2)
1066+
self.assertEqual(model_1.size.memory, '1048576')
1067+
10581068

10591069
if __name__ == '__main__':
10601070
absltest.main()

dm_control/mjcf/schema.xml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595
</element>
9696
<element name="size">
9797
<attributes>
98+
<attribute name="memory" type="string" conflict_allowed="true" conflict_behavior="max_bytes"/>
9899
<attribute name="njmax" type="int" conflict_allowed="true" conflict_behavior="max"/>
99100
<attribute name="nconmax" type="int" conflict_allowed="true" conflict_behavior="max"/>
100101
<attribute name="nstack" type="int"/>

0 commit comments

Comments
 (0)