|
| 1 | +# --- |
| 2 | +# jupyter: |
| 3 | +# jupytext: |
| 4 | +# text_representation: |
| 5 | +# extension: .py |
| 6 | +# format_name: percent |
| 7 | +# format_version: '1.3' |
| 8 | +# jupytext_version: 1.16.7 |
| 9 | +# kernelspec: |
| 10 | +# display_name: Python 3 (ipykernel) |
| 11 | +# language: python |
| 12 | +# name: python3 |
| 13 | +# --- |
| 14 | + |
| 15 | +# %% [markdown] |
| 16 | +# # Text-to-Image 手法の実践 |
| 17 | + |
| 18 | +# %% [markdown] |
| 19 | +# [](https://colab.research.google.com/github/py-img-gen/python-image-generation/blob/main/notebooks/1-2_text-to-image-generation.ipynb) |
| 20 | + |
| 21 | +# %% [markdown] |
| 22 | +# ## 準備 |
| 23 | + |
| 24 | +# %% |
| 25 | +# !pip install -qq py-img-gen |
| 26 | + |
| 27 | +# %% |
| 28 | +import torch |
| 29 | + |
| 30 | +# GPU が使用できる場合は GPU (= cuda) を指定し、 |
| 31 | +# そうでない場合は CPU を指定 |
| 32 | +device = torch.device( |
| 33 | + "cuda" if torch.cuda.is_available() else "cpu" |
| 34 | +) |
| 35 | +# 通常は単精度 (float32) を使用するが、 |
| 36 | +# メモリ使用量削減のため半精度 (float16)を使用 |
| 37 | +dtype = torch.float16 |
| 38 | +# 生成結果の再現性を確保するためにシード値を設定 |
| 39 | +seed = 42 |
| 40 | + |
| 41 | +# %% |
| 42 | +import logging |
| 43 | + |
| 44 | +# error ログのみを表示する |
| 45 | +logger_name = "diffusers.pipelines.pipeline_utils" |
| 46 | +logging.getLogger(logger_name).setLevel(logging.ERROR) |
| 47 | + |
| 48 | +# %% [markdown] |
| 49 | +# ## Stable Diffusion を扱うパイプラインの構築 |
| 50 | + |
| 51 | +# %% |
| 52 | +from diffusers import StableDiffusionPipeline |
| 53 | + |
| 54 | +model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5" |
| 55 | + |
| 56 | +pipe = StableDiffusionPipeline.from_pretrained( |
| 57 | + model_id, torch_dtype=dtype |
| 58 | +) |
| 59 | + |
| 60 | +# %% |
| 61 | +print(f"Move pipeline to {device}") |
| 62 | + |
| 63 | +pipe = pipe.to(device) |
| 64 | + |
| 65 | +# %% |
| 66 | +# 訳: 宇宙飛行士が馬に乗っている写真 |
| 67 | +text = "a photograph of an astronaut riding a horse" |
| 68 | + |
| 69 | +# 画像を生成 |
| 70 | +output = pipe(prompt=text) |
| 71 | + |
| 72 | +# ここで image は |
| 73 | +# pillow (https://pillow.readthedocs.io/en/stable) 形式 |
| 74 | +image = output.images[0] |
| 75 | + |
| 76 | +# 画像を表示 |
| 77 | +image |
| 78 | + |
| 79 | +# %% |
| 80 | +# 乱数生成器に指定されたシード値を設定 |
| 81 | +generator = torch.manual_seed(seed) |
| 82 | + |
| 83 | +# pipe の引数である generator に上記の乱数生成器を渡して画像を生成 |
| 84 | +output = pipe(prompt=text, generator=generator) |
| 85 | +image = output.images[0] |
| 86 | +image # 画像を表示 |
| 87 | + |
| 88 | +# %% |
| 89 | +generator = torch.manual_seed(seed) |
| 90 | + |
| 91 | +# 推論時のステップ数である num_inference_steps を 15 に設定 (デフォルトは 50) |
| 92 | +output = pipe( |
| 93 | + prompt=text, generator=generator, num_inference_steps=15 |
| 94 | +) |
| 95 | +image = output.images[0] |
| 96 | +image # 画像を表示 |
| 97 | + |
| 98 | +# %% |
| 99 | +from diffusers.utils import make_image_grid |
| 100 | + |
| 101 | +text = "a photograph of an astronaut riding a horse" |
| 102 | + |
| 103 | +num_rows, num_cols = 4, 3 # 行数・列数を指定 |
| 104 | +num_images = num_rows * num_cols # 生成画像数 |
| 105 | + |
| 106 | +output = pipe(prompt=text, num_images_per_prompt=num_images) |
| 107 | + |
| 108 | +# make_image_grid 関数を使用してグリッド上に複数生成画像を表示 |
| 109 | +make_image_grid( |
| 110 | + images=output.images, rows=num_rows, cols=num_cols |
| 111 | +) |
| 112 | + |
| 113 | +# %% [markdown] |
| 114 | +# ## Stable Diffusion v1 による画像生成 |
| 115 | + |
| 116 | +# %% |
| 117 | +# 吾輩は猫である。名前はまだ無い。(夏目漱石「吾輩は猫である」冒頭より) |
| 118 | +text = "I am a cat. As yet I have no name." |
| 119 | + |
| 120 | +# シード値を固定して画像を生成 |
| 121 | +output = pipe( |
| 122 | + prompt=text, generator=torch.manual_seed(seed) |
| 123 | +) |
| 124 | +image = output.images[0] |
| 125 | +image # 画像を表示 |
| 126 | + |
| 127 | +# %% |
| 128 | +# 国境の長いトンネルを抜けると雪国であった。(川端康成「雪国」冒頭より) |
| 129 | +text = "The train came out of the long tunnel into the snow country." |
| 130 | + |
| 131 | +# シード値を固定して画像を生成 |
| 132 | +output = pipe( |
| 133 | + prompt=text, generator=torch.manual_seed(seed) |
| 134 | +) |
| 135 | +image = output.images[0] |
| 136 | +image # 画像を表示 |
| 137 | + |
| 138 | +# %% |
| 139 | +# 春はあけぼの、ようよう白く成りゆく山際、少し明かりて、紫だちたる雲の細くたなびきたる。 |
| 140 | +# (清少納言「枕草子」冒頭より) |
| 141 | +text = "In the dawn of spring, the mountains are turning white, and the purple clouds are trailing thinly with a little light" |
| 142 | + |
| 143 | +output = pipe( |
| 144 | + prompt=text, generator=torch.manual_seed(seed) |
| 145 | +) |
| 146 | +image = output.images[0] |
| 147 | +image # 画像を表示 |
| 148 | + |
| 149 | +# %% |
| 150 | +import gc |
| 151 | + |
| 152 | +pipe = pipe.to("cpu") |
| 153 | +del pipe |
| 154 | +gc.collect() |
| 155 | +torch.cuda.empty_cache() |
| 156 | + |
| 157 | + |
| 158 | +# %% [markdown] |
| 159 | +# ## Stable Diffusion v2 による画像生成 |
| 160 | + |
| 161 | +# %% |
| 162 | +model_id = "stabilityai/stable-diffusion-2" |
| 163 | + |
| 164 | +pipe = StableDiffusionPipeline.from_pretrained( |
| 165 | + model_id, torch_dtype=dtype |
| 166 | +) |
| 167 | +pipe = pipe.to(device) |
| 168 | + |
| 169 | +# %% |
| 170 | +pipe.enable_attention_slicing() |
| 171 | + |
| 172 | +# %% |
| 173 | +text = "a photograph of an astronaut riding a horse" |
| 174 | + |
| 175 | +generator = torch.manual_seed(seed) |
| 176 | +output = pipe(prompt=text, generator=generator) |
| 177 | +image = output.images[0] |
| 178 | +image # 画像を表示 |
| 179 | + |
| 180 | +# %% |
| 181 | +# 吾輩は猫である。名前はまだ無い。 |
| 182 | +text = "I am a cat. As yet I have no name." |
| 183 | + |
| 184 | +output = pipe( |
| 185 | + prompt=text, generator=torch.manual_seed(seed) |
| 186 | +) |
| 187 | +image = output.images[0] |
| 188 | +image # 画像を表示 |
| 189 | + |
| 190 | +# %% |
| 191 | +# 国境の長いトンネルを抜けると雪国であった。 |
| 192 | +text = "The train came out of the long tunnel into the snow country." |
| 193 | + |
| 194 | +output = pipe( |
| 195 | + prompt=text, generator=torch.manual_seed(seed) |
| 196 | +) |
| 197 | +image = output.images[0] |
| 198 | +image # 画像を表示 |
| 199 | + |
| 200 | +# %% |
| 201 | +# 春はあけぼの、ようよう白く成りゆく山際、少し明かりて、紫だちたる雲の細くたなびきたる。 |
| 202 | +text = "In the dawn of spring, the mountains are turning white, and the purple clouds are trailing thinly with a little light" |
| 203 | + |
| 204 | +output = pipe( |
| 205 | + prompt=text, generator=torch.manual_seed(seed) |
| 206 | +) |
| 207 | +image = output.images[0] |
| 208 | +image # 画像を表示 |
| 209 | + |
| 210 | +# %% |
| 211 | +pipe = pipe.to("cpu") |
| 212 | +del pipe |
| 213 | +gc.collect() |
| 214 | +torch.cuda.empty_cache() |
| 215 | + |
| 216 | +# %% [markdown] |
| 217 | +# ## waifu-diffusion による画像生成 |
| 218 | + |
| 219 | +# %% |
| 220 | +pipe = StableDiffusionPipeline.from_pretrained( |
| 221 | + "hakurei/waifu-diffusion", torch_dtype=dtype |
| 222 | +) |
| 223 | +pipe = pipe.to(device) |
| 224 | + |
| 225 | +# %% |
| 226 | +text = "1girl, aqua eyes, baseball cap, blonde hair, closed mouth, earrings, green background, hat, hoop earrings, jewelry, looking at viewer, shirt, short hair, simple background, solo, upper body, yellow shirt" |
| 227 | + |
| 228 | +num_rows, num_cols = 4, 3 |
| 229 | +num_images = num_rows * num_cols |
| 230 | + |
| 231 | +generator = torch.manual_seed(seed) |
| 232 | +output = pipe( |
| 233 | + prompt=text, |
| 234 | + generator=generator, |
| 235 | + num_images_per_prompt=num_images, |
| 236 | +) |
| 237 | + |
| 238 | +make_image_grid(output.images, rows=num_rows, cols=num_cols) |
| 239 | + |
| 240 | +# %% |
| 241 | +pipe = pipe.to("cpu") |
| 242 | +del pipe |
| 243 | +gc.collect() |
| 244 | +torch.cuda.empty_cache() |
| 245 | + |
| 246 | + |
| 247 | +# %% [markdown] |
| 248 | +# ## nitro-diffusion による画像生成 |
| 249 | + |
| 250 | +# %% |
| 251 | +pipe = StableDiffusionPipeline.from_pretrained( |
| 252 | + "nitrosocke/nitro-diffusion", torch_dtype=dtype |
| 253 | +) |
| 254 | +pipe = pipe.to(device) |
| 255 | + |
| 256 | +# %% |
| 257 | +text = ( |
| 258 | + "archer arcane style magical princess with golden hair" |
| 259 | +) |
| 260 | + |
| 261 | +num_rows, num_cols = 4, 3 |
| 262 | +num_images = num_rows * num_cols |
| 263 | + |
| 264 | +generator = torch.manual_seed(seed) |
| 265 | +output = pipe( |
| 266 | + prompt=text, |
| 267 | + generator=generator, |
| 268 | + num_images_per_prompt=num_images, |
| 269 | +) |
| 270 | + |
| 271 | +make_image_grid(output.images, rows=num_rows, cols=num_cols) |
| 272 | + |
| 273 | +# %% |
| 274 | +pipe = pipe.to("cpu") |
| 275 | +del pipe |
| 276 | +gc.collect() |
| 277 | +torch.cuda.empty_cache() |
0 commit comments