Skip to content

Commit f7fa9c6

Browse files
committed
fix: avoid issues when sigma_min is close to 0
1 parent 3e81246 commit f7fa9c6

File tree

1 file changed

+19
-9
lines changed

1 file changed

+19
-9
lines changed

denoiser.hpp

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -347,35 +347,37 @@ struct SmoothStepScheduler : SigmaScheduler {
347347
}
348348
};
349349

350-
// Implementation adapted from https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15608
350+
/*
351+
* KL Optimal:
352+
* Original work from https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15608.
353+
* Implemented using https://github.com/comfyanonymous/ComfyUI/pull/6206 as a reference.
354+
*/
351355
struct KLOptimalScheduler : SigmaScheduler {
352356
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
353357
std::vector<float> sigmas;
354358

355359
if (n == 0) {
356360
return sigmas;
357361
}
362+
358363
if (n == 1) {
359364
sigmas.push_back(sigma_max);
360365
sigmas.push_back(0.0f);
361366
return sigmas;
362367
}
363368

369+
sigmas.reserve(n + 1);
370+
364371
float alpha_min = std::atan(sigma_min);
365372
float alpha_max = std::atan(sigma_max);
366373

367374
for (uint32_t i = 0; i < n; ++i) {
368-
// t goes from 0.0 to 1.0
369-
float t = static_cast<float>(i) / static_cast<float>(n - 1);
370375

371-
// Interpolate in the angle domain
376+
float t = static_cast<float>(i) / static_cast<float>(n - 1);
372377
float angle = t * alpha_min + (1.0f - t) * alpha_max;
373-
374-
// Convert back to sigma
375378
sigmas.push_back(std::tan(angle));
376379
}
377380

378-
// Append the final zero to sigma
379381
sigmas.push_back(0.0f);
380382

381383
return sigmas;
@@ -459,6 +461,10 @@ struct CompVisDenoiser : public Denoiser {
459461
}
460462

461463
float sigma_to_t(float sigma) override {
464+
if (sigma <= 1e-6f) {
465+
return (float)(TIMESTEPS - 1);
466+
}
467+
462468
float log_sigma = std::log(sigma);
463469
std::vector<float> dists;
464470
dists.reserve(TIMESTEPS);
@@ -734,8 +740,12 @@ static bool sample_k_diffusion(sample_method_t method,
734740
float* vec_x = (float*)x->data;
735741
float* vec_denoised = (float*)denoised->data;
736742

737-
for (int i = 0; i < ggml_nelements(d); i++) {
738-
vec_d[i] = (vec_x[i] - vec_denoised[i]) / sigma;
743+
if (sigma < 1e-6f) {
744+
ggml_set_f32(d, 0.0f);
745+
} else {
746+
for (int i = 0; i < ggml_nelements(d); i++) {
747+
vec_d[i] = (vec_x[i] - vec_denoised[i]) / sigma;
748+
}
739749
}
740750
}
741751

0 commit comments

Comments
 (0)