|
| 1 | +import io |
| 2 | +import copy |
| 3 | +import random |
| 4 | +import requests |
| 5 | +from pathlib import Path |
| 6 | + |
| 7 | +import gradio as gr |
| 8 | +import matplotlib.pyplot as plt |
| 9 | +import matplotlib.patches as patches |
| 10 | +import numpy as np |
| 11 | + |
| 12 | +from PIL import Image, ImageDraw |
| 13 | + |
| 14 | + |
| 15 | +DESCRIPTION = "# Florence-2 OpenVINO Demo" |
| 16 | + |
| 17 | +colormap = [ |
| 18 | + "blue", |
| 19 | + "orange", |
| 20 | + "green", |
| 21 | + "purple", |
| 22 | + "brown", |
| 23 | + "pink", |
| 24 | + "gray", |
| 25 | + "olive", |
| 26 | + "cyan", |
| 27 | + "red", |
| 28 | + "lime", |
| 29 | + "indigo", |
| 30 | + "violet", |
| 31 | + "aqua", |
| 32 | + "magenta", |
| 33 | + "coral", |
| 34 | + "gold", |
| 35 | + "tan", |
| 36 | + "skyblue", |
| 37 | +] |
| 38 | + |
| 39 | + |
| 40 | +example_images = [ |
| 41 | + ("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true", "car.jpg"), |
| 42 | + ("https://github.com/openvinotoolkit/openvino_notebooks/assets/29454499/d5fbbd1a-d484-415c-88cb-9986625b7b11", "cat.png"), |
| 43 | + ("https://github.com/user-attachments/assets/8c9ae017-7837-4abc-ae92-c1054c9ec350", "hand-written.png"), |
| 44 | +] |
| 45 | + |
| 46 | + |
| 47 | +def fig_to_pil(fig): |
| 48 | + buf = io.BytesIO() |
| 49 | + fig.savefig(buf, format="png") |
| 50 | + buf.seek(0) |
| 51 | + return Image.open(buf) |
| 52 | + |
| 53 | + |
| 54 | +def plot_bbox(image, data): |
| 55 | + fig, ax = plt.subplots() |
| 56 | + ax.imshow(image) |
| 57 | + for bbox, label in zip(data["bboxes"], data["labels"]): |
| 58 | + x1, y1, x2, y2 = bbox |
| 59 | + rect = patches.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=1, edgecolor="r", facecolor="none") |
| 60 | + ax.add_patch(rect) |
| 61 | + plt.text(x1, y1, label, color="white", fontsize=8, bbox=dict(facecolor="red", alpha=0.5)) |
| 62 | + ax.axis("off") |
| 63 | + return fig |
| 64 | + |
| 65 | + |
| 66 | +def draw_polygons(image, prediction, fill_mask=False): |
| 67 | + |
| 68 | + draw = ImageDraw.Draw(image) |
| 69 | + scale = 1 |
| 70 | + for polygons, label in zip(prediction["polygons"], prediction["labels"]): |
| 71 | + color = random.choice(colormap) |
| 72 | + fill_color = random.choice(colormap) if fill_mask else None |
| 73 | + for _polygon in polygons: |
| 74 | + _polygon = np.array(_polygon).reshape(-1, 2) |
| 75 | + if len(_polygon) < 3: |
| 76 | + print("Invalid polygon:", _polygon) |
| 77 | + continue |
| 78 | + _polygon = (_polygon * scale).reshape(-1).tolist() |
| 79 | + if fill_mask: |
| 80 | + draw.polygon(_polygon, outline=color, fill=fill_color) |
| 81 | + else: |
| 82 | + draw.polygon(_polygon, outline=color) |
| 83 | + draw.text((_polygon[0] + 8, _polygon[1] + 2), label, fill=color) |
| 84 | + return image |
| 85 | + |
| 86 | + |
| 87 | +def convert_to_od_format(data): |
| 88 | + bboxes = data.get("bboxes", []) |
| 89 | + labels = data.get("bboxes_labels", []) |
| 90 | + od_results = {"bboxes": bboxes, "labels": labels} |
| 91 | + return od_results |
| 92 | + |
| 93 | + |
| 94 | +def draw_ocr_bboxes(image, prediction): |
| 95 | + scale = 1 |
| 96 | + draw = ImageDraw.Draw(image) |
| 97 | + bboxes, labels = prediction["quad_boxes"], prediction["labels"] |
| 98 | + for box, label in zip(bboxes, labels): |
| 99 | + color = random.choice(colormap) |
| 100 | + new_box = (np.array(box) * scale).tolist() |
| 101 | + draw.polygon(new_box, width=3, outline=color) |
| 102 | + draw.text((new_box[0] + 8, new_box[1] + 2), "{}".format(label), align="right", fill=color) |
| 103 | + return image |
| 104 | + |
| 105 | + |
| 106 | +css = """ |
| 107 | + #output { |
| 108 | + height: 500px; |
| 109 | + overflow: auto; |
| 110 | + border: 1px solid #ccc; |
| 111 | + } |
| 112 | +""" |
| 113 | + |
| 114 | + |
| 115 | +single_task_list = [ |
| 116 | + "Caption", |
| 117 | + "Detailed Caption", |
| 118 | + "More Detailed Caption", |
| 119 | + "Object Detection", |
| 120 | + "Dense Region Caption", |
| 121 | + "Region Proposal", |
| 122 | + "Caption to Phrase Grounding", |
| 123 | + "Referring Expression Segmentation", |
| 124 | + "Region to Segmentation", |
| 125 | + "Open Vocabulary Detection", |
| 126 | + "Region to Category", |
| 127 | + "Region to Description", |
| 128 | + "OCR", |
| 129 | + "OCR with Region", |
| 130 | +] |
| 131 | + |
| 132 | +cascased_task_list = ["Caption + Grounding", "Detailed Caption + Grounding", "More Detailed Caption + Grounding"] |
| 133 | + |
| 134 | + |
| 135 | +def update_task_dropdown(choice): |
| 136 | + if choice == "Cascased task": |
| 137 | + return gr.Dropdown(choices=cascased_task_list, value="Caption + Grounding") |
| 138 | + else: |
| 139 | + return gr.Dropdown(choices=single_task_list, value="Caption") |
| 140 | + |
| 141 | + |
| 142 | +def make_demo(model, processor): |
| 143 | + for url, filename in example_images: |
| 144 | + if not Path(filename).exists(): |
| 145 | + image = Image.open(requests.get(url, stream=True).raw) |
| 146 | + image.save(filename) |
| 147 | + |
| 148 | + def run_example(task_prompt, image, text_input=None): |
| 149 | + if text_input is None: |
| 150 | + prompt = task_prompt |
| 151 | + else: |
| 152 | + prompt = task_prompt + text_input |
| 153 | + inputs = processor(text=prompt, images=image, return_tensors="pt") |
| 154 | + generated_ids = model.generate( |
| 155 | + input_ids=inputs["input_ids"], |
| 156 | + pixel_values=inputs["pixel_values"], |
| 157 | + max_new_tokens=1024, |
| 158 | + early_stopping=False, |
| 159 | + do_sample=False, |
| 160 | + num_beams=3, |
| 161 | + ) |
| 162 | + generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] |
| 163 | + parsed_answer = processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height)) |
| 164 | + return parsed_answer |
| 165 | + |
| 166 | + def process_image(image, task_prompt, text_input=None): |
| 167 | + image = Image.fromarray(image) # Convert NumPy array to PIL Image |
| 168 | + if task_prompt == "Caption": |
| 169 | + task_prompt = "<CAPTION>" |
| 170 | + results = run_example(task_prompt, image) |
| 171 | + return results, None |
| 172 | + elif task_prompt == "Detailed Caption": |
| 173 | + task_prompt = "<DETAILED_CAPTION>" |
| 174 | + results = run_example(task_prompt, image) |
| 175 | + return results, None |
| 176 | + elif task_prompt == "More Detailed Caption": |
| 177 | + task_prompt = "<MORE_DETAILED_CAPTION>" |
| 178 | + results = run_example(task_prompt, image) |
| 179 | + return results, None |
| 180 | + elif task_prompt == "Caption + Grounding": |
| 181 | + task_prompt = "<CAPTION>" |
| 182 | + results = run_example(task_prompt, image) |
| 183 | + text_input = results[task_prompt] |
| 184 | + task_prompt = "<CAPTION_TO_PHRASE_GROUNDING>" |
| 185 | + results = run_example(task_prompt, image, text_input) |
| 186 | + results["<CAPTION>"] = text_input |
| 187 | + fig = plot_bbox(image, results["<CAPTION_TO_PHRASE_GROUNDING>"]) |
| 188 | + return results, fig_to_pil(fig) |
| 189 | + elif task_prompt == "Detailed Caption + Grounding": |
| 190 | + task_prompt = "<DETAILED_CAPTION>" |
| 191 | + results = run_example(task_prompt, image) |
| 192 | + text_input = results[task_prompt] |
| 193 | + task_prompt = "<CAPTION_TO_PHRASE_GROUNDING>" |
| 194 | + results = run_example(task_prompt, image, text_input) |
| 195 | + results["<DETAILED_CAPTION>"] = text_input |
| 196 | + fig = plot_bbox(image, results["<CAPTION_TO_PHRASE_GROUNDING>"]) |
| 197 | + return results, fig_to_pil(fig) |
| 198 | + elif task_prompt == "More Detailed Caption + Grounding": |
| 199 | + task_prompt = "<MORE_DETAILED_CAPTION>" |
| 200 | + results = run_example(task_prompt, image) |
| 201 | + text_input = results[task_prompt] |
| 202 | + task_prompt = "<CAPTION_TO_PHRASE_GROUNDING>" |
| 203 | + results = run_example(task_prompt, image, text_input) |
| 204 | + results["<MORE_DETAILED_CAPTION>"] = text_input |
| 205 | + fig = plot_bbox(image, results["<CAPTION_TO_PHRASE_GROUNDING>"]) |
| 206 | + return results, fig_to_pil(fig) |
| 207 | + elif task_prompt == "Object Detection": |
| 208 | + task_prompt = "<OD>" |
| 209 | + results = run_example(task_prompt, image) |
| 210 | + fig = plot_bbox(image, results["<OD>"]) |
| 211 | + return results, fig_to_pil(fig) |
| 212 | + elif task_prompt == "Dense Region Caption": |
| 213 | + task_prompt = "<DENSE_REGION_CAPTION>" |
| 214 | + results = run_example(task_prompt, image) |
| 215 | + fig = plot_bbox(image, results["<DENSE_REGION_CAPTION>"]) |
| 216 | + return results, fig_to_pil(fig) |
| 217 | + elif task_prompt == "Region Proposal": |
| 218 | + task_prompt = "<REGION_PROPOSAL>" |
| 219 | + results = run_example(task_prompt, image) |
| 220 | + fig = plot_bbox(image, results["<REGION_PROPOSAL>"]) |
| 221 | + return results, fig_to_pil(fig) |
| 222 | + elif task_prompt == "Caption to Phrase Grounding": |
| 223 | + task_prompt = "<CAPTION_TO_PHRASE_GROUNDING>" |
| 224 | + results = run_example(task_prompt, image, text_input) |
| 225 | + fig = plot_bbox(image, results["<CAPTION_TO_PHRASE_GROUNDING>"]) |
| 226 | + return results, fig_to_pil(fig) |
| 227 | + elif task_prompt == "Referring Expression Segmentation": |
| 228 | + task_prompt = "<REFERRING_EXPRESSION_SEGMENTATION>" |
| 229 | + results = run_example(task_prompt, image, text_input) |
| 230 | + output_image = copy.deepcopy(image) |
| 231 | + output_image = draw_polygons(output_image, results["<REFERRING_EXPRESSION_SEGMENTATION>"], fill_mask=True) |
| 232 | + return results, output_image |
| 233 | + elif task_prompt == "Region to Segmentation": |
| 234 | + task_prompt = "<REGION_TO_SEGMENTATION>" |
| 235 | + results = run_example(task_prompt, image, text_input) |
| 236 | + output_image = copy.deepcopy(image) |
| 237 | + output_image = draw_polygons(output_image, results["<REGION_TO_SEGMENTATION>"], fill_mask=True) |
| 238 | + return results, output_image |
| 239 | + elif task_prompt == "Open Vocabulary Detection": |
| 240 | + task_prompt = "<OPEN_VOCABULARY_DETECTION>" |
| 241 | + results = run_example(task_prompt, image, text_input) |
| 242 | + bbox_results = convert_to_od_format(results["<OPEN_VOCABULARY_DETECTION>"]) |
| 243 | + fig = plot_bbox(image, bbox_results) |
| 244 | + return results, fig_to_pil(fig) |
| 245 | + elif task_prompt == "Region to Category": |
| 246 | + task_prompt = "<REGION_TO_CATEGORY>" |
| 247 | + results = run_example(task_prompt, image, text_input) |
| 248 | + return results, None |
| 249 | + elif task_prompt == "Region to Description": |
| 250 | + task_prompt = "<REGION_TO_DESCRIPTION>" |
| 251 | + results = run_example(task_prompt, image, text_input) |
| 252 | + return results, None |
| 253 | + elif task_prompt == "OCR": |
| 254 | + task_prompt = "<OCR>" |
| 255 | + results = run_example(task_prompt, image) |
| 256 | + return results, None |
| 257 | + elif task_prompt == "OCR with Region": |
| 258 | + task_prompt = "<OCR_WITH_REGION>" |
| 259 | + results = run_example(task_prompt, image) |
| 260 | + output_image = copy.deepcopy(image) |
| 261 | + output_image = draw_ocr_bboxes(output_image, results["<OCR_WITH_REGION>"]) |
| 262 | + return results, output_image |
| 263 | + else: |
| 264 | + return "", None |
| 265 | + |
| 266 | + with gr.Blocks(css=css) as demo: |
| 267 | + gr.Markdown(DESCRIPTION) |
| 268 | + with gr.Tab(label="Florence-2 Image Captioning"): |
| 269 | + with gr.Row(): |
| 270 | + with gr.Column(): |
| 271 | + input_img = gr.Image(label="Input Picture") |
| 272 | + task_type = gr.Radio(choices=["Single task", "Cascased task"], label="Task type selector", value="Single task") |
| 273 | + task_prompt = gr.Dropdown(choices=single_task_list, label="Task Prompt", value="Caption") |
| 274 | + task_type.change(fn=update_task_dropdown, inputs=task_type, outputs=task_prompt) |
| 275 | + text_input = gr.Textbox(label="Text Input (optional)") |
| 276 | + submit_btn = gr.Button(value="Submit") |
| 277 | + with gr.Column(): |
| 278 | + output_text = gr.Textbox(label="Output Text") |
| 279 | + output_img = gr.Image(label="Output Image") |
| 280 | + |
| 281 | + gr.Examples( |
| 282 | + examples=[["car.jpg", "Region to Segmentation"], ["hand-written.png", "OCR with Region"], ["cat.png", "Detailed Caption"]], |
| 283 | + inputs=[input_img, task_prompt], |
| 284 | + label="Try examples", |
| 285 | + ) |
| 286 | + |
| 287 | + submit_btn.click(process_image, [input_img, task_prompt, text_input], [output_text, output_img]) |
| 288 | + |
| 289 | + return demo |
0 commit comments