mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 09:34:37 +00:00
revert padding change for sd chroma
This commit is contained in:
parent
1cf7648305
commit
7de88802f9
2 changed files with 27 additions and 30 deletions
|
@ -1288,21 +1288,6 @@ struct PixArtCLIPEmbedder : public Conditioner {
|
||||||
return {t5_tokens, t5_weights, t5_mask};
|
return {t5_tokens, t5_weights, t5_mask};
|
||||||
}
|
}
|
||||||
|
|
||||||
void modify_mask_to_attend_padding(struct ggml_tensor* mask, int max_seq_length, int num_extra_padding = 8) {
|
|
||||||
float* mask_data = (float*)mask->data;
|
|
||||||
int num_pad = 0;
|
|
||||||
for (int64_t i = 0; i < max_seq_length; i++) {
|
|
||||||
if (num_pad >= num_extra_padding) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (std::isinf(mask_data[i])) {
|
|
||||||
mask_data[i] = 0;
|
|
||||||
++num_pad;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// LOG_DEBUG("PAD: %d", num_pad);
|
|
||||||
}
|
|
||||||
|
|
||||||
SDCondition get_learned_condition_common(ggml_context* work_ctx,
|
SDCondition get_learned_condition_common(ggml_context* work_ctx,
|
||||||
int n_threads,
|
int n_threads,
|
||||||
std::tuple<std::vector<int>, std::vector<float>, std::vector<float>> token_and_weights,
|
std::tuple<std::vector<int>, std::vector<float>, std::vector<float>> token_and_weights,
|
||||||
|
@ -1389,21 +1374,6 @@ struct PixArtCLIPEmbedder : public Conditioner {
|
||||||
hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, 256);
|
hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, 256);
|
||||||
ggml_set_f32(hidden_states, 0.f);
|
ggml_set_f32(hidden_states, 0.f);
|
||||||
}
|
}
|
||||||
|
|
||||||
int mask_pad = 1;
|
|
||||||
const char* SD_CHROMA_MASK_PAD_OVERRIDE = getenv("SD_CHROMA_MASK_PAD_OVERRIDE");
|
|
||||||
if (SD_CHROMA_MASK_PAD_OVERRIDE != nullptr) {
|
|
||||||
std::string mask_pad_str = SD_CHROMA_MASK_PAD_OVERRIDE;
|
|
||||||
try {
|
|
||||||
mask_pad = std::stoi(mask_pad_str);
|
|
||||||
} catch (const std::invalid_argument&) {
|
|
||||||
LOG_WARN("SD_CHROMA_MASK_PAD_OVERRIDE environment variable is not a valid integer (%s). Falling back to default (%d)", SD_CHROMA_MASK_PAD_OVERRIDE, mask_pad);
|
|
||||||
} catch (const std::out_of_range&) {
|
|
||||||
LOG_WARN("SD_CHROMA_MASK_PAD_OVERRIDE environment variable value is out of range for `int` type (%s). Falling back to default (%d)", SD_CHROMA_MASK_PAD_OVERRIDE, mask_pad);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
modify_mask_to_attend_padding(t5_attn_mask, ggml_nelements(t5_attn_mask), mask_pad);
|
|
||||||
|
|
||||||
return SDCondition(hidden_states, t5_attn_mask, NULL);
|
return SDCondition(hidden_states, t5_attn_mask, NULL);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -709,6 +709,20 @@ namespace Flux {
|
||||||
return ids;
|
return ids;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void chroma_modify_mask_to_attend_padding(struct ggml_tensor* mask, int max_seq_length, int num_extra_padding = 8) {
|
||||||
|
float* mask_data = (float*)mask->data;
|
||||||
|
int num_pad = 0;
|
||||||
|
for (int64_t i = 0; i < max_seq_length; i++) {
|
||||||
|
if (num_pad >= num_extra_padding) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (std::isinf(mask_data[i])) {
|
||||||
|
mask_data[i] = 0;
|
||||||
|
++num_pad;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// LOG_DEBUG("PAD: %d", num_pad);
|
||||||
|
}
|
||||||
|
|
||||||
// Generate positional embeddings
|
// Generate positional embeddings
|
||||||
std::vector<float> gen_pe(int h, int w, int patch_size, int bs, int context_len, int theta, const std::vector<int>& axes_dim) {
|
std::vector<float> gen_pe(int h, int w, int patch_size, int bs, int context_len, int theta, const std::vector<int>& axes_dim) {
|
||||||
|
@ -1084,6 +1098,19 @@ namespace Flux {
|
||||||
guidance = ggml_set_f32(guidance, 0);
|
guidance = ggml_set_f32(guidance, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int mask_pad = 1;
|
||||||
|
const char* SD_CHROMA_MASK_PAD_OVERRIDE = getenv("SD_CHROMA_MASK_PAD_OVERRIDE");
|
||||||
|
if (SD_CHROMA_MASK_PAD_OVERRIDE != nullptr) {
|
||||||
|
std::string mask_pad_str = SD_CHROMA_MASK_PAD_OVERRIDE;
|
||||||
|
try {
|
||||||
|
mask_pad = std::stoi(mask_pad_str);
|
||||||
|
} catch (const std::invalid_argument&) {
|
||||||
|
LOG_WARN("SD_CHROMA_MASK_PAD_OVERRIDE environment variable is not a valid integer (%s). Falling back to default (%d)", SD_CHROMA_MASK_PAD_OVERRIDE, mask_pad);
|
||||||
|
} catch (const std::out_of_range&) {
|
||||||
|
LOG_WARN("SD_CHROMA_MASK_PAD_OVERRIDE environment variable value is out of range for `int` type (%s). Falling back to default (%d)", SD_CHROMA_MASK_PAD_OVERRIDE, mask_pad);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
flux.chroma_modify_mask_to_attend_padding(y, ggml_nelements(y), mask_pad);
|
||||||
|
|
||||||
const char* SD_CHROMA_USE_DIT_MASK = getenv("SD_CHROMA_USE_DIT_MASK");
|
const char* SD_CHROMA_USE_DIT_MASK = getenv("SD_CHROMA_USE_DIT_MASK");
|
||||||
if (SD_CHROMA_USE_DIT_MASK != nullptr) {
|
if (SD_CHROMA_USE_DIT_MASK != nullptr) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue