Skip to content

Commit cd2d10f

Browse files
committed
add smart_open
1 parent f8d2de6 commit cd2d10f

File tree

1 file changed

+75
-26
lines changed

1 file changed

+75
-26
lines changed

torchdrug/utils/file.py

Lines changed: 75 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import os
2+
import struct
23
import logging
4+
from tqdm import tqdm
35

46

57
logger = logging.getLogger(__name__)
@@ -30,12 +32,35 @@ def download(url, path, save_file=None, md5=None):
3032
return save_file
3133

3234

35+
def smart_open(file_name, mode="rb"):
36+
"""
37+
Open a regular file or a zipped file.
38+
39+
This function can be used as drop-in replacement of the builtin function `open()`.
40+
41+
Parameters:
42+
file_name (str): file name
43+
mode (str, optional): open mode for the file stream
44+
"""
45+
import bz2
46+
import gzip
47+
48+
extension = os.path.splitext(file_name)[1]
49+
if extension == '.bz2':
50+
return bz2.BZ2File(file_name, mode)
51+
elif extension == '.gz':
52+
return gzip.GzipFile(file_name, mode)
53+
else:
54+
return open(file_name, mode)
55+
56+
3357
def extract(zip_file, member=None):
3458
"""
3559
Extract files from a zip file. Currently, ``zip``, ``gz``, ``tar.gz``, ``tar`` file types are supported.
3660
3761
Parameters:
38-
member (str, optional): extract a specific member from the zip file.
62+
zip_file (str): file name
63+
member (str, optional): extract specific member from the zip file.
3964
If not specified, extract all members.
4065
"""
4166
import gzip
@@ -47,40 +72,64 @@ def extract(zip_file, member=None):
4772
if zip_name.endswith(".tar"):
4873
extension = ".tar" + extension
4974
zip_name = zip_name[:-4]
50-
51-
if member is None:
52-
save_file = zip_name
53-
else:
54-
save_file = os.path.join(os.path.dirname(zip_name), os.path.basename(member))
55-
if os.path.exists(save_file):
56-
return save_file
57-
58-
if member is None:
59-
logger.info("Extracting %s to %s" % (zip_file, save_file))
60-
else:
61-
logger.info("Extracting %s from %s to %s" % (member, zip_file, save_file))
75+
save_path = os.path.dirname(zip_file)
6276

6377
if extension == ".gz":
64-
with gzip.open(zip_file, "rb") as fin, open(save_file, "wb") as fout:
65-
shutil.copyfileobj(fin, fout)
78+
member = os.path.basename(zip_name)
79+
members = [member]
80+
save_files = [os.path.join(save_path, member)]
81+
for _member, save_file in zip(members, save_files):
82+
with open(zip_file, "rb") as fin:
83+
fin.seek(-4, 2)
84+
file_size = struct.unpack("<I", fin.read())[0]
85+
with gzip.open(zip_file, "rb") as fin:
86+
if not os.path.exists(save_file) or file_size != os.path.getsize(save_file):
87+
logger.info("Extracting %s to %s" % (zip_file, save_file))
88+
with open(save_file, "wb") as fout:
89+
shutil.copyfileobj(fin, fout)
6690
elif extension in [".tar.gz", ".tgz", ".tar"]:
67-
if member is None:
68-
with tarfile.open(zip_file, "r") as fin:
69-
fin.extractall(save_file)
91+
tar = tarfile.open(zip_file, "r")
92+
if member is not None:
93+
members = [member]
94+
save_files = [os.path.join(save_path, os.path.basename(member))]
95+
logger.info("Extracting %s from %s to %s" % (member, zip_file, save_files[0]))
7096
else:
71-
with tarfile.open(zip_file, "r").extractfile(member) as fin, open(save_file, "wb") as fout:
72-
shutil.copyfileobj(fin, fout)
97+
members = tar.getnames()
98+
save_files = [os.path.join(save_path, _member) for _member in members]
99+
logger.info("Extracting %s to %s" % (zip_file, save_path))
100+
for _member, save_file in zip(members, save_files):
101+
if tar.getmember(_member).isdir():
102+
os.makedirs(save_file, exist_ok=True)
103+
continue
104+
os.makedirs(os.path.dirname(save_file), exist_ok=True)
105+
if not os.path.exists(save_file) or tar.getmember(_member).size != os.path.getsize(save_file):
106+
with tar.extractfile(_member) as fin, open(save_file, "wb") as fout:
107+
shutil.copyfileobj(fin, fout)
73108
elif extension == ".zip":
74-
if member is None:
75-
with zipfile.ZipFile(zip_file) as fin:
76-
fin.extractall(save_file)
109+
zipped = zipfile.ZipFile(zip_file)
110+
if member is not None:
111+
members = [member]
112+
save_files = [os.path.join(save_path, os.path.basename(member))]
113+
logger.info("Extracting %s from %s to %s" % (member, zip_file, save_files[0]))
77114
else:
78-
with zipfile.ZipFile(zip_file).open(member, "r") as fin, open(save_file, "wb") as fout:
79-
shutil.copyfileobj(fin, fout)
115+
members = zipped.namelist()
116+
save_files = [os.path.join(save_path, _member) for _member in members]
117+
logger.info("Extracting %s to %s" % (zip_file, save_path))
118+
for _member, save_file in zip(members, save_files):
119+
if zipped.getinfo(_member).is_dir():
120+
os.makedirs(save_file, exist_ok=True)
121+
continue
122+
os.makedirs(os.path.dirname(save_file), exist_ok=True)
123+
if not os.path.exists(save_file) or zipped.getinfo(_member).file_size != os.path.getsize(save_file):
124+
with zipped.open(_member, "r") as fin, open(save_file, "wb") as fout:
125+
shutil.copyfileobj(fin, fout)
80126
else:
81127
raise ValueError("Unknown file extension `%s`" % extension)
82128

83-
return save_file
129+
if len(save_files) == 1:
130+
return save_files[0]
131+
else:
132+
return save_path
84133

85134

86135
def compute_md5(file_name, chunk_size=65536):

0 commit comments

Comments
 (0)