Skip to content

Commit 111b1d2

Browse files
committed
PyMJCF recursive include tags relative to base model
1 parent d6f9cb4 commit 111b1d2

File tree

1 file changed

+90
-22
lines changed

1 file changed

+90
-22
lines changed

dm_control/mjcf/parser.py

Lines changed: 90 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626

2727

2828
def from_xml_string(xml_string, escape_separators=False,
29-
model_dir='', resolve_references=True, assets=None):
29+
model_dir='', resolve_references=True, assets=None,
30+
base_model_dir=None):
3031
"""Parses an XML string into an MJCF object model.
3132
3233
Args:
@@ -41,6 +42,9 @@ def from_xml_string(xml_string, escape_separators=False,
4142
assets: (optional) A dictionary of pre-loaded assets, of the form
4243
`{filename: bytestring}`. If present, PyMJCF will search for assets in
4344
this dictionary before attempting to load them from the filesystem.
45+
base_model_dir: (optional) Path to the directory containing the base model.
46+
This is used to prefix the paths of <include> elements' file attributes
47+
to support nested includes as in the MuJoCo compiler.
4448
4549
Returns:
4650
An `mjcf.RootElement`.
@@ -49,11 +53,12 @@ def from_xml_string(xml_string, escape_separators=False,
4953
return _parse(xml_root, escape_separators,
5054
model_dir=model_dir,
5155
resolve_references=resolve_references,
52-
assets=assets)
56+
assets=assets, base_model_dir=base_model_dir)
5357

5458

5559
def from_file(file_handle, escape_separators=False,
56-
model_dir='', resolve_references=True, assets=None):
60+
model_dir='', resolve_references=True, assets=None,
61+
base_model_dir=None):
5762
"""Parses an XML file into an MJCF object model.
5863
5964
Args:
@@ -68,6 +73,9 @@ def from_file(file_handle, escape_separators=False,
6873
assets: (optional) A dictionary of pre-loaded assets, of the form
6974
`{filename: bytestring}`. If present, PyMJCF will search for assets in
7075
this dictionary before attempting to load them from the filesystem.
76+
base_model_dir: (optional) Path to the directory containing the base model.
77+
This is used to prefix the paths of <include> elements' file attributes
78+
to support nested includes as in the MuJoCo compiler.
7179
7280
Returns:
7381
An `mjcf.RootElement`.
@@ -76,11 +84,11 @@ def from_file(file_handle, escape_separators=False,
7684
return _parse(xml_root, escape_separators,
7785
model_dir=model_dir,
7886
resolve_references=resolve_references,
79-
assets=assets)
87+
assets=assets, base_model_dir=base_model_dir)
8088

8189

8290
def from_path(path, escape_separators=False, resolve_references=True,
83-
assets=None):
91+
assets=None, base_model_dir=None):
8492
"""Parses an XML file into an MJCF object model.
8593
8694
Args:
@@ -94,6 +102,9 @@ def from_path(path, escape_separators=False, resolve_references=True,
94102
assets: (optional) A dictionary of pre-loaded assets, of the form
95103
`{filename: bytestring}`. If present, PyMJCF will search for assets in
96104
this dictionary before attempting to load them from the filesystem.
105+
base_model_dir: (optional) Path to the directory containing the base model.
106+
This is used to prefix the paths of <include> elements' file attributes
107+
to support nested includes as in the MuJoCo compiler.
97108
98109
Returns:
99110
An `mjcf.RootElement`.
@@ -103,11 +114,12 @@ def from_path(path, escape_separators=False, resolve_references=True,
103114
xml_root = etree.fromstring(contents)
104115
return _parse(xml_root, escape_separators,
105116
model_dir=model_dir, resolve_references=resolve_references,
106-
assets=assets)
117+
assets=assets, base_model_dir=base_model_dir)
107118

108119

109120
def _parse(xml_root, escape_separators=False,
110-
model_dir='', resolve_references=True, assets=None):
121+
model_dir='', resolve_references=True, assets=None,
122+
base_model_dir=None):
111123
"""Parses a complete MJCF model from an XML.
112124
113125
Args:
@@ -122,6 +134,9 @@ def _parse(xml_root, escape_separators=False,
122134
assets: (optional) A dictionary of pre-loaded assets, of the form
123135
`{filename: bytestring}`. If present, PyMJCF will search for assets in
124136
this dictionary before attempting to load them from the filesystem.
137+
base_model_dir: (optional) Path to the directory containing the base model.
138+
This is used to prefix the paths of <include> elements' file attributes
139+
to support nested includes as in the MuJoCo compiler.
125140
126141
Returns:
127142
An `mjcf.RootElement`.
@@ -140,20 +155,9 @@ def _parse(xml_root, escape_separators=False,
140155
# Recursively parse any included XML files.
141156
to_include = []
142157
for include_tag in xml_root.findall('include'):
143-
try:
144-
# First look for the path to the included XML file in the assets dict.
145-
path_or_xml_string = assets[include_tag.attrib['file']]
146-
parsing_func = from_xml_string
147-
except KeyError:
148-
# If it's not present in the assets dict then attempt to load the XML
149-
# from the filesystem.
150-
path_or_xml_string = os.path.join(model_dir, include_tag.attrib['file'])
151-
parsing_func = from_path
152-
included_mjcf = parsing_func(
153-
path_or_xml_string,
154-
escape_separators=escape_separators,
155-
resolve_references=resolve_references,
156-
assets=assets)
158+
included_mjcf = _parse_include(include_tag, escape_separators, model_dir,
159+
resolve_references, assets, base_model_dir)
160+
157161
to_include.append(included_mjcf)
158162
# We must remove <include/> tags before parsing the main XML file, since
159163
# these are a schema violation.
@@ -165,7 +169,7 @@ def _parse(xml_root, escape_separators=False,
165169
except KeyError:
166170
model = None
167171
mjcf_root = element.RootElement(
168-
model=model, model_dir=model_dir, assets=assets)
172+
model=model, model_dir=base_model_dir or model_dir, assets=assets)
169173
_parse_children(xml_root, mjcf_root, escape_separators)
170174

171175
# Merge in the included XML files.
@@ -180,6 +184,70 @@ def _parse(xml_root, escape_separators=False,
180184
return mjcf_root
181185

182186

187+
def _parse_include(include_tag, escape_separators, model_dir, resolve_references, assets, base_model_dir):
188+
"""
189+
Parses an included XML file.
190+
191+
Args:
192+
include_tag: An `etree.Element` object with tag 'include'.
193+
escape_separators: (optional) A boolean, whether to replace '/' characters
194+
in element identifiers. If `False`, any '/' present in the XML causes
195+
a ValueError to be raised.
196+
model_dir: (optional) Path to the directory containing the model XML file.
197+
This is used to prefix the paths of all asset files.
198+
resolve_references: (optional) A boolean indicating whether the parser
199+
should attempt to resolve reference attributes to a corresponding element.
200+
assets: (optional) A dictionary of pre-loaded assets, of the form
201+
`{filename: bytestring}`. If present, PyMJCF will search for assets in
202+
this dictionary before attempting to load them from the filesystem.
203+
base_model_dir: (optional) Path to the directory containing the base model.
204+
This is used to prefix the paths of <include> elements' file attributes
205+
to support nested includes as in the MuJoCo compiler.
206+
207+
Returns:
208+
An `mjcf.RootElement`.
209+
210+
Raises:
211+
FileNotFoundError: If the included the inner paths of the included XML could
212+
not be resolved.
213+
"""
214+
215+
base_dirs = [model_dir] # always look in the current model dir first
216+
if base_model_dir is not None:
217+
base_dirs.append(base_model_dir) # then look in the base model dir if provided
218+
219+
not_found_exception = None # a container for the final exception if some file references are not resolved
220+
221+
# try to parse the included XML file from each of the base dirs
222+
for working_dir in base_dirs:
223+
224+
# setup new parsing kwargs dict with current base model dir
225+
parsing_func_kwargs = dict(
226+
escape_separators=escape_separators,
227+
resolve_references=resolve_references,
228+
assets=assets,
229+
base_model_dir=working_dir
230+
)
231+
232+
try:
233+
# First look for the path to the included XML file in the assets dict.
234+
path_or_xml_string = assets[include_tag.attrib['file']]
235+
parsing_func = from_xml_string
236+
parsing_func_kwargs.update(dict(model_dir=working_dir)) # requires explicit model dir
237+
except KeyError:
238+
# If it's not present in the assets dict then attempt to load the XML
239+
# from the filesystem.
240+
path_or_xml_string = os.path.join(working_dir, include_tag.attrib['file'])
241+
parsing_func = from_path
242+
try:
243+
# if successfully parsed the included XML file, stop searching
244+
return parsing_func(path_or_xml_string, **parsing_func_kwargs)
245+
except FileNotFoundError as e:
246+
# base model dir did not resolve the inner include paths
247+
not_found_exception = e
248+
249+
raise FileNotFoundError('Could not find an appropriate base path for include tag') from not_found_exception
250+
183251
def _parse_children(xml_element, mjcf_element, escape_separators=False):
184252
"""Parses all children of a given XML element into an MJCF element.
185253

0 commit comments

Comments
 (0)