Skip to content

Commit b7fe854

Browse files
guangy10Guang Yang
authored andcommitted
Allow override inputs to export recipe (huggingface#37508)
Add option to specify dynamic shapes during export Co-authored-by: Guang Yang <guangyang@fb.com>
1 parent fb8628c commit b7fe854

File tree

2 files changed

+90
-16
lines changed

2 files changed

+90
-16
lines changed

src/transformers/integrations/executorch.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
1111
# specific language governing permissions and limitations under the License.
1212

13+
import logging
1314
from typing import Optional
1415

1516
import torch
@@ -50,14 +51,22 @@ def __init__(
5051
"""
5152
super().__init__()
5253

53-
if model.config.cache_implementation == "static":
54+
if not hasattr(model.config, "use_cache") or model.config.use_cache is False:
55+
raise ValueError("The model must have caching enabled to be performant.")
56+
57+
if not hasattr(model.config, "cache_implementation"):
58+
# If `cache_implementation` is not specified explicitly in the config, `DynamicCache` will
59+
# be used by default, so export will use `StaticCache` by default.
60+
logging.info("Using `StaticCache` for export as `cache_implementation` is not specified in the config.")
5461
self.model = TorchExportableModuleWithStaticCache(model)
55-
elif model.config.cache_implementation == "hybrid":
56-
self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len)
5762
else:
58-
raise ValueError(
59-
f"Unsupported cache implementation in this export recipe: '{model.config.cache_implementation}'"
60-
)
63+
if model.config.cache_implementation == "hybrid":
64+
self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len)
65+
else:
66+
raise ValueError(
67+
f"Unsupported cache implementation: {model.config.cache_implementation}. "
68+
"Please use `hybrid` or `static`."
69+
)
6170

6271
def forward(
6372
self,
@@ -462,15 +471,19 @@ def convert_and_export_with_cache(
462471
model: PreTrainedModel,
463472
example_input_ids: Optional[torch.Tensor] = None,
464473
example_cache_position: Optional[torch.Tensor] = None,
474+
dynamic_shapes: Optional[dict] = None,
475+
strict: Optional[bool] = None,
465476
):
466477
"""
467478
Convert a `PreTrainedModel` into an exportable module and export it using `torch.export`,
468479
ensuring the exported model is compatible with `ExecuTorch`.
469480
470481
Args:
471482
model (`PreTrainedModel`): The pretrained model to be exported.
472-
example_input_ids (`torch.Tensor`): Example input token id used by `torch.export`.
473-
example_cache_position (`torch.Tensor`): Example current cache position used by `torch.export`.
483+
example_input_ids (`Optional[torch.Tensor]`): Example input token id used by `torch.export`.
484+
example_cache_position (`Optional[torch.Tensor]`): Example current cache position used by `torch.export`.
485+
dynamic_shapes(`Optional[dict]`): Dynamic shapes used by `torch.export`.
486+
strict(`Optional[bool]`): Flag to instruct `torch.export` to use `torchdynamo`.
474487
475488
Returns:
476489
Exported program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`.
@@ -489,14 +502,21 @@ def convert_and_export_with_cache(
489502
example_cache_position if example_cache_position is not None else torch.tensor([0], dtype=torch.long)
490503
)
491504

492-
if is_torch_greater_or_equal("2.5.0"):
505+
if is_torch_greater_or_equal("2.6.0"):
493506
exported_program = torch.export.export(
494507
TorchExportableModuleWithStaticCache(model),
495-
args=(example_input_ids,),
496-
kwargs={"cache_position": example_cache_position},
497-
strict=True,
508+
args=(example_input_ids, example_cache_position),
509+
kwargs={},
510+
dynamic_shapes=dynamic_shapes,
511+
strict=strict if strict is not None else True,
498512
)
499513
else:
514+
if dynamic_shapes is not None:
515+
logging.warning(
516+
"Dynamic shapes spec will be ignored by convert_and_export_with_cache for torch < 2.6.0."
517+
)
518+
if strict is not None:
519+
logging.warning("The strict flag will be ingored by convert_and_export_with_cache for torch < 2.6.0.")
500520
# We have to keep this path for BC.
501521
#
502522
# Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal

tests/utils/test_cache_utils.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
is_torch_available,
2626
require_gptq,
2727
require_non_xpu,
28-
require_read_token,
2928
require_torch,
3029
require_torch_accelerator,
3130
require_torch_gpu,
@@ -693,8 +692,6 @@ def test_dynamic_cache_exportability(self):
693692
for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache):
694693
self.assertTrue(torch.allclose(v1, v2))
695694

696-
@slow
697-
@require_read_token
698695
def test_static_cache_exportability(self):
699696
"""
700697
Tests that static cache works with `torch.export()`
@@ -709,8 +706,9 @@ def test_static_cache_exportability(self):
709706
attn_implementation = "sdpa" # Export and ExecuTorch only works for SdpaAttention
710707
batch_size = 1
711708
max_cache_len = 1234
709+
model_id = "hf-internal-testing/tiny-random-LlamaForCausalLM"
712710
model = AutoModelForCausalLM.from_pretrained(
713-
"google/gemma-2b",
711+
model_id,
714712
device_map=device,
715713
torch_dtype=dtype,
716714
attn_implementation=attn_implementation,
@@ -748,3 +746,59 @@ def test_static_cache_exportability(self):
748746
n_static_value_caches = n_static_value_caches + 1
749747
self.assertEqual(n_static_key_caches, model.config.num_hidden_layers)
750748
self.assertEqual(n_static_value_caches, model.config.num_hidden_layers)
749+
750+
# Export with dynamic shapes using Dim.AUTO
751+
tokenizer = AutoTokenizer.from_pretrained(model_id)
752+
input_ids = tokenizer("Here's everything I know", return_tensors="pt").input_ids
753+
dynamic_shapes = {"input_ids": {1: torch.export.Dim.AUTO}, "cache_position": None}
754+
exported_program = convert_and_export_with_cache(
755+
model,
756+
example_input_ids=input_ids,
757+
dynamic_shapes=dynamic_shapes,
758+
strict=False,
759+
)
760+
761+
def test_hybrid_cache_exportability(self):
762+
"""
763+
Tests that static cache works with `torch.export()`
764+
"""
765+
if not is_torch_greater_or_equal("2.6"):
766+
self.skipTest(reason="This test requires torch >= 2.6 to run.")
767+
768+
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
769+
770+
set_seed(0)
771+
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
772+
model = AutoModelForCausalLM.from_pretrained(model_id)
773+
model.eval()
774+
self.assertEqual(model.config.use_cache, True)
775+
self.assertEqual(model.config.cache_implementation, "hybrid")
776+
777+
# Export + HybridCache
778+
model.eval()
779+
max_batch_size = 1
780+
max_cache_len = 23
781+
exportable_module = TorchExportableModuleForDecoderOnlyLM(model, max_batch_size, max_cache_len)
782+
exported_program = exportable_module.export()
783+
n_g_key_caches = n_g_value_caches = 0
784+
for buffer_name, buffer in exported_program.named_buffers():
785+
if buffer_name.startswith("key_cache"):
786+
self.assertTrue(buffer.shape[0] == max_batch_size)
787+
self.assertTrue(buffer.shape[2] == max_cache_len)
788+
n_g_key_caches = n_g_key_caches + 1
789+
if buffer_name.startswith("value_cache"):
790+
self.assertTrue(buffer.shape[0] == max_batch_size)
791+
self.assertTrue(buffer.shape[2] == max_cache_len)
792+
n_g_value_caches = n_g_value_caches + 1
793+
self.assertEqual(n_g_key_caches, model.config.num_hidden_layers)
794+
self.assertEqual(n_g_value_caches, model.config.num_hidden_layers)
795+
796+
# Export with dynamic shapes using Dim.AUTO
797+
tokenizer = AutoTokenizer.from_pretrained(model_id)
798+
input_ids = tokenizer("Here's everything I know", return_tensors="pt").input_ids
799+
dynamic_shapes = {"input_ids": {1: torch.export.Dim.AUTO}, "cache_position": None}
800+
exported_program = exportable_module.export(
801+
input_ids=input_ids,
802+
dynamic_shapes=dynamic_shapes,
803+
strict=False,
804+
)

0 commit comments

Comments
 (0)