Skip to content
This repository was archived by the owner on Sep 26, 2025. It is now read-only.

Commit a6be717

Browse files
committed
add box segmenter solution
1 parent c8ff6d9 commit a6be717

12 files changed

+174
-1
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ ______________________________________________________________________
2121

2222
## Latest News 🔥
2323

24+
- Added the Box Segmenter all-in-one solution ([model](https://huggingface.co/finegrain/finegrain-box-segmenter), [HF space](https://huggingface.co/spaces/finegrain/finegrain-object-cutter))
25+
- Added [MVANet](https://arxiv.org/abs/2404.07445) for high resolution segmentation
2426
- Added [IC-Light](https://github.com/lllyasviel/IC-Light) to manipulate the illumination of images
2527
- Added Multi Upscaler for high-resolution image generation, inspired from [Clarity Upscaler](https://github.com/philz1337x/clarity-upscaler) ([HF Space](https://huggingface.co/spaces/finegrain/enhancer))
2628
- Added [HQ-SAM](https://arxiv.org/abs/2306.01567) for high quality mask prediction with Segment Anything

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ doc = [
6767
"mkdocstrings[python]>=0.24.0",
6868
"mkdocs-literate-nav>=0.6.1",
6969
]
70+
solutions = [
71+
"huggingface-hub>=0.24.6",
72+
]
7073

7174
[build-system]
7275
requires = ["hatchling"]

requirements.lock

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,10 @@ gitpython==3.1.43
100100
# via wandb
101101
griffe==0.48.0
102102
# via mkdocstrings-python
103-
huggingface-hub==0.24.5
103+
huggingface-hub==0.24.6
104104
# via datasets
105105
# via diffusers
106+
# via refiners
106107
# via timm
107108
# via tokenizers
108109
# via transformers

scripts/prepare_test_weights.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,15 @@ def download_mvanet():
466466
check_hash(dest_filename, "b915d492")
467467

468468

469+
def download_box_segmenter():
470+
download_file(
471+
"https://huggingface.co/finegrain/finegrain-box-segmenter/resolve/v0.1/model.safetensors",
472+
dest_folder=test_weights_dir,
473+
filename="finegrain-box-segmenter-v0-1.safetensors",
474+
expected_hash="e0450e8c",
475+
)
476+
477+
469478
def printg(msg: str):
470479
"""print in green color"""
471480
print("\033[92m" + msg + "\033[0m")
@@ -861,6 +870,7 @@ def download_all():
861870
download_sdxl_lightning_lora()
862871
download_ic_light()
863872
download_mvanet()
873+
download_box_segmenter()
864874

865875

866876
def convert_all():

src/refiners/solutions/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .box_segmenter import BoxSegmenter
2+
3+
__all__ = ["BoxSegmenter"]
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from pathlib import Path
2+
3+
import torch
4+
from PIL import Image
5+
6+
from refiners.fluxion.utils import image_to_tensor, no_grad, normalize, tensor_to_image
7+
from refiners.foundationals.swin.mvanet import MVANet
8+
9+
BoundingBox = tuple[int, int, int, int]
10+
11+
12+
class BoxSegmenter:
13+
def __init__(
14+
self,
15+
*,
16+
margin: float = 0.05,
17+
weights: Path | str | dict[str, torch.Tensor] | None = None,
18+
device: torch.device | str = "cpu",
19+
):
20+
assert margin >= 0
21+
self.margin = margin
22+
23+
self.device = torch.device(device)
24+
self.model = MVANet(device=self.device).eval()
25+
26+
if weights is None:
27+
from huggingface_hub.file_download import hf_hub_download # type: ignore[reportUnknownVariableType]
28+
29+
weights = hf_hub_download(
30+
repo_id="finegrain/finegrain-box-segmenter",
31+
filename="model.safetensors",
32+
revision="v0.1",
33+
)
34+
35+
if isinstance(weights, dict):
36+
self.model.load_state_dict(weights)
37+
else:
38+
self.model.load_from_safetensors(weights)
39+
40+
def __call__(self, img: Image.Image, box_prompt: BoundingBox | None = None) -> Image.Image:
41+
return self.run(img, box_prompt)
42+
43+
def add_margin(self, box: BoundingBox) -> BoundingBox:
44+
x0, y0, x1, y1 = box
45+
mx = int((x1 - x0) * self.margin)
46+
my = int((y1 - y0) * self.margin)
47+
return (x0 - mx, y0 - my, x1 + mx, y1 + my)
48+
49+
@staticmethod
50+
def crop_pad(img: Image.Image, box: BoundingBox) -> Image.Image:
51+
img = img.convert("RGB")
52+
53+
x0, y0, x1, y1 = box
54+
px0, py0, px1, py1 = (max(0, -x0), max(0, -y0), max(0, x1 - img.width), max(0, y1 - img.height))
55+
if (px0, py0, px1, py1) == (0, 0, 0, 0):
56+
return img.crop(box)
57+
58+
padded = Image.new("RGB", (img.width + px0 + px1, img.height + py0 + py1))
59+
padded.paste(img, (px0, py0))
60+
return padded.crop((x0 + px0, y0 + py0, x1 + px0, y1 + py0))
61+
62+
def predict(self, img: Image.Image) -> Image.Image:
63+
in_t = image_to_tensor(img.resize((1024, 1024), Image.Resampling.BILINEAR)).squeeze()
64+
in_t = normalize(in_t, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]).unsqueeze(0)
65+
with no_grad():
66+
prediction: torch.Tensor = self.model(in_t.to(self.device)).sigmoid()
67+
return tensor_to_image(prediction).resize(img.size, Image.Resampling.BILINEAR)
68+
69+
def run(self, img: Image.Image, box_prompt: BoundingBox | None = None) -> Image.Image:
70+
if box_prompt is None:
71+
box_prompt = (0, 0, img.width, img.height)
72+
73+
box = self.add_margin(box_prompt)
74+
cropped = self.crop_pad(img, box)
75+
prediction = self.predict(cropped)
76+
77+
out = Image.new("L", (img.width, img.height))
78+
out.paste(prediction, box)
79+
return out

tests/e2e/test_solutions.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from pathlib import Path
2+
from warnings import warn
3+
4+
import pytest
5+
import torch
6+
from PIL import Image
7+
from tests.utils import ensure_similar_images
8+
9+
from refiners.solutions import BoxSegmenter
10+
11+
12+
def _img_open(path: Path) -> Image.Image:
13+
return Image.open(path) # type: ignore
14+
15+
16+
@pytest.fixture(scope="module")
17+
def ref_path(test_e2e_path: Path) -> Path:
18+
return test_e2e_path / "test_solutions_ref"
19+
20+
21+
@pytest.fixture(scope="module")
22+
def ref_shelves(ref_path: Path) -> Image.Image:
23+
return _img_open(ref_path / "shelves.jpg").convert("RGB")
24+
25+
26+
@pytest.fixture
27+
def expected_box_segmenter_plant_mask(ref_path: Path) -> Image.Image:
28+
return _img_open(ref_path / "expected_box_segmenter_plant_mask.png")
29+
30+
31+
@pytest.fixture
32+
def expected_box_segmenter_spray_mask(ref_path: Path) -> Image.Image:
33+
return _img_open(ref_path / "expected_box_segmenter_spray_mask.png")
34+
35+
36+
@pytest.fixture
37+
def expected_box_segmenter_spray_cropped_mask(ref_path: Path) -> Image.Image:
38+
return _img_open(ref_path / "expected_box_segmenter_spray_cropped_mask.png")
39+
40+
41+
@pytest.fixture(scope="module")
42+
def box_segmenter_weights(test_weights_path: Path) -> Path:
43+
weights = test_weights_path / "finegrain-box-segmenter-v0-1.safetensors"
44+
if not weights.is_file():
45+
warn(f"could not find weights at {test_weights_path}, skipping")
46+
pytest.skip(allow_module_level=True)
47+
return weights
48+
49+
50+
def test_box_segmenter(
51+
box_segmenter_weights: Path,
52+
ref_shelves: Image.Image,
53+
expected_box_segmenter_plant_mask: Image.Image,
54+
expected_box_segmenter_spray_mask: Image.Image,
55+
expected_box_segmenter_spray_cropped_mask: Image.Image,
56+
test_device: torch.device,
57+
):
58+
segmenter = BoxSegmenter(weights=box_segmenter_weights, device=test_device)
59+
60+
plant_mask = segmenter(ref_shelves, box_prompt=(504, 82, 754, 368))
61+
ensure_similar_images(plant_mask.convert("RGB"), expected_box_segmenter_plant_mask.convert("RGB"))
62+
63+
spray_box = (461, 542, 594, 823)
64+
spray_mask = segmenter(ref_shelves, box_prompt=spray_box)
65+
ensure_similar_images(spray_mask.convert("RGB"), expected_box_segmenter_spray_mask.convert("RGB"))
66+
67+
# Test left and bottom padding.
68+
off_l, off_b = 11, 7
69+
shelves_cropped = ref_shelves.crop((spray_box[0] - off_l, 0, ref_shelves.width, spray_box[3] + off_b))
70+
spray_cropped_box = (off_l, spray_box[1], spray_box[2] - spray_box[0] + off_l, spray_box[3])
71+
spray_cropped_mask = segmenter(shelves_cropped, box_prompt=spray_cropped_box)
72+
ensure_similar_images(spray_cropped_mask.convert("RGB"), expected_box_segmenter_spray_cropped_mask.convert("RGB"))
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
`shelves.jpg` is found here: https://www.freepik.com/free-photo/front-view-shelves-with-plants_6446859.htm
2+
3+
`expected_box_segmenter_plant_mask.png`, `expected_box_segmenter_spray_mask.png` and `expected_box_segmenter_spray_cropped_mask.png` are generated with Refiners.
19.3 KB
Loading
6.38 KB
Loading

0 commit comments

Comments
 (0)