Skip to content

Commit 6b38215

Browse files
author
Guang Yang
committed
Export to ExecuTorch: Code Skeleton
1 parent 7e8d857 commit 6b38215

File tree

16 files changed

+715
-1
lines changed

16 files changed

+715
-1
lines changed

optimum/commands/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,5 @@
1414

1515
from .base import BaseOptimumCLICommand, CommandInfo, RootOptimumCLICommand
1616
from .env import EnvironmentCommand
17-
from .export import ExportCommand, ONNXExportCommand, TFLiteExportCommand
17+
from .export import ExecuTorchExportCommand, ExportCommand, ONNXExportCommand, TFLiteExportCommand
1818
from .optimum_cli import optimum_cli_subcommand

optimum/commands/export/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@
1414

1515

1616
from .base import ExportCommand
17+
from .executorch import ExecuTorchExportCommand
1718
from .onnx import ONNXExportCommand
1819
from .tflite import TFLiteExportCommand

optimum/commands/export/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""optimum.exporters command-line interface base classes."""
1616

1717
from .. import BaseOptimumCLICommand, CommandInfo
18+
from .executorch import ExecuTorchExportCommand
1819
from .onnx import ONNXExportCommand
1920
from .tflite import TFLiteExportCommand
2021

@@ -25,6 +26,11 @@ class ExportCommand(BaseOptimumCLICommand):
2526
help="Export PyTorch and TensorFlow models to several format.",
2627
)
2728
SUBCOMMANDS = (
29+
CommandInfo(
30+
name="executorch",
31+
help="Export PyTorch model to ExecuTorch.",
32+
subcommand_class=ExecuTorchExportCommand,
33+
),
2834
CommandInfo(
2935
name="onnx",
3036
help="Export PyTorch and TensorFlow to ONNX.",

optimum/commands/export/executorch.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""Defines the command line for the export with ExecuTorch."""
2+
3+
from pathlib import Path
4+
from typing import TYPE_CHECKING
5+
6+
from ...exporters import TasksManager
7+
from ..base import BaseOptimumCLICommand
8+
9+
10+
if TYPE_CHECKING:
11+
from argparse import ArgumentParser
12+
13+
14+
def parse_args_executorch(parser):
15+
required_group = parser.add_argument_group("Required arguments")
16+
required_group.add_argument(
17+
"-m", "--model", type=str, required=True, help="Model ID on huggingface.co or path on disk to load model from."
18+
)
19+
required_group.add_argument(
20+
"--output_dir", type=Path, help="Path indicating the directory where to store the generated ExecuTorch model."
21+
)
22+
23+
optional_group = parser.add_argument_group("Optional arguments")
24+
optional_group.add_argument(
25+
"--task",
26+
default="auto",
27+
help=(
28+
"The task to export the model for. If not specified, the task will be auto-inferred based on the model. Available tasks depend on the model, but are among:"
29+
f" {str(TasksManager.get_all_tasks())}. For decoder models, use `xxx-with-past` to export the model using past key values in the decoder."
30+
),
31+
)
32+
optional_group.add_argument(
33+
"--recipe",
34+
type=str,
35+
default="xnnpack",
36+
help='Pre-defined recipes for export to ExecuTorch. Defaults to "xnnpack".',
37+
)
38+
39+
40+
class ExecuTorchExportCommand(BaseOptimumCLICommand):
41+
@staticmethod
42+
def parse_args(parser: "ArgumentParser"):
43+
return parse_args_executorch(parser)
44+
45+
def run(self):
46+
from ...exporters.executorch import main_export
47+
48+
main_export(
49+
model_name_or_path=self.args.model,
50+
task=self.args.task,
51+
recipe=self.args.recipe,
52+
output_dir=self.args.output_dir,
53+
)

optimum/executorchruntime/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from typing import TYPE_CHECKING
2+
from transformers.utils import _LazyModule
3+
4+
5+
_import_structure = {
6+
"modeling_executorch": [
7+
"ExecuTorchModelForCausalLM",
8+
],
9+
}
10+
11+
if TYPE_CHECKING:
12+
from .modeling_executorch import ExecuTorchModelForCausalLM
13+
else:
14+
import sys
15+
16+
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
"""ExecuTorchModelForXXX classes, allowing to run ExecuTorch Models with ExecuTorch Runtime using the same API as Transformers."""
2+
3+
import logging
4+
import os
5+
import warnings
6+
from pathlib import Path
7+
from tempfile import TemporaryDirectory
8+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
9+
10+
import torch
11+
from executorch.extension.pybindings.portable_lib import _load_for_executorch
12+
from huggingface_hub import hf_hub_download
13+
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
14+
from huggingface_hub.utils import EntryNotFoundError
15+
from transformers import (
16+
AutoConfig,
17+
AutoModel,
18+
GenerationMixin,
19+
AutoModelForCausalLM,
20+
GenerationConfig,
21+
)
22+
from transformers.integrations.executorch import TorchExportableModuleWithStaticCache
23+
from transformers.modeling_outputs import (
24+
BaseModelOutput,
25+
CausalLMOutput,
26+
CausalLMOutputWithPast,
27+
ModelOutput,
28+
)
29+
30+
from ..exporters import TasksManager
31+
from ..exporters.executorch import main_export
32+
from ..modeling_base import FROM_PRETRAINED_START_DOCSTRING, OptimizedModel
33+
34+
if TYPE_CHECKING:
35+
from transformers import PretrainedConfig
36+
37+
38+
logger = logging.getLogger(__name__)
39+
40+
41+
class ExecuTorchModelForCausalLM(OptimizedModel):
42+
"""
43+
ExecuTorch model with a causal language modeling head for ExecuTorch Runtime inference.
44+
"""
45+
46+
auto_model_class = AutoModelForCausalLM
47+
48+
def __init__(
49+
self,
50+
model: "ExecuTorchModule",
51+
config: "PretrainedConfig",
52+
):
53+
super().__init__(model, config)
54+
self.et_model = model
55+
print(f"DEBUG all static methods: {self.et_model.method_names()}")
56+
self.use_kv_cache = self.et_model.run_method("use_kv_cache")[0]
57+
self.max_seq_len = self.et_model.run_method("get_max_seq_len")[0]
58+
self.max_batch_size = self.et_model.run_method("get_max_batch_size")[0]
59+
self.dtype = self.et_model.run_method("get_dtype")[0]
60+
self.bos_token_id = self.et_model.run_method("get_bos_id")[0]
61+
self.eos_token_id = self.et_model.run_method("get_eos_id")[0]
62+
self.vocab_size = self.et_model.run_method("get_vocab_size")[0]
63+
64+
def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor) -> torch.Tensor:
65+
return self.et_model.forward((input_ids, cache_position))[0]
66+
67+
@classmethod
68+
def from_pretrained(
69+
cls,
70+
model_dir_path: Union[str, Path],
71+
task: str,
72+
recipe: str,
73+
config: "PretrainedConfig" = None,
74+
use_auth_token: Optional[Union[bool, str]] = None,
75+
token: Optional[Union[bool, str]] = None,
76+
revision: Optional[str] = None,
77+
force_download: bool = False,
78+
cache_dir: str = HUGGINGFACE_HUB_CACHE,
79+
subfolder: str = "",
80+
local_files_only: bool = False,
81+
) -> "ExecuTorchModelForCausalLM":
82+
if use_auth_token is not None:
83+
warnings.warn(
84+
"The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.",
85+
FutureWarning,
86+
)
87+
if token is not None:
88+
raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.")
89+
token = use_auth_token
90+
91+
full_path = os.path.join(f"{model_dir_path}", "model.pte")
92+
model = _load_for_executorch(full_path)
93+
logging.debug(f"{model.method_meta('forward')}")
94+
return cls(
95+
model=model,
96+
config=config,
97+
)
98+
99+
def _save_pretrained(self, save_directory):
100+
"""
101+
Saves a model weights into a directory, so that it can be re-loaded using the
102+
[`from_pretrained`] class method.
103+
"""
104+
raise NotImplementedError
105+
106+
@classmethod
107+
def _export(
108+
cls,
109+
model_id: str,
110+
task: str,
111+
recipe: str,
112+
config: "PretrainedConfig",
113+
use_auth_token: Optional[Union[bool, str]] = None,
114+
token: Optional[Union[bool, str]] = None,
115+
revision: Optional[str] = None,
116+
force_download: bool = False,
117+
cache_dir: str = HUGGINGFACE_HUB_CACHE,
118+
subfolder: str = "",
119+
local_files_only: bool = False,
120+
trust_remote_code: bool = False,
121+
):
122+
if use_auth_token is not None:
123+
warnings.warn(
124+
"The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.",
125+
FutureWarning,
126+
)
127+
if token is not None:
128+
raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.")
129+
token = use_auth_token
130+
131+
save_dir = TemporaryDirectory()
132+
save_dir_path = Path(save_dir.name)
133+
134+
# Export to ExecuTorch and save the pte file to the temporary directory
135+
main_export(
136+
model_name_or_path=model_id,
137+
output=save_dir_path,
138+
task=task,
139+
recipe=recipe,
140+
subfolder=subfolder,
141+
revision=revision,
142+
cache_dir=cache_dir,
143+
token=token,
144+
local_files_only=local_files_only,
145+
force_download=force_download,
146+
trust_remote_code=trust_remote_code,
147+
)
148+
149+
return cls._from_pretrained(
150+
model_dir_path=save_dir_path,
151+
task=task,
152+
recipe=recipe,
153+
config=config,
154+
use_auth_token=use_auth_token,
155+
subfolder=subfolder,
156+
revision=revision,
157+
cache_dir=cache_dir,
158+
token=token,
159+
local_files_only=local_files_only,
160+
force_download=force_download,
161+
)
162+
163+
def generate(
164+
self,
165+
prompt_tokens: List[int],
166+
echo: bool = False,
167+
pos_base: int = 0,
168+
) -> List[int]:
169+
170+
self.device = torch.device("cpu")
171+
self.max_seq_len = 256
172+
generated_tokens = []
173+
174+
# prefill
175+
for i, prompt_token in enumerate(prompt_tokens):
176+
logits = self.forward(
177+
input_ids=torch.tensor([prompt_token], dtype=torch.long, device=self.device).unsqueeze(0),
178+
cache_position=torch.tensor([i], dtype=torch.long, device=self.device),
179+
)
180+
181+
next_token = torch.argmax(logits, dim=-1).item()
182+
generated_tokens = prompt_tokens + [next_token]
183+
184+
while len(generated_tokens) < self.max_seq_len:
185+
logits = self.forward(
186+
input_ids=torch.tensor([next_token], dtype=torch.long, device=self.device).unsqueeze(0),
187+
cache_position=torch.tensor(
188+
[pos_base + len(generated_tokens) - 1],
189+
dtype=torch.long,
190+
device=self.device,
191+
),
192+
)
193+
next_token = torch.argmax(logits, dim=-1).item()
194+
generated_tokens.append(next_token)
195+
if next_token == self.eos_token_id:
196+
break
197+
198+
return generated_tokens if echo else generated_tokens[len(prompt_tokens) :]
199+
200+
def text_generation(
201+
self,
202+
tokenizer: "PreTrainedTokenizer",
203+
prompt: str,
204+
echo: bool = True,
205+
) -> List[int]:
206+
"""
207+
Perform text completion for a prompt using the language model.
208+
209+
Args:
210+
prompt (str): Text prompt for completion.
211+
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
212+
213+
Returns:
214+
Generated list of tokens.
215+
216+
Note:
217+
This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness.
218+
"""
219+
self.tokenizer = tokenizer
220+
if self.tokenizer.bos_token_id is not None and self.tokenizer.bos_token_id != self.bos_token_id:
221+
raise ValueError(
222+
f"The tokenizer's bos_token_id={self.tokenizer.bos_token_id} must be the same as the model's bos_token_id={self.bos_token_id}."
223+
)
224+
if self.tokenizer.eos_token_id is not None and self.tokenizer.eos_token_id != self.eos_token_id:
225+
raise ValueError(
226+
f"The tokenizer's eos_token_id={self.tokenizer.eos_token_id} must be the same as the model's eos_token_id={self.eos_token_id}."
227+
)
228+
229+
prompt_tokens = self.tokenizer.encode(prompt)
230+
generated_tokens = self.generate(
231+
prompt_tokens=prompt_tokens,
232+
echo=echo,
233+
)
234+
return self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from typing import TYPE_CHECKING
2+
3+
from transformers.utils import _LazyModule
4+
5+
6+
_import_structure = {
7+
"convert": [
8+
"export_to_executorch",
9+
],
10+
"__main__": ["main_export"],
11+
}
12+
13+
if TYPE_CHECKING:
14+
from .__main__ import main_export
15+
from .convert import export_to_executorch
16+
else:
17+
import sys
18+
19+
sys.modules[__name__] = _LazyModule(
20+
__name__,
21+
globals()["__file__"],
22+
_import_structure,
23+
module_spec=__spec__,
24+
)

0 commit comments

Comments
 (0)