Revert "revert padding change for sd chroma"

This reverts commit 7de88802f9.
This commit is contained in:
Concedo 2025-06-14 10:10:34 +08:00
parent 5f9e96e82d
commit bfb47cbcd8
2 changed files with 30 additions and 27 deletions

View file

@ -709,20 +709,6 @@ namespace Flux {
return ids;
}
void chroma_modify_mask_to_attend_padding(struct ggml_tensor* mask, int max_seq_length, int num_extra_padding = 8) {
float* mask_data = (float*)mask->data;
int num_pad = 0;
for (int64_t i = 0; i < max_seq_length; i++) {
if (num_pad >= num_extra_padding) {
break;
}
if (std::isinf(mask_data[i])) {
mask_data[i] = 0;
++num_pad;
}
}
// LOG_DEBUG("PAD: %d", num_pad);
}
// Generate positional embeddings
std::vector<float> gen_pe(int h, int w, int patch_size, int bs, int context_len, int theta, const std::vector<int>& axes_dim) {
@ -1098,19 +1084,6 @@ namespace Flux {
guidance = ggml_set_f32(guidance, 0);
}
int mask_pad = 1;
const char* SD_CHROMA_MASK_PAD_OVERRIDE = getenv("SD_CHROMA_MASK_PAD_OVERRIDE");
if (SD_CHROMA_MASK_PAD_OVERRIDE != nullptr) {
std::string mask_pad_str = SD_CHROMA_MASK_PAD_OVERRIDE;
try {
mask_pad = std::stoi(mask_pad_str);
} catch (const std::invalid_argument&) {
LOG_WARN("SD_CHROMA_MASK_PAD_OVERRIDE environment variable is not a valid integer (%s). Falling back to default (%d)", SD_CHROMA_MASK_PAD_OVERRIDE, mask_pad);
} catch (const std::out_of_range&) {
LOG_WARN("SD_CHROMA_MASK_PAD_OVERRIDE environment variable value is out of range for `int` type (%s). Falling back to default (%d)", SD_CHROMA_MASK_PAD_OVERRIDE, mask_pad);
}
}
flux.chroma_modify_mask_to_attend_padding(y, ggml_nelements(y), mask_pad);
const char* SD_CHROMA_USE_DIT_MASK = getenv("SD_CHROMA_USE_DIT_MASK");
if (SD_CHROMA_USE_DIT_MASK != nullptr) {