Skip to content

Commit ca6df6a

Browse files
authored
Merge pull request #47 from mobiusml/add_media_id_model
Media ID and Quesition Models
2 parents 84917e1 + 757a661 commit ca6df6a

29 files changed

+349
-150
lines changed

aana/configs/db.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,32 @@
55

66
from alembic import command
77
from alembic.config import Config
8-
from sqlalchemy import String, create_engine
8+
from sqlalchemy import String, TypeDecorator, create_engine
99

10-
# These are here so we can change types in a single place.
10+
from aana.models.pydantic.media_id import MediaId
1111

12-
media_id_type: TypeAlias = str
13-
MediaIdSqlType: TypeAlias = String
12+
13+
class MediaIdType(TypeDecorator):
14+
"""Custom type for handling MediaId objects with SQLAlchemy."""
15+
16+
impl = String
17+
18+
cache_ok = True
19+
20+
def process_bind_param(self, value, dialect):
21+
"""Convert a MediaId instance to a string value for storage."""
22+
if value is None:
23+
return value
24+
return str(value)
25+
26+
def process_result_value(self, value, dialect):
27+
"""Convert a string value from the database back into a MediaId instance."""
28+
if value is None:
29+
return value
30+
return MediaId(value)
31+
32+
33+
MediaIdSqlType: TypeAlias = MediaIdType
1434

1535

1636
class SQLiteConfig(TypedDict):

aana/configs/pipeline.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
from aana.models.pydantic.captions import CaptionsList, VideoCaptionsList
1515
from aana.models.pydantic.chat_message import ChatDialog
1616
from aana.models.pydantic.image_input import ImageInputList
17+
from aana.models.pydantic.media_id import MediaId
1718
from aana.models.pydantic.prompt import Prompt
19+
from aana.models.pydantic.question import Question
1820
from aana.models.pydantic.sampling_params import SamplingParams
1921
from aana.models.pydantic.video_input import VideoInput, VideoInputList
2022
from aana.models.pydantic.video_metadata import VideoMetadata
@@ -525,6 +527,7 @@
525527
"name": "media_id",
526528
"key": "media_id",
527529
"path": "media_id",
530+
"data_model": MediaId,
528531
}
529532
],
530533
},
@@ -537,6 +540,7 @@
537540
"name": "question",
538541
"key": "question",
539542
"path": "question",
543+
"data_model": Question,
540544
}
541545
],
542546
},

aana/exceptions/database.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
from mobius_pipeline.exceptions import BaseException
22

3-
from aana.configs.db import media_id_type
3+
from aana.models.pydantic.media_id import MediaId
44

55

66
class NotFoundException(BaseException):
77
"""Raised when an item searched by id is not found."""
88

9-
def __init__(self, table_name: str, id: int | media_id_type): # noqa: A002
9+
def __init__(self, table_name: str, id: int | MediaId): # noqa: A002
1010
"""Constructor.
1111
1212
Args:
1313
table_name (str): the name of the table being queried.
14-
id (media_id_type): the id of the item to be retrieved.
14+
id (int | MediaId): the id of the item to be retrieved.
1515
"""
1616
super().__init__(table=table_name, id=id)
1717
self.table_name = table_name
@@ -26,12 +26,12 @@ def __reduce__(self):
2626
class MediaIdAlreadyExistsException(BaseException):
2727
"""Raised when a media_id already exists."""
2828

29-
def __init__(self, table_name: str, media_id: media_id_type):
29+
def __init__(self, table_name: str, media_id: MediaId):
3030
"""Constructor.
3131
3232
Args:
3333
table_name (str): the name of the table being queried.
34-
media_id (media_id_type): the id of the item to be retrieved.
34+
media_id (MediaId): the id of the item to be retrieved.
3535
"""
3636
super().__init__(table=table_name, id=media_id)
3737
self.table_name = table_name

aana/exceptions/general.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -153,28 +153,6 @@ def __reduce__(self):
153153
return (self.__class__, (self.prompt_len, self.max_len))
154154

155155

156-
class MediaIdNotFoundException(BaseException):
157-
"""Exception raised when a media ID is not found.
158-
159-
Attributes:
160-
media_id (str): the media ID
161-
"""
162-
163-
def __init__(self, media_id: str):
164-
"""Initialize the exception.
165-
166-
Args:
167-
media_id (str): the media ID
168-
"""
169-
super().__init__(media_id=media_id)
170-
self.media_id = media_id
171-
self.http_status_code = 404
172-
173-
def __reduce__(self):
174-
"""Used for pickling."""
175-
return (self.__class__, (self.media_id,))
176-
177-
178156
class EndpointNotFoundException(BaseException):
179157
"""Exception raised when an endpoint is not found.
180158

aana/models/core/image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ class Image(Media):
130130
url (str): The URL of the image.
131131
content (bytes): The content of the image in bytes (image file as bytes).
132132
numpy (np.ndarray): The image as a numpy array.
133-
media_id (str): The ID of the image, generated automatically if not provided.
133+
media_id (MediaId): The ID of the image, generated automatically if not provided.
134134
"""
135135

136136
media_dir: Path | None = settings.image_dir

aana/models/core/media.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import hashlib
2-
import uuid
32
from dataclasses import dataclass, field
43
from pathlib import Path
54

5+
from aana.models.pydantic.media_id import MediaId
66
from aana.utils.general import download_file
77

88

@@ -19,13 +19,13 @@ class Media:
1919
path (Path): the path to the media file
2020
url (str): the URL of the media
2121
content (bytes): the content of the media in bytes
22-
media_id (str): the ID of the media. If not provided, it will be generated automatically.
22+
media_id (MediaId): the ID of the media. If not provided, it will be generated automatically.
2323
"""
2424

2525
path: Path | None = None
2626
url: str | None = None
2727
content: bytes | None = None
28-
media_id: str = field(default_factory=lambda: str(uuid.uuid4()))
28+
media_id: MediaId = field(default_factory=lambda: MediaId.random())
2929
save_on_disk: bool = True
3030
is_saved: bool = False
3131
media_dir: Path | None = None

aana/models/core/video.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class Video(Media):
2020
path (Path): the path to the video file
2121
url (str): the URL of the video
2222
content (bytes): the content of the video in bytes
23-
media_id (str): the ID of the video. If not provided, it will be generated automatically.
23+
media_id (MediaId): the ID of the video. If not provided, it will be generated automatically.
2424
title (str): the title of the video
2525
description (str): the description of the video
2626
media_dir (Path): the directory to save the video in

aana/models/db/caption.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
from sqlalchemy import CheckConstraint, Column, Float, ForeignKey, Integer, String
66
from sqlalchemy.orm import relationship
77

8-
from aana.configs.db import MediaIdSqlType, media_id_type
8+
from aana.configs.db import MediaIdSqlType
99
from aana.models.db.base import BaseEntity, TimeStampEntity
1010

1111
if typing.TYPE_CHECKING:
1212
from aana.models.pydantic.captions import Caption
13+
from aana.models.pydantic.media_id import MediaId
1314

1415

1516
class CaptionEntity(BaseEntity, TimeStampEntity):
@@ -52,7 +53,7 @@ class CaptionEntity(BaseEntity, TimeStampEntity):
5253
def from_caption_output(
5354
cls,
5455
model_name: str,
55-
media_id: media_id_type,
56+
media_id: MediaId,
5657
video_id: int,
5758
frame_id: int,
5859
frame_timestamp: float,

aana/models/db/timeline.py.bak

Lines changed: 0 additions & 33 deletions
This file was deleted.

aana/models/db/transcript.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from sqlalchemy import JSON, CheckConstraint, Column, Float, ForeignKey, Integer, String
66
from sqlalchemy.orm import relationship
77

8-
from aana.configs.db import MediaIdSqlType, media_id_type
8+
from aana.configs.db import MediaIdSqlType
99
from aana.models.db.base import BaseEntity, TimeStampEntity
1010

1111
if TYPE_CHECKING:
@@ -14,6 +14,7 @@
1414
AsrTranscription,
1515
AsrTranscriptionInfo,
1616
)
17+
from aana.models.pydantic.media_id import MediaId
1718

1819

1920
class TranscriptEntity(BaseEntity, TimeStampEntity):
@@ -54,7 +55,7 @@ class TranscriptEntity(BaseEntity, TimeStampEntity):
5455
def from_asr_output(
5556
cls,
5657
model_name: str,
57-
media_id: media_id_type,
58+
media_id: MediaId,
5859
video_id: int,
5960
info: AsrTranscriptionInfo,
6061
transcription: AsrTranscription,

aana/models/pydantic/base.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from pydantic import BaseModel
2+
from pyparsing import Any
23

34

45
class BaseListModel(BaseModel):
@@ -34,3 +35,75 @@ def __contains__(self, item):
3435
def __add__(self, other):
3536
"""Add two models."""
3637
return self.__class__(__root__=self.__root__ + other.__root__)
38+
39+
40+
class BaseStringModel(BaseModel):
41+
"""The base model for pydantic models that are just strings."""
42+
43+
__root__: str
44+
45+
def __init__(self, __root__value: Any = None, **data):
46+
"""Initialize the model."""
47+
if __root__value is not None:
48+
super().__init__(__root__=__root__value, **data)
49+
else:
50+
super().__init__(**data)
51+
52+
def __str__(self) -> str:
53+
"""Convert to a string."""
54+
return self.__root__
55+
56+
def __repr__(self) -> str:
57+
"""Convert to a string representation."""
58+
return f"{self.__class__.__name__}({self.__root__!r})"
59+
60+
def __eq__(self, other: Any) -> bool:
61+
"""Check if two models are equal."""
62+
if isinstance(other, self.__class__):
63+
return self.__root__ == other.__root__
64+
if isinstance(other, str):
65+
return self.__root__ == other
66+
return NotImplemented
67+
68+
def __hash__(self) -> int:
69+
"""Get hash of model."""
70+
return hash(self.__root__)
71+
72+
def __getitem__(self, key):
73+
"""Get item at key of model."""
74+
return self.__root__[key]
75+
76+
def __len__(self) -> int:
77+
"""Get length of model."""
78+
return len(self.__root__)
79+
80+
def __iter__(self):
81+
"""Get iterator for model."""
82+
return iter(self.__root__)
83+
84+
def __contains__(self, item):
85+
"""Check if modle contains item."""
86+
return item in self.__root__
87+
88+
def __add__(self, other):
89+
"""Add two models or a model and a string."""
90+
if isinstance(other, self.__class__):
91+
return self.__class__(__root__=self.__root__ + other.__root__)
92+
if isinstance(other, str):
93+
return str(self.__root__) + other
94+
return NotImplemented
95+
96+
def __getattr__(self, item):
97+
"""Automatically delegate method calls to self.__root__ if they are not found in the model.
98+
99+
Check if the attribute is a callable (method) of __root__ and return a wrapped call if it is.
100+
This will handle methods like startswith, endswith, and split.
101+
"""
102+
attr = getattr(self.__root__, item)
103+
if callable(attr):
104+
105+
def wrapper(*args, **kwargs):
106+
return attr(*args, **kwargs)
107+
108+
return wrapper
109+
return attr

aana/models/pydantic/image_input.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import io
2-
import uuid
32
from pathlib import Path
43
from types import MappingProxyType
54

@@ -9,6 +8,7 @@
98

109
from aana.models.core.image import Image
1110
from aana.models.pydantic.base import BaseListModel
11+
from aana.models.pydantic.media_id import MediaId
1212

1313

1414
class ImageInput(BaseModel):
@@ -45,8 +45,8 @@ class ImageInput(BaseModel):
4545
"Set this field to 'file' to upload files to the endpoint."
4646
),
4747
)
48-
media_id: str = Field(
49-
default_factory=lambda: str(uuid.uuid4()),
48+
media_id: MediaId = Field(
49+
default_factory=lambda: MediaId.random(),
5050
description="The ID of the image. If not provided, it will be generated automatically.",
5151
)
5252

@@ -55,7 +55,7 @@ def media_id_must_not_be_empty(cls, media_id):
5555
"""Validates that the media_id is not an empty string.
5656
5757
Args:
58-
media_id (str): The value of the media_id field.
58+
media_id (MediaId): The value of the media_id field.
5959
6060
Raises:
6161
ValueError: If the media_id is an empty string.

0 commit comments

Comments
 (0)