Skip to content

Commit 9cdb68d

Browse files
Merge pull request #208 from stergioc/hoptimus
Addin H-optimus-0 Model
2 parents a2afa34 + e51b69f commit 9cdb68d

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

sopa/patches/models/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
from .histo_ssl import HistoSSLFeatures
33
from .dinov2 import DINOv2Features
44
from .dummy import DummyFeatures
5+
from .hoptimus0 import HOPTIMUSFeatures
56

6-
__all__ = ["Resnet50Features", "HistoSSLFeatures", "DINOv2Features", "DummyFeatures"]
7+
__all__ = ["Resnet50Features", "HistoSSLFeatures", "DINOv2Features", "DummyFeatures", "HOPTIMUSFeatures"]
78

89
available_models = {
910
"resnet50": Resnet50Features,
1011
"histo_ssl": HistoSSLFeatures,
1112
"dinov2": DINOv2Features,
1213
"dummy": DummyFeatures,
14+
"hoptimus0": HOPTIMUSFeatures,
1315
}

sopa/patches/models/hoptimus0.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from torch import nn
2+
from torchvision import transforms
3+
4+
5+
class HOPTIMUSFeatures(nn.Module):
6+
def __init__(self):
7+
super().__init__()
8+
9+
import timm
10+
11+
self.model = timm.create_model("hf_hub:bioptimus/H-optimus-0", pretrained=True)
12+
13+
def forward(self, x):
14+
transform = transforms.Compose(
15+
[
16+
transforms.Normalize(mean=(0.707223, 0.578729, 0.703617), std=(0.211883, 0.230117, 0.177517)),
17+
]
18+
)
19+
20+
return self.model(transform(x))

0 commit comments

Comments
 (0)