Skip to content

Commit d9a62fa

Browse files
authored
Merge pull request #44 from mobiusml/edr/getting_started_doc
Edr/getting started doc
2 parents ca6df6a + 240c8ee commit d9a62fa

File tree

10 files changed

+1178
-2
lines changed

10 files changed

+1178
-2
lines changed

aana/configs/deployments.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
from aana.deployments.hf_blip2_deployment import HFBlip2Config, HFBlip2Deployment
2+
from aana.deployments.stablediffusion2_deployment import (
3+
StableDiffusion2Config,
4+
StableDiffusion2Deployment,
5+
)
26
from aana.deployments.vllm_deployment import VLLMConfig, VLLMDeployment
37
from aana.deployments.whisper_deployment import (
48
WhisperComputeType,
@@ -45,4 +49,13 @@
4549
compute_type=WhisperComputeType.FLOAT16,
4650
).dict(),
4751
),
52+
"stablediffusion2_deployment": StableDiffusion2Deployment.options(
53+
num_replicas=1,
54+
max_concurrent_queries=1000,
55+
ray_actor_options={"num_gpus": 1},
56+
user_config=StableDiffusion2Config(
57+
model="stabilityai/stable-diffusion-2",
58+
dtype=Dtype.FLOAT16,
59+
).dict(),
60+
),
4861
}

aana/configs/endpoints.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,4 +186,17 @@
186186
],
187187
),
188188
],
189+
"stablediffusion2": [
190+
Endpoint(
191+
name="imagegen",
192+
path="/generate_image",
193+
summary="Generates an image from a text prompt",
194+
outputs=[
195+
EndpointOutput(
196+
name="image_path_stablediffusion2",
197+
output="image_path_stablediffusion2",
198+
)
199+
],
200+
)
201+
],
189202
}

aana/configs/pipeline.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
33
It is used to generate the pipeline and the API endpoints.
44
"""
5+
import PIL.Image
56

67
from aana.models.pydantic.asr_output import (
78
AsrSegments,
@@ -721,6 +722,41 @@
721722
}
722723
],
723724
},
725+
{
726+
"name": "stable-diffusion-2-imagegen",
727+
"type": "ray_deployment",
728+
"deployment_name": "stablediffusion2_deployment",
729+
"method": "generate",
730+
"inputs": [{"name": "prompt", "key": "prompt", "path": "prompt"}],
731+
"outputs": [
732+
{
733+
"name": "image_stablediffusion2",
734+
"key": "image",
735+
"path": "stablediffusion2-image",
736+
"data_model": PIL.Image.Image,
737+
}
738+
],
739+
},
740+
{
741+
"name": "save_image_stablediffusion2",
742+
"type": "function",
743+
"function": "aana.utils.image.save_image",
744+
"dict_output": True,
745+
"inputs": [
746+
{
747+
"name": "image_stablediffusion2",
748+
"key": "image",
749+
"path": "stablediffusion2-image",
750+
},
751+
],
752+
"outputs": [
753+
{
754+
"name": "image_path_stablediffusion2",
755+
"key": "path",
756+
"path": "image_path",
757+
}
758+
],
759+
},
724760
{
725761
"name": "save_video",
726762
"type": "function",
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from typing import Any, TypedDict
2+
3+
import PIL
4+
import torch
5+
from diffusers import EulerDiscreteScheduler, StableDiffusionPipeline
6+
from pydantic import BaseModel, Field
7+
from ray import serve
8+
9+
from aana.deployments.base_deployment import BaseDeployment
10+
from aana.models.core.dtype import Dtype
11+
from aana.models.pydantic.prompt import Prompt
12+
13+
14+
class StableDiffusion2Output(TypedDict):
15+
"""Output class for the StableDiffusion2 deployment."""
16+
17+
image: PIL.Image.Image
18+
19+
20+
class StableDiffusion2Config(BaseModel):
21+
"""The configuration for the Stable Diffusion 2 deployment.
22+
23+
Attributes:
24+
model (str): the model ID on HuggingFace
25+
dtype (str): the data type (optional, default: "auto"), one of "auto", "float32", "float16"
26+
"""
27+
28+
model: str
29+
dtype: Dtype = Field(default=Dtype.AUTO)
30+
31+
32+
@serve.deployment
33+
class StableDiffusion2Deployment(BaseDeployment):
34+
"""Stable Diffusion 2 deployment."""
35+
36+
async def apply_config(self, config: dict[str, Any]):
37+
"""Apply the configuration.
38+
39+
The method is called when the deployment is created or updated.
40+
41+
It loads the model and scheduler from HuggingFace.
42+
43+
The configuration should conform to the StableDiffusion2Confgi schema.
44+
"""
45+
config_obj = StableDiffusion2Config(**config)
46+
47+
# Load the model and processor from HuggingFace
48+
self.model_id = config_obj.model
49+
self.dtype = config_obj.dtype
50+
if self.dtype == Dtype.INT8:
51+
self.torch_dtype = Dtype.FLOAT16.to_torch()
52+
else:
53+
self.torch_dtype = self.dtype.to_torch()
54+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
55+
self.model = StableDiffusionPipeline.from_pretrained(
56+
self.model_id,
57+
torch_dtype=self.torch_dtype,
58+
scheduler=EulerDiscreteScheduler.from_pretrained(
59+
self.model_id, subfolder="scheduler"
60+
),
61+
device_map="auto",
62+
)
63+
64+
self.model.to(self.device)
65+
66+
async def generate(self, prompt: Prompt) -> StableDiffusion2Output:
67+
"""Runs the model on a given prompt and returns the first output.
68+
69+
Arguments:
70+
prompt (Prompt): the prompt to the model.
71+
72+
Returns:
73+
StableDiffusion2Output: a dictionary with one key containing the result
74+
"""
75+
image = self.model(str(prompt)).images[0]
76+
return {"image": image}

aana/models/core/file.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from pathlib import Path
2+
from typing import TypedDict
3+
4+
5+
class PathResult(TypedDict):
6+
"""Represents a path result describing a file on disk."""
7+
8+
path: Path

aana/utils/image.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from pathlib import Path
2+
from uuid import uuid4
3+
4+
import PIL.Image
5+
6+
from aana.configs.settings import settings
7+
from aana.models.core.file import PathResult
8+
9+
10+
def save_image(image: PIL.Image.Image, full_path: Path | None = None) -> PathResult:
11+
"""Saves an image to the given full path, or randomely generates one if no path is supplied.
12+
13+
Arguments:
14+
image (Image): the image to save
15+
full_path (Path|None): the path to save the image to. If None, will generate one randomly.
16+
17+
Returns:
18+
PathResult: contains the path to the saved image.
19+
"""
20+
if not full_path:
21+
full_path = settings.image_dir / f"{uuid4()}.png"
22+
image.save(full_path)
23+
return {"path": full_path}

docs/diagram.png

161 KB
Loading

0 commit comments

Comments
 (0)