Skip to content

ImageGen models support for export_models.py #3301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions demos/common/export_models/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def add_common_arguments(parser):
parser_rerank.add_argument('--num_streams', default="1", help='The number of parallel execution streams to use for the model. Use at least 2 on 2 socket CPU systems.', dest='num_streams')
parser_rerank.add_argument('--max_doc_length', default=16000, type=int, help='Maximum length of input documents in tokens', dest='max_doc_length')
parser_rerank.add_argument('--version', default="1", help='version of the model', dest='version')

parser_image_generation = subparsers.add_parser('image_generation', help='export model for image generation endpoint')
add_common_arguments(parser_image_generation)
parser_image_generation.add_argument('--resolution', default="512x512", help='Resolution of generated images if not specified by the request', dest='resolution') # unused for now, param as an example
args = vars(parser.parse_args())

embedding_graph_template = """input_stream: "REQUEST_PAYLOAD:input"
Expand Down Expand Up @@ -213,6 +217,23 @@ def add_common_arguments(parser):
]
}"""

image_generation_graph_template = """input_stream: "HTTP_REQUEST_PAYLOAD:input"
output_stream: "HTTP_RESPONSE_PAYLOAD:output"

node: {
name: "ImageGenExecutor"
calculator: "ImageGenCalculator"
input_stream: "HTTP_REQUEST_PAYLOAD:input"
input_side_packet: "IMAGE_GEN_NODE_RESOURCES:pipes"
output_stream: "HTTP_RESPONSE_PAYLOAD:output"
node_options: {
[type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: {
models_path: "{{model_path}}",
#resolution: "{{resolution}}", # unused for now
}
}
}"""

def export_rerank_tokenizer(source_model, destination_path, max_length):
import openvino as ov
from openvino_tokenizers import convert_tokenizer
Expand Down Expand Up @@ -448,6 +469,27 @@ def export_rerank_model(model_repository_path, source_model, model_name, precisi
add_servable_to_config(config_file_path, model_name, os.path.relpath( os.path.join(model_repository_path, model_name), os.path.dirname(config_file_path)))


def export_image_generation_model(model_repository_path, source_model, model_name, precision, task_parameters, config_file_path, resolution):
model_path = "./"
model_index_path = os.path.join(target_path, 'model_index.json')

if os.path.isfile(model_index_path):
print("Model index file already exists. Skipping conversion.")
return

optimum_command = "optimum-cli export openvino --model {} --weight-format {} {}".format(source_model, precision, target_path)

if os.system(optimum_command):
raise ValueError("Failed to export image generation model model", source_model)

gtemplate = jinja2.Environment(loader=jinja2.BaseLoader).from_string(image_generation_graph_template)
graph_content = gtemplate.render(model_path=model_path, resolution=resolution)
with open(os.path.join(model_repository_path, model_name, 'graph.pbtxt'), 'w') as f:
f.write(graph_content)
print("Created graph {}".format(os.path.join(model_repository_path, model_name, 'graph.pbtxt')))
add_servable_to_config(config_file_path, model_name, os.path.relpath( os.path.join(model_repository_path, model_name), os.path.dirname(config_file_path)))


if not os.path.isdir(args['model_repository_path']):
raise ValueError(f"The model repository path '{args['model_repository_path']}' is not a valid directory.")
if args['source_model'] is None:
Expand Down Expand Up @@ -477,4 +519,6 @@ def export_rerank_model(model_repository_path, source_model, model_name, precisi
elif args['task'] == 'rerank':
export_rerank_model(args['model_repository_path'], args['source_model'], args['model_name'] ,args['precision'], template_parameters, str(args['version']), args['config_file_path'], args['max_doc_length'])

elif args['task'] == 'image_generation':
export_image_generation_model(args['model_repository_path'], args['source_model'], args['model_name'], args['precision'], template_parameters, args['config_file_path'], args['resolution'])

1 change: 1 addition & 0 deletions demos/common/export_models/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ einops
torchvision==0.21.0
timm==1.0.15
auto-gptq==0.7.1
diffusers==0.33.1 # for image generation