|
| 1 | +import os |
| 2 | +from typing import Literal |
| 3 | + |
| 4 | +from hydra_zen import builds |
| 5 | +from omegaconf import MISSING |
| 6 | +from timm.data.transforms import ResizeKeepRatio |
| 7 | +from torchvision import transforms |
| 8 | + |
| 9 | +from mmlearn.conf import external_store |
| 10 | + |
| 11 | +@external_store(group="datasets/transforms") |
| 12 | +def med_clip_vision_transform( |
| 13 | + image_crop_size: int = 224, job_type: Literal["train", "eval"] = "train" |
| 14 | +) -> transforms.Compose: |
| 15 | + """Return transforms for training/evaluating CLIP with medical images. |
| 16 | +
|
| 17 | + Parameters |
| 18 | + ---------- |
| 19 | + image_crop_size : int, default=224 |
| 20 | + Size of the image crop. |
| 21 | + job_type : {"train", "eval"}, default="train" |
| 22 | + Type of the job (training or evaluation) for which the transforms are needed. |
| 23 | +
|
| 24 | + Returns |
| 25 | + ------- |
| 26 | + transforms.Compose |
| 27 | + Composed transforms for training CLIP with medical images. |
| 28 | + """ |
| 29 | + return transforms.Compose( |
| 30 | + [ |
| 31 | + ResizeKeepRatio( |
| 32 | + 512 if job_type == "train" else image_crop_size, interpolation="bicubic" |
| 33 | + ), |
| 34 | + transforms.RandomCrop(image_crop_size) |
| 35 | + if job_type == "train" |
| 36 | + else transforms.CenterCrop(image_crop_size), |
| 37 | + transforms.ToTensor(), |
| 38 | + transforms.Normalize( |
| 39 | + mean=[0.48145466, 0.4578275, 0.40821073], |
| 40 | + std=[0.26862954, 0.26130258, 0.27577711], |
| 41 | + ), |
| 42 | + ] |
| 43 | + ) |
0 commit comments