1
1
import os
2
+ import struct
2
3
import logging
4
+ from tqdm import tqdm
3
5
4
6
5
7
logger = logging .getLogger (__name__ )
@@ -30,12 +32,35 @@ def download(url, path, save_file=None, md5=None):
30
32
return save_file
31
33
32
34
35
+ def smart_open (file_name , mode = "rb" ):
36
+ """
37
+ Open a regular file or a zipped file.
38
+
39
+ This function can be used as drop-in replacement of the builtin function `open()`.
40
+
41
+ Parameters:
42
+ file_name (str): file name
43
+ mode (str, optional): open mode for the file stream
44
+ """
45
+ import bz2
46
+ import gzip
47
+
48
+ extension = os .path .splitext (file_name )[1 ]
49
+ if extension == '.bz2' :
50
+ return bz2 .BZ2File (file_name , mode )
51
+ elif extension == '.gz' :
52
+ return gzip .GzipFile (file_name , mode )
53
+ else :
54
+ return open (file_name , mode )
55
+
56
+
33
57
def extract (zip_file , member = None ):
34
58
"""
35
59
Extract files from a zip file. Currently, ``zip``, ``gz``, ``tar.gz``, ``tar`` file types are supported.
36
60
37
61
Parameters:
38
- member (str, optional): extract a specific member from the zip file.
62
+ zip_file (str): file name
63
+ member (str, optional): extract specific member from the zip file.
39
64
If not specified, extract all members.
40
65
"""
41
66
import gzip
@@ -47,40 +72,64 @@ def extract(zip_file, member=None):
47
72
if zip_name .endswith (".tar" ):
48
73
extension = ".tar" + extension
49
74
zip_name = zip_name [:- 4 ]
50
-
51
- if member is None :
52
- save_file = zip_name
53
- else :
54
- save_file = os .path .join (os .path .dirname (zip_name ), os .path .basename (member ))
55
- if os .path .exists (save_file ):
56
- return save_file
57
-
58
- if member is None :
59
- logger .info ("Extracting %s to %s" % (zip_file , save_file ))
60
- else :
61
- logger .info ("Extracting %s from %s to %s" % (member , zip_file , save_file ))
75
+ save_path = os .path .dirname (zip_file )
62
76
63
77
if extension == ".gz" :
64
- with gzip .open (zip_file , "rb" ) as fin , open (save_file , "wb" ) as fout :
65
- shutil .copyfileobj (fin , fout )
78
+ member = os .path .basename (zip_name )
79
+ members = [member ]
80
+ save_files = [os .path .join (save_path , member )]
81
+ for _member , save_file in zip (members , save_files ):
82
+ with open (zip_file , "rb" ) as fin :
83
+ fin .seek (- 4 , 2 )
84
+ file_size = struct .unpack ("<I" , fin .read ())[0 ]
85
+ with gzip .open (zip_file , "rb" ) as fin :
86
+ if not os .path .exists (save_file ) or file_size != os .path .getsize (save_file ):
87
+ logger .info ("Extracting %s to %s" % (zip_file , save_file ))
88
+ with open (save_file , "wb" ) as fout :
89
+ shutil .copyfileobj (fin , fout )
66
90
elif extension in [".tar.gz" , ".tgz" , ".tar" ]:
67
- if member is None :
68
- with tarfile .open (zip_file , "r" ) as fin :
69
- fin .extractall (save_file )
91
+ tar = tarfile .open (zip_file , "r" )
92
+ if member is not None :
93
+ members = [member ]
94
+ save_files = [os .path .join (save_path , os .path .basename (member ))]
95
+ logger .info ("Extracting %s from %s to %s" % (member , zip_file , save_files [0 ]))
70
96
else :
71
- with tarfile .open (zip_file , "r" ).extractfile (member ) as fin , open (save_file , "wb" ) as fout :
72
- shutil .copyfileobj (fin , fout )
97
+ members = tar .getnames ()
98
+ save_files = [os .path .join (save_path , _member ) for _member in members ]
99
+ logger .info ("Extracting %s to %s" % (zip_file , save_path ))
100
+ for _member , save_file in zip (members , save_files ):
101
+ if tar .getmember (_member ).isdir ():
102
+ os .makedirs (save_file , exist_ok = True )
103
+ continue
104
+ os .makedirs (os .path .dirname (save_file ), exist_ok = True )
105
+ if not os .path .exists (save_file ) or tar .getmember (_member ).size != os .path .getsize (save_file ):
106
+ with tar .extractfile (_member ) as fin , open (save_file , "wb" ) as fout :
107
+ shutil .copyfileobj (fin , fout )
73
108
elif extension == ".zip" :
74
- if member is None :
75
- with zipfile .ZipFile (zip_file ) as fin :
76
- fin .extractall (save_file )
109
+ zipped = zipfile .ZipFile (zip_file )
110
+ if member is not None :
111
+ members = [member ]
112
+ save_files = [os .path .join (save_path , os .path .basename (member ))]
113
+ logger .info ("Extracting %s from %s to %s" % (member , zip_file , save_files [0 ]))
77
114
else :
78
- with zipfile .ZipFile (zip_file ).open (member , "r" ) as fin , open (save_file , "wb" ) as fout :
79
- shutil .copyfileobj (fin , fout )
115
+ members = zipped .namelist ()
116
+ save_files = [os .path .join (save_path , _member ) for _member in members ]
117
+ logger .info ("Extracting %s to %s" % (zip_file , save_path ))
118
+ for _member , save_file in zip (members , save_files ):
119
+ if zipped .getinfo (_member ).is_dir ():
120
+ os .makedirs (save_file , exist_ok = True )
121
+ continue
122
+ os .makedirs (os .path .dirname (save_file ), exist_ok = True )
123
+ if not os .path .exists (save_file ) or zipped .getinfo (_member ).file_size != os .path .getsize (save_file ):
124
+ with zipped .open (_member , "r" ) as fin , open (save_file , "wb" ) as fout :
125
+ shutil .copyfileobj (fin , fout )
80
126
else :
81
127
raise ValueError ("Unknown file extension `%s`" % extension )
82
128
83
- return save_file
129
+ if len (save_files ) == 1 :
130
+ return save_files [0 ]
131
+ else :
132
+ return save_path
84
133
85
134
86
135
def compute_md5 (file_name , chunk_size = 65536 ):
0 commit comments