Skip to content

Commit 42ed9a5

Browse files
authored
[Feat] Support serving pedestrian attribute recognition pipeline (#2437)
* Extract common functions * Add pedestrian attribute recognition serving app
1 parent ad00fd2 commit 42ed9a5

File tree

9 files changed

+327
-898
lines changed

9 files changed

+327
-898
lines changed

docs/pipeline_usage/tutorials/cv_pipelines/pedestrian_attribute_recognition.en.md

Lines changed: 94 additions & 334 deletions
Large diffs are not rendered by default.

docs/pipeline_usage/tutorials/cv_pipelines/pedestrian_attribute_recognition.md

Lines changed: 32 additions & 413 deletions
Large diffs are not rendered by default.

paddlex/inference/pipelines/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,10 @@
3939
from .ppchatocrv3 import PPChatOCRPipeline
4040
from .layout_parsing import LayoutParsingPipeline
4141
from .pp_shitu_v2 import ShiTuV2Pipeline
42-
from .attribute_recognition import AttributeRecPipeline
42+
from .attribute_recognition import (
43+
PedestrianAttributeRecPipeline,
44+
VehicleAttributeRecPipeline,
45+
)
4346

4447

4548
def load_pipeline_config(pipeline: str) -> Dict[str, Any]:

paddlex/inference/pipelines/attribute_recognition.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
class AttributeRecPipeline(BasePipeline):
2525
"""Attribute Rec Pipeline"""
2626

27-
entities = ["pedestrian_attribute_recognition", "vehicle_attribute_recognition"]
28-
2927
def __init__(
3028
self,
3129
det_model,
@@ -84,3 +82,11 @@ def get_final_result(self, det_res, cls_res):
8482
}
8583
)
8684
return AttributeRecResult(single_img_res)
85+
86+
87+
class PedestrianAttributeRecPipeline(AttributeRecPipeline):
88+
entities = "pedestrian_attribute_recognition"
89+
90+
91+
class VehicleAttributeRecPipeline(AttributeRecPipeline):
92+
entities = "vehicle_attribute_recognition"

paddlex/inference/pipelines/serving/_pipeline_apps/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from fastapi import FastAPI
1818

19+
from ...attribute_recognition import PedestrianAttributeRecPipeline
1920
from ...base import BasePipeline
2021
from ...formula_recognition import FormulaRecognitionPipeline
2122
from ...layout_parsing import LayoutParsingPipeline
@@ -48,6 +49,9 @@
4849
)
4950
from .object_detection import create_pipeline_app as create_object_detection_app
5051
from .ocr import create_pipeline_app as create_ocr_app
52+
from .pedestrian_attribute_recognition import (
53+
create_pipeline_app as create_pedestrian_attribute_recognition_app,
54+
)
5155
from .ppchatocrv3 import create_pipeline_app as create_ppchatocrv3_app
5256
from .seal_recognition import create_pipeline_app as create_seal_recognition_app
5357
from .semantic_segmentation import (
@@ -158,6 +162,12 @@ def create_pipeline_app(
158162
"Expected `pipeline` to be an instance of `LayoutParsingPipeline`."
159163
)
160164
return create_layout_parsing_app(pipeline, app_config)
165+
elif pipeline_name == "pedestrian_attribute_recognition":
166+
if not isinstance(pipeline, PedestrianAttributeRecPipeline):
167+
raise TypeError(
168+
"Expected `pipeline` to be an instance of `PedestrianAttributeRecPipeline`."
169+
)
170+
return create_pedestrian_attribute_recognition_app(pipeline, app_config)
161171
else:
162172
if BasePipeline.get(pipeline_name):
163173
raise ValueError(

paddlex/inference/pipelines/serving/_pipeline_apps/layout_parsing.py

Lines changed: 5 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,14 @@
1313
# limitations under the License.
1414

1515
import os
16-
import re
17-
import uuid
1816
from typing import Final, List, Literal, Optional, Tuple
19-
from urllib.parse import parse_qs, urlparse
2017

2118
import cv2
2219
import numpy as np
2320
from fastapi import FastAPI, HTTPException
2421
from numpy.typing import ArrayLike
2522
from pydantic import BaseModel, Field
26-
from typing_extensions import Annotated, TypeAlias, assert_never
23+
from typing_extensions import Annotated, TypeAlias
2724

2825
from .....utils import logging
2926
from ...layout_parsing import LayoutParsingPipeline
@@ -71,71 +68,6 @@ class InferResult(BaseModel):
7168
layoutParsingResults: List[LayoutParsingResult]
7269

7370

74-
def _generate_request_id() -> str:
75-
return str(uuid.uuid4())
76-
77-
78-
def _infer_file_type(url: str) -> FileType:
79-
# Is it more reliable to guess the file type based on the response headers?
80-
SUPPORTED_IMG_EXTS: Final[List[str]] = [".jpg", ".jpeg", ".png"]
81-
82-
url_parts = urlparse(url)
83-
ext = os.path.splitext(url_parts.path)[1]
84-
# HACK: The support for BOS URLs with query params is implementation-based,
85-
# not interface-based.
86-
is_bos_url = (
87-
re.fullmatch(r"(?:bj|bd|su|gz|cd|hkg|fwh|fsh)\.bcebos\.com", url_parts.netloc)
88-
is not None
89-
)
90-
if is_bos_url and url_parts.query:
91-
params = parse_qs(url_parts.query)
92-
if (
93-
"responseContentDisposition" not in params
94-
or len(params["responseContentDisposition"]) != 1
95-
):
96-
raise ValueError("`responseContentDisposition` not found")
97-
match_ = re.match(
98-
r"attachment;filename=(.*)", params["responseContentDisposition"][0]
99-
)
100-
if not match_ or not match_.groups()[0] is not None:
101-
raise ValueError(
102-
"Failed to extract the filename from `responseContentDisposition`"
103-
)
104-
ext = os.path.splitext(match_.groups()[0])[1]
105-
ext = ext.lower()
106-
if ext == ".pdf":
107-
return 0
108-
elif ext in SUPPORTED_IMG_EXTS:
109-
return 1
110-
else:
111-
raise ValueError("Unsupported file type")
112-
113-
114-
def _bytes_to_arrays(
115-
file_bytes: bytes,
116-
file_type: FileType,
117-
*,
118-
max_img_size: Tuple[int, int],
119-
max_num_imgs: int,
120-
) -> List[np.ndarray]:
121-
if file_type == 0:
122-
images = serving_utils.read_pdf(
123-
file_bytes, resize=True, max_num_imgs=max_num_imgs
124-
)
125-
elif file_type == 1:
126-
images = [serving_utils.image_bytes_to_array(file_bytes)]
127-
else:
128-
assert_never(file_type)
129-
h, w = images[0].shape[0:2]
130-
if w > max_img_size[1] or h > max_img_size[0]:
131-
if w / h > max_img_size[0] / max_img_size[1]:
132-
factor = max_img_size[0] / w
133-
else:
134-
factor = max_img_size[1] / h
135-
images = [cv2.resize(img, (int(factor * w), int(factor * h))) for img in images]
136-
return images
137-
138-
13971
def _postprocess_image(
14072
img: ArrayLike,
14173
request_id: str,
@@ -180,12 +112,12 @@ async def _infer(
180112
pipeline = ctx.pipeline
181113
aiohttp_session = ctx.aiohttp_session
182114

183-
request_id = _generate_request_id()
115+
request_id = serving_utils.generate_request_id()
184116

185117
if request.fileType is None:
186118
if serving_utils.is_url(request.file):
187119
try:
188-
file_type = _infer_file_type(request.file)
120+
file_type = serving_utils.infer_file_type(request.file)
189121
except Exception as e:
190122
logging.exception(e)
191123
raise HTTPException(
@@ -195,7 +127,7 @@ async def _infer(
195127
else:
196128
raise HTTPException(status_code=422, detail="Unknown file type")
197129
else:
198-
file_type = request.fileType
130+
file_type = "PDF" if request.fileType == 0 else "IMAGE"
199131

200132
if request.inferenceParams:
201133
max_long_side = request.inferenceParams.maxLongSide
@@ -210,7 +142,7 @@ async def _infer(
210142
request.file, aiohttp_session
211143
)
212144
images = await serving_utils.call_async(
213-
_bytes_to_arrays,
145+
serving_utils.file_to_images,
214146
file_bytes,
215147
file_type,
216148
max_img_size=ctx.extra["max_img_size"],
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import List
16+
17+
from fastapi import FastAPI, HTTPException
18+
from pydantic import BaseModel, Field
19+
from typing_extensions import Annotated, TypeAlias
20+
21+
from .....utils import logging
22+
from ...attribute_recognition import PedestrianAttributeRecPipeline
23+
from .. import utils as serving_utils
24+
from ..app import AppConfig, create_app
25+
from ..models import Response, ResultResponse
26+
27+
28+
class InferRequest(BaseModel):
29+
image: str
30+
31+
32+
BoundingBox: TypeAlias = Annotated[List[float], Field(min_length=4, max_length=4)]
33+
34+
35+
class Attribute(BaseModel):
36+
label: str
37+
score: float
38+
39+
40+
class Pedestrian(BaseModel):
41+
bbox: BoundingBox
42+
attributes: List[Attribute]
43+
score: float
44+
45+
46+
class InferResult(BaseModel):
47+
pedestrians: List[Pedestrian]
48+
image: str
49+
50+
51+
def create_pipeline_app(
52+
pipeline: PedestrianAttributeRecPipeline, app_config: AppConfig
53+
) -> FastAPI:
54+
app, ctx = create_app(
55+
pipeline=pipeline, app_config=app_config, app_aiohttp_session=True
56+
)
57+
58+
@app.post(
59+
"/pedestrian-attribute-recognition",
60+
operation_id="infer",
61+
responses={422: {"model": Response}},
62+
)
63+
async def _infer(request: InferRequest) -> ResultResponse[InferResult]:
64+
pipeline = ctx.pipeline
65+
aiohttp_session = ctx.aiohttp_session
66+
67+
try:
68+
file_bytes = await serving_utils.get_raw_bytes(
69+
request.image, aiohttp_session
70+
)
71+
image = serving_utils.image_bytes_to_array(file_bytes)
72+
73+
result = (await pipeline.infer(image))[0]
74+
75+
pedestrians: List[Pedestrian] = []
76+
for obj in result["boxes"]:
77+
pedestrians.append(
78+
Pedestrian(
79+
bbox=obj["coordinate"],
80+
attributes=[
81+
Attribute(label=l, score=s)
82+
for l, s in zip(obj["labels"], obj["cls_scores"])
83+
],
84+
score=obj["det_score"],
85+
)
86+
)
87+
output_image_base64 = serving_utils.image_to_base64(result.img)
88+
89+
return ResultResponse(
90+
logId=serving_utils.generate_log_id(),
91+
errorCode=0,
92+
errorMsg="Success",
93+
result=InferResult(pedestrians=pedestrians, image=output_image_base64),
94+
)
95+
96+
except Exception as e:
97+
logging.exception(e)
98+
raise HTTPException(status_code=500, detail="Internal server error")
99+
100+
return app

0 commit comments

Comments
 (0)