Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions keras_hub/src/utils/timm/preset_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,10 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs):

def load_image_converter(self, cls, **kwargs):
pretrained_cfg = self.config.get("pretrained_cfg", None)
if not pretrained_cfg or "input_size" not in pretrained_cfg:
if not pretrained_cfg:
return None
# This assumes the same basic setup for all timm preprocessing, We may
# need to extend this as we cover more model types.
input_size = pretrained_cfg["input_size"]
mean = pretrained_cfg["mean"]
std = pretrained_cfg["std"]
scale = [1.0 / 255.0 / s for s in std]
Expand All @@ -63,7 +62,6 @@ def load_image_converter(self, cls, **kwargs):
if interpolation not in ("bilinear", "nearest", "bicubic"):
interpolation = "bilinear" # Unsupported interpolation type.
return cls(
image_size=input_size[1:],
scale=scale,
offset=offset,
interpolation=interpolation,
Expand Down