sd: add eta support (#2164)

This commit is contained in:
Wagner Bruna 2026-04-25 08:04:13 -03:00 committed by GitHub
parent 18a3bedf63
commit c04832bb2b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 13 additions and 0 deletions

View file

@ -225,6 +225,7 @@ struct sd_generation_inputs
const int seed = 0;
const char * sample_method = nullptr;
const char * scheduler = nullptr;
const float eta = -1.0f;
const int clip_skip = -1;
const int vid_req_frames = 1;
const int video_output_type = 0; //0=gif, 1=avi, 2=both

View file

@ -389,6 +389,7 @@ class sd_generation_inputs(ctypes.Structure):
("seed", ctypes.c_int),
("sample_method", ctypes.c_char_p),
("scheduler", ctypes.c_char_p),
("eta", ctypes.c_float),
("clip_skip", ctypes.c_int),
("vid_req_frames", ctypes.c_int),
("video_output_type", ctypes.c_int),
@ -2645,6 +2646,7 @@ def sd_generate(genparams):
sample_method = (genparams.get("sampler_name") or "default")
scheduler = (genparams.get("scheduler") or "default").lower()
clip_skip = tryparseint(genparams.get("clip_skip", -1),-1)
eta = tryparsefloat(genparams.get("eta", None), None)
vid_req_frames = tryparseint(genparams.get("frames", 1),1)
vid_req_frames = 1 if (not vid_req_frames or vid_req_frames < 1) else vid_req_frames
video_output_type = genparams.get("video_output_type", 0)
@ -2697,6 +2699,7 @@ def sd_generate(genparams):
inputs.seed = ((seed + 2**31) % 2**32) - 2**31
inputs.sample_method = sd_sampler_canonical_name(sample_method).encode("UTF-8")
inputs.scheduler = scheduler.encode("UTF-8")
inputs.eta = -1.0 if eta is None else eta
inputs.clip_skip = clip_skip
inputs.vid_req_frames = vid_req_frames
inputs.video_output_type = video_output_type

View file

@ -133,6 +133,7 @@ struct SDParams {
float distilled_guidance = -1.0f;
float shifted_timestep = 0;
float flow_shift = -1.0f;
float eta = -1.0f;
float strength = 0.75f;
int64_t seed = 42;
bool clip_on_cpu = false;
@ -600,6 +601,8 @@ static std::string get_image_params(const sd_img_gen_params_t & params, const st
<< " | Size: " << params.width << "x" << params.height
<< " | Sampler: " << sd_sample_method_name(params.sample_params.sample_method)
<< get_scheduler_name(params.sample_params.scheduler, true);
if (params.sample_params.eta != -1.0f)
ss << "| Eta: " << params.sample_params.eta;
if (params.sample_params.shifted_timestep != 0)
ss << "| Timestep Shift: " << params.sample_params.shifted_timestep;
if (params.sample_params.flow_shift > 0.f && params.sample_params.flow_shift != INFINITY)
@ -978,6 +981,7 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
sd_params->sample_steps = inputs.sample_steps;
sd_params->shifted_timestep = inputs.shifted_timestep;
sd_params->flow_shift = inputs.flow_shift;
sd_params->eta = inputs.eta;
sd_params->seed = inputs.seed;
sd_params->width = inputs.width;
sd_params->height = inputs.height;
@ -1212,6 +1216,9 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
params.sample_params.scheduler = sd_params->scheduler;
params.sample_params.sample_steps = sd_params->sample_steps;
params.sample_params.shifted_timestep = sd_params->shifted_timestep;
if (sd_params->eta >= 0.f && sd_params->eta <= 1.f) {
params.sample_params.eta = sd_params->eta;
}
if (sd_params->flow_shift > 0.f && sd_params->flow_shift != INFINITY) {
params.sample_params.flow_shift = sd_params->flow_shift;
}
@ -1418,6 +1425,8 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
jsoninfo["extra_generation_params"] = nlohmann::json::object();
if (params.sample_params.scheduler != scheduler_t::SCHEDULER_COUNT)
jsoninfo["extra_generation_params"]["Schedule type"] = get_scheduler_name(params.sample_params.scheduler);
if (params.sample_params.eta >= 0 && params.sample_params.eta <= 1)
jsoninfo["eta"] = params.sample_params.eta;
if (is_img2img)
jsoninfo["denoising_strength"] = params.strength;
if (sd_params->model_path.empty())