Skip to content

Commit 8a4384d

Browse files
authored
Convert to string when needed + faster .zstd (#7683)
convert to string when needed + faster .zstd
1 parent 611f5a5 commit 8a4384d

File tree

3 files changed

+28
-8
lines changed

3 files changed

+28
-8
lines changed

src/datasets/filesystems/compression.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class BaseCompressedFileFileSystem(AbstractArchiveFileSystem):
1414
None # protocol passed in prefix to the url. ex: "gzip", for gzip://file.txt::http://foo.bar/file.txt.gz
1515
)
1616
compression: str = None # compression type in fsspec. ex: "gzip"
17-
extension: str = None # extension of the filename to strip. ex: "".gz" to get file.txt from file.txt.gz
17+
extensions: list[str] = None # extensions of the filename to strip. ex: ".gz" to get file.txt from file.txt.gz
1818

1919
def __init__(
2020
self, fo: str = "", target_protocol: Optional[str] = None, target_options: Optional[dict] = None, **kwargs
@@ -90,31 +90,31 @@ class Bz2FileSystem(BaseCompressedFileFileSystem):
9090

9191
protocol = "bz2"
9292
compression = "bz2"
93-
extension = ".bz2"
93+
extensions = [".bz2"]
9494

9595

9696
class GzipFileSystem(BaseCompressedFileFileSystem):
9797
"""Read contents of GZIP file as a filesystem with one file inside."""
9898

9999
protocol = "gzip"
100100
compression = "gzip"
101-
extension = ".gz"
101+
extensions = [".gz", ".gzip"]
102102

103103

104104
class Lz4FileSystem(BaseCompressedFileFileSystem):
105105
"""Read contents of LZ4 file as a filesystem with one file inside."""
106106

107107
protocol = "lz4"
108108
compression = "lz4"
109-
extension = ".lz4"
109+
extensions = [".lz4"]
110110

111111

112112
class XzFileSystem(BaseCompressedFileFileSystem):
113113
"""Read contents of .xz (LZMA) file as a filesystem with one file inside."""
114114

115115
protocol = "xz"
116116
compression = "xz"
117-
extension = ".xz"
117+
extensions = [".xz"]
118118

119119

120120
class ZstdFileSystem(BaseCompressedFileFileSystem):
@@ -124,4 +124,4 @@ class ZstdFileSystem(BaseCompressedFileFileSystem):
124124

125125
protocol = "zstd"
126126
compression = "zstd"
127-
extension = ".zst"
127+
extensions = [".zst", ".zstd"]

src/datasets/packaged_modules/json/json.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,20 @@ def _cast_table(self, pa_table: pa.Table) -> pa.Table:
9090
for column_name in set(self.config.features) - set(pa_table.column_names):
9191
type = self.config.features.arrow_schema.field(column_name).type
9292
pa_table = pa_table.append_column(column_name, pa.array([None] * len(pa_table), type=type))
93+
# convert to string when needed
94+
for i, column_name in enumerate(pa_table.column_names):
95+
if pa.types.is_struct(pa_table[column_name].type) and self.config.features.get(
96+
column_name, None
97+
) == datasets.Value("string"):
98+
jsonl = (
99+
pa_table[column_name]
100+
.to_pandas(types_mapper=pd.ArrowDtype)
101+
.to_json(orient="records", lines=True)
102+
)
103+
string_array = pa.array(
104+
("{" + x.rstrip() for x in ("\n" + jsonl).split("\n{") if x), type=pa.string()
105+
)
106+
pa_table = pa_table.set_column(i, column_name, string_array)
93107
# more expensive cast to support nested structures with keys in a different order
94108
# allows str <-> int/float or str to Audio for example
95109
pa_table = table_cast(pa_table, self.config.features.arrow_schema)

src/datasets/utils/file_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -461,12 +461,18 @@ def readline(f: io.RawIOBase):
461461
]
462462
COMPRESSION_EXTENSION_TO_PROTOCOL = {
463463
# single file compression
464-
**{fs_class.extension.lstrip("."): fs_class.protocol for fs_class in COMPRESSION_FILESYSTEMS},
464+
**{
465+
extension.lstrip("."): fs_class.protocol
466+
for fs_class in COMPRESSION_FILESYSTEMS
467+
for extension in fs_class.extensions
468+
},
465469
# archive compression
466470
"zip": "zip",
467471
}
468472
SINGLE_FILE_COMPRESSION_EXTENSION_TO_PROTOCOL = {
469-
fs_class.extension.lstrip("."): fs_class.protocol for fs_class in COMPRESSION_FILESYSTEMS
473+
extension.lstrip("."): fs_class.protocol
474+
for fs_class in COMPRESSION_FILESYSTEMS
475+
for extension in fs_class.extensions
470476
}
471477
SINGLE_FILE_COMPRESSION_PROTOCOLS = {fs_class.protocol for fs_class in COMPRESSION_FILESYSTEMS}
472478
SINGLE_SLASH_AFTER_PROTOCOL_PATTERN = re.compile(r"(?<!:):/")

0 commit comments

Comments
 (0)