|
| 1 | +from streamlit_helpers import * |
| 2 | +from st_keyup import st_keyup |
| 3 | +from sgm.modules.diffusionmodules.sampling import EulerAncestralSampler |
| 4 | + |
| 5 | +VERSION2SPECS = { |
| 6 | + "SDXL-Turbo": { |
| 7 | + "H": 512, |
| 8 | + "W": 512, |
| 9 | + "C": 4, |
| 10 | + "f": 8, |
| 11 | + "is_legacy": False, |
| 12 | + "config": "configs/inference/sd_xl_base.yaml", |
| 13 | + "ckpt": "checkpoints/sd_xl_turbo_1.0.safetensors", |
| 14 | + }, |
| 15 | +} |
| 16 | + |
| 17 | + |
| 18 | +class SubstepSampler(EulerAncestralSampler): |
| 19 | + def __init__(self, n_sample_steps=1, *args, **kwargs): |
| 20 | + super().__init__(*args, **kwargs) |
| 21 | + self.n_sample_steps = n_sample_steps |
| 22 | + self.steps_subset = [0, 100, 200, 300, 1000] |
| 23 | + |
| 24 | + def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): |
| 25 | + sigmas = self.discretization( |
| 26 | + self.num_steps if num_steps is None else num_steps, device=self.device |
| 27 | + ) |
| 28 | + sigmas = sigmas[ |
| 29 | + self.steps_subset[: self.n_sample_steps] + self.steps_subset[-1:] |
| 30 | + ] |
| 31 | + uc = cond |
| 32 | + x *= torch.sqrt(1.0 + sigmas[0] ** 2.0) |
| 33 | + num_sigmas = len(sigmas) |
| 34 | + s_in = x.new_ones([x.shape[0]]) |
| 35 | + return x, s_in, sigmas, num_sigmas, cond, uc |
| 36 | + |
| 37 | + |
| 38 | +def seeded_randn(shape, seed): |
| 39 | + randn = np.random.RandomState(seed).randn(*shape) |
| 40 | + randn = torch.from_numpy(randn).to(device="cuda", dtype=torch.float32) |
| 41 | + return randn |
| 42 | + |
| 43 | + |
| 44 | +class SeededNoise: |
| 45 | + def __init__(self, seed): |
| 46 | + self.seed = seed |
| 47 | + |
| 48 | + def __call__(self, x): |
| 49 | + self.seed = self.seed + 1 |
| 50 | + return seeded_randn(x.shape, self.seed) |
| 51 | + |
| 52 | + |
| 53 | +def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None): |
| 54 | + value_dict = {} |
| 55 | + for key in keys: |
| 56 | + if key == "txt": |
| 57 | + value_dict["prompt"] = prompt |
| 58 | + value_dict["negative_prompt"] = "" |
| 59 | + |
| 60 | + if key == "original_size_as_tuple": |
| 61 | + orig_width = init_dict["orig_width"] |
| 62 | + orig_height = init_dict["orig_height"] |
| 63 | + |
| 64 | + value_dict["orig_width"] = orig_width |
| 65 | + value_dict["orig_height"] = orig_height |
| 66 | + |
| 67 | + if key == "crop_coords_top_left": |
| 68 | + crop_coord_top = 0 |
| 69 | + crop_coord_left = 0 |
| 70 | + |
| 71 | + value_dict["crop_coords_top"] = crop_coord_top |
| 72 | + value_dict["crop_coords_left"] = crop_coord_left |
| 73 | + |
| 74 | + if key == "aesthetic_score": |
| 75 | + value_dict["aesthetic_score"] = 6.0 |
| 76 | + value_dict["negative_aesthetic_score"] = 2.5 |
| 77 | + |
| 78 | + if key == "target_size_as_tuple": |
| 79 | + value_dict["target_width"] = init_dict["target_width"] |
| 80 | + value_dict["target_height"] = init_dict["target_height"] |
| 81 | + |
| 82 | + return value_dict |
| 83 | + |
| 84 | + |
| 85 | +def sample( |
| 86 | + model, |
| 87 | + sampler, |
| 88 | + prompt="A lush garden with oversized flowers and vibrant colors, inhabited by miniature animals.", |
| 89 | + H=1024, |
| 90 | + W=1024, |
| 91 | + seed=0, |
| 92 | + filter=None, |
| 93 | +): |
| 94 | + F = 8 |
| 95 | + C = 4 |
| 96 | + shape = (1, C, H // F, W // F) |
| 97 | + |
| 98 | + value_dict = init_embedder_options( |
| 99 | + keys=get_unique_embedder_keys_from_conditioner(model.conditioner), |
| 100 | + init_dict={ |
| 101 | + "orig_width": W, |
| 102 | + "orig_height": H, |
| 103 | + "target_width": W, |
| 104 | + "target_height": H, |
| 105 | + }, |
| 106 | + prompt=prompt, |
| 107 | + ) |
| 108 | + |
| 109 | + if seed is None: |
| 110 | + seed = torch.seed() |
| 111 | + precision_scope = autocast |
| 112 | + with torch.no_grad(): |
| 113 | + with precision_scope("cuda"): |
| 114 | + batch, batch_uc = get_batch( |
| 115 | + get_unique_embedder_keys_from_conditioner(model.conditioner), |
| 116 | + value_dict, |
| 117 | + [1], |
| 118 | + ) |
| 119 | + c = model.conditioner(batch) |
| 120 | + uc = None |
| 121 | + randn = seeded_randn(shape, seed) |
| 122 | + |
| 123 | + def denoiser(input, sigma, c): |
| 124 | + return model.denoiser( |
| 125 | + model.model, |
| 126 | + input, |
| 127 | + sigma, |
| 128 | + c, |
| 129 | + ) |
| 130 | + |
| 131 | + samples_z = sampler(denoiser, randn, cond=c, uc=uc) |
| 132 | + samples_x = model.decode_first_stage(samples_z) |
| 133 | + samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) |
| 134 | + if filter is not None: |
| 135 | + samples = filter(samples) |
| 136 | + samples = ( |
| 137 | + (255 * samples) |
| 138 | + .to(dtype=torch.uint8) |
| 139 | + .permute(0, 2, 3, 1) |
| 140 | + .detach() |
| 141 | + .cpu() |
| 142 | + .numpy() |
| 143 | + ) |
| 144 | + return samples |
| 145 | + |
| 146 | + |
| 147 | +def v_spacer(height) -> None: |
| 148 | + for _ in range(height): |
| 149 | + st.write("\n") |
| 150 | + |
| 151 | + |
| 152 | +if __name__ == "__main__": |
| 153 | + st.title("Turbo") |
| 154 | + |
| 155 | + head_cols = st.columns([1, 1, 1]) |
| 156 | + with head_cols[0]: |
| 157 | + version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0) |
| 158 | + version_dict = VERSION2SPECS[version] |
| 159 | + |
| 160 | + with head_cols[1]: |
| 161 | + v_spacer(2) |
| 162 | + if st.checkbox("Load Model"): |
| 163 | + mode = "txt2img" |
| 164 | + else: |
| 165 | + mode = "skip" |
| 166 | + |
| 167 | + if mode != "skip": |
| 168 | + state = init_st(version_dict, load_filter=True) |
| 169 | + if state["msg"]: |
| 170 | + st.info(state["msg"]) |
| 171 | + model = state["model"] |
| 172 | + load_model(model) |
| 173 | + |
| 174 | + # seed |
| 175 | + if "seed" not in st.session_state: |
| 176 | + st.session_state.seed = 0 |
| 177 | + |
| 178 | + def increment_counter(): |
| 179 | + st.session_state.seed += 1 |
| 180 | + |
| 181 | + def decrement_counter(): |
| 182 | + if st.session_state.seed > 0: |
| 183 | + st.session_state.seed -= 1 |
| 184 | + |
| 185 | + with head_cols[2]: |
| 186 | + n_steps = st.number_input(label="number of steps", min_value=1, max_value=4) |
| 187 | + |
| 188 | + sampler = SubstepSampler( |
| 189 | + n_sample_steps=1, |
| 190 | + num_steps=1000, |
| 191 | + eta=1.0, |
| 192 | + discretization_config=dict( |
| 193 | + target="sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization" |
| 194 | + ), |
| 195 | + ) |
| 196 | + sampler.n_sample_steps = n_steps |
| 197 | + default_prompt = "A cinematic shot of a baby racoon wearing an intricate italian priest robe." |
| 198 | + prompt = st_keyup("Enter a value", value=default_prompt, debounce=300, key="interactive_text") |
| 199 | + |
| 200 | + cols = st.columns([1, 5, 1]) |
| 201 | + if mode != "skip": |
| 202 | + with cols[0]: |
| 203 | + v_spacer(14) |
| 204 | + st.button("↩", on_click=decrement_counter) |
| 205 | + with cols[2]: |
| 206 | + v_spacer(14) |
| 207 | + st.button("↪", on_click=increment_counter) |
| 208 | + |
| 209 | + sampler.noise_sampler = SeededNoise(seed=st.session_state.seed) |
| 210 | + out = sample( |
| 211 | + model, sampler, H=512, W=512, seed=st.session_state.seed, prompt=prompt, filter=state.get("filter") |
| 212 | + ) |
| 213 | + with cols[1]: |
| 214 | + st.image(out[0]) |
0 commit comments