mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-04-26 10:41:25 +00:00
sd: add eta support (#2164)
This commit is contained in:
parent
18a3bedf63
commit
c04832bb2b
3 changed files with 13 additions and 0 deletions
1
expose.h
1
expose.h
|
|
@ -225,6 +225,7 @@ struct sd_generation_inputs
|
|||
const int seed = 0;
|
||||
const char * sample_method = nullptr;
|
||||
const char * scheduler = nullptr;
|
||||
const float eta = -1.0f;
|
||||
const int clip_skip = -1;
|
||||
const int vid_req_frames = 1;
|
||||
const int video_output_type = 0; //0=gif, 1=avi, 2=both
|
||||
|
|
|
|||
|
|
@ -389,6 +389,7 @@ class sd_generation_inputs(ctypes.Structure):
|
|||
("seed", ctypes.c_int),
|
||||
("sample_method", ctypes.c_char_p),
|
||||
("scheduler", ctypes.c_char_p),
|
||||
("eta", ctypes.c_float),
|
||||
("clip_skip", ctypes.c_int),
|
||||
("vid_req_frames", ctypes.c_int),
|
||||
("video_output_type", ctypes.c_int),
|
||||
|
|
@ -2645,6 +2646,7 @@ def sd_generate(genparams):
|
|||
sample_method = (genparams.get("sampler_name") or "default")
|
||||
scheduler = (genparams.get("scheduler") or "default").lower()
|
||||
clip_skip = tryparseint(genparams.get("clip_skip", -1),-1)
|
||||
eta = tryparsefloat(genparams.get("eta", None), None)
|
||||
vid_req_frames = tryparseint(genparams.get("frames", 1),1)
|
||||
vid_req_frames = 1 if (not vid_req_frames or vid_req_frames < 1) else vid_req_frames
|
||||
video_output_type = genparams.get("video_output_type", 0)
|
||||
|
|
@ -2697,6 +2699,7 @@ def sd_generate(genparams):
|
|||
inputs.seed = ((seed + 2**31) % 2**32) - 2**31
|
||||
inputs.sample_method = sd_sampler_canonical_name(sample_method).encode("UTF-8")
|
||||
inputs.scheduler = scheduler.encode("UTF-8")
|
||||
inputs.eta = -1.0 if eta is None else eta
|
||||
inputs.clip_skip = clip_skip
|
||||
inputs.vid_req_frames = vid_req_frames
|
||||
inputs.video_output_type = video_output_type
|
||||
|
|
|
|||
|
|
@ -133,6 +133,7 @@ struct SDParams {
|
|||
float distilled_guidance = -1.0f;
|
||||
float shifted_timestep = 0;
|
||||
float flow_shift = -1.0f;
|
||||
float eta = -1.0f;
|
||||
float strength = 0.75f;
|
||||
int64_t seed = 42;
|
||||
bool clip_on_cpu = false;
|
||||
|
|
@ -600,6 +601,8 @@ static std::string get_image_params(const sd_img_gen_params_t & params, const st
|
|||
<< " | Size: " << params.width << "x" << params.height
|
||||
<< " | Sampler: " << sd_sample_method_name(params.sample_params.sample_method)
|
||||
<< get_scheduler_name(params.sample_params.scheduler, true);
|
||||
if (params.sample_params.eta != -1.0f)
|
||||
ss << "| Eta: " << params.sample_params.eta;
|
||||
if (params.sample_params.shifted_timestep != 0)
|
||||
ss << "| Timestep Shift: " << params.sample_params.shifted_timestep;
|
||||
if (params.sample_params.flow_shift > 0.f && params.sample_params.flow_shift != INFINITY)
|
||||
|
|
@ -978,6 +981,7 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
|
|||
sd_params->sample_steps = inputs.sample_steps;
|
||||
sd_params->shifted_timestep = inputs.shifted_timestep;
|
||||
sd_params->flow_shift = inputs.flow_shift;
|
||||
sd_params->eta = inputs.eta;
|
||||
sd_params->seed = inputs.seed;
|
||||
sd_params->width = inputs.width;
|
||||
sd_params->height = inputs.height;
|
||||
|
|
@ -1212,6 +1216,9 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
|
|||
params.sample_params.scheduler = sd_params->scheduler;
|
||||
params.sample_params.sample_steps = sd_params->sample_steps;
|
||||
params.sample_params.shifted_timestep = sd_params->shifted_timestep;
|
||||
if (sd_params->eta >= 0.f && sd_params->eta <= 1.f) {
|
||||
params.sample_params.eta = sd_params->eta;
|
||||
}
|
||||
if (sd_params->flow_shift > 0.f && sd_params->flow_shift != INFINITY) {
|
||||
params.sample_params.flow_shift = sd_params->flow_shift;
|
||||
}
|
||||
|
|
@ -1418,6 +1425,8 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
|
|||
jsoninfo["extra_generation_params"] = nlohmann::json::object();
|
||||
if (params.sample_params.scheduler != scheduler_t::SCHEDULER_COUNT)
|
||||
jsoninfo["extra_generation_params"]["Schedule type"] = get_scheduler_name(params.sample_params.scheduler);
|
||||
if (params.sample_params.eta >= 0 && params.sample_params.eta <= 1)
|
||||
jsoninfo["eta"] = params.sample_params.eta;
|
||||
if (is_img2img)
|
||||
jsoninfo["denoising_strength"] = params.strength;
|
||||
if (sd_params->model_path.empty())
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue