From 48ec5605d40028bb4439319c36bbf0847ee59f14 Mon Sep 17 00:00:00 2001 From: ProgramadorArtificial Date: Tue, 24 Sep 2024 18:16:07 -0300 Subject: [PATCH 1/6] Add option to save model as float16 (half) --- tools/misc/publish_model.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/tools/misc/publish_model.py b/tools/misc/publish_model.py index addf4cca64..d024f3151a 100644 --- a/tools/misc/publish_model.py +++ b/tools/misc/publish_model.py @@ -20,12 +20,21 @@ def parse_args(): type=str, default=['meta', 'state_dict'], help='keys to save in published checkpoint (default: meta state_dict)') + parser.add_argument( + '--float16', + action='store_true', + default=False, + help='Whether save model as float16') args = parser.parse_args() return args -def process_checkpoint(in_file, out_file, save_keys=['meta', 'state_dict']): +def process_checkpoint(in_file, + out_file, + save_keys=['meta', 'state_dict'], + float16=False): checkpoint = torch.load(in_file, map_location='cpu') + checkpoint['meta']['float16'] = float16 # only keep `meta` and `state_dict` for smaller file size ckpt_keys = list(checkpoint.keys()) @@ -41,6 +50,17 @@ def process_checkpoint(in_file, out_file, save_keys=['meta', 'state_dict']): # if it is necessary to remove some sensitive data in checkpoint['meta'], # add the code here. + if float16: + print(save_keys) + if 'meta' not in save_keys: + raise ValueError( + 'Key `meta` must be in save_keys to save model as float16. ' + 'Change float16 to False or add `meta` in save_keys.') + print_log('Saving model as float16.', logger='current') + for key in checkpoint['state_dict'].keys(): + checkpoint['state_dict'][key] = checkpoint['state_dict'][key].half( + ) + if digit_version(TORCH_VERSION) >= digit_version('1.8.0'): torch.save(checkpoint, out_file, _use_new_zipfile_serialization=False) else: @@ -58,7 +78,8 @@ def process_checkpoint(in_file, out_file, save_keys=['meta', 'state_dict']): def main(): args = parse_args() - process_checkpoint(args.in_file, args.out_file, args.save_keys) + process_checkpoint(args.in_file, args.out_file, args.save_keys, + args.float16) if __name__ == '__main__': From 23a897e551eb43d1579d1ebf2aa790d6e0f95419 Mon Sep 17 00:00:00 2001 From: ProgramadorArtificial Date: Tue, 24 Sep 2024 18:16:31 -0300 Subject: [PATCH 2/6] Add validation to change input image to half if the model is in half and CUDA --- mmpose/models/pose_estimators/base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mmpose/models/pose_estimators/base.py b/mmpose/models/pose_estimators/base.py index 216f592fda..db902ec6ae 100644 --- a/mmpose/models/pose_estimators/base.py +++ b/mmpose/models/pose_estimators/base.py @@ -158,6 +158,9 @@ def forward(self, if self.metainfo is not None: for data_sample in data_samples: data_sample.set_metainfo(self.metainfo) + param = next(self.backbone.parameters()) + if param.is_cuda and param.dtype == torch.float16: + inputs = inputs.half() return self.predict(inputs, data_samples) elif mode == 'tensor': return self._forward(inputs) From 14b1032c0b30ed562acf466f2385684fae2cf4e9 Mon Sep 17 00:00:00 2001 From: ProgramadorArtificial Date: Tue, 24 Sep 2024 18:16:40 -0300 Subject: [PATCH 3/6] Update documentation with option to save model as float16 (half) --- docs/en/user_guides/how_to_deploy.md | 6 ++++++ docs/zh_cn/user_guides/how_to_deploy.md | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/docs/en/user_guides/how_to_deploy.md b/docs/en/user_guides/how_to_deploy.md index 0b8e31a395..458bc3eb3f 100644 --- a/docs/en/user_guides/how_to_deploy.md +++ b/docs/en/user_guides/how_to_deploy.md @@ -32,6 +32,12 @@ For example: python tools/misc/publish_model.py ./epoch_10.pth ./epoch_10_publish.pth ``` +To save model as float16 (half) add --float16, which is as follows: + +```shell +python tools/misc/publish_model.py ${IN_FILE} ${OUT_FILE} --float16 +``` + The script will automatically simplify the model, save the simplified model to the specified path, and add a timestamp to the filename, for example, `./epoch_10_publish-21815b2c_20230726.pth`. ## Deployment with MMDeploy diff --git a/docs/zh_cn/user_guides/how_to_deploy.md b/docs/zh_cn/user_guides/how_to_deploy.md index 2349fcca09..4cad3c70d3 100644 --- a/docs/zh_cn/user_guides/how_to_deploy.md +++ b/docs/zh_cn/user_guides/how_to_deploy.md @@ -32,6 +32,12 @@ python tools/misc/publish_model.py ${IN_FILE} ${OUT_FILE} python tools/misc/publish_model.py ./epoch_10.pth ./epoch_10_publish.pth ``` +要将模型保存为 float16 (half),请添加 --float16,如下所示: + +```shell +python tools/misc/publish_model.py ${IN_FILE} ${OUT_FILE} --float16 +``` + 脚本会自动对模型进行精简,并将精简后的模型保存到制定路径,并在文件名的最后加上时间戳,例如 `./epoch_10_publish-21815b2c_20230726.pth`。 ## 使用 MMDeploy 部署 From a7300af76b1afae2465633fd436d92b5f0696a1f Mon Sep 17 00:00:00 2001 From: ProgramadorArtificial Date: Wed, 25 Sep 2024 18:11:04 -0300 Subject: [PATCH 4/6] Fix numpy.dtype size changed, may indicate binary incompatibility. Expected 96 from C header, got 88 from PyObject --- requirements/tests.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/tests.txt b/requirements/tests.txt index c63bc90822..2b9f4868db 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -2,6 +2,7 @@ coverage flake8 interrogate isort==4.3.21 +numpy==1.26.4 parameterized pytest pytest-runner From 2db9cb92af2bc174776eb53373e7c97c046d16a9 Mon Sep 17 00:00:00 2001 From: ProgramadorArtificial Date: Wed, 25 Sep 2024 18:41:14 -0300 Subject: [PATCH 5/6] Change numpy version --- requirements/tests.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/tests.txt b/requirements/tests.txt index 2b9f4868db..23298c9b88 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -2,7 +2,7 @@ coverage flake8 interrogate isort==4.3.21 -numpy==1.26.4 +numpy==1.21.6 parameterized pytest pytest-runner From a235610b2eefe4f9b40598f33fe98025e557033c Mon Sep 17 00:00:00 2001 From: ProgramadorArtificial Date: Thu, 26 Sep 2024 09:06:35 -0300 Subject: [PATCH 6/6] Remove numpy from requirements test. Not worked --- requirements/tests.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements/tests.txt b/requirements/tests.txt index 23298c9b88..c63bc90822 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -2,7 +2,6 @@ coverage flake8 interrogate isort==4.3.21 -numpy==1.21.6 parameterized pytest pytest-runner