10
10
# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
11
11
# specific language governing permissions and limitations under the License.
12
12
13
+ import logging
13
14
from typing import Optional
14
15
15
16
import torch
@@ -50,14 +51,22 @@ def __init__(
50
51
"""
51
52
super ().__init__ ()
52
53
53
- if model .config .cache_implementation == "static" :
54
+ if not hasattr (model .config , "use_cache" ) or model .config .use_cache is False :
55
+ raise ValueError ("The model must have caching enabled to be performant." )
56
+
57
+ if not hasattr (model .config , "cache_implementation" ):
58
+ # If `cache_implementation` is not specified explicitly in the config, `DynamicCache` will
59
+ # be used by default, so export will use `StaticCache` by default.
60
+ logging .info ("Using `StaticCache` for export as `cache_implementation` is not specified in the config." )
54
61
self .model = TorchExportableModuleWithStaticCache (model )
55
- elif model .config .cache_implementation == "hybrid" :
56
- self .model = TorchExportableModuleWithHybridCache (model , max_batch_size , max_cache_len )
57
62
else :
58
- raise ValueError (
59
- f"Unsupported cache implementation in this export recipe: '{ model .config .cache_implementation } '"
60
- )
63
+ if model .config .cache_implementation == "hybrid" :
64
+ self .model = TorchExportableModuleWithHybridCache (model , max_batch_size , max_cache_len )
65
+ else :
66
+ raise ValueError (
67
+ f"Unsupported cache implementation: { model .config .cache_implementation } . "
68
+ "Please use `hybrid` or `static`."
69
+ )
61
70
62
71
def forward (
63
72
self ,
@@ -462,15 +471,19 @@ def convert_and_export_with_cache(
462
471
model : PreTrainedModel ,
463
472
example_input_ids : Optional [torch .Tensor ] = None ,
464
473
example_cache_position : Optional [torch .Tensor ] = None ,
474
+ dynamic_shapes : Optional [dict ] = None ,
475
+ strict : Optional [bool ] = None ,
465
476
):
466
477
"""
467
478
Convert a `PreTrainedModel` into an exportable module and export it using `torch.export`,
468
479
ensuring the exported model is compatible with `ExecuTorch`.
469
480
470
481
Args:
471
482
model (`PreTrainedModel`): The pretrained model to be exported.
472
- example_input_ids (`torch.Tensor`): Example input token id used by `torch.export`.
473
- example_cache_position (`torch.Tensor`): Example current cache position used by `torch.export`.
483
+ example_input_ids (`Optional[torch.Tensor]`): Example input token id used by `torch.export`.
484
+ example_cache_position (`Optional[torch.Tensor]`): Example current cache position used by `torch.export`.
485
+ dynamic_shapes(`Optional[dict]`): Dynamic shapes used by `torch.export`.
486
+ strict(`Optional[bool]`): Flag to instruct `torch.export` to use `torchdynamo`.
474
487
475
488
Returns:
476
489
Exported program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`.
@@ -489,14 +502,21 @@ def convert_and_export_with_cache(
489
502
example_cache_position if example_cache_position is not None else torch .tensor ([0 ], dtype = torch .long )
490
503
)
491
504
492
- if is_torch_greater_or_equal ("2.5 .0" ):
505
+ if is_torch_greater_or_equal ("2.6 .0" ):
493
506
exported_program = torch .export .export (
494
507
TorchExportableModuleWithStaticCache (model ),
495
- args = (example_input_ids ,),
496
- kwargs = {"cache_position" : example_cache_position },
497
- strict = True ,
508
+ args = (example_input_ids , example_cache_position ),
509
+ kwargs = {},
510
+ dynamic_shapes = dynamic_shapes ,
511
+ strict = strict if strict is not None else True ,
498
512
)
499
513
else :
514
+ if dynamic_shapes is not None :
515
+ logging .warning (
516
+ "Dynamic shapes spec will be ignored by convert_and_export_with_cache for torch < 2.6.0."
517
+ )
518
+ if strict is not None :
519
+ logging .warning ("The strict flag will be ingored by convert_and_export_with_cache for torch < 2.6.0." )
500
520
# We have to keep this path for BC.
501
521
#
502
522
# Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal
0 commit comments