Skip to content

Commit a7895e3

Browse files
committed
apply review comments
1 parent abd1b68 commit a7895e3

File tree

3 files changed

+4
-8
lines changed

3 files changed

+4
-8
lines changed

optimum/exporters/openvino/__main__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -410,10 +410,10 @@ class StoreAttr(object):
410410
)
411411
model.config.pad_token_id = pad_token_id
412412

413-
if hasattr(model.config, "export_model_type") and model.config.export_model_type is not None:
413+
if hasattr(model.config, "export_model_type"):
414414
model_type = model.config.export_model_type.replace("_", "-")
415415
else:
416-
model_type = (getattr(model.config, "model_type", "") or "").replace("_", "-")
416+
model_type = model.config.model_type.replace("_", "-")
417417

418418
if (
419419
not custom_architecture

optimum/exporters/openvino/convert.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import inspect
1919
import logging
2020
import os
21-
import types
2221
from pathlib import Path
2322
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
2423

@@ -1115,10 +1114,7 @@ def get_ltx_video_models_for_export(pipeline, exporter, int_dtype, float_dtype):
11151114
}
11161115
)
11171116

1118-
def _vae_decode_forward(self, latent_sample, timestep=None):
1119-
return self.decode(z=latent_sample, temb=timestep)
1120-
1121-
vae_decoder.forward = types.MethodType(_vae_decode_forward, vae_decoder)
1117+
vae_decoder.forward = lambda latent_sample, timestep=None: vae_decoder.decode(z=latent_sample, temb=timestep)
11221118

11231119
vae_config_constructor = TasksManager.get_exporter_config_constructor(
11241120
model=vae_decoder,

tests/openvino/test_diffusion.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1006,7 +1006,7 @@ def test_textual_inversion(self):
10061006
np.testing.assert_allclose(ov_output, diffusers_output, atol=1e-4, rtol=1e-2)
10071007

10081008

1009-
@unittest.skipIf(is_transformers_version("<", "4.45"))
1009+
@unittest.skipIf(is_transformers_version("<", "4.45"), "Required transformers >= 4.45")
10101010
class OVPipelineForText2VideoTest(unittest.TestCase):
10111011
SUPPORTED_ARCHITECTURES = []
10121012
if is_transformers_version(">=", "4.45.0"):

0 commit comments

Comments
 (0)