Skip to content

Commit 6ac2c78

Browse files
authored
Merge pull request #54 from mobiusml/pydantic_v2_migration
Pydantic v2 Migration
2 parents d9a62fa + 55ce3bc commit 6ac2c78

File tree

204 files changed

+6211
-1519
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

204 files changed

+6211
-1519
lines changed

.devcontainer/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04
1+
FROM nvidia/cuda:12.1.0-cudnn8-devel-ubuntu22.04
22
RUN apt-get update && apt-get install -y libgl1 libglib2.0-0 ffmpeg

aana/api/api_generation.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from fastapi.responses import StreamingResponse
88
from mobius_pipeline.node.socket import Socket
99
from mobius_pipeline.pipeline.pipeline import Pipeline
10-
from pydantic import BaseModel, Field, ValidationError, create_model, parse_raw_as
10+
from pydantic import BaseModel, Field, ValidationError, create_model
1111

1212
from aana.api.app import custom_exception_handler
1313
from aana.api.responses import AanaJSONResponse
@@ -237,9 +237,9 @@ def get_file_upload_field(
237237
continue
238238

239239
# check if pydantic model has file_upload field and it's set to True
240-
file_upload_enabled = getattr(data_model.Config, "file_upload", False)
241-
file_upload_description = getattr(
242-
data_model.Config, "file_upload_description", ""
240+
file_upload_enabled = data_model.model_config.get("file_upload", False)
241+
file_upload_description = data_model.model_config.get(
242+
"file_upload_description", ""
243243
)
244244

245245
if file_upload_enabled and file_upload_field is None:
@@ -330,7 +330,7 @@ def create_endpoint_func( # noqa: C901
330330

331331
async def route_func_body(body: str, files: list[UploadFile] | None = None): # noqa: C901
332332
# parse form data as a pydantic model and validate it
333-
data = parse_raw_as(RequestModel, body)
333+
data = RequestModel.model_validate_json(body)
334334

335335
# if the input requires file upload, add the files to the data
336336
if file_upload_field and files:
@@ -341,7 +341,7 @@ async def route_func_body(body: str, files: list[UploadFile] | None = None): #
341341
# data.dict() will convert all nested models to dicts
342342
# and we want to keep them as pydantic models
343343
data_dict = {}
344-
for field_name in data.__fields__:
344+
for field_name in data.model_fields:
345345
field_value = getattr(data, field_name)
346346
data_dict[field_name] = field_value
347347

aana/api/app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ async def validation_exception_handler(request: Request, exc: ValidationError):
2828
error="ValidationError",
2929
message="Validation error",
3030
data=exc.errors(),
31-
).dict(),
31+
).model_dump(),
3232
)
3333

3434

@@ -77,7 +77,7 @@ def custom_exception_handler(request: Request | None, exc_raw: Exception):
7777
status_code=status_code,
7878
content=ExceptionResponseModel(
7979
error=error, message=message, data=data, stacktrace=stacktrace
80-
).dict(),
80+
).model_dump(),
8181
)
8282

8383

aana/api/request_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# TODO: improve type annotations
1212

1313

14-
@serve.deployment(route_prefix="/", num_replicas=1, ray_actor_options={"num_cpus": 0.1})
14+
@serve.deployment(ray_actor_options={"num_cpus": 0.1})
1515
@serve.ingress(app)
1616
class RequestHandler:
1717
"""This class is used to handle requests to the Aana application."""

aana/configs/db.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from enum import Enum
22
from os import PathLike
33
from pathlib import Path
4-
from typing import TypeAlias, TypedDict
4+
from typing import TypeAlias
55

66
from alembic import command
77
from alembic.config import Config
88
from sqlalchemy import String, TypeDecorator, create_engine
9+
from typing_extensions import TypedDict
910

1011
from aana.models.pydantic.media_id import MediaId
1112

aana/configs/deployments.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@
2222
model="TheBloke/Llama-2-7b-Chat-AWQ",
2323
dtype="auto",
2424
quantization="awq",
25-
gpu_memory_reserved=10000,
25+
gpu_memory_reserved=13000,
26+
enforce_eager=True,
2627
default_sampling_params=SamplingParams(
2728
temperature=0.0, top_p=1.0, top_k=-1, max_tokens=1024
2829
),
2930
chat_template="llama2",
30-
).dict(),
31+
).model_dump(),
3132
),
3233
"hf_blip2_deployment_opt_2_7b": HFBlip2Deployment.options(
3334
num_replicas=1,
@@ -38,7 +39,7 @@
3839
dtype=Dtype.FLOAT16,
3940
batch_size=2,
4041
num_processing_threads=2,
41-
).dict(),
42+
).model_dump(),
4243
),
4344
"whisper_deployment_medium": WhisperDeployment.options(
4445
num_replicas=1,
@@ -47,7 +48,7 @@
4748
user_config=WhisperConfig(
4849
model_size=WhisperModelSize.MEDIUM,
4950
compute_type=WhisperComputeType.FLOAT16,
50-
).dict(),
51+
).model_dump(),
5152
),
5253
"stablediffusion2_deployment": StableDiffusion2Deployment.options(
5354
num_replicas=1,

aana/configs/settings.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from pathlib import Path
22

3-
from pydantic import BaseSettings
3+
from pydantic_settings import BaseSettings
44

55
from aana.configs.db import DBConfig
66

@@ -17,8 +17,8 @@ class Settings(BaseSettings):
1717
"""A pydantic model for SDK settings."""
1818

1919
tmp_data_dir: Path = Path("/tmp/aana_data") # noqa: S108
20-
image_dir = tmp_data_dir / "images"
21-
video_dir = tmp_data_dir / "videos"
20+
image_dir: Path = tmp_data_dir / "images"
21+
video_dir: Path = tmp_data_dir / "videos"
2222
num_workers: int = 2
2323

2424
db_config: DBConfig = {

aana/deployments/hf_blip2_deployment.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
from typing import Any, TypedDict
1+
from typing import Any
22

33
import torch
44
import transformers
55
from pydantic import BaseModel, Field
66
from ray import serve
77
from transformers import Blip2ForConditionalGeneration, Blip2Processor
8+
from typing_extensions import TypedDict
89

910
from aana.deployments.base_deployment import BaseDeployment
1011
from aana.exceptions.general import InferenceException

aana/deployments/vllm_deployment.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,26 @@
1+
import contextlib
12
from collections.abc import AsyncGenerator
2-
from typing import Any, TypedDict
3+
from typing import Any
34

45
from pydantic import BaseModel, Field
56
from ray import serve
7+
from typing_extensions import TypedDict
68
from vllm.engine.arg_utils import AsyncEngineArgs
79
from vllm.engine.async_llm_engine import AsyncLLMEngine
8-
from vllm.model_executor.utils import set_random_seed
10+
11+
with contextlib.suppress(ImportError):
12+
from vllm.model_executor.utils import (
13+
set_random_seed, # Ignore if we don't have GPU and only run on CPU with test cache
14+
)
915
from vllm.sampling_params import SamplingParams as VLLMSamplingParams
10-
from vllm.utils import get_gpu_memory, random_uuid
16+
from vllm.utils import random_uuid
1117

1218
from aana.deployments.base_deployment import BaseDeployment
1319
from aana.exceptions.general import InferenceException, PromptTooLongException
1420
from aana.models.pydantic.chat_message import ChatDialog, ChatMessage
1521
from aana.models.pydantic.sampling_params import SamplingParams
1622
from aana.utils.chat_template import apply_chat_template
17-
from aana.utils.general import merged_options
23+
from aana.utils.general import get_gpu_memory, merged_options
1824
from aana.utils.test import test_cache
1925

2026

@@ -28,6 +34,9 @@ class VLLMConfig(BaseModel):
2834
gpu_memory_reserved (float): the GPU memory reserved for the model in mb
2935
default_sampling_params (SamplingParams): the default sampling parameters.
3036
max_model_len (int): the maximum generated text length in tokens (optional, default: None)
37+
chat_template (str): the name of the chat template, if not provided, the chat template from the model will be used
38+
but some models may not have a chat template (optional, default: None)
39+
enforce_eager (bool): whether to enforce eager execution (optional, default: False)
3140
"""
3241

3342
model: str
@@ -37,6 +46,7 @@ class VLLMConfig(BaseModel):
3746
default_sampling_params: SamplingParams
3847
max_model_len: int | None = Field(default=None)
3948
chat_template: str | None = Field(default=None)
49+
enforce_eager: bool | None = Field(default=False)
4050

4151

4252
class LLMOutput(TypedDict):
@@ -107,6 +117,7 @@ async def apply_config(self, config: dict[str, Any]):
107117
model=config_obj.model,
108118
dtype=config_obj.dtype,
109119
quantization=config_obj.quantization,
120+
enforce_eager=config_obj.enforce_eager,
110121
gpu_memory_utilization=self.gpu_memory_utilization,
111122
max_model_len=config_obj.max_model_len,
112123
)
@@ -116,7 +127,7 @@ async def apply_config(self, config: dict[str, Any]):
116127

117128
# create the engine
118129
self.engine = AsyncLLMEngine.from_engine_args(args)
119-
self.tokenizer = self.engine.engine.tokenizer
130+
self.tokenizer = self.engine.engine.tokenizer.tokenizer
120131
self.model_config = await self.engine.get_model_config()
121132

122133
@test_cache
@@ -148,7 +159,7 @@ async def generate_stream(
148159
try:
149160
# convert SamplingParams to VLLMSamplingParams
150161
sampling_params_vllm = VLLMSamplingParams(
151-
**sampling_params.dict(exclude_unset=True)
162+
**sampling_params.model_dump(exclude_unset=True)
152163
)
153164
# start the request
154165
request_id = random_uuid()

aana/deployments/whisper_deployment.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from collections.abc import AsyncGenerator
22
from enum import Enum
3-
from typing import Any, TypedDict, cast
3+
from typing import Any, cast
44

55
import torch
66
from faster_whisper import WhisperModel
77
from pydantic import BaseModel, Field
88
from ray import serve
9+
from typing_extensions import TypedDict
910

1011
from aana.deployments.base_deployment import BaseDeployment
1112
from aana.exceptions.general import InferenceException
@@ -161,7 +162,7 @@ async def transcribe(
161162
params = WhisperParams()
162163
media_path: str = str(media.path)
163164
try:
164-
segments, info = self.model.transcribe(media_path, **params.dict())
165+
segments, info = self.model.transcribe(media_path, **params.model_dump())
165166
except Exception as e:
166167
raise InferenceException(self.model_name) from e
167168

@@ -196,7 +197,7 @@ async def transcribe_stream(
196197
params = WhisperParams()
197198
media_path: str = str(media.path)
198199
try:
199-
segments, info = self.model.transcribe(media_path, **params.dict())
200+
segments, info = self.model.transcribe(media_path, **params.model_dump())
200201
except Exception as e:
201202
raise InferenceException(self.model_name) from e
202203

aana/models/db/transcript.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,5 +69,5 @@ def from_asr_output(
6969
language=info.language,
7070
language_confidence=info.language_confidence,
7171
transcript=transcription.text,
72-
segments=[s.dict() for s in segments],
72+
segments=[s.model_dump() for s in segments],
7373
)

0 commit comments

Comments
 (0)