Skip to content

[Data] added XML datasource #52539

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions python/ray/data/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1322,6 +1322,20 @@ py_test(
],
)

py_test(
name = "test_xml",
size = "small",
srcs = ["tests/test_xml.py"],
tags = [
"exclusive",
"team:data",
],
deps = [
":conftest",
"//:ray_lib",
],
)

py_test(
name = "test_state_export",
size = "small",
Expand Down
2 changes: 2 additions & 0 deletions python/ray/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
read_tfrecords,
read_videos,
read_webdataset,
read_xml,
)

# Module-level cached global functions for callable classes. It needs to be defined here
Expand Down Expand Up @@ -165,6 +166,7 @@
"read_tfrecords",
"read_videos",
"read_webdataset",
"read_xml",
"Preprocessor",
"TFXReadOptions",
]
101 changes: 101 additions & 0 deletions python/ray/data/_internal/datasource/xml_datasource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import logging
from io import BytesIO
from typing import TYPE_CHECKING, List, Optional, Union, Iterable

from ray.data.block import DataBatch
from ray.data.datasource.file_based_datasource import FileBasedDatasource

if TYPE_CHECKING:
import pyarrow

logger = logging.getLogger(__name__)

_XML_ROWS_PER_CHUNK = 10000


def _element_to_dict(element, parent_key="", sep="."):
"""Recursively flattens XML element into a dict."""
d = {}

# Include attributes if present
for k, v in element.attrib.items():
d[f"{parent_key}@{k}" if parent_key else f"@{k}"] = v

children = list(element)
if children:
# For each child, recurse
for child in children:
child_key = f"{parent_key}{sep}{child.tag}" if parent_key else child.tag
child_dict = _element_to_dict(child, child_key, sep=sep)
d.update(child_dict)
else:
# Leaf node: keep the element text
if element.text and element.text.strip():
d[parent_key] = element.text.strip()
elif element.attrib:
pass # Already added above
else:
d[parent_key] = None # No info

return d


class XMLDatasource(FileBasedDatasource):
"""XML datasource, handles nested XML elements."""

_FILE_EXTENSIONS = [
"xml",
"xml.gz",
"xml.br",
"xml.zst",
"xml.lz4",
]

def __init__(
self,
paths: Union[str, List[str]],
record_tag: Optional[str] = None,
sep: str = ".",
**file_based_datasource_kwargs,
):
"""
Args:
paths: The file or directory paths.
record_tag: Tag corresponding to repeated record, e.g. 'record' or 'row'.
sep: Separator for flattened nested keys; default is '.' (for 'user.name').
"""
super().__init__(paths, **file_based_datasource_kwargs)
self.record_tag = record_tag
self.sep = sep

def _parse_xml_buffer(self, buffer: "pyarrow.lib.Buffer") -> Iterable[DataBatch]:
import xml.etree.ElementTree as ET
import pyarrow as pa

if buffer.size == 0:
return

tree = ET.parse(BytesIO(buffer))
root = tree.getroot()

# Determine the tag used for each record
tag = self.record_tag or (root[0].tag if len(root) > 0 else None)
if tag is None:
raise ValueError("Cannot determine XML record tag.")

batch = []
for elem in root.findall(tag):
row = _element_to_dict(elem, sep=self.sep)
batch.append(row)
if len(batch) >= _XML_ROWS_PER_CHUNK:
yield pa.Table.from_pylist(batch)
batch = []

if batch:
yield pa.Table.from_pylist(batch)

def _read_stream(self, f: "pyarrow.NativeFile", path: str):
import pyarrow as pa

buffer: pa.lib.Buffer = f.read_buffer()
yield from self._parse_xml_buffer(buffer)
53 changes: 53 additions & 0 deletions python/ray/data/read_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from ray.data._internal.datasource.torch_datasource import TorchDatasource
from ray.data._internal.datasource.video_datasource import VideoDatasource
from ray.data._internal.datasource.webdataset_datasource import WebDatasetDatasource
from ray.data._internal.datasource.xml_datasource import XMLDatasource
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
from ray.data._internal.logical.operators.from_operators import (
FromArrow,
Expand Down Expand Up @@ -3589,6 +3590,58 @@ def read_clickhouse(
)


@PublicAPI
def read_xml(
*,
paths: Union[str, List[str]],
record_tag: Optional[str] = None,
sep: str = ".",
ray_remote_args: Optional[Dict[str, Any]] = None,
concurrency: Optional[int] = None,
override_num_blocks: Optional[int] = None,
) -> Dataset:
"""
Create a :class:`~ray.data.Dataset` from a XML file(s).

Examples:
>>> import ray
>>> ds = ray.data.read_xml( # doctest: +SKIP
... paths="my_data.xml",
... record_tag="record",
... )

Args:
paths: A string or list of strings representing the file paths to the XML files.
The paths can be local or remote (e.g., S3, GCS).
record_tag: The XML tag name that represents a single record in the XML files.
If not specified, the first tag in the XML file will be used.
sep: The separator used to flatten nested XML structures. Default is ".".
ray_remote_args: kwargs passed to :func:`ray.remote` in the read tasks.
concurrency: The maximum number of Ray tasks to run concurrently. Set this
to control number of tasks to run concurrently. This doesn't change the
total number of tasks run or the total number of output blocks. By default,
concurrency is dynamically decided based on the available resources.
override_num_blocks: Override the number of output blocks from all read tasks.
By default, the number of output blocks is dynamically decided based on
input data size and available resources. You shouldn't manually set this
value in most cases.

Returns:
A :class:`~ray.data.Dataset` producing records read from the XML files.
""" # noqa: E501
datasource = XMLDatasource(
paths=paths,
record_tag=record_tag,
sep=sep,
)
return read_datasource(
datasource=datasource,
ray_remote_args=ray_remote_args,
concurrency=concurrency,
override_num_blocks=override_num_blocks,
)


def _get_datasource_or_legacy_reader(
ds: Datasource,
ctx: DataContext,
Expand Down
166 changes: 166 additions & 0 deletions python/ray/data/tests/test_xml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import os
import pytest
import ray

# Example single-record and multi-record XML
SIMPLE_XML = """<root>
<user id="1" active="true">
<name>John</name>
<email>john@example.com</email>
<info>
<age>35</age>
<city>NY</city>
</info>
</user>
</root>"""

MULTI_XML = """<root>
<user id="1" active="true">
<name>John</name>
<email>john@example.com</email>
<info>
<age>35</age>
<city>NY</city>
</info>
</user>
<user id="2">
<name>Jane</name>
<info>
<age>25</age>
<city>LA</city>
</info>
</user>
</root>"""

EMPTY_XML = "<root></root>"


def write_xml(tmp_path, fname, content):
path = os.path.join(tmp_path, fname)
with open(path, "w") as f:
f.write(content)
return path


def test_read_xml_simple(tmp_path):
path = write_xml(tmp_path, "simple.xml", SIMPLE_XML)
ds = ray.data.read_xml(paths=path, record_tag="user")
rows = ds.take_all()
assert len(rows) == 1
row = rows[0]
assert row["@id"] == "1"
assert row["name"] == "John"
assert row["info.age"] == "35"
assert row["info.city"] == "NY"
assert row["@active"] == "true"
assert row["email"] == "john@example.com"


def test_read_xml_multiple(tmp_path):
path = write_xml(tmp_path, "multi.xml", MULTI_XML)
ds = ray.data.read_xml(paths=path, record_tag="user")
rows = ds.take_all()
assert len(rows) == 2
john, jane = rows
assert john["@id"] == "1"
assert jane["name"] == "Jane"
assert jane["info.city"] == "LA"
assert "email" not in jane


def test_empty_xml(tmp_path):
path = write_xml(tmp_path, "empty.xml", EMPTY_XML)
ds = ray.data.read_xml(paths=path, record_tag="user")
assert ds.count() == 0


def test_read_xml_many_files(tmp_path):
path1 = write_xml(tmp_path, "multi1.xml", MULTI_XML)
path2 = write_xml(tmp_path, "multi2.xml", MULTI_XML)
ds = ray.data.read_xml(paths=[path1, path2], record_tag="user")
rows = ds.take_all()
assert len(rows) == 4
# Ensure all user ids present (unordered)
ids = sorted([int(r["@id"]) for r in rows if "@id" in r])
assert ids == [1, 1, 2, 2]


def test_read_xml_empty_files(tmp_path):
path1 = write_xml(tmp_path, "e1.xml", EMPTY_XML)
path2 = write_xml(tmp_path, "e2.xml", EMPTY_XML)
ds = ray.data.read_xml(paths=[path1, path2], record_tag="user")
assert ds.count() == 0


@pytest.mark.parametrize("ignore_missing_paths", [True, False])
def test_read_xml_ignore_missing_paths(tmp_path, ignore_missing_paths):
path = write_xml(tmp_path, "multi.xml", MULTI_XML)
paths = [path, "missing.xml"]
if ignore_missing_paths:
ds = ray.data.read_xml(
paths=paths, ignore_missing_paths=True, record_tag="user"
)
assert ds.count() == 2
else:
with pytest.raises(FileNotFoundError):
ray.data.read_xml(
paths=paths, ignore_missing_paths=False, record_tag="user"
).materialize()


def test_read_xml_schema(tmp_path):
path = write_xml(tmp_path, "multi.xml", MULTI_XML)
ds = ray.data.read_xml(paths=path, record_tag="user")
schema = ds.schema()
field_names = set(schema.names)
# Check expected fields present
assert "name" in field_names
assert "info.age" in field_names
assert "info.city" in field_names
assert "@id" in field_names


def test_read_xml_row_missing_fields(tmp_path):
# Jane is missing @active, email field.
path = write_xml(tmp_path, "multi.xml", MULTI_XML)
ds = ray.data.read_xml(paths=path, record_tag="user")
jane = ds.take_all()[1]
assert "email" not in jane
assert "@active" not in jane


def test_read_xml_large(tmp_path):
"""Test with a large XML file."""
n = 500
content = (
"<root>"
+ "".join(
[
f'<user id="{i}"><name>User{i}</name><info><age>{20+i}</age></info></user>'
for i in range(n)
]
)
+ "</root>"
)
path = write_xml(tmp_path, "large.xml", content)
ds = ray.data.read_xml(paths=path, record_tag="user")
assert ds.count() == n
df = ds.to_pandas()
assert df.shape[0] == n
assert all(df["name"].str.startswith("User"))


# Test missing record_tag
def test_record_tag_none(tmp_path):
xml = "<root><person><foo>x</foo></person></root>"
path = write_xml(tmp_path, "x.xml", xml)
ds = ray.data.read_xml(paths=path)
rows = ds.take_all()
assert len(rows) == 1
assert rows[0]["foo"] == "x"


if __name__ == "__main__":
import sys

sys.exit(pytest.main(["-v", __file__]))