Skip to content

Commit 7e8d857

Browse files
authored
Add ONNX export support for granite models (#2043)
* feat(exporters/onnx): Add GraniteOnnxConfig and task support list Branch: OnnxGranite Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Add granite's normalized config for inference Branch: OnnxGranite Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat(onnx opt): Add onnx optimization support for granite Branch: OnnxGranite Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix(onnx/granite): Use LlamaOnnxConfig as the base for GraniteOnnxConfig Branch: OnnxGranite Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix(onnxruntime): Add "granite" to list of model types with grouped attention Branch: OnnxGranite Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Add granite to the list of models that require position_ids Branch: OnnxGranite Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix(granite): Add MIN_TORCH_VERSION for recently fixed torch bug #2043 (comment) Branch: OnnxGranite Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * test(granite): Add tiny random granite test for onnx exporter Branch: OnnxGranite Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * tests(onnxruntime): Add granite to onnxruntime tests Branch: OnnxGranite Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> --------- Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 6802a0c commit 7e8d857

File tree

9 files changed

+19
-1
lines changed

9 files changed

+19
-1
lines changed

optimum/exporters/onnx/model_configs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,11 @@ class GemmaOnnxConfig(LlamaOnnxConfig):
298298
pass
299299

300300

301+
class GraniteOnnxConfig(LlamaOnnxConfig):
302+
MIN_TRANSFORMERS_VERSION = version.parse("4.45.0")
303+
MIN_TORCH_VERSION = version.parse("2.5.0")
304+
305+
301306
class PhiOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
302307
DEFAULT_ONNX_OPSET = 14 # Phi now uses F.scaled_dot_product_attention by default for torch>=2.1.1.
303308
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

optimum/exporters/onnx/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
"phi",
8787
"phi3",
8888
"qwen2",
89+
"granite",
8990
}
9091

9192

optimum/exporters/tasks.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -915,6 +915,13 @@ class TasksManager:
915915
"text-classification",
916916
onnx="LlamaOnnxConfig",
917917
),
918+
"granite": supported_tasks_mapping(
919+
"feature-extraction",
920+
"feature-extraction-with-past",
921+
"text-generation",
922+
"text-generation-with-past",
923+
onnx="GraniteOnnxConfig",
924+
),
918925
"pegasus": supported_tasks_mapping(
919926
"feature-extraction",
920927
"feature-extraction-with-past",

optimum/onnxruntime/modeling_decoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def prepare_past_key_values(
340340
if self.model_type == "gemma":
341341
num_attention_heads = self.normalized_config.num_key_value_heads
342342
embed_size_per_head = self.normalized_config.head_dim
343-
elif self.model_type in {"mistral", "llama", "qwen2"}:
343+
elif self.model_type in {"mistral", "llama", "qwen2", "granite"}:
344344
num_attention_heads = self.normalized_config.num_key_value_heads
345345
else:
346346
num_attention_heads = self.normalized_config.num_attention_heads

optimum/onnxruntime/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ class ORTConfigManager:
128128
"gpt-neo": "gpt2",
129129
"gpt-neox": "gpt2",
130130
"gptj": "gpt2",
131+
"granite": "gpt2",
131132
# longt5 with O4 results in segmentation fault
132133
"longt5": "bert",
133134
"llama": "gpt2",

optimum/utils/normalized_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ class NormalizedConfigManager:
281281
"xlm-roberta": NormalizedTextConfig,
282282
"yolos": NormalizedVisionConfig,
283283
"qwen2": NormalizedTextConfig,
284+
"granite": NormalizedTextConfigWithGQA,
284285
}
285286

286287
@classmethod

tests/exporters/exporters_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
"gpt-neo": "hf-internal-testing/tiny-random-GPTNeoModel",
101101
"gpt-neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
102102
"gptj": "hf-internal-testing/tiny-random-GPTJModel",
103+
"granite": "hf-internal-testing/tiny-random-GraniteForCausalLM",
103104
"groupvit": "hf-internal-testing/tiny-random-groupvit",
104105
"ibert": "hf-internal-testing/tiny-random-IBertModel",
105106
"imagegpt": "hf-internal-testing/tiny-random-ImageGPTModel",

tests/onnxruntime/test_modeling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2324,6 +2324,7 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin):
23242324
"gpt_neo",
23252325
"gpt_neox",
23262326
"gptj",
2327+
"granite",
23272328
"llama",
23282329
"mistral",
23292330
"mpt",

tests/onnxruntime/utils_onnxruntime_tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@
104104
"gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel",
105105
"gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
106106
"gptj": "hf-internal-testing/tiny-random-GPTJForCausalLM",
107+
"granite": "hf-internal-testing/tiny-random-GraniteForCausalLM",
107108
"groupvit": "hf-internal-testing/tiny-random-groupvit",
108109
"hubert": "hf-internal-testing/tiny-random-HubertModel",
109110
"ibert": "hf-internal-testing/tiny-random-IBertModel",

0 commit comments

Comments
 (0)