|
| 1 | +# --- |
| 2 | +# jupyter: |
| 3 | +# jupytext: |
| 4 | +# text_representation: |
| 5 | +# extension: .py |
| 6 | +# format_name: percent |
| 7 | +# format_version: '1.3' |
| 8 | +# jupytext_version: 1.16.7 |
| 9 | +# kernelspec: |
| 10 | +# display_name: Python 3 (ipykernel) |
| 11 | +# language: python |
| 12 | +# name: python3 |
| 13 | +# --- |
| 14 | + |
| 15 | +# %% [markdown] |
| 16 | +# # Attend-and-Excite の実装 |
| 17 | + |
| 18 | +# %% [markdown] |
| 19 | +# [](https://colab.research.google.com/github/py-img-gen/python-image-generation/blob/main/notebooks/5-2-1_attend-and-excite.ipynb) |
| 20 | + |
| 21 | +# %% [markdown] |
| 22 | +# ## 準備 |
| 23 | + |
| 24 | +# %% |
| 25 | +# !pip install -qq py-img-gen |
| 26 | + |
| 27 | +# %% |
| 28 | +import warnings |
| 29 | + |
| 30 | +import torch |
| 31 | + |
| 32 | +device = torch.device( |
| 33 | + "cuda" if torch.cuda.is_available() else "cpu" |
| 34 | +) |
| 35 | +dtype = torch.float16 |
| 36 | +seed = 42 |
| 37 | + |
| 38 | +warnings.simplefilter("ignore") |
| 39 | + |
| 40 | +# %% [markdown] |
| 41 | +# ## オリジナルの StableDiffusionPipeline の読み込み |
| 42 | + |
| 43 | +# %% |
| 44 | +from diffusers import StableDiffusionPipeline |
| 45 | + |
| 46 | +model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5" |
| 47 | + |
| 48 | +pipe_sd = StableDiffusionPipeline.from_pretrained( |
| 49 | + model_id, torch_dtype=dtype |
| 50 | +) |
| 51 | +pipe_sd = pipe_sd.to(device) |
| 52 | + |
| 53 | +# %% [markdown] |
| 54 | +# ## Attend and Excite を実装した StableDiffusionAttendAndExcitePipeline の読み込み |
| 55 | + |
| 56 | +# %% |
| 57 | +from diffusers import StableDiffusionAttendAndExcitePipeline |
| 58 | + |
| 59 | +pipe_ae = ( |
| 60 | + StableDiffusionAttendAndExcitePipeline.from_pretrained( |
| 61 | + model_id, torch_dtype=dtype |
| 62 | + ) |
| 63 | +) |
| 64 | +pipe_ae = pipe_ae.to(device) |
| 65 | + |
| 66 | +# %% [markdown] |
| 67 | +# ## StableDiffusion での画像生成 |
| 68 | + |
| 69 | +# %% |
| 70 | +from diffusers.utils import make_image_grid |
| 71 | + |
| 72 | +prompt = "A horse and a dog" |
| 73 | + |
| 74 | +images_sd = pipe_sd( |
| 75 | + prompt, |
| 76 | + num_images_per_prompt=2, |
| 77 | + generator=torch.manual_seed(seed), |
| 78 | +).images |
| 79 | + |
| 80 | +# %% |
| 81 | +gen_result_sd = make_image_grid( |
| 82 | + images=images_sd, rows=1, cols=2 |
| 83 | +) |
| 84 | +gen_result_sd |
| 85 | + |
| 86 | +# %% [markdown] |
| 87 | +# ## Attend and Excite を適用した Stable Diffusion での画像生成 |
| 88 | + |
| 89 | +# %% |
| 90 | +# `get_indices` 関数を使用して、対象のトークン(horse と dog)のインデックスを調べる |
| 91 | +# 2 と 5 がそれぞれ horse と dog であることを確認 |
| 92 | +print(f"Indicies: {pipe_ae.get_indices(prompt)}") |
| 93 | + |
| 94 | +# %% |
| 95 | +# 上記で調べたトークンのインデックスを指定 |
| 96 | +token_indices = [2, 5] |
| 97 | + |
| 98 | +# Attend-and-Excite パイプラインによって画像を生成 |
| 99 | +images_ae = pipe_ae( |
| 100 | + prompt, |
| 101 | + num_images_per_prompt=2, |
| 102 | + generator=torch.manual_seed(seed), |
| 103 | + # |
| 104 | + # Additional arguments for Attend-and-Excite |
| 105 | + # 対象のトークンを指定 |
| 106 | + # |
| 107 | + token_indices=token_indices, |
| 108 | +).images |
| 109 | + |
| 110 | +# %% |
| 111 | +gen_result_ae = make_image_grid( |
| 112 | + images=images_ae, rows=1, cols=2 |
| 113 | +) |
| 114 | +gen_result_ae |
| 115 | + |
| 116 | +# %% [markdown] |
| 117 | +# ## 生成結果の比較 |
| 118 | + |
| 119 | +# %% |
| 120 | +import matplotlib.pyplot as plt |
| 121 | +from mpl_toolkits.axes_grid1 import ImageGrid |
| 122 | + |
| 123 | +fig = plt.figure(figsize=(20, 5)) |
| 124 | +grid = ImageGrid( |
| 125 | + fig, |
| 126 | + rect=111, |
| 127 | + nrows_ncols=(1, 2), |
| 128 | + axes_pad=0.1, |
| 129 | +) |
| 130 | +fig.suptitle(f"Prompt: {prompt}") |
| 131 | + |
| 132 | +images = [ |
| 133 | + gen_result_sd, |
| 134 | + gen_result_ae, |
| 135 | +] |
| 136 | +titles = [ |
| 137 | + r"Stable Diffusion ${\it without}$ Attend-and-Excite", |
| 138 | + r"Stable Diffusion ${\it with}$ Attend-and-Excite", |
| 139 | +] |
| 140 | +for i, (image, title) in enumerate(zip(images, titles)): |
| 141 | + grid[i].imshow(image) |
| 142 | + grid[i].axis("off") |
| 143 | + grid[i].set_title(title) |
0 commit comments