Skip to content

Add Custom Scheduler #694

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ arguments:
--sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm}
sampling method (default: "euler_a")
--steps STEPS number of sample steps (default: 20)
--sigmas SIGMAS Custom sigma values for the sampler, comma-separated list (e.g., "14.61,7.8,3.5,0.0")
--rng {std_default, cuda} RNG (default: cuda)
-s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)
-b, --batch-count COUNT number of images to generate
Expand Down
4 changes: 4 additions & 0 deletions denoiser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,10 @@ static void sample_k_diffusion(sample_method_t method,

for (int i = 0; i < steps; i++) {
float sigma = sigmas[i];
float sigma_next = sigmas[i+1]; // For logging

// Log the sigma values for the current step
LOG_INFO("Step %d/%zu: sigma_current = %.4f, sigma_next = %.4f", i + 1, steps, sigma, sigma_next);

// denoise
ggml_tensor* denoised = model(x, sigma, i + 1);
Expand Down
75 changes: 70 additions & 5 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
#include <random>
#include <string>
#include <vector>

#include <sstream>
#include <iomanip>
// #include "preprocessing.hpp"
#include "flux.hpp"
#include "stable-diffusion.h"
Expand Down Expand Up @@ -129,6 +130,7 @@ struct SDParams {
float slg_scale = 0.f;
float skip_layer_start = 0.01f;
float skip_layer_end = 0.2f;
std::vector<float> custom_sigmas;
};

void print_params(SDParams params) {
Expand Down Expand Up @@ -175,6 +177,13 @@ void print_params(SDParams params) {
printf(" strength(img2img): %.2f\n", params.strength);
printf(" rng: %s\n", rng_type_to_str[params.rng_type]);
printf(" seed: %ld\n", params.seed);
if (!params.custom_sigmas.empty()) {
printf(" custom_sigmas: [");
for (size_t i = 0; i < params.custom_sigmas.size(); ++i) {
printf("%.4f%s", params.custom_sigmas[i], i == params.custom_sigmas.size() - 1 ? "" : ", ");
}
printf("]\n");
}
printf(" batch_count: %d\n", params.batch_count);
printf(" vae_tiling: %s\n", params.vae_tiling ? "true" : "false");
printf(" upscale_repeats: %d\n", params.upscale_repeats);
Expand Down Expand Up @@ -231,8 +240,12 @@ void print_usage(int argc, const char* argv[]) {
printf(" --steps STEPS number of sample steps (default: 20)\n");
printf(" --rng {std_default, cuda} RNG (default: cuda)\n");
printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n");
printf(" --sigmas SIGMAS Custom sigma values for the sampler, comma-separated (e.g., \"14.61,7.8,3.5,0.0\").\n");
printf(" Overrides --schedule. Number of provided sigmas can be less than steps;\n");
printf(" it will be padded with zeros. The last sigma is always forced to 0.\n");
printf(" -b, --batch-count COUNT number of images to generate\n");
printf(" --schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete)\n");
printf(" --schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete).\n");
printf(" Ignored if --sigmas is used.\n");
printf(" --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n");
printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n");
printf(" --vae-tiling process vae in tiles to reduce memory usage\n");
Expand Down Expand Up @@ -629,6 +642,44 @@ void parse_args(int argc, const char** argv, SDParams& params) {
break;
}
params.skip_layer_end = std::stof(argv[i]);
} else if (arg == "--sigmas") {
if (++i >= argc) {
invalid_arg = true;
break;
}
std::string sigmas_str = argv[i];
if (!sigmas_str.empty() && sigmas_str.front() == '[') {
sigmas_str.erase(0, 1);
}
if (!sigmas_str.empty() && sigmas_str.back() == ']') {
sigmas_str.pop_back();
}

std::stringstream ss(sigmas_str);
std::string item;
while(std::getline(ss, item, ',')) {
item.erase(0, item.find_first_not_of(" \t\n\r\f\v"));
item.erase(item.find_last_not_of(" \t\n\r\f\v") + 1);
if (!item.empty()) {
try {
params.custom_sigmas.push_back(std::stof(item));
} catch (const std::invalid_argument& e) {
fprintf(stderr, "error: invalid float value '%s' in --sigmas\n", item.c_str());
invalid_arg = true;
break;
} catch (const std::out_of_range& e) {
fprintf(stderr, "error: float value '%s' out of range in --sigmas\n", item.c_str());
invalid_arg = true;
break;
}
}
}
if (invalid_arg) break;
if (params.custom_sigmas.empty() && !sigmas_str.empty()) {
fprintf(stderr, "error: could not parse any sigma values from '%s'\n", argv[i]);
invalid_arg = true;
break;
}
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
print_usage(argc, argv);
Expand Down Expand Up @@ -736,8 +787,16 @@ std::string get_image_params(SDParams params, int64_t seed) {
parameter_string += "Model: " + sd_basename(params.model_path) + ", ";
parameter_string += "RNG: " + std::string(rng_type_to_str[params.rng_type]) + ", ";
parameter_string += "Sampler: " + std::string(sample_method_str[params.sample_method]);
if (params.schedule == KARRAS) {
parameter_string += " karras";
if (!params.custom_sigmas.empty()) {
parameter_string += ", Custom Sigmas: [";
for (size_t i = 0; i < params.custom_sigmas.size(); ++i) {
std::ostringstream oss;
oss << std::fixed << std::setprecision(4) << params.custom_sigmas[i];
parameter_string += oss.str() + (i == params.custom_sigmas.size() - 1 ? "" : ", ");
}
parameter_string += "]";
} else if (params.schedule != DEFAULT) { // Only show schedule if not using custom sigmas
parameter_string += " " + std::string(schedule_str[params.schedule]);
}
parameter_string += ", ";
parameter_string += "Version: stable-diffusion.cpp";
Expand Down Expand Up @@ -963,6 +1022,8 @@ int main(int argc, const char* argv[]) {
params.style_ratio,
params.normalize_input,
params.input_id_images_path.c_str(),
params.custom_sigmas.empty() ? nullptr : params.custom_sigmas.data(),
(int)params.custom_sigmas.size(),
params.skip_layers.data(),
params.skip_layers.size(),
params.slg_scale,
Expand All @@ -988,7 +1049,9 @@ int main(int argc, const char* argv[]) {
params.sample_method,
params.sample_steps,
params.strength,
params.seed);
params.seed,
params.custom_sigmas.empty() ? nullptr : params.custom_sigmas.data(),
(int)params.custom_sigmas.size());
if (results == NULL) {
printf("generate failed\n");
free_sd_ctx(sd_ctx);
Expand Down Expand Up @@ -1032,6 +1095,8 @@ int main(int argc, const char* argv[]) {
params.style_ratio,
params.normalize_input,
params.input_id_images_path.c_str(),
params.custom_sigmas.empty() ? nullptr : params.custom_sigmas.data(),
(int)params.custom_sigmas.size(),
params.skip_layers.data(),
params.skip_layers.size(),
params.slg_scale,
Expand Down
95 changes: 83 additions & 12 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
float slg_scale = 0,
float skip_layer_start = 0.01,
float skip_layer_end = 0.2,
const std::vector<float>& sigmas_override = {},
ggml_tensor* masked_image = NULL) {
if (seed < 0) {
// Generally, when using the provided command line, the seed is always >0.
Expand All @@ -1227,7 +1228,12 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
// }
// std::cout << std::endl;

int sample_steps = sigmas.size() - 1;
const std::vector<float>& sigmas_to_use = sigmas_override;
int sample_steps = sigmas_to_use.size() > 1 ? sigmas_to_use.size() - 1 : 0;
if (sample_steps == 0 && !sigmas_to_use.empty()) { // e.g. if sigmas_override has only one element
LOG_WARN("Received sigmas_override with %zu elements, implying 0 steps. This might not be intended.", sigmas_to_use.size());
}


// Apply lora
auto result_pair = extract_and_remove_lora(prompt);
Expand Down Expand Up @@ -1463,7 +1469,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
guidance,
eta,
sample_method,
sigmas,
sigmas_to_use,
start_merge_step,
id_cond,
skip_layers,
Expand Down Expand Up @@ -1539,6 +1545,8 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
float style_ratio,
bool normalize_input,
const char* input_id_images_path_c_str,
const float* custom_sigmas,
int custom_sigmas_count,
int* skip_layers = NULL,
size_t skip_layers_count = 0,
float slg_scale = 0,
Expand Down Expand Up @@ -1575,7 +1583,26 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,

size_t t0 = ggml_time_ms();

std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps);
std::vector<float> sigmas_for_generation;
if (custom_sigmas_count > 0 && custom_sigmas != nullptr) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could extract this block into a method and share it between txt2img, img2img and img2vid. txt2img seems to have additional logging the other two are missing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@idostyle thanks for the review. I’ve extracted the sigma preparation into a shared helper, you're right, it's better for maintainability

LOG_INFO("Using custom sigmas provided by user.");
sigmas_for_generation.assign(custom_sigmas, custom_sigmas + custom_sigmas_count);
size_t target_len = static_cast<size_t>(sample_steps) + 1;
if (sigmas_for_generation.size() < target_len) {
sigmas_for_generation.resize(target_len, 0.0f);
} else if (sigmas_for_generation.size() > target_len) {
sigmas_for_generation.resize(target_len);
}
if (!sigmas_for_generation.empty()) {
sigmas_for_generation.back() = 0.0f; // Ensure the last sigma is 0
}
if (sd_ctx->sd->denoiser->schedule->version == DEFAULT && custom_sigmas_count > 0) {
LOG_INFO("Custom sigmas are used, --schedule option is ignored.");
}
} else {
sigmas_for_generation = sd_ctx->sd->denoiser->get_sigmas(sample_steps);
}


int C = 4;
if (sd_version_is_sd3(sd_ctx->sd->version)) {
Expand Down Expand Up @@ -1610,7 +1637,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
width,
height,
sample_method,
sigmas,
sigmas_for_generation,
seed,
batch_count,
control_cond,
Expand All @@ -1621,7 +1648,9 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
skip_layers_vec,
slg_scale,
skip_layer_start,
skip_layer_end);
skip_layer_end,
sigmas_for_generation,
nullptr /* masked_image for txt2img is null */);

size_t t1 = ggml_time_ms();

Expand Down Expand Up @@ -1651,6 +1680,8 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
float style_ratio,
bool normalize_input,
const char* input_id_images_path_c_str,
const float* custom_sigmas,
int custom_sigmas_count,
int* skip_layers = NULL,
size_t skip_layers_count = 0,
float slg_scale = 0,
Expand Down Expand Up @@ -1770,13 +1801,35 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
size_t t1 = ggml_time_ms();
LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);

std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps);
std::vector<float> base_sigmas;
if (custom_sigmas_count > 0 && custom_sigmas != nullptr) {
LOG_INFO("Using custom sigmas provided by user for img2img base schedule.");
base_sigmas.assign(custom_sigmas, custom_sigmas + custom_sigmas_count);
size_t target_len = static_cast<size_t>(sample_steps) + 1;
if (base_sigmas.size() < target_len) {
base_sigmas.resize(target_len, 0.0f);
} else if (base_sigmas.size() > target_len) {
base_sigmas.resize(target_len);
}
if (!base_sigmas.empty()) {
base_sigmas.back() = 0.0f; // Ensure the last sigma is 0
}
} else {
base_sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps);
}

size_t t_enc = static_cast<size_t>(sample_steps * strength);
if (t_enc == sample_steps)
t_enc--;
LOG_INFO("target t_enc is %zu steps", t_enc);
std::vector<float> sigma_sched;
sigma_sched.assign(sigmas.begin() + sample_steps - t_enc - 1, sigmas.end());
std::vector<float> sigmas_for_generation;
if (sample_steps - t_enc -1 < base_sigmas.size()) { // Check bounds
sigmas_for_generation.assign(base_sigmas.begin() + sample_steps - t_enc - 1, base_sigmas.end());
} else {
LOG_WARN("Cannot create sub-schedule for img2img due to strength/steps/custom_sigmas combination. Using full base_sigmas.");
sigmas_for_generation = base_sigmas;
}


sd_image_t* result_images = generate_image(sd_ctx,
work_ctx,
Expand All @@ -1790,7 +1843,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
width,
height,
sample_method,
sigma_sched,
sigmas_for_generation,
seed,
batch_count,
control_cond,
Expand All @@ -1802,6 +1855,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
slg_scale,
skip_layer_start,
skip_layer_end,
sigmas_for_generation,
masked_image);

size_t t2 = ggml_time_ms();
Expand All @@ -1824,14 +1878,31 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx,
enum sample_method_t sample_method,
int sample_steps,
float strength,
int64_t seed) {
int64_t seed,
const float* custom_sigmas,
int custom_sigmas_count) {
if (sd_ctx == NULL) {
return NULL;
}

LOG_INFO("img2vid %dx%d", width, height);

std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps);
std::vector<float> sigmas_for_generation;
if (custom_sigmas_count > 0 && custom_sigmas != nullptr) {
LOG_INFO("Using custom sigmas provided by user for img2vid.");
sigmas_for_generation.assign(custom_sigmas, custom_sigmas + custom_sigmas_count);
size_t target_len = static_cast<size_t>(sample_steps) + 1;
if (sigmas_for_generation.size() < target_len) {
sigmas_for_generation.resize(target_len, 0.0f);
} else if (sigmas_for_generation.size() > target_len) {
sigmas_for_generation.resize(target_len);
}
if (!sigmas_for_generation.empty()) {
sigmas_for_generation.back() = 0.0f; // Ensure the last sigma is 0
}
} else {
sigmas_for_generation = sd_ctx->sd->denoiser->get_sigmas(sample_steps);
}

struct ggml_init_params params;
params.mem_size = static_cast<size_t>(10 * 1024) * 1024; // 10 MB
Expand Down Expand Up @@ -1902,7 +1973,7 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx,
0.f,
0.f,
sample_method,
sigmas,
sigmas_for_generation,
-1,
SDCondition(NULL, NULL, NULL));

Expand Down
8 changes: 7 additions & 1 deletion stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx,
float style_strength,
bool normalize_input,
const char* input_id_images_path,
const float* custom_sigmas,
int custom_sigmas_count,
int* skip_layers,
size_t skip_layers_count,
float slg_scale,
Expand Down Expand Up @@ -199,6 +201,8 @@ SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx,
float style_strength,
bool normalize_input,
const char* input_id_images_path,
const float* custom_sigmas,
int custom_sigmas_count,
int* skip_layers,
size_t skip_layers_count,
float slg_scale,
Expand All @@ -218,7 +222,9 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx,
enum sample_method_t sample_method,
int sample_steps,
float strength,
int64_t seed);
int64_t seed,
const float* custom_sigmas,
int custom_sigmas_count);

typedef struct upscaler_ctx_t upscaler_ctx_t;

Expand Down