@@ -347,35 +347,37 @@ struct SmoothStepScheduler : SigmaScheduler {
347347 }
348348};
349349
350- // Implementation adapted from https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15608
350+ /*
351+ * KL Optimal:
352+ * Original work from https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15608.
353+ * Implemented using https://github.com/comfyanonymous/ComfyUI/pull/6206 as a reference.
354+ */
351355struct KLOptimalScheduler : SigmaScheduler {
352356 std::vector<float > get_sigmas (uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
353357 std::vector<float > sigmas;
354358
355359 if (n == 0 ) {
356360 return sigmas;
357361 }
362+
358363 if (n == 1 ) {
359364 sigmas.push_back (sigma_max);
360365 sigmas.push_back (0 .0f );
361366 return sigmas;
362367 }
363368
369+ sigmas.reserve (n + 1 );
370+
364371 float alpha_min = std::atan (sigma_min);
365372 float alpha_max = std::atan (sigma_max);
366373
367374 for (uint32_t i = 0 ; i < n; ++i) {
368- // t goes from 0.0 to 1.0
369- float t = static_cast <float >(i) / static_cast <float >(n - 1 );
370375
371- // Interpolate in the angle domain
376+ float t = static_cast < float >(i) / static_cast < float >(n - 1 );
372377 float angle = t * alpha_min + (1 .0f - t) * alpha_max;
373-
374- // Convert back to sigma
375378 sigmas.push_back (std::tan (angle));
376379 }
377380
378- // Append the final zero to sigma
379381 sigmas.push_back (0 .0f );
380382
381383 return sigmas;
@@ -459,6 +461,10 @@ struct CompVisDenoiser : public Denoiser {
459461 }
460462
461463 float sigma_to_t (float sigma) override {
464+ if (sigma <= 1e-6f ) {
465+ return (float )(TIMESTEPS - 1 );
466+ }
467+
462468 float log_sigma = std::log (sigma);
463469 std::vector<float > dists;
464470 dists.reserve (TIMESTEPS);
@@ -734,8 +740,12 @@ static bool sample_k_diffusion(sample_method_t method,
734740 float * vec_x = (float *)x->data ;
735741 float * vec_denoised = (float *)denoised->data ;
736742
737- for (int i = 0 ; i < ggml_nelements (d); i++) {
738- vec_d[i] = (vec_x[i] - vec_denoised[i]) / sigma;
743+ if (sigma < 1e-6f ) {
744+ ggml_set_f32 (d, 0 .0f );
745+ } else {
746+ for (int i = 0 ; i < ggml_nelements (d); i++) {
747+ vec_d[i] = (vec_x[i] - vec_denoised[i]) / sigma;
748+ }
739749 }
740750 }
741751
0 commit comments