From 1e0d0766977f3eb291ce3d0143ffdab8aa577e2b Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Mon, 9 Jun 2025 18:00:17 +0800 Subject: [PATCH 1/3] Add support for all dict methods to `ShardedH5IOStroe`. --- keras/src/saving/saving_lib.py | 218 +++++++++++++++++++++++++--- keras/src/saving/saving_lib_test.py | 32 +++- 2 files changed, 223 insertions(+), 27 deletions(-) diff --git a/keras/src/saving/saving_lib.py b/keras/src/saving/saving_lib.py index 72492cb4532c..74006fd2b9df 100644 --- a/keras/src/saving/saving_lib.py +++ b/keras/src/saving/saving_lib.py @@ -1040,15 +1040,20 @@ def __bool__(self): # will mistakenly using `__len__` to determine the value. return self.h5_file.__bool__() - def _get_h5_file(self, path_or_io): + def _get_h5_file(self, path_or_io, mode=None): + mode = mode or self.mode + if mode not in ("r", "w", "a"): + raise ValueError( + f"`mode` should be either 'r', 'w' or 'a'. Received: {mode}" + ) if self.archive: - if self.mode == "w": + if mode == "w": self.io_file = io.BytesIO() else: self.io_file = self.archive.open(str(path_or_io), "r") - return h5py.File(self.io_file, mode=self.mode) + return h5py.File(self.io_file, mode=mode) else: - return h5py.File(path_or_io, mode=self.mode) + return h5py.File(path_or_io, mode=mode) def make(self, path, metadata=None): """Make a new H5 entry group. @@ -1148,10 +1153,12 @@ def __getitem__(self, key): and value.attrs["dtype"] == "bfloat16" ): value = np.array(value, dtype=ml_dtypes.bfloat16) + else: + value = np.array(value) return value def __setitem__(self, key, value): - if self.mode != "w": + if self.mode not in ("w", "a"): raise ValueError("Setting a value is only allowed in write mode.") if not self._h5_entry_initialized: self._create_h5_group(self._h5_entry_path) @@ -1164,7 +1171,7 @@ def __setitem__(self, key, value): self._h5_entry_group[key] = value def __delitem__(self, key): - if self.mode != "w": + if self.mode not in ("w", "a"): raise ValueError("Deleting a value is only allowed in write mode.") del self._h5_entry_group[key] @@ -1202,7 +1209,7 @@ def __init__(self, path_or_io, max_shard_size=5, archive=None, mode="r"): self.archive = archive self.io_file = None - self.max_shard_size = float(max_shard_size) + self.max_shard_size = float(max_shard_size) * 1024**3 # To bytes. self.base_name = self.path.stem.replace(".weights", "") if self.path.suffix != ".json": @@ -1226,6 +1233,7 @@ def __init__(self, path_or_io, max_shard_size=5, archive=None, mode="r"): self.current_shard_size = 0 self.total_shard_size = 0 # In bytes. self.current_shard_path = None + self.current_shard_filenames = [] if self.mode == "w": self.sharding_config = { "metadata": { @@ -1243,6 +1251,27 @@ def __init__(self, path_or_io, max_shard_size=5, archive=None, mode="r"): self.sharding_config = json.load(map_file) self.h5_file = self._create_new_shard_file() + def make(self, path, metadata=None): + """Make a new H5 entry group. + + This method is only available in write mode. It defers the creation of + the H5 entry group until `__setitem__` is called, preventing the + creation of empty groups. + + The information about the current shard is reset. + + Args: + path: `str`. The variable path. + metadata: Optional `dict`. The metadata to save with the H5 entry + group. Defaults to `None`. + """ + self.current_shard_filenames = [] + if self.h5_file is not None: + self.current_shard_filenames.append( + pathlib.Path(self.h5_file.filename).name + ) + return super().make(path, metadata) + def get(self, path): """Get the H5 entry group. @@ -1259,9 +1288,17 @@ def get(self, path): # If not found, check shard map and switch files. weight_map = self.sharding_config["weight_map"] - filename = weight_map.get(parsed_path) or weight_map.get( + filenames = weight_map.get(parsed_path) or weight_map.get( "/" + parsed_path + "/vars" ) + if filenames is not None: + if not isinstance(filenames, list): + filenames = [filenames] + self.current_shard_filenames = filenames + filename = filenames[0] + else: + self.current_shard_filenames = [] + filename = None if filename is not None and filename != self.current_shard_path.name: self.close() @@ -1269,7 +1306,9 @@ def get(self, path): return super().get(path) def close(self): - self.h5_file.close() + if self.h5_file is not None: + self.h5_file.close() + self.h5_file = None if self.mode == "w": self.sharding_config["metadata"]["total_size"] = ( self.total_shard_size @@ -1289,28 +1328,128 @@ def close(self): # Shard-specific methods. def _create_new_shard_file(self): + """Create a new shard file and return the H5 file object.""" new_shard_path = ( f"{self.base_name}_{self.current_shard_index:05}.weights.h5" ) self.current_shard_index += 1 self.current_shard_path = self.path.with_name(new_shard_path) - return self._get_h5_file(self.current_shard_path) + h5_file = self._get_h5_file(self.current_shard_path) + self.current_shard_filenames.append(pathlib.Path(h5_file.filename).name) + self._h5_entry_initialized = False + return h5_file + + def _switch_h5_file(self, filename, mode): + """Switch to a different H5 file with the specified mode. + + This is useful for retrieving information from all shards, such as the + total length, keys, and items. + """ + if mode not in ("r", "a"): + raise ValueError( + f"`mode` should be either 'r' or 'a'. Received: {mode}" + ) + self.close() + self.h5_file = self._get_h5_file( + self.path.with_name(filename), mode=mode + ) + self._get_h5_group(self._h5_entry_path) + + def _restore_h5_file(self): + """Ensure the current shard is the last one created. + + We use mode="a" to avoid truncating the file during the switching. + """ + if ( + pathlib.Path(self.h5_file.filename).name + != self.current_shard_path.name + ): + self._switch_h5_file(self.current_shard_path.name, mode="a") # H5 entry level methods. + def _get_h5_group(self, path): + """Get the H5 entry group. If it doesn't exist, return an empty dict.""" + try: + if not path: + self._h5_entry_group = self.h5_file["vars"] + else: + self._h5_entry_group = self.h5_file[path]["vars"] + self._h5_entry_initialized = True + except KeyError: + self._h5_entry_group = {} + self._h5_entry_initialized = False + + # Dict methods. + + def __len__(self): + total_len = self._h5_entry_group.__len__() + for filename in self.current_shard_filenames: + if filename == self.current_shard_path.name: + continue + self._switch_h5_file(filename, mode="r") + total_len += self._h5_entry_group.__len__() + self._restore_h5_file() + return total_len + + def keys(self): + keys = set(self._h5_entry_group.keys()) + for filename in self.current_shard_filenames: + if filename == self.current_shard_path.name: + continue + self._switch_h5_file(filename, mode="r") + keys.update(self._h5_entry_group.keys()) + self._restore_h5_file() + return keys + + def items(self): + yield from self._h5_entry_group.items() + for filename in self.current_shard_filenames: + if filename == self.current_shard_path.name: + continue + self._switch_h5_file(filename, mode="r") + yield from self._h5_entry_group.items() + self._restore_h5_file() + + def values(self): + yield from self._h5_entry_group.values() + for filename in self.current_shard_filenames: + if filename == self.current_shard_path.name: + continue + self._switch_h5_file(filename, mode="r") + yield from self._h5_entry_group.values() + self._restore_h5_file() + + def __getitem__(self, key): + if key in self._h5_entry_group: + return super().__getitem__(key) + + for filename in self.current_shard_filenames: + if filename == self.current_shard_path.name: + continue + self._switch_h5_file(filename, mode="r") + if key in self._h5_entry_group: + item = super().__getitem__(key) + self._restore_h5_file() + return item + raise KeyError( + f"Key '{key}' not found in any of the shards: " + f"{self.current_shard_filenames}" + ) + def __setitem__(self, key, value): + self._restore_h5_file() + # Accumulate `current_shard_size`. value = backend.convert_to_numpy(value) dtype = backend.standardize_dtype(value.dtype) weight_counts = math.prod(value.shape) per_param_size = dtype_utils.dtype_size(dtype) - value_size = weight_counts * per_param_size / (8.0 * 1024**3) # To GB. - self.total_shard_size += weight_counts * per_param_size / 8 # In bytes. + value_size = weight_counts * per_param_size / 8 # In bytes. + self.total_shard_size += value_size if value_size > self.max_shard_size: - value_size_str = readable_memory_size(value_size * 1024**3) - max_shard_size_str = readable_memory_size( - self.max_shard_size * 1024**3 - ) + value_size_str = readable_memory_size(value_size) + max_shard_size_str = readable_memory_size(self.max_shard_size) raise ValueError( f"The size of {key} is {value_size_str} which " f"exceeds the maximum shard size {max_shard_size_str}. You " @@ -1323,16 +1462,53 @@ def __setitem__(self, key, value): if self.current_shard_size > self.max_shard_size: self.close() self.h5_file = self._create_new_shard_file() - self.make(self._h5_entry_path) self.current_shard_size = value_size super().__setitem__(key, value) + # Update the weight map. variable_path = self._h5_entry_group.name - if variable_path not in self.sharding_config["weight_map"]: - self.sharding_config["weight_map"][variable_path] = ( - self.current_shard_path.name - ) + shard_filename = self.current_shard_path.name + weight_map = self.sharding_config["weight_map"] + if variable_path not in weight_map: + weight_map[variable_path] = shard_filename + else: + if not isinstance(weight_map[variable_path], list): + weight_map[variable_path] = [weight_map[variable_path]] + if shard_filename not in weight_map[variable_path]: + weight_map[variable_path].append(shard_filename) + + def __delitem__(self, key): + if key in self._h5_entry_group: + super().__delitem__(key) + return + + for filename in self.current_shard_filenames: + if filename == self.current_shard_path.name: + continue + self._switch_h5_file(filename, mode="a") + if key in self._h5_entry_group: + super().__delitem__(key) + self._restore_h5_file() + return + raise KeyError( + f"Key '{key}' not found in any of the shards: " + f"{self.current_shard_filenames}" + ) + + def __contains__(self, item): + if item in self._h5_entry_group: + return True + + for filename in self.current_shard_filenames: + if filename == self.current_shard_path.name: + continue + self._switch_h5_file(filename, mode="r") + if item in self._h5_entry_group: + self._restore_h5_file() + return True + self._restore_h5_file() + return False class NpzIOStore: diff --git a/keras/src/saving/saving_lib_test.py b/keras/src/saving/saving_lib_test.py index 8ba3001f5f6b..36e861d64ac9 100644 --- a/keras/src/saving/saving_lib_test.py +++ b/keras/src/saving/saving_lib_test.py @@ -789,10 +789,10 @@ def test_weights_sharding(self, model_name, max_shard_size): model, temp_filepath, max_shard_size=max_shard_size ) self.assertIn("mymodel.weights.json", os.listdir(temp_filepath.parent)) - if max_shard_size == 512: + if max_shard_size == 1: # 1 sharded file + 1 config file = 2. self.assertLen(os.listdir(temp_filepath.parent), 2) - elif max_shard_size == 10: + elif max_shard_size == 0.01: # 3 sharded file + 1 config file = 4. self.assertLen(os.listdir(temp_filepath.parent), 4) @@ -1272,24 +1272,30 @@ def test_sharded_h5_io_store_basics(self): name = "sharded_store" temp_filepath = Path(os.path.join(self.get_temp_dir(), f"{name}.json")) - # Pre-defined data. - a = np.random.random((2, 4)).astype("float32") - b = np.random.random((4, 8)).astype("int32") + # Pre-defined data. Each has about 0.0037GB. + a = np.random.random((1000, 1000)).astype("float32") + b = np.random.random((1000, 1000)).astype("int32") # Set. - store = saving_lib.ShardedH5IOStore(temp_filepath, mode="w") + store = saving_lib.ShardedH5IOStore( + temp_filepath, max_shard_size=0.005, mode="w" + ) vars_store = store.make("vars") vars_store["a"] = a vars_store["b"] = b vars_store["c"] = 42 + self.assertLen(store.sharding_config["weight_map"]["/vars/vars"], 2) + self.assertLen(vars_store, 3) self.assertAllClose(vars_store["a"], a) self.assertAllClose(vars_store["b"], b) self.assertEqual(int(vars_store["c"][()]), 42) # Delete. del vars_store["c"] + self.assertLen(vars_store, 2) # Contain. + self.assertIn("a", vars_store) self.assertNotIn("c", vars_store) store.close() @@ -1301,10 +1307,24 @@ def test_sharded_h5_io_store_basics(self): # Get. store = saving_lib.ShardedH5IOStore(temp_filepath, mode="r") vars_store = store.get("vars") + self.assertLen(vars_store, 2) self.assertAllClose(vars_store["a"], a) self.assertAllClose(vars_store["b"], b) self.assertNotIn("c", vars_store) + # Keys. + for key in ["a", "b"]: + self.assertIn(key, vars_store.keys()) + + # Items. + for key, value in vars_store.items(): + if key == "a": + self.assertAllClose(value, a) + elif key == "b": + self.assertAllClose(value, b) + else: + raise ValueError(f"Unexpected key: {key}") + def test_sharded_h5_io_store_exception_raised(self): temp_filepath = Path(os.path.join(self.get_temp_dir(), "store.h5")) From fb066df39a76f74731933dca43f3678c0bfae80e Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Mon, 9 Jun 2025 18:20:38 +0800 Subject: [PATCH 2/3] Increase test coverage. --- keras/src/saving/saving_lib_test.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/keras/src/saving/saving_lib_test.py b/keras/src/saving/saving_lib_test.py index 36e861d64ac9..1e24f3b1e998 100644 --- a/keras/src/saving/saving_lib_test.py +++ b/keras/src/saving/saving_lib_test.py @@ -1293,6 +1293,9 @@ def test_sharded_h5_io_store_basics(self): # Delete. del vars_store["c"] self.assertLen(vars_store, 2) + del vars_store["a"] # Delete from an older shard. + self.assertLen(vars_store, 1) + vars_store["a"] = a # Contain. self.assertIn("a", vars_store) @@ -1325,6 +1328,15 @@ def test_sharded_h5_io_store_basics(self): else: raise ValueError(f"Unexpected key: {key}") + # Values. + for value in vars_store.values(): + if backend.standardize_dtype(value.dtype) == "float32": + self.assertAllClose(value, a) + elif backend.standardize_dtype(value.dtype) == "int32": + self.assertAllClose(value, b) + else: + raise ValueError(f"Unexpected value: {value}") + def test_sharded_h5_io_store_exception_raised(self): temp_filepath = Path(os.path.join(self.get_temp_dir(), "store.h5")) @@ -1354,4 +1366,16 @@ def test_sharded_h5_io_store_exception_raised(self): "float32" ) + # Bad `get`. + with self.assertRaisesRegex( + KeyError, r"Key 'abc' not found in any of the shards:" + ): + vars_store["abc"] + + # Bad `del`. + with self.assertRaisesRegex( + KeyError, r"Key 'abc' not found in any of the shards:" + ): + del vars_store["abc"] + store.close() From 3d828e969323b44f362c662e334034023eb288e0 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Tue, 10 Jun 2025 19:24:34 +0800 Subject: [PATCH 3/3] Update. --- keras/src/saving/saving_lib.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/keras/src/saving/saving_lib.py b/keras/src/saving/saving_lib.py index 74006fd2b9df..3d19e81ddec6 100644 --- a/keras/src/saving/saving_lib.py +++ b/keras/src/saving/saving_lib.py @@ -1153,7 +1153,11 @@ def __getitem__(self, key): and value.attrs["dtype"] == "bfloat16" ): value = np.array(value, dtype=ml_dtypes.bfloat16) - else: + elif ( + hasattr(value, "shape") + and hasattr(value, "dtype") + and not isinstance(value, np.ndarray) + ): value = np.array(value) return value