Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 157 additions & 41 deletions src/pystow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import csv
import gzip
import hashlib
import io
import logging
import lzma
import os
Expand Down Expand Up @@ -56,6 +57,7 @@

__all__ = [
"DownloadBackend",
"DownloadError",
"Hash",
"HexDigestError",
"HexDigestMismatch",
Expand All @@ -81,6 +83,8 @@
"n",
"name_from_s3_key",
"name_from_url",
"open_zip_reader",
"open_zip_writer",
"path_to_sqlite",
"raise_on_digest_mismatch",
"read_rdf",
Expand All @@ -98,8 +102,10 @@
"write_lzma_csv",
"write_pickle_gz",
"write_tarfile_csv",
"write_tarfile_xml",
"write_zipfile_csv",
"write_zipfile_np",
"write_zipfile_rdf",
"write_zipfile_xml",
]

Expand All @@ -113,6 +119,17 @@
#: so we can do type checking
Hash: TypeAlias = "hashlib._Hash"

Reader: TypeAlias = "_csv._reader"
Writer: TypeAlias = "_csv._writer"

#: A human-readable flag for how to open a file.
Operation: TypeAlias = Literal["read", "write"]
OPERATION_VALUES: set[str] = set(typing.get_args(Operation))

#: A human-readable flag for how to open a file.
Representation: TypeAlias = Literal["text", "binary"]
REPRESENTATION_VALUES: set[str] = set(typing.get_args(Representation))


class HexDigestMismatch(NamedTuple):
"""Contains information about a hexdigest mismatch."""
Expand Down Expand Up @@ -650,9 +667,8 @@ def write_zipfile_csv(
:func:`pandas.DataFrame.to_csv`.
"""
bytes_io = get_df_io(df, sep=sep, index=index, **kwargs)
with zipfile.ZipFile(file=path, mode="w") as zip_file:
with zip_file.open(inner_path, mode="w") as file:
file.write(bytes_io.read())
with open_zipfile(path, inner_path, operation="write", representation="binary") as file:
file.write(bytes_io.read())


def read_zipfile_csv(
Expand All @@ -669,9 +685,88 @@ def read_zipfile_csv(
"""
import pandas as pd

with zipfile.ZipFile(file=path) as zip_file:
with zip_file.open(inner_path) as file:
return pd.read_csv(file, sep=sep, **kwargs)
with open_zipfile(path, inner_path, representation="text", operation="read") as file:
return pd.read_csv(file, sep=sep, **kwargs)


# docstr-coverage:excused `overload`
@typing.overload
@contextlib.contextmanager
def open_zipfile(
path: str | Path,
inner_path: str,
*,
operation: Operation = ...,
representation: Literal["text"],
) -> Generator[typing.TextIO, None, None]: ...


# docstr-coverage:excused `overload`
@typing.overload
@contextlib.contextmanager
def open_zipfile(
path: str | Path,
inner_path: str,
*,
operation: Operation = ...,
representation: Literal["binary"],
) -> Generator[typing.BinaryIO, None, None]: ...


@contextlib.contextmanager
def open_zipfile(
path: str | Path,
inner_path: str,
*,
operation: Operation = "read",
representation: Representation,
) -> Generator[typing.TextIO, None, None] | Generator[typing.BinaryIO, None, None]:
"""Open a zipfile."""
mode: Literal["r", "w"] = "r" if operation == "read" else "w"
# there might be a better way to deal with the mode here
with zipfile.ZipFile(file=path, mode=mode) as zip_file:
with zip_file.open(inner_path, mode=mode) as binary_file:
if representation == "text":
with io.TextIOWrapper(binary_file, encoding="utf-8") as text_file:
yield text_file
elif representation == "binary":
yield cast(typing.BinaryIO, binary_file)
else:
raise ValueError


@contextlib.contextmanager
def open_zip_reader(
path: str | Path, inner_path: str, delimiter: str = "\t", **kwargs: Any
) -> Generator[Reader, None, None]:
"""Read an inner CSV file from a zip archive.

:param path: The path to the zip archive
:param inner_path: The path inside the zip archive to the CSV
:param delimiter: The separator in the CSV. Defaults to tab.
:param kwargs: Additional kwargs to pass to :func:`csv.reader`.

:returns: A reader over the file
"""
with open_zipfile(path, inner_path, representation="text") as file:
yield csv.reader(file, delimiter=delimiter, **kwargs)


@contextlib.contextmanager
def open_zip_writer(
path: str | Path, inner_path: str, delimiter: str = "\t", **kwargs: Any
) -> Generator[Writer, None, None]:
"""Open a writer for an inner CSV file from a zip archive.

:param path: The path to the zip archive
:param inner_path: The path inside the zip archive to the CSV
:param delimiter: The separator in the CSV. Defaults to tab.
:param kwargs: Additional kwargs to pass to :func:`csv.writer`.

:returns: A writer over the file
"""
with open_zipfile(path, inner_path, operation="write", representation="text") as file:
yield csv.writer(file, delimiter=delimiter, **kwargs)


def write_zipfile_xml(
Expand All @@ -684,15 +779,14 @@ def write_zipfile_xml(

:param element_tree: An XML element tree
:param path: The path to the resulting zip archive
:param inner_path: The path inside the zip archive to write the dataframe
:param kwargs: Additional kwargs to pass to :func:`tostring`
:param inner_path: The path inside the zip archive to write the XML element
:param kwargs: Additional kwargs to pass to :func:`lxml.etree.tostring`
"""
from lxml import etree

kwargs.setdefault("pretty_print", True)
with zipfile.ZipFile(file=path, mode="w") as zip_file:
with zip_file.open(inner_path, mode="w") as file:
file.write(etree.tostring(element_tree, **kwargs))
with open_zipfile(path, inner_path, operation="write", representation="binary") as file:
file.write(etree.tostring(element_tree, **kwargs))


def read_zipfile_xml(path: str | Path, inner_path: str, **kwargs: Any) -> lxml.etree.ElementTree:
Expand All @@ -706,9 +800,8 @@ def read_zipfile_xml(path: str | Path, inner_path: str, **kwargs: Any) -> lxml.e
"""
from lxml import etree

with zipfile.ZipFile(file=path) as zip_file:
with zip_file.open(inner_path) as file:
return etree.parse(file, **kwargs)
with open_zipfile(path, inner_path, operation="read", representation="binary") as file:
return etree.parse(file, **kwargs)


def write_zipfile_np(
Expand All @@ -725,10 +818,10 @@ def write_zipfile_np(
:param kwargs: Additional kwargs to pass to :func:`get_np_io` and transitively to
:func:`numpy.save`.
"""
bytes_io = get_np_io(arr, **kwargs)
with zipfile.ZipFile(file=path, mode="w") as zip_file:
with zip_file.open(inner_path, mode="w") as file:
file.write(bytes_io.read())
import numpy as np

with open_zipfile(path, inner_path, operation="write", representation="binary") as file:
np.save(file, arr, **kwargs)


def read_zip_np(path: str | Path, inner_path: str, **kwargs: Any) -> numpy.typing.ArrayLike:
Expand All @@ -742,29 +835,41 @@ def read_zip_np(path: str | Path, inner_path: str, **kwargs: Any) -> numpy.typin
"""
import numpy as np

with zipfile.ZipFile(file=path) as zip_file:
with zip_file.open(inner_path) as file:
return cast(np.typing.ArrayLike, np.load(file, **kwargs))
with open_zipfile(path, inner_path, operation="read", representation="binary") as file:
return cast(np.typing.ArrayLike, np.load(file, **kwargs))


def read_zipfile_rdf(path: str | Path, inner_path: str, **kwargs: Any) -> rdflib.Graph:
"""Read an inner RDF file from a zip archive.

:param path: The path to the zip archive
:param inner_path: The path inside the zip archive to the dataframe
:param kwargs: Additional kwargs to pass to :func:`pandas.read_csv`.
:param kwargs: Additional kwargs to pass to :meth:`rdflib.Graph.parse`.

:returns: A graph
"""
import rdflib

graph = rdflib.Graph()
with zipfile.ZipFile(file=path) as zip_file:
with zip_file.open(inner_path) as file:
graph.parse(file, **kwargs)
with open_zipfile(path, inner_path, operation="read", representation="binary") as file:
graph.parse(file, **kwargs)
return graph


def write_zipfile_rdf(
graph: rdflib.Graph, path: str | Path, inner_path: str, **kwargs: Any
) -> None:
"""Read an inner RDF file from a zip archive.

:param graph: The graph to write
:param path: The path to the zip archive
:param inner_path: The path inside the zip archive to the dataframe
:param kwargs: Additional kwargs to pass to :meth:`rdflib.Graph.parse`.
"""
with open_zipfile(path, inner_path, operation="write", representation="binary") as file:
graph.serialize(file, **kwargs)


def write_tarfile_csv(
df: pandas.DataFrame,
path: str | Path,
Expand All @@ -790,6 +895,29 @@ def write_tarfile_csv(
tar_file.addfile(tarinfo, BytesIO(s.encode("utf-8")))


def write_tarfile_xml(
element_tree: lxml.etree.ElementTree,
path: str | Path,
inner_path: str,
**kwargs: Any,
) -> None:
"""Write an XML document a tar archive.

:param element_tree: An element
:param path: The path to the resulting tar archive
:param inner_path: The path inside the tar archive to write the dataframe
:param kwargs: Additional kwargs to pass to :func:`lxml.etree.tostring`
"""
from lxml import etree

kwargs.setdefault("pretty_print", True)
s = etree.tostring(element_tree, **kwargs)
tarinfo = tarfile.TarInfo(name=inner_path)
tarinfo.size = len(s)
with tarfile.TarFile(path, mode="w") as tar_file:
tar_file.addfile(tarinfo, BytesIO(s))


def read_tarfile_csv(
path: str | Path, inner_path: str, sep: str = "\t", **kwargs: Any
) -> pandas.DataFrame:
Expand Down Expand Up @@ -835,13 +963,9 @@ def read_rdf(path: str | Path, **kwargs: Any) -> rdflib.Graph:
"""
import rdflib

if isinstance(path, str):
path = Path(path)
graph = rdflib.Graph()
with (
gzip.open(path, "rb") if isinstance(path, Path) and path.suffix == ".gz" else open(path)
) as file:
graph.parse(file, **kwargs) # type:ignore
with safe_open(path, representation="binary", operation="read") as file:
graph.parse(file, **kwargs)
return graph


Expand Down Expand Up @@ -1121,14 +1245,6 @@ def gunzip(source: str | Path, target: str | Path) -> None:
shutil.copyfileobj(in_file, out_file)


#: A human-readable flag for how to open a file.
Operation: TypeAlias = Literal["read", "write"]
OPERATION_VALUES: set[str] = set(typing.get_args(Operation))

#: A human-readable flag for how to open a file.
Representation: TypeAlias = Literal["text", "binary"]
REPRESENTATION_VALUES: set[str] = set(typing.get_args(Representation))

MODE_MAP: dict[tuple[Operation, Representation], Literal["rt", "wt", "rb", "wb"]] = {
("read", "text"): "rt",
("read", "binary"): "rb",
Expand Down Expand Up @@ -1181,7 +1297,7 @@ def safe_open(
@contextlib.contextmanager
def safe_open_writer(
f: str | Path | TextIO, *, delimiter: str = "\t", **kwargs: Any
) -> Generator[_csv._writer, None, None]:
) -> Generator[Writer, None, None]:
"""Open a CSV writer, wrapping :func:`csv.writer`.

:param f: A path to a file, or an already open text-based IO object
Expand Down Expand Up @@ -1224,7 +1340,7 @@ def safe_open_dict_writer(
@contextlib.contextmanager
def safe_open_reader(
f: str | Path | TextIO, *, delimiter: str = "\t", **kwargs: Any
) -> Generator[_csv._reader, None, None]:
) -> Generator[Reader, None, None]:
"""Open a CSV reader, wrapping :func:`csv.reader`.

:param f: A path to a file, or an already open text-based IO object
Expand Down
Loading
Loading