Skip to content

Commit 223a449

Browse files
authored
Improve tarfile utilities (#104)
This PR unifies writing to tarfiles, though it would be nice if it were possible to not have to seek back to the beginning of the file
1 parent bdf1279 commit 223a449

File tree

2 files changed

+53
-16
lines changed

2 files changed

+53
-16
lines changed

src/pystow/utils.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,37 @@ def open_zipfile(
735735
raise ValueError
736736

737737

738+
@contextlib.contextmanager
739+
def open_tarfile(
740+
path: str | Path,
741+
inner_path: str,
742+
*,
743+
operation: Operation = "read",
744+
representation: Representation = "binary",
745+
) -> Generator[typing.IO[bytes], None, None]:
746+
"""Open a tar file."""
747+
if representation != "binary":
748+
raise NotImplementedError
749+
750+
if operation == "read":
751+
with tarfile.open(path, "r") as tar:
752+
member = tar.getmember(inner_path)
753+
file = tar.extractfile(member)
754+
if file is None:
755+
raise FileNotFoundError(f"could not find {inner_path} in tarfile {path}")
756+
yield file
757+
elif operation == "write":
758+
file = BytesIO()
759+
yield file
760+
file.seek(0)
761+
tarinfo = tarfile.TarInfo(name=inner_path)
762+
tarinfo.size = len(file.getbuffer())
763+
with tarfile.TarFile(path, mode="w") as tar_file:
764+
tar_file.addfile(tarinfo, file)
765+
else:
766+
raise ValueError
767+
768+
738769
@contextlib.contextmanager
739770
def open_zip_reader(
740771
path: str | Path, inner_path: str, delimiter: str = "\t", **kwargs: Any
@@ -888,11 +919,8 @@ def write_tarfile_csv(
888919
:param kwargs: Additional kwargs to pass to :func:`get_df_io` and transitively to
889920
:func:`pandas.DataFrame.to_csv`.
890921
"""
891-
s = df.to_csv(sep=sep, index=index, **kwargs)
892-
tarinfo = tarfile.TarInfo(name=inner_path)
893-
tarinfo.size = len(s)
894-
with tarfile.TarFile(path, mode="w") as tar_file:
895-
tar_file.addfile(tarinfo, BytesIO(s.encode("utf-8")))
922+
with open_tarfile(path, inner_path, operation="write") as file:
923+
df.to_csv(file, sep=sep, index=index, **kwargs)
896924

897925

898926
def write_tarfile_xml(
@@ -911,11 +939,9 @@ def write_tarfile_xml(
911939
from lxml import etree
912940

913941
kwargs.setdefault("pretty_print", True)
914-
s = etree.tostring(element_tree, **kwargs)
915-
tarinfo = tarfile.TarInfo(name=inner_path)
916-
tarinfo.size = len(s)
917-
with tarfile.TarFile(path, mode="w") as tar_file:
918-
tar_file.addfile(tarinfo, BytesIO(s))
942+
943+
with open_tarfile(path, inner_path, operation="write") as file:
944+
file.write(etree.tostring(element_tree, **kwargs))
919945

920946

921947
def read_tarfile_csv(
@@ -932,9 +958,8 @@ def read_tarfile_csv(
932958
"""
933959
import pandas as pd
934960

935-
with tarfile.open(path) as tar_file:
936-
with tar_file.extractfile(inner_path) as file: # type: ignore
937-
return pd.read_csv(file, sep=sep, **kwargs)
961+
with open_tarfile(path, inner_path) as file:
962+
return pd.read_csv(file, sep=sep, **kwargs)
938963

939964

940965
def read_tarfile_xml(path: str | Path, inner_path: str, **kwargs: Any) -> lxml.etree.ElementTree:
@@ -948,9 +973,8 @@ def read_tarfile_xml(path: str | Path, inner_path: str, **kwargs: Any) -> lxml.e
948973
"""
949974
from lxml import etree
950975

951-
with tarfile.open(path) as tar_file:
952-
with tar_file.extractfile(inner_path) as file: # type: ignore
953-
return etree.parse(file, **kwargs)
976+
with open_tarfile(path, inner_path) as file:
977+
return etree.parse(file, **kwargs)
954978

955979

956980
def read_rdf(path: str | Path, **kwargs: Any) -> rdflib.Graph:

tests/test_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
mock_envvar,
2727
n,
2828
name_from_url,
29+
open_tarfile,
2930
open_zip_reader,
3031
open_zip_writer,
3132
open_zipfile,
@@ -265,6 +266,18 @@ def test_zip_writer_exc(self) -> None:
265266
with open_zipfile(path, "test.tsv", operation="write", representation="lolno"): # type:ignore
266267
pass
267268

269+
def test_tar_open(self) -> None:
270+
"""Test writing and reading a tar file."""
271+
with tempfile.TemporaryDirectory() as directory:
272+
path = Path(directory).joinpath("test.tar.gz")
273+
inner = "test_inner.tsv"
274+
with open_tarfile(path, inner, operation="write") as file:
275+
file.write(b"c1\tc2\nv1\tv2")
276+
277+
with open_tarfile(path, inner, operation="read") as file:
278+
self.assertEqual(b"c1\tc2\n", next(file))
279+
self.assertEqual(b"v1\tv2", next(file))
280+
268281

269282
class TestDownload(unittest.TestCase):
270283
"""Tests for downloading."""

0 commit comments

Comments
 (0)