diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index b3ae569e..68f90dd4 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -905,7 +905,8 @@ int main(int argc, const char* argv[]) { input_image_buffer}; sd_image_t* control_image = NULL; - if (params.control_net_path.size() > 0 && params.control_image_path.size() > 0) { + if (params.control_image_path.size() > 0) { + printf("load image from '%s'\n", params.control_image_path.c_str()); int c = 0; control_image_buffer = stbi_load(params.control_image_path.c_str(), ¶ms.width, ¶ms.height, &c, 3); if (control_image_buffer == NULL) { diff --git a/flux.hpp b/flux.hpp index 11045918..81b1c59a 100644 --- a/flux.hpp +++ b/flux.hpp @@ -984,7 +984,8 @@ namespace Flux { struct ggml_tensor* pe, struct ggml_tensor* mod_index_arange = NULL, std::vector ref_latents = {}, - std::vector skip_layers = {}) { + std::vector skip_layers = {}, + SDVersion version = VERSION_FLUX) { // Forward pass of DiT. // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) // timestep: (N,) tensor of diffusion timesteps @@ -1007,7 +1008,8 @@ namespace Flux { auto img = process_img(ctx, x); uint64_t img_tokens = img->ne[1]; - if (c_concat != NULL) { + if (version == VERSION_FLUX_FILL) { + GGML_ASSERT(c_concat != NULL); ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0); ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C); @@ -1015,6 +1017,29 @@ namespace Flux { mask = process_img(ctx, mask); img = ggml_concat(ctx, img, ggml_concat(ctx, masked, mask, 0), 0); + } else if (version == VERSION_FLEX_2) { + GGML_ASSERT(c_concat != NULL); + ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0); + ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C); + ggml_tensor* control = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1)); + + masked = ggml_pad(ctx, masked, pad_w, pad_h, 0, 0); + mask = ggml_pad(ctx, mask, pad_w, pad_h, 0, 0); + control = ggml_pad(ctx, control, pad_w, pad_h, 0, 0); + + masked = patchify(ctx, masked, patch_size); + mask = patchify(ctx, mask, patch_size); + control = patchify(ctx, control, patch_size); + + img = ggml_concat(ctx, img, ggml_concat(ctx, ggml_concat(ctx, masked, mask, 0), control, 0), 0); + } else if (version == VERSION_FLUX_CONTROLS) { + GGML_ASSERT(c_concat != NULL); + + ggml_tensor* control = ggml_pad(ctx, c_concat, pad_w, pad_h, 0, 0); + + control = patchify(ctx, control, patch_size); + + img = ggml_concat(ctx, img, control, 0); } if (ref_latents.size() > 0) { @@ -1055,13 +1080,17 @@ namespace Flux { SDVersion version = VERSION_FLUX, bool flash_attn = false, bool use_mask = false) - : GGMLRunner(backend), use_mask(use_mask) { + : GGMLRunner(backend), version(version), use_mask(use_mask) { flux_params.flash_attn = flash_attn; flux_params.guidance_embed = false; flux_params.depth = 0; flux_params.depth_single_blocks = 0; if (version == VERSION_FLUX_FILL) { flux_params.in_channels = 384; + } else if (version == VERSION_FLUX_CONTROLS) { + flux_params.in_channels = 128; + } else if (version == VERSION_FLEX_2) { + flux_params.in_channels = 196; } for (auto pair : tensor_types) { std::string tensor_name = pair.first; @@ -1171,7 +1200,8 @@ namespace Flux { pe, mod_index_arange, ref_latents, - skip_layers); + skip_layers, + version); ggml_build_forward_expand(gf, out); diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 9f6a4fef..da691f3e 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -380,18 +380,24 @@ __STATIC_INLINE__ void sd_mask_to_tensor(const uint8_t* image_data, __STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data, struct ggml_tensor* mask, - struct ggml_tensor* output) { + struct ggml_tensor* output, + float masked_value = 0.5f) { int64_t width = output->ne[0]; int64_t height = output->ne[1]; int64_t channels = output->ne[2]; + float rescale_mx = mask->ne[0]/output->ne[0]; + float rescale_my = mask->ne[1]/output->ne[1]; GGML_ASSERT(output->type == GGML_TYPE_F32); for (int ix = 0; ix < width; ix++) { for (int iy = 0; iy < height; iy++) { - float m = ggml_tensor_get_f32(mask, ix, iy); + int mx = (int)(ix * rescale_mx); + int my = (int)(iy * rescale_my); + float m = ggml_tensor_get_f32(mask, mx, my); m = round(m); // inpaint models need binary masks - ggml_tensor_set_f32(mask, m, ix, iy); + ggml_tensor_set_f32(mask, m, mx, my); for (int k = 0; k < channels; k++) { - float value = (1 - m) * (ggml_tensor_get_f32(image_data, ix, iy, k) - .5) + .5; + float value = ggml_tensor_get_f32(image_data, ix, iy, k); + value = (1 - m) * (value - masked_value) + masked_value; ggml_tensor_set_f32(output, value, ix, iy, k); } } diff --git a/model.cpp b/model.cpp index 2e40e004..022be4d2 100644 --- a/model.cpp +++ b/model.cpp @@ -1685,10 +1685,15 @@ SDVersion ModelLoader::get_sd_version() { } if (is_flux) { - is_inpaint = input_block_weight.ne[0] == 384; - if (is_inpaint) { + if (input_block_weight.ne[0] == 384) { return VERSION_FLUX_FILL; } + if (input_block_weight.ne[0] == 128) { + return VERSION_FLUX_CONTROLS; + } + if(input_block_weight.ne[0] == 196){ + return VERSION_FLEX_2; + } return VERSION_FLUX; } diff --git a/model.h b/model.h index a6266039..409258a7 100644 --- a/model.h +++ b/model.h @@ -31,11 +31,13 @@ enum SDVersion { VERSION_SD3, VERSION_FLUX, VERSION_FLUX_FILL, + VERSION_FLUX_CONTROLS, + VERSION_FLEX_2, VERSION_COUNT, }; static inline bool sd_version_is_flux(SDVersion version) { - if (version == VERSION_FLUX || version == VERSION_FLUX_FILL) { + if (version == VERSION_FLUX || version == VERSION_FLUX_FILL || version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2 ) { return true; } return false; @@ -70,7 +72,7 @@ static inline bool sd_version_is_sdxl(SDVersion version) { } static inline bool sd_version_is_inpaint(SDVersion version) { - if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL) { + if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2) { return true; } return false; @@ -87,8 +89,12 @@ static inline bool sd_version_is_unet_edit(SDVersion version) { return version == VERSION_SD1_PIX2PIX || version == VERSION_SDXL_PIX2PIX; } +static inline bool sd_version_is_control(SDVersion version) { + return version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2; +} + static bool sd_version_is_inpaint_or_unet_edit(SDVersion version) { - return sd_version_is_unet_edit(version) || sd_version_is_inpaint(version); + return sd_version_is_unet_edit(version) || sd_version_is_inpaint(version)|| sd_version_is_control(version); } enum PMVersion { diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 402585f1..28022e62 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -36,7 +36,10 @@ const char* model_version_to_str[] = { "SVD", "SD3.x", "Flux", - "Flux Fill"}; + "Flux Fill", + "Flux Control", + "Flex.2", +}; const char* sampling_methods_str[] = { "Euler A", @@ -95,7 +98,7 @@ class StableDiffusionGGML { std::shared_ptr diffusion_model; std::shared_ptr first_stage_model; std::shared_ptr tae_first_stage; - std::shared_ptr control_net; + std::shared_ptr control_net = NULL; std::shared_ptr pmid_model; std::shared_ptr pmid_lora; std::shared_ptr pmid_id_embeds; @@ -297,6 +300,11 @@ class StableDiffusionGGML { // TODO: shift_factor } + if (sd_version_is_control(version)) { + // Might need vae encode for control cond + vae_decode_only = false; + } + bool clip_on_cpu = sd_ctx_params->keep_clip_on_cpu; if (version == VERSION_SVD) { @@ -933,7 +941,7 @@ class StableDiffusionGGML { std::vector controls; - if (control_hint != NULL) { + if (control_hint != NULL && control_net != NULL) { control_net->compute(n_threads, noised_input, control_hint, timesteps, cond.c_crossattn, cond.c_vector); controls = control_net->controls; // print_ggml_tensor(controls[12]); @@ -972,7 +980,7 @@ class StableDiffusionGGML { float* negative_data = NULL; if (has_unconditioned) { // uncond - if (control_hint != NULL) { + if (control_hint != NULL && control_net != NULL) { control_net->compute(n_threads, noised_input, control_hint, timesteps, uncond.c_crossattn, uncond.c_vector); controls = control_net->controls; } @@ -1717,10 +1725,24 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, int W = width / 8; int H = height / 8; LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]); + + struct ggml_tensor* control_latent = NULL; + if (sd_version_is_control(sd_ctx->sd->version) && image_hint != NULL) { + if (!sd_ctx->sd->use_tiny_autoencoder) { + struct ggml_tensor* control_moments = sd_ctx->sd->encode_first_stage(work_ctx, image_hint); + control_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, control_moments); + } else { + control_latent = sd_ctx->sd->encode_first_stage(work_ctx, image_hint); + } + ggml_tensor_scale(control_latent, control_strength); + } + if (sd_version_is_inpaint(sd_ctx->sd->version)) { int64_t mask_channels = 1; if (sd_ctx->sd->version == VERSION_FLUX_FILL) { mask_channels = 8 * 8; // flatten the whole mask + } else if (sd_ctx->sd->version == VERSION_FLEX_2) { + mask_channels = 1 + init_latent->ne[2]; } auto empty_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1); // no mask, set the whole image as masked @@ -1734,6 +1756,11 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, for (int64_t c = init_latent->ne[2]; c < empty_latent->ne[2]; c++) { ggml_tensor_set_f32(empty_latent, 1, x, y, c); } + } else if (sd_ctx->sd->version == VERSION_FLEX_2) { + for (int64_t c = 0; c < empty_latent->ne[2]; c++) { + // 0x16,1x1,0x16 + ggml_tensor_set_f32(empty_latent, c == init_latent->ne[2], x, y, c); + } } else { ggml_tensor_set_f32(empty_latent, 1, x, y, 0); for (int64_t c = 1; c < empty_latent->ne[2]; c++) { @@ -1742,7 +1769,28 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, } } } - if (concat_latent == NULL) { + + if (sd_ctx->sd->version == VERSION_FLEX_2 && control_latent != NULL && sd_ctx->sd->control_net == NULL) { + bool no_inpaint = concat_latent == NULL; + if (no_inpaint) { + concat_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1); + } + // fill in the control image here + for (int64_t x = 0; x < control_latent->ne[0]; x++) { + for (int64_t y = 0; y < control_latent->ne[1]; y++) { + if (no_inpaint) { + for (int64_t c = 0; c < concat_latent->ne[2] - control_latent->ne[2]; c++) { + // 0x16,1x1,0x16 + ggml_tensor_set_f32(concat_latent, c == init_latent->ne[2], x, y, c); + } + } + for (int64_t c = 0; c < control_latent->ne[2]; c++) { + float v = ggml_tensor_get_f32(control_latent, x, y, c); + ggml_tensor_set_f32(concat_latent, v, x, y, concat_latent->ne[2] - control_latent->ne[2] + c); + } + } + } + } else if (concat_latent == NULL) { concat_latent = empty_latent; } cond.c_concat = concat_latent; @@ -1752,10 +1800,20 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, auto empty_latent = ggml_dup_tensor(work_ctx, init_latent); ggml_set_f32(empty_latent, 0); uncond.c_concat = empty_latent; - if (concat_latent == NULL) { - concat_latent = empty_latent; + cond.c_concat = ref_latents[0]; + if (cond.c_concat == NULL) { + cond.c_concat = empty_latent; + } + } else if (sd_version_is_control(sd_ctx->sd->version)) { + auto empty_latent = ggml_dup_tensor(work_ctx, init_latent); + ggml_set_f32(empty_latent, 0); + uncond.c_concat = empty_latent; + if (sd_ctx->sd->control_net == NULL) { + cond.c_concat = control_latent; + } + if (cond.c_concat == NULL) { + cond.c_concat = empty_latent; } - cond.c_concat = ref_latents[0]; } SDCondition img_cond; if (uncond.c_crossattn != NULL && @@ -1914,6 +1972,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g size_t t0 = ggml_time_ms(); ggml_tensor* init_latent = NULL; + ggml_tensor* init_moments = NULL; ggml_tensor* concat_latent = NULL; ggml_tensor* denoise_mask = NULL; std::vector sigmas = sd_ctx->sd->denoiser->get_sigmas(sd_img_gen_params->sample_steps); @@ -1935,19 +1994,35 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g sd_mask_to_tensor(sd_img_gen_params->mask_image.data, mask_img); sd_image_to_tensor(sd_img_gen_params->init_image.data, init_img); + if (!sd_ctx->sd->use_tiny_autoencoder) { + init_moments = sd_ctx->sd->encode_first_stage(work_ctx, init_img); + init_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, init_moments); + } else { + init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); + } + if (sd_version_is_inpaint(sd_ctx->sd->version)) { int64_t mask_channels = 1; if (sd_ctx->sd->version == VERSION_FLUX_FILL) { mask_channels = 8 * 8; // flatten the whole mask + } else if (sd_ctx->sd->version == VERSION_FLEX_2) { + mask_channels = 1 + init_latent->ne[2]; } - ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); - sd_apply_mask(init_img, mask_img, masked_img); ggml_tensor* masked_latent = NULL; - if (!sd_ctx->sd->use_tiny_autoencoder) { - ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, masked_img); - masked_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments); + if (sd_ctx->sd->version != VERSION_FLEX_2) { + // most inpaint models mask before vae + ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); + sd_apply_mask(init_img, mask_img, masked_img); + if (!sd_ctx->sd->use_tiny_autoencoder) { + ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, masked_img); + masked_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments); + } else { + masked_latent = sd_ctx->sd->encode_first_stage(work_ctx, masked_img); + } } else { - masked_latent = sd_ctx->sd->encode_first_stage(work_ctx, masked_img); + // mask after vae + masked_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1); + sd_apply_mask(init_latent, mask_img, masked_latent, 0.); } concat_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, @@ -1973,12 +2048,18 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2] + x * 8 + y); } } - } else { + } else if (sd_ctx->sd->version == VERSION_FLEX_2) { float m = ggml_tensor_get_f32(mask_img, mx, my); - ggml_tensor_set_f32(concat_latent, m, ix, iy, 0); + // masked image for (int k = 0; k < masked_latent->ne[2]; k++) { float v = ggml_tensor_get_f32(masked_latent, ix, iy, k); - ggml_tensor_set_f32(concat_latent, v, ix, iy, k + mask_channels); + ggml_tensor_set_f32(concat_latent, v, ix, iy, k); + } + // downsampled mask + ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2]); + // control (todo: support this) + for (int k = 0; k < masked_latent->ne[2]; k++) { + ggml_tensor_set_f32(concat_latent, 0, ix, iy, masked_latent->ne[2] + 1 + k); } } } @@ -1998,12 +2079,6 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g } } - if (!sd_ctx->sd->use_tiny_autoencoder) { - ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, init_img); - init_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments); - } else { - init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); - } } else { LOG_INFO("TXT2IMG"); if (sd_version_is_inpaint(sd_ctx->sd->version)) { diff --git a/vae.hpp b/vae.hpp index 4add881f..7ad0a9c3 100644 --- a/vae.hpp +++ b/vae.hpp @@ -559,6 +559,7 @@ struct AutoEncoderKL : public GGMLRunner { bool decode_graph, struct ggml_tensor** output, struct ggml_context* output_ctx = NULL) { + GGML_ASSERT(!decode_only || decode_graph); auto get_graph = [&]() -> struct ggml_cgraph* { return build_graph(z, decode_graph); };