|
17 | 17 | import json
|
18 | 18 | import os
|
19 | 19 | from unittest.mock import patch
|
| 20 | +import tempfile |
20 | 21 |
|
21 | 22 | # Third Party
|
22 | 23 | import pytest
|
| 24 | +import filecmp |
23 | 25 |
|
24 | 26 | # Local
|
25 | 27 | from build.utils import process_accelerate_launch_args
|
| 28 | +from build.utils import copy_checkpoint |
26 | 29 |
|
27 | 30 | HAPPY_PATH_DUMMY_CONFIG_PATH = os.path.join(
|
28 | 31 | os.path.dirname(__file__), "dummy_job_config.json"
|
@@ -108,3 +111,132 @@ def test_process_accelerate_launch_custom_config_file(patch_path_exists):
|
108 | 111 | temp_job_config = {"accelerate_launch_args": {"config_file": dummy_config_path}}
|
109 | 112 | args = process_accelerate_launch_args(temp_job_config)
|
110 | 113 | assert args.config_file == dummy_config_path
|
| 114 | + |
| 115 | + |
| 116 | +class CopyCheckpointTestConfig: |
| 117 | + def __init__(self, temp_root): |
| 118 | + |
| 119 | + # Create the following file tree for testing: |
| 120 | + # test_root |
| 121 | + # test_copytree_source |
| 122 | + # tf1.txt |
| 123 | + # tf2.txt |
| 124 | + # tf3.txt |
| 125 | + # subdir1 |
| 126 | + # tf4.txt |
| 127 | + # tf5.txt |
| 128 | + # tf6.txt |
| 129 | + |
| 130 | + self.test_root = temp_root |
| 131 | + self.source_dir = os.path.join(self.test_root, "test_copytree_source") |
| 132 | + self.source_sub_dir = os.path.join(self.source_dir, "subdir1") |
| 133 | + |
| 134 | + os.mkdir(self.source_dir) |
| 135 | + for file_number in range(2): |
| 136 | + with open( |
| 137 | + os.path.join(self.source_dir, f"tf{file_number+1}.txt"), |
| 138 | + "a", |
| 139 | + encoding="utf-8", |
| 140 | + ) as f: |
| 141 | + f.close() |
| 142 | + |
| 143 | + os.mkdir(self.source_sub_dir) |
| 144 | + for file_number in range(2): |
| 145 | + with open( |
| 146 | + os.path.join(self.source_sub_dir, f"tf{file_number+4}.txt"), |
| 147 | + "a", |
| 148 | + encoding="utf-8", |
| 149 | + ) as f: |
| 150 | + f.close() |
| 151 | + |
| 152 | + def are_dir_trees_equal(self, dir1, dir2): |
| 153 | + |
| 154 | + dirs_cmp = filecmp.dircmp(dir1, dir2) |
| 155 | + if ( |
| 156 | + len(dirs_cmp.left_only) > 0 |
| 157 | + or len(dirs_cmp.right_only) > 0 |
| 158 | + or len(dirs_cmp.funny_files) > 0 |
| 159 | + ): |
| 160 | + return False |
| 161 | + (_, mismatch, errors) = filecmp.cmpfiles( |
| 162 | + dir1, dir2, dirs_cmp.common_files, shallow=False |
| 163 | + ) |
| 164 | + if len(mismatch) > 0 or len(errors) > 0: |
| 165 | + return False |
| 166 | + for common_dir in dirs_cmp.common_dirs: |
| 167 | + new_dir1 = os.path.join(dir1, common_dir) |
| 168 | + new_dir2 = os.path.join(dir2, common_dir) |
| 169 | + if not self.are_dir_trees_equal(new_dir1, new_dir2): |
| 170 | + return False |
| 171 | + return True |
| 172 | + |
| 173 | + |
| 174 | +def test_copy_checkpoint_dest_dir_does_not_exist(): |
| 175 | + |
| 176 | + # Init source directory |
| 177 | + with tempfile.TemporaryDirectory() as test_root: |
| 178 | + config = CopyCheckpointTestConfig(test_root) |
| 179 | + |
| 180 | + target_dir_does_not_exist = os.path.join( |
| 181 | + config.test_root, "test_copytree_target" |
| 182 | + ) |
| 183 | + |
| 184 | + # Execute the copy |
| 185 | + copy_checkpoint(config.source_dir, target_dir_does_not_exist) |
| 186 | + assert config.are_dir_trees_equal(config.source_dir, target_dir_does_not_exist) |
| 187 | + |
| 188 | + |
| 189 | +def test_copy_checkpoint_dest_dir_does_exist(): |
| 190 | + |
| 191 | + # Init source directory |
| 192 | + with tempfile.TemporaryDirectory() as test_root: |
| 193 | + config = CopyCheckpointTestConfig(test_root) |
| 194 | + |
| 195 | + # Init target directory |
| 196 | + target_dir_does_exist = os.path.join(config.test_root, "test_copytree_target2") |
| 197 | + os.mkdir(target_dir_does_exist) |
| 198 | + # Add a file to the target. This file will be overwritten during the copy. |
| 199 | + with open( |
| 200 | + os.path.join(target_dir_does_exist, "tf1.txt"), |
| 201 | + "a", |
| 202 | + encoding="utf-8", |
| 203 | + ) as f: |
| 204 | + f.close() |
| 205 | + # Add a file to the target. This file does not exist in source. |
| 206 | + with open( |
| 207 | + os.path.join(target_dir_does_exist, "tf9.txt"), |
| 208 | + "a", |
| 209 | + encoding="utf-8", |
| 210 | + ) as f: |
| 211 | + f.close() |
| 212 | + # Execute the copy |
| 213 | + copy_checkpoint(config.source_dir, target_dir_does_exist) |
| 214 | + assert os.path.exists(os.path.join(target_dir_does_exist, "tf9.txt")) |
| 215 | + # Remove it so we can validate the dir trees are equal. |
| 216 | + os.remove(os.path.join(target_dir_does_exist, "tf9.txt")) |
| 217 | + assert config.are_dir_trees_equal(config.source_dir, target_dir_does_exist) |
| 218 | + |
| 219 | + |
| 220 | +def test_copy_checkpoint_dest_dir_not_writeable(): |
| 221 | + |
| 222 | + # Init source directory |
| 223 | + with tempfile.TemporaryDirectory() as test_root: |
| 224 | + config = CopyCheckpointTestConfig(test_root) |
| 225 | + |
| 226 | + # Init target directory |
| 227 | + target_dir_not_writeable = os.path.join( |
| 228 | + config.test_root, "test_copytree_notwriteable" |
| 229 | + ) |
| 230 | + |
| 231 | + os.makedirs(target_dir_not_writeable, mode=0o446) |
| 232 | + |
| 233 | + # Execute the copy. Should FAIL |
| 234 | + with pytest.raises(PermissionError) as e: |
| 235 | + copy_checkpoint(config.source_dir, target_dir_not_writeable) |
| 236 | + assert "Permission denied:" in str(e.value) |
| 237 | + |
| 238 | + |
| 239 | +def test_copy_checkpoint_source_dir_does_not_exist(): |
| 240 | + with pytest.raises(FileNotFoundError) as e: |
| 241 | + copy_checkpoint("/doesnotexist", "/tmp") |
| 242 | + assert "No such file or directory" in str(e.value) |
0 commit comments