Skip to content

Commit 7dfd4e7

Browse files
olson-ibmanhuong
andauthored
Replace shutil.copytree() to fix permission error (#251)
* Closes #1089 Signed-off-by: Joe Olson <joe.olson@ibm.com> * added unit tests, fixed other issues raised by review. Signed-off-by: Joe Olson <joe.olson@ibm.com> * Closes 1089 Signed-off-by: Joe Olson <joe.olson@ibm.com> * Closes #1089 Signed-off-by: Joe Olson <joe.olson@ibm.com> * Closes #1089 Signed-off-by: Joe Olson <joe.olson@ibm.com> * Closes #1089 Signed-off-by: Joe Olson <joe.olson@ibm.com> * Closes #1089 Signed-off-by: Joe Olson <joe.olson@ibm.com> * Closes #1089 Signed-off-by: Joe Olson <joe.olson@ibm.com> * Closes #1089 Signed-off-by: Joe Olson <joe.olson@ibm.com> --------- Signed-off-by: Joe Olson <joe.olson@ibm.com> Co-authored-by: Anh Uong <anh.uong@ibm.com>
1 parent 6d15cf9 commit 7dfd4e7

File tree

3 files changed

+152
-4
lines changed

3 files changed

+152
-4
lines changed

build/accelerate_launch.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
process_accelerate_launch_args,
3636
serialize_args,
3737
get_highest_checkpoint,
38+
copy_checkpoint,
3839
)
3940
from tuning.utils.config_utils import get_json_config
4041
from tuning.config.tracker_configs import FileLoggingTrackerConfig
@@ -124,10 +125,8 @@ def main():
124125
pt_checkpoint_dir,
125126
original_output_dir,
126127
)
127-
shutil.copytree(
128-
os.path.join(tempdir, pt_checkpoint_dir),
129-
original_output_dir,
130-
dirs_exist_ok=True,
128+
copy_checkpoint(
129+
os.path.join(tempdir, pt_checkpoint_dir), original_output_dir
131130
)
132131
except Exception as e: # pylint: disable=broad-except
133132
logging.error(traceback.format_exc())

build/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,23 @@
2121
# Third Party
2222
import torch
2323
from accelerate.commands.launch import launch_command_parser
24+
import shutil
25+
26+
27+
def copy_checkpoint(source, destination):
28+
if not os.path.exists(destination):
29+
os.makedirs(destination)
30+
shutil.copystat(source, destination)
31+
# Have a list of directory objects, now iterate over them.
32+
for item in os.listdir(source):
33+
source_file = os.path.join(source, item)
34+
destination_file = os.path.join(destination, item)
35+
if os.path.isdir(source_file):
36+
# recursive call for subdirectories
37+
copy_checkpoint(source_file, destination_file)
38+
else:
39+
# straight copy.
40+
shutil.copy2(source_file, destination_file)
2441

2542

2643
def get_highest_checkpoint(dir_path):

tests/build/test_utils.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@
1717
import json
1818
import os
1919
from unittest.mock import patch
20+
import tempfile
2021

2122
# Third Party
2223
import pytest
24+
import filecmp
2325

2426
# Local
2527
from build.utils import process_accelerate_launch_args
28+
from build.utils import copy_checkpoint
2629

2730
HAPPY_PATH_DUMMY_CONFIG_PATH = os.path.join(
2831
os.path.dirname(__file__), "dummy_job_config.json"
@@ -108,3 +111,132 @@ def test_process_accelerate_launch_custom_config_file(patch_path_exists):
108111
temp_job_config = {"accelerate_launch_args": {"config_file": dummy_config_path}}
109112
args = process_accelerate_launch_args(temp_job_config)
110113
assert args.config_file == dummy_config_path
114+
115+
116+
class CopyCheckpointTestConfig:
117+
def __init__(self, temp_root):
118+
119+
# Create the following file tree for testing:
120+
# test_root
121+
# test_copytree_source
122+
# tf1.txt
123+
# tf2.txt
124+
# tf3.txt
125+
# subdir1
126+
# tf4.txt
127+
# tf5.txt
128+
# tf6.txt
129+
130+
self.test_root = temp_root
131+
self.source_dir = os.path.join(self.test_root, "test_copytree_source")
132+
self.source_sub_dir = os.path.join(self.source_dir, "subdir1")
133+
134+
os.mkdir(self.source_dir)
135+
for file_number in range(2):
136+
with open(
137+
os.path.join(self.source_dir, f"tf{file_number+1}.txt"),
138+
"a",
139+
encoding="utf-8",
140+
) as f:
141+
f.close()
142+
143+
os.mkdir(self.source_sub_dir)
144+
for file_number in range(2):
145+
with open(
146+
os.path.join(self.source_sub_dir, f"tf{file_number+4}.txt"),
147+
"a",
148+
encoding="utf-8",
149+
) as f:
150+
f.close()
151+
152+
def are_dir_trees_equal(self, dir1, dir2):
153+
154+
dirs_cmp = filecmp.dircmp(dir1, dir2)
155+
if (
156+
len(dirs_cmp.left_only) > 0
157+
or len(dirs_cmp.right_only) > 0
158+
or len(dirs_cmp.funny_files) > 0
159+
):
160+
return False
161+
(_, mismatch, errors) = filecmp.cmpfiles(
162+
dir1, dir2, dirs_cmp.common_files, shallow=False
163+
)
164+
if len(mismatch) > 0 or len(errors) > 0:
165+
return False
166+
for common_dir in dirs_cmp.common_dirs:
167+
new_dir1 = os.path.join(dir1, common_dir)
168+
new_dir2 = os.path.join(dir2, common_dir)
169+
if not self.are_dir_trees_equal(new_dir1, new_dir2):
170+
return False
171+
return True
172+
173+
174+
def test_copy_checkpoint_dest_dir_does_not_exist():
175+
176+
# Init source directory
177+
with tempfile.TemporaryDirectory() as test_root:
178+
config = CopyCheckpointTestConfig(test_root)
179+
180+
target_dir_does_not_exist = os.path.join(
181+
config.test_root, "test_copytree_target"
182+
)
183+
184+
# Execute the copy
185+
copy_checkpoint(config.source_dir, target_dir_does_not_exist)
186+
assert config.are_dir_trees_equal(config.source_dir, target_dir_does_not_exist)
187+
188+
189+
def test_copy_checkpoint_dest_dir_does_exist():
190+
191+
# Init source directory
192+
with tempfile.TemporaryDirectory() as test_root:
193+
config = CopyCheckpointTestConfig(test_root)
194+
195+
# Init target directory
196+
target_dir_does_exist = os.path.join(config.test_root, "test_copytree_target2")
197+
os.mkdir(target_dir_does_exist)
198+
# Add a file to the target. This file will be overwritten during the copy.
199+
with open(
200+
os.path.join(target_dir_does_exist, "tf1.txt"),
201+
"a",
202+
encoding="utf-8",
203+
) as f:
204+
f.close()
205+
# Add a file to the target. This file does not exist in source.
206+
with open(
207+
os.path.join(target_dir_does_exist, "tf9.txt"),
208+
"a",
209+
encoding="utf-8",
210+
) as f:
211+
f.close()
212+
# Execute the copy
213+
copy_checkpoint(config.source_dir, target_dir_does_exist)
214+
assert os.path.exists(os.path.join(target_dir_does_exist, "tf9.txt"))
215+
# Remove it so we can validate the dir trees are equal.
216+
os.remove(os.path.join(target_dir_does_exist, "tf9.txt"))
217+
assert config.are_dir_trees_equal(config.source_dir, target_dir_does_exist)
218+
219+
220+
def test_copy_checkpoint_dest_dir_not_writeable():
221+
222+
# Init source directory
223+
with tempfile.TemporaryDirectory() as test_root:
224+
config = CopyCheckpointTestConfig(test_root)
225+
226+
# Init target directory
227+
target_dir_not_writeable = os.path.join(
228+
config.test_root, "test_copytree_notwriteable"
229+
)
230+
231+
os.makedirs(target_dir_not_writeable, mode=0o446)
232+
233+
# Execute the copy. Should FAIL
234+
with pytest.raises(PermissionError) as e:
235+
copy_checkpoint(config.source_dir, target_dir_not_writeable)
236+
assert "Permission denied:" in str(e.value)
237+
238+
239+
def test_copy_checkpoint_source_dir_does_not_exist():
240+
with pytest.raises(FileNotFoundError) as e:
241+
copy_checkpoint("/doesnotexist", "/tmp")
242+
assert "No such file or directory" in str(e.value)

0 commit comments

Comments
 (0)