Skip to content

Commit 3e61261

Browse files
Support serving vehicle attribute recognition pipeline (#2452)
1 parent 0160aa3 commit 3e61261

File tree

2 files changed

+113
-1
lines changed

2 files changed

+113
-1
lines changed

paddlex/inference/pipelines/serving/_pipeline_apps/__init__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616

1717
from fastapi import FastAPI
1818

19-
from ...attribute_recognition import PedestrianAttributeRecPipeline
19+
from ...attribute_recognition import (
20+
PedestrianAttributeRecPipeline,
21+
VehicleAttributeRecPipeline,
22+
)
2023
from ...base import BasePipeline
2124
from ...formula_recognition import FormulaRecognitionPipeline
2225
from ...layout_parsing import LayoutParsingPipeline
@@ -52,6 +55,9 @@
5255
from .pedestrian_attribute_recognition import (
5356
create_pipeline_app as create_pedestrian_attribute_recognition_app,
5457
)
58+
from .vehicle_attribute_recognition import (
59+
create_pipeline_app as create_vehicle_attribute_recognition_app,
60+
)
5561
from .ppchatocrv3 import create_pipeline_app as create_ppchatocrv3_app
5662
from .seal_recognition import create_pipeline_app as create_seal_recognition_app
5763
from .semantic_segmentation import (
@@ -168,6 +174,12 @@ def create_pipeline_app(
168174
"Expected `pipeline` to be an instance of `PedestrianAttributeRecPipeline`."
169175
)
170176
return create_pedestrian_attribute_recognition_app(pipeline, app_config)
177+
elif pipeline_name == "vehicle_attribute_recognition":
178+
if not isinstance(pipeline, VehicleAttributeRecPipeline):
179+
raise TypeError(
180+
"Expected `pipeline` to be an instance of `VehicleAttributeRecPipeline`."
181+
)
182+
return create_vehicle_attribute_recognition_app(pipeline, app_config)
171183
else:
172184
if BasePipeline.get(pipeline_name):
173185
raise ValueError(
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import List
16+
17+
from fastapi import FastAPI, HTTPException
18+
from pydantic import BaseModel, Field
19+
from typing_extensions import Annotated, TypeAlias
20+
21+
from .....utils import logging
22+
from ...attribute_recognition import VehicleAttributeRecPipeline
23+
from .. import utils as serving_utils
24+
from ..app import AppConfig, create_app
25+
from ..models import Response, ResultResponse
26+
27+
28+
class InferRequest(BaseModel):
29+
image: str
30+
31+
32+
BoundingBox: TypeAlias = Annotated[List[float], Field(min_length=4, max_length=4)]
33+
34+
35+
class Attribute(BaseModel):
36+
label: str
37+
score: float
38+
39+
40+
class Vehicle(BaseModel):
41+
bbox: BoundingBox
42+
attributes: List[Attribute]
43+
score: float
44+
45+
46+
class InferResult(BaseModel):
47+
vehicles: List[Vehicle]
48+
image: str
49+
50+
51+
def create_pipeline_app(
52+
pipeline: VehicleAttributeRecPipeline, app_config: AppConfig
53+
) -> FastAPI:
54+
app, ctx = create_app(
55+
pipeline=pipeline, app_config=app_config, app_aiohttp_session=True
56+
)
57+
58+
@app.post(
59+
"/vehicle-attribute-recognition",
60+
operation_id="infer",
61+
responses={422: {"model": Response}},
62+
)
63+
async def _infer(request: InferRequest) -> ResultResponse[InferResult]:
64+
pipeline = ctx.pipeline
65+
aiohttp_session = ctx.aiohttp_session
66+
67+
try:
68+
file_bytes = await serving_utils.get_raw_bytes(
69+
request.image, aiohttp_session
70+
)
71+
image = serving_utils.image_bytes_to_array(file_bytes)
72+
73+
result = (await pipeline.infer(image))[0]
74+
75+
vehicles: List[Vehicle] = []
76+
for obj in result["boxes"]:
77+
vehicles.append(
78+
Vehicle(
79+
bbox=obj["coordinate"],
80+
attributes=[
81+
Attribute(label=l, score=s)
82+
for l, s in zip(obj["labels"], obj["cls_scores"])
83+
],
84+
score=obj["det_score"],
85+
)
86+
)
87+
output_image_base64 = serving_utils.image_to_base64(result.img)
88+
89+
return ResultResponse(
90+
logId=serving_utils.generate_log_id(),
91+
errorCode=0,
92+
errorMsg="Success",
93+
result=InferResult(vehicles=vehicles, image=output_image_base64),
94+
)
95+
96+
except Exception as e:
97+
logging.exception(e)
98+
raise HTTPException(status_code=500, detail="Internal server error")
99+
100+
return app

0 commit comments

Comments
 (0)