mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-12 18:09:42 +00:00
merge base support for chroma, however its not working correctly
This commit is contained in:
parent
dcf88d6e78
commit
30cf433ab4
5 changed files with 554 additions and 105 deletions
|
@ -117,6 +117,7 @@ namespace Flux {
|
|||
struct ggml_tensor* k,
|
||||
struct ggml_tensor* v,
|
||||
struct ggml_tensor* pe,
|
||||
struct ggml_tensor* mask,
|
||||
bool flash_attn) {
|
||||
// q,k,v: [N, L, n_head, d_head]
|
||||
// pe: [L, d_head/2, 2, 2]
|
||||
|
@ -124,7 +125,7 @@ namespace Flux {
|
|||
q = apply_rope(ctx, q, pe); // [N*n_head, L, d_head]
|
||||
k = apply_rope(ctx, k, pe); // [N*n_head, L, d_head]
|
||||
|
||||
auto x = ggml_nn_attention_ext(ctx, q, k, v, v->ne[1], NULL, false, true, flash_attn); // [N, L, n_head*d_head]
|
||||
auto x = ggml_nn_attention_ext(ctx, q, k, v, v->ne[1], mask, false, true, flash_attn); // [N, L, n_head*d_head]
|
||||
return x;
|
||||
}
|
||||
|
||||
|
@ -167,13 +168,13 @@ namespace Flux {
|
|||
return x;
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* pe) {
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* pe, struct ggml_tensor* mask) {
|
||||
// x: [N, n_token, dim]
|
||||
// pe: [n_token, d_head/2, 2, 2]
|
||||
// return [N, n_token, dim]
|
||||
auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head]
|
||||
x = attention(ctx, qkv[0], qkv[1], qkv[2], pe, flash_attn); // [N, n_token, dim]
|
||||
x = post_attention(ctx, x); // [N, n_token, dim]
|
||||
auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head]
|
||||
x = attention(ctx, qkv[0], qkv[1], qkv[2], pe, mask, flash_attn); // [N, n_token, dim]
|
||||
x = post_attention(ctx, x); // [N, n_token, dim]
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
@ -185,6 +186,13 @@ namespace Flux {
|
|||
|
||||
ModulationOut(ggml_tensor* shift = NULL, ggml_tensor* scale = NULL, ggml_tensor* gate = NULL)
|
||||
: shift(shift), scale(scale), gate(gate) {}
|
||||
|
||||
ModulationOut(struct ggml_context* ctx, ggml_tensor* vec, int64_t offset) {
|
||||
int64_t stride = vec->nb[1] * vec->ne[1];
|
||||
shift = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 0)); // [N, dim]
|
||||
scale = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 1)); // [N, dim]
|
||||
gate = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 2)); // [N, dim]
|
||||
}
|
||||
};
|
||||
|
||||
struct Modulation : public GGMLBlock {
|
||||
|
@ -210,19 +218,12 @@ namespace Flux {
|
|||
auto m = ggml_reshape_3d(ctx, out, vec->ne[0], multiplier, vec->ne[1]); // [N, multiplier, dim]
|
||||
m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [multiplier, N, dim]
|
||||
|
||||
int64_t offset = m->nb[1] * m->ne[1];
|
||||
auto shift_0 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, dim]
|
||||
auto scale_0 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, dim]
|
||||
auto gate_0 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, dim]
|
||||
|
||||
ModulationOut m_0 = ModulationOut(ctx, m, 0);
|
||||
if (is_double) {
|
||||
auto shift_1 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, dim]
|
||||
auto scale_1 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, dim]
|
||||
auto gate_1 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, dim]
|
||||
return {ModulationOut(shift_0, scale_0, gate_0), ModulationOut(shift_1, scale_1, gate_1)};
|
||||
return {m_0, ModulationOut(ctx, m, 3)};
|
||||
}
|
||||
|
||||
return {ModulationOut(shift_0, scale_0, gate_0), ModulationOut()};
|
||||
return {m_0, ModulationOut()};
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -242,25 +243,33 @@ namespace Flux {
|
|||
|
||||
struct DoubleStreamBlock : public GGMLBlock {
|
||||
bool flash_attn;
|
||||
bool prune_mod;
|
||||
int idx = 0;
|
||||
|
||||
public:
|
||||
DoubleStreamBlock(int64_t hidden_size,
|
||||
int64_t num_heads,
|
||||
float mlp_ratio,
|
||||
int idx = 0,
|
||||
bool qkv_bias = false,
|
||||
bool flash_attn = false)
|
||||
: flash_attn(flash_attn) {
|
||||
bool flash_attn = false,
|
||||
bool prune_mod = false)
|
||||
: idx(idx), flash_attn(flash_attn), prune_mod(prune_mod) {
|
||||
int64_t mlp_hidden_dim = hidden_size * mlp_ratio;
|
||||
blocks["img_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
|
||||
blocks["img_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
|
||||
blocks["img_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, flash_attn));
|
||||
if (!prune_mod) {
|
||||
blocks["img_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
|
||||
}
|
||||
blocks["img_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
|
||||
blocks["img_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, flash_attn));
|
||||
|
||||
blocks["img_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
|
||||
blocks["img_mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, mlp_hidden_dim));
|
||||
// img_mlp.1 is nn.GELU(approximate="tanh")
|
||||
blocks["img_mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(mlp_hidden_dim, hidden_size));
|
||||
|
||||
blocks["txt_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
|
||||
if (!prune_mod) {
|
||||
blocks["txt_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
|
||||
}
|
||||
blocks["txt_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
|
||||
blocks["txt_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, flash_attn));
|
||||
|
||||
|
@ -270,17 +279,34 @@ namespace Flux {
|
|||
blocks["txt_mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(mlp_hidden_dim, hidden_size));
|
||||
}
|
||||
|
||||
std::vector<ModulationOut> get_distil_img_mod(struct ggml_context* ctx, struct ggml_tensor* vec) {
|
||||
// TODO: not hardcoded?
|
||||
const int single_blocks_count = 38;
|
||||
const int double_blocks_count = 19;
|
||||
|
||||
int64_t offset = 6 * idx + 3 * single_blocks_count;
|
||||
return {ModulationOut(ctx, vec, offset), ModulationOut(ctx, vec, offset + 3)};
|
||||
}
|
||||
|
||||
std::vector<ModulationOut> get_distil_txt_mod(struct ggml_context* ctx, struct ggml_tensor* vec) {
|
||||
// TODO: not hardcoded?
|
||||
const int single_blocks_count = 38;
|
||||
const int double_blocks_count = 19;
|
||||
|
||||
int64_t offset = 6 * idx + 6 * double_blocks_count + 3 * single_blocks_count;
|
||||
return {ModulationOut(ctx, vec, offset), ModulationOut(ctx, vec, offset + 3)};
|
||||
}
|
||||
|
||||
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx,
|
||||
struct ggml_tensor* img,
|
||||
struct ggml_tensor* txt,
|
||||
struct ggml_tensor* vec,
|
||||
struct ggml_tensor* pe) {
|
||||
struct ggml_tensor* pe,
|
||||
struct ggml_tensor* mask = NULL) {
|
||||
// img: [N, n_img_token, hidden_size]
|
||||
// txt: [N, n_txt_token, hidden_size]
|
||||
// pe: [n_img_token + n_txt_token, d_head/2, 2, 2]
|
||||
// return: ([N, n_img_token, hidden_size], [N, n_txt_token, hidden_size])
|
||||
|
||||
auto img_mod = std::dynamic_pointer_cast<Modulation>(blocks["img_mod"]);
|
||||
auto img_norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["img_norm1"]);
|
||||
auto img_attn = std::dynamic_pointer_cast<SelfAttention>(blocks["img_attn"]);
|
||||
|
||||
|
@ -288,7 +314,6 @@ namespace Flux {
|
|||
auto img_mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["img_mlp.0"]);
|
||||
auto img_mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["img_mlp.2"]);
|
||||
|
||||
auto txt_mod = std::dynamic_pointer_cast<Modulation>(blocks["txt_mod"]);
|
||||
auto txt_norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["txt_norm1"]);
|
||||
auto txt_attn = std::dynamic_pointer_cast<SelfAttention>(blocks["txt_attn"]);
|
||||
|
||||
|
@ -296,10 +321,22 @@ namespace Flux {
|
|||
auto txt_mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["txt_mlp.0"]);
|
||||
auto txt_mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["txt_mlp.2"]);
|
||||
|
||||
auto img_mods = img_mod->forward(ctx, vec);
|
||||
std::vector<ModulationOut> img_mods;
|
||||
if (prune_mod) {
|
||||
img_mods = get_distil_img_mod(ctx, vec);
|
||||
} else {
|
||||
auto img_mod = std::dynamic_pointer_cast<Modulation>(blocks["img_mod"]);
|
||||
img_mods = img_mod->forward(ctx, vec);
|
||||
}
|
||||
ModulationOut img_mod1 = img_mods[0];
|
||||
ModulationOut img_mod2 = img_mods[1];
|
||||
auto txt_mods = txt_mod->forward(ctx, vec);
|
||||
std::vector<ModulationOut> txt_mods;
|
||||
if (prune_mod) {
|
||||
txt_mods = get_distil_txt_mod(ctx, vec);
|
||||
} else {
|
||||
auto txt_mod = std::dynamic_pointer_cast<Modulation>(blocks["txt_mod"]);
|
||||
txt_mods = txt_mod->forward(ctx, vec);
|
||||
}
|
||||
ModulationOut txt_mod1 = txt_mods[0];
|
||||
ModulationOut txt_mod2 = txt_mods[1];
|
||||
|
||||
|
@ -324,7 +361,7 @@ namespace Flux {
|
|||
auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
||||
auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
||||
|
||||
auto attn = attention(ctx, q, k, v, pe, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head]
|
||||
auto attn = attention(ctx, q, k, v, pe, mask, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head]
|
||||
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
|
||||
auto txt_attn_out = ggml_view_3d(ctx,
|
||||
attn,
|
||||
|
@ -373,14 +410,18 @@ namespace Flux {
|
|||
int64_t hidden_size;
|
||||
int64_t mlp_hidden_dim;
|
||||
bool flash_attn;
|
||||
bool prune_mod;
|
||||
int idx = 0;
|
||||
|
||||
public:
|
||||
SingleStreamBlock(int64_t hidden_size,
|
||||
int64_t num_heads,
|
||||
float mlp_ratio = 4.0f,
|
||||
int idx = 0,
|
||||
float qk_scale = 0.f,
|
||||
bool flash_attn = false)
|
||||
: hidden_size(hidden_size), num_heads(num_heads), flash_attn(flash_attn) {
|
||||
bool flash_attn = false,
|
||||
bool prune_mod = false)
|
||||
: hidden_size(hidden_size), num_heads(num_heads), idx(idx), flash_attn(flash_attn), prune_mod(prune_mod) {
|
||||
int64_t head_dim = hidden_size / num_heads;
|
||||
float scale = qk_scale;
|
||||
if (scale <= 0.f) {
|
||||
|
@ -393,26 +434,37 @@ namespace Flux {
|
|||
blocks["norm"] = std::shared_ptr<GGMLBlock>(new QKNorm(head_dim));
|
||||
blocks["pre_norm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
|
||||
// mlp_act is nn.GELU(approximate="tanh")
|
||||
blocks["modulation"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, false));
|
||||
if (!prune_mod) {
|
||||
blocks["modulation"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, false));
|
||||
}
|
||||
}
|
||||
|
||||
ModulationOut get_distil_mod(struct ggml_context* ctx, struct ggml_tensor* vec) {
|
||||
int64_t offset = 3 * idx;
|
||||
return ModulationOut(ctx, vec, offset);
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x,
|
||||
struct ggml_tensor* vec,
|
||||
struct ggml_tensor* pe) {
|
||||
struct ggml_tensor* pe,
|
||||
struct ggml_tensor* mask = NULL) {
|
||||
// x: [N, n_token, hidden_size]
|
||||
// pe: [n_token, d_head/2, 2, 2]
|
||||
// return: [N, n_token, hidden_size]
|
||||
|
||||
auto linear1 = std::dynamic_pointer_cast<Linear>(blocks["linear1"]);
|
||||
auto linear2 = std::dynamic_pointer_cast<Linear>(blocks["linear2"]);
|
||||
auto norm = std::dynamic_pointer_cast<QKNorm>(blocks["norm"]);
|
||||
auto pre_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["pre_norm"]);
|
||||
auto modulation = std::dynamic_pointer_cast<Modulation>(blocks["modulation"]);
|
||||
|
||||
auto mods = modulation->forward(ctx, vec);
|
||||
ModulationOut mod = mods[0];
|
||||
auto linear1 = std::dynamic_pointer_cast<Linear>(blocks["linear1"]);
|
||||
auto linear2 = std::dynamic_pointer_cast<Linear>(blocks["linear2"]);
|
||||
auto norm = std::dynamic_pointer_cast<QKNorm>(blocks["norm"]);
|
||||
auto pre_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["pre_norm"]);
|
||||
ModulationOut mod;
|
||||
if (prune_mod) {
|
||||
mod = get_distil_mod(ctx, vec);
|
||||
} else {
|
||||
auto modulation = std::dynamic_pointer_cast<Modulation>(blocks["modulation"]);
|
||||
|
||||
mod = modulation->forward(ctx, vec)[0];
|
||||
}
|
||||
auto x_mod = Flux::modulate(ctx, pre_norm->forward(ctx, x), mod.shift, mod.scale);
|
||||
auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim]
|
||||
qkv_mlp = ggml_cont(ctx, ggml_permute(ctx, qkv_mlp, 2, 0, 1, 3)); // [hidden_size * 3 + mlp_hidden_dim, N, n_token]
|
||||
|
@ -443,7 +495,7 @@ namespace Flux {
|
|||
auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head]
|
||||
q = norm->query_norm(ctx, q);
|
||||
k = norm->key_norm(ctx, k);
|
||||
auto attn = attention(ctx, q, k, v, pe, flash_attn); // [N, n_token, hidden_size]
|
||||
auto attn = attention(ctx, q, k, v, pe, mask, flash_attn); // [N, n_token, hidden_size]
|
||||
|
||||
auto attn_mlp = ggml_concat(ctx, attn, ggml_gelu_inplace(ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim]
|
||||
auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size]
|
||||
|
@ -454,13 +506,27 @@ namespace Flux {
|
|||
};
|
||||
|
||||
struct LastLayer : public GGMLBlock {
|
||||
bool prune_mod;
|
||||
|
||||
public:
|
||||
LastLayer(int64_t hidden_size,
|
||||
int64_t patch_size,
|
||||
int64_t out_channels) {
|
||||
blocks["norm_final"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false));
|
||||
blocks["linear"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, patch_size * patch_size * out_channels));
|
||||
blocks["adaLN_modulation.1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, 2 * hidden_size));
|
||||
int64_t out_channels,
|
||||
bool prune_mod = false) : prune_mod(prune_mod) {
|
||||
blocks["norm_final"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false));
|
||||
blocks["linear"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, patch_size * patch_size * out_channels));
|
||||
if (!prune_mod) {
|
||||
blocks["adaLN_modulation.1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, 2 * hidden_size));
|
||||
}
|
||||
}
|
||||
|
||||
ModulationOut get_distil_mod(struct ggml_context* ctx, struct ggml_tensor* vec) {
|
||||
int64_t offset = vec->ne[2] - 2;
|
||||
int64_t stride = vec->nb[1] * vec->ne[1];
|
||||
auto shift = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 0)); // [N, dim]
|
||||
auto scale = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 1)); // [N, dim]
|
||||
// No gate
|
||||
return ModulationOut(shift, scale, NULL);
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||
|
@ -469,17 +535,24 @@ namespace Flux {
|
|||
// x: [N, n_token, hidden_size]
|
||||
// c: [N, hidden_size]
|
||||
// return: [N, n_token, patch_size * patch_size * out_channels]
|
||||
auto norm_final = std::dynamic_pointer_cast<LayerNorm>(blocks["norm_final"]);
|
||||
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
|
||||
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
|
||||
auto norm_final = std::dynamic_pointer_cast<LayerNorm>(blocks["norm_final"]);
|
||||
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
|
||||
struct ggml_tensor *shift, *scale;
|
||||
if (prune_mod) {
|
||||
auto mod = get_distil_mod(ctx, c);
|
||||
shift = mod.shift;
|
||||
scale = mod.scale;
|
||||
} else {
|
||||
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
|
||||
|
||||
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, 2 * hidden_size]
|
||||
m = ggml_reshape_3d(ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size]
|
||||
m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size]
|
||||
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, 2 * hidden_size]
|
||||
m = ggml_reshape_3d(ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size]
|
||||
m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size]
|
||||
|
||||
int64_t offset = m->nb[1] * m->ne[1];
|
||||
auto shift = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
|
||||
auto scale = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
|
||||
int64_t offset = m->nb[1] * m->ne[1];
|
||||
shift = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
|
||||
scale = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
|
||||
}
|
||||
|
||||
x = Flux::modulate(ctx, norm_final->forward(ctx, x), shift, scale);
|
||||
x = linear->forward(ctx, x);
|
||||
|
@ -488,6 +561,34 @@ namespace Flux {
|
|||
}
|
||||
};
|
||||
|
||||
struct ChromaApproximator : public GGMLBlock {
|
||||
int64_t inner_size = 5120;
|
||||
int64_t n_layers = 5;
|
||||
ChromaApproximator(int64_t in_channels = 64, int64_t hidden_size = 3072) {
|
||||
blocks["in_proj"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, inner_size, true));
|
||||
for (int i = 0; i < n_layers; i++) {
|
||||
blocks["norms." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new RMSNorm(inner_size));
|
||||
blocks["layers." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new MLPEmbedder(inner_size, inner_size));
|
||||
}
|
||||
blocks["out_proj"] = std::shared_ptr<GGMLBlock>(new Linear(inner_size, hidden_size, true));
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
||||
auto in_proj = std::dynamic_pointer_cast<Linear>(blocks["in_proj"]);
|
||||
auto out_proj = std::dynamic_pointer_cast<Linear>(blocks["out_proj"]);
|
||||
|
||||
x = in_proj->forward(ctx, x);
|
||||
for (int i = 0; i < n_layers; i++) {
|
||||
auto norm = std::dynamic_pointer_cast<RMSNorm>(blocks["norms." + std::to_string(i)]);
|
||||
auto embed = std::dynamic_pointer_cast<MLPEmbedder>(blocks["layers." + std::to_string(i)]);
|
||||
x = ggml_add_inplace(ctx, x, embed->forward(ctx, norm->forward(ctx, x)));
|
||||
}
|
||||
x = out_proj->forward(ctx, x);
|
||||
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
struct FluxParams {
|
||||
int64_t in_channels = 64;
|
||||
int64_t out_channels = 64;
|
||||
|
@ -504,6 +605,7 @@ namespace Flux {
|
|||
bool qkv_bias = true;
|
||||
bool guidance_embed = true;
|
||||
bool flash_attn = true;
|
||||
bool is_chroma = false;
|
||||
};
|
||||
|
||||
struct Flux : public GGMLBlock {
|
||||
|
@ -607,6 +709,7 @@ namespace Flux {
|
|||
return ids;
|
||||
}
|
||||
|
||||
|
||||
// Generate positional embeddings
|
||||
std::vector<float> gen_pe(int h, int w, int patch_size, int bs, int context_len, int theta, const std::vector<int>& axes_dim) {
|
||||
std::vector<std::vector<float>> ids = gen_ids(h, w, patch_size, bs, context_len);
|
||||
|
@ -645,11 +748,15 @@ namespace Flux {
|
|||
: params(params) {
|
||||
int64_t pe_dim = params.hidden_size / params.num_heads;
|
||||
|
||||
blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.in_channels, params.hidden_size, true));
|
||||
blocks["time_in"] = std::shared_ptr<GGMLBlock>(new MLPEmbedder(256, params.hidden_size));
|
||||
blocks["vector_in"] = std::shared_ptr<GGMLBlock>(new MLPEmbedder(params.vec_in_dim, params.hidden_size));
|
||||
if (params.guidance_embed) {
|
||||
blocks["guidance_in"] = std::shared_ptr<GGMLBlock>(new MLPEmbedder(256, params.hidden_size));
|
||||
blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.in_channels, params.hidden_size, true));
|
||||
if (params.is_chroma) {
|
||||
blocks["distilled_guidance_layer"] = std::shared_ptr<GGMLBlock>(new ChromaApproximator(params.in_channels, params.hidden_size));
|
||||
} else {
|
||||
blocks["time_in"] = std::shared_ptr<GGMLBlock>(new MLPEmbedder(256, params.hidden_size));
|
||||
blocks["vector_in"] = std::shared_ptr<GGMLBlock>(new MLPEmbedder(params.vec_in_dim, params.hidden_size));
|
||||
if (params.guidance_embed) {
|
||||
blocks["guidance_in"] = std::shared_ptr<GGMLBlock>(new MLPEmbedder(256, params.hidden_size));
|
||||
}
|
||||
}
|
||||
blocks["txt_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.context_in_dim, params.hidden_size, true));
|
||||
|
||||
|
@ -657,19 +764,23 @@ namespace Flux {
|
|||
blocks["double_blocks." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new DoubleStreamBlock(params.hidden_size,
|
||||
params.num_heads,
|
||||
params.mlp_ratio,
|
||||
i,
|
||||
params.qkv_bias,
|
||||
params.flash_attn));
|
||||
params.flash_attn,
|
||||
params.is_chroma));
|
||||
}
|
||||
|
||||
for (int i = 0; i < params.depth_single_blocks; i++) {
|
||||
blocks["single_blocks." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new SingleStreamBlock(params.hidden_size,
|
||||
params.num_heads,
|
||||
params.mlp_ratio,
|
||||
i,
|
||||
0.f,
|
||||
params.flash_attn));
|
||||
params.flash_attn,
|
||||
params.is_chroma));
|
||||
}
|
||||
|
||||
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new LastLayer(params.hidden_size, 1, params.out_channels));
|
||||
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new LastLayer(params.hidden_size, 1, params.out_channels, params.is_chroma));
|
||||
}
|
||||
|
||||
struct ggml_tensor* patchify(struct ggml_context* ctx,
|
||||
|
@ -726,25 +837,54 @@ namespace Flux {
|
|||
struct ggml_tensor* y,
|
||||
struct ggml_tensor* guidance,
|
||||
struct ggml_tensor* pe,
|
||||
struct ggml_tensor* arange = NULL,
|
||||
std::vector<int> skip_layers = std::vector<int>()) {
|
||||
auto img_in = std::dynamic_pointer_cast<Linear>(blocks["img_in"]);
|
||||
auto time_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["time_in"]);
|
||||
auto vector_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["vector_in"]);
|
||||
auto txt_in = std::dynamic_pointer_cast<Linear>(blocks["txt_in"]);
|
||||
auto final_layer = std::dynamic_pointer_cast<LastLayer>(blocks["final_layer"]);
|
||||
|
||||
img = img_in->forward(ctx, img);
|
||||
auto vec = time_in->forward(ctx, ggml_nn_timestep_embedding(ctx, timesteps, 256, 10000, 1000.f));
|
||||
img = img_in->forward(ctx, img);
|
||||
struct ggml_tensor* vec;
|
||||
struct ggml_tensor* txt_img_mask = NULL;
|
||||
if (params.is_chroma) {
|
||||
int64_t mod_index_length = 344;
|
||||
auto approx = std::dynamic_pointer_cast<ChromaApproximator>(blocks["distilled_guidance_layer"]);
|
||||
auto distill_timestep = ggml_nn_timestep_embedding(ctx, timesteps, 16, 10000, 1000.f);
|
||||
auto distill_guidance = ggml_nn_timestep_embedding(ctx, guidance, 16, 10000, 1000.f);
|
||||
|
||||
if (params.guidance_embed) {
|
||||
GGML_ASSERT(guidance != NULL);
|
||||
auto guidance_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["guidance_in"]);
|
||||
// bf16 and fp16 result is different
|
||||
auto g_in = ggml_nn_timestep_embedding(ctx, guidance, 256, 10000, 1000.f);
|
||||
vec = ggml_add(ctx, vec, guidance_in->forward(ctx, g_in));
|
||||
// auto arange = ggml_arange(ctx, 0, (float)mod_index_length, 1); // Not working on a lot of backends, precomputing it on CPU instead
|
||||
GGML_ASSERT(arange != NULL);
|
||||
auto modulation_index = ggml_nn_timestep_embedding(ctx, arange, 32, 10000, 1000.f); // [1, 344, 32]
|
||||
|
||||
// Batch broadcast (will it ever be useful)
|
||||
modulation_index = ggml_repeat(ctx, modulation_index, ggml_new_tensor_3d(ctx, GGML_TYPE_F32, modulation_index->ne[0], modulation_index->ne[1], img->ne[2])); // [N, 344, 32]
|
||||
|
||||
auto timestep_guidance = ggml_concat(ctx, distill_timestep, distill_guidance, 0); // [N, 1, 32]
|
||||
timestep_guidance = ggml_repeat(ctx, timestep_guidance, modulation_index); // [N, 344, 32]
|
||||
|
||||
vec = ggml_concat(ctx, timestep_guidance, modulation_index, 0); // [N, 344, 64]
|
||||
// Permute for consistency with non-distilled modulation implementation
|
||||
vec = ggml_cont(ctx, ggml_permute(ctx, vec, 0, 2, 1, 3)); // [344, N, 64]
|
||||
vec = approx->forward(ctx, vec); // [344, N, hidden_size]
|
||||
|
||||
if (y != NULL) {
|
||||
txt_img_mask = ggml_pad(ctx, y, img->ne[1], 0, 0, 0);
|
||||
}
|
||||
} else {
|
||||
auto time_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["time_in"]);
|
||||
auto vector_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["vector_in"]);
|
||||
vec = time_in->forward(ctx, ggml_nn_timestep_embedding(ctx, timesteps, 256, 10000, 1000.f));
|
||||
if (params.guidance_embed) {
|
||||
GGML_ASSERT(guidance != NULL);
|
||||
auto guidance_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["guidance_in"]);
|
||||
// bf16 and fp16 result is different
|
||||
auto g_in = ggml_nn_timestep_embedding(ctx, guidance, 256, 10000, 1000.f);
|
||||
vec = ggml_add(ctx, vec, guidance_in->forward(ctx, g_in));
|
||||
}
|
||||
|
||||
vec = ggml_add(ctx, vec, vector_in->forward(ctx, y));
|
||||
}
|
||||
|
||||
vec = ggml_add(ctx, vec, vector_in->forward(ctx, y));
|
||||
txt = txt_in->forward(ctx, txt);
|
||||
|
||||
for (int i = 0; i < params.depth; i++) {
|
||||
|
@ -754,7 +894,7 @@ namespace Flux {
|
|||
|
||||
auto block = std::dynamic_pointer_cast<DoubleStreamBlock>(blocks["double_blocks." + std::to_string(i)]);
|
||||
|
||||
auto img_txt = block->forward(ctx, img, txt, vec, pe);
|
||||
auto img_txt = block->forward(ctx, img, txt, vec, pe, txt_img_mask);
|
||||
img = img_txt.first; // [N, n_img_token, hidden_size]
|
||||
txt = img_txt.second; // [N, n_txt_token, hidden_size]
|
||||
}
|
||||
|
@ -766,7 +906,7 @@ namespace Flux {
|
|||
}
|
||||
auto block = std::dynamic_pointer_cast<SingleStreamBlock>(blocks["single_blocks." + std::to_string(i)]);
|
||||
|
||||
txt_img = block->forward(ctx, txt_img, vec, pe);
|
||||
txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask);
|
||||
}
|
||||
|
||||
txt_img = ggml_cont(ctx, ggml_permute(ctx, txt_img, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
|
||||
|
@ -781,7 +921,6 @@ namespace Flux {
|
|||
img = ggml_cont(ctx, ggml_permute(ctx, img, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
|
||||
|
||||
img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels)
|
||||
|
||||
return img;
|
||||
}
|
||||
|
||||
|
@ -793,6 +932,7 @@ namespace Flux {
|
|||
struct ggml_tensor* y,
|
||||
struct ggml_tensor* guidance,
|
||||
struct ggml_tensor* pe,
|
||||
struct ggml_tensor* arange = NULL,
|
||||
std::vector<int> skip_layers = std::vector<int>()) {
|
||||
// Forward pass of DiT.
|
||||
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
|
@ -830,7 +970,7 @@ namespace Flux {
|
|||
img = ggml_concat(ctx, img, ggml_concat(ctx, masked, mask, 0), 0);
|
||||
}
|
||||
|
||||
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, skip_layers); // [N, h*w, C * patch_size * patch_size]
|
||||
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, arange, skip_layers); // [N, h*w, C * patch_size * patch_size]
|
||||
|
||||
// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
|
||||
out = unpatchify(ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size, patch_size); // [N, C, H + pad_h, W + pad_w]
|
||||
|
@ -845,7 +985,8 @@ namespace Flux {
|
|||
public:
|
||||
FluxParams flux_params;
|
||||
Flux flux;
|
||||
std::vector<float> pe_vec; // for cache
|
||||
std::vector<float> pe_vec, range; // for cache
|
||||
SDVersion version;
|
||||
|
||||
FluxRunner(ggml_backend_t backend,
|
||||
std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types,
|
||||
|
@ -868,6 +1009,10 @@ namespace Flux {
|
|||
// not schnell
|
||||
flux_params.guidance_embed = true;
|
||||
}
|
||||
if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) {
|
||||
// Chroma
|
||||
flux_params.is_chroma = true;
|
||||
}
|
||||
size_t db = tensor_name.find("double_blocks.");
|
||||
if (db != std::string::npos) {
|
||||
tensor_name = tensor_name.substr(db); // remove prefix
|
||||
|
@ -887,7 +1032,9 @@ namespace Flux {
|
|||
}
|
||||
|
||||
LOG_INFO("Flux blocks: %d double, %d single", flux_params.depth, flux_params.depth_single_blocks);
|
||||
if (!flux_params.guidance_embed) {
|
||||
if (flux_params.is_chroma) {
|
||||
LOG_INFO("Using pruned modulation (Chroma)");
|
||||
} else if (!flux_params.guidance_embed) {
|
||||
LOG_INFO("Flux guidance is disabled (Schnell mode)");
|
||||
}
|
||||
|
||||
|
@ -913,14 +1060,51 @@ namespace Flux {
|
|||
GGML_ASSERT(x->ne[3] == 1);
|
||||
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false);
|
||||
|
||||
struct ggml_tensor* precompute_arange = NULL;
|
||||
|
||||
x = to_backend(x);
|
||||
context = to_backend(context);
|
||||
if (c_concat != NULL) {
|
||||
c_concat = to_backend(c_concat);
|
||||
}
|
||||
y = to_backend(y);
|
||||
if (flux_params.is_chroma) {
|
||||
const char* SD_CHROMA_ENABLE_GUIDANCE = getenv("SD_CHROMA_ENABLE_GUIDANCE");
|
||||
bool disable_guidance = true;
|
||||
if (SD_CHROMA_ENABLE_GUIDANCE != NULL) {
|
||||
std::string enable_guidance_str = SD_CHROMA_ENABLE_GUIDANCE;
|
||||
if (enable_guidance_str == "ON" || enable_guidance_str == "TRUE") {
|
||||
LOG_WARN("Chroma guidance has been enabled. Image might be broken. (SD_CHROMA_ENABLE_GUIDANCE env variable to \"OFF\" to disable)", SD_CHROMA_ENABLE_GUIDANCE);
|
||||
disable_guidance = false;
|
||||
} else if (enable_guidance_str != "OFF" && enable_guidance_str != "FALSE") {
|
||||
LOG_WARN("SD_CHROMA_ENABLE_GUIDANCE environment variable has unexpected value. Assuming default (\"OFF\"). (Expected \"ON\"/\"TRUE\" or\"OFF\"/\"FALSE\", got \"%s\")", SD_CHROMA_ENABLE_GUIDANCE);
|
||||
}
|
||||
}
|
||||
if (disable_guidance) {
|
||||
// LOG_DEBUG("Forcing guidance to 0 for chroma model (SD_CHROMA_ENABLE_GUIDANCE env variable to \"ON\" to enable)");
|
||||
guidance = ggml_set_f32(guidance, 0);
|
||||
}
|
||||
|
||||
|
||||
const char* SD_CHROMA_USE_DIT_MASK = getenv("SD_CHROMA_USE_DIT_MASK");
|
||||
if (SD_CHROMA_USE_DIT_MASK != nullptr) {
|
||||
std::string sd_chroma_use_DiT_mask_str = SD_CHROMA_USE_DIT_MASK;
|
||||
if (sd_chroma_use_DiT_mask_str == "OFF" || sd_chroma_use_DiT_mask_str == "FALSE") {
|
||||
y = NULL;
|
||||
} else if (sd_chroma_use_DiT_mask_str != "ON" && sd_chroma_use_DiT_mask_str != "TRUE") {
|
||||
LOG_WARN("SD_CHROMA_USE_DIT_MASK environment variable has unexpected value. Assuming default (\"ON\"). (Expected \"ON\"/\"TRUE\" or\"OFF\"/\"FALSE\", got \"%s\")", SD_CHROMA_USE_DIT_MASK);
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_arrange is not working on some backends, and y isn't used, so let's reuse y to precompute it
|
||||
range = arange(0, 344);
|
||||
precompute_arange = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_F32, range.size());
|
||||
set_backend_tensor_data(precompute_arange, range.data());
|
||||
// y = NULL;
|
||||
}
|
||||
y = to_backend(y);
|
||||
|
||||
timesteps = to_backend(timesteps);
|
||||
if (flux_params.guidance_embed) {
|
||||
if (flux_params.guidance_embed || flux_params.is_chroma) {
|
||||
guidance = to_backend(guidance);
|
||||
}
|
||||
|
||||
|
@ -941,6 +1125,7 @@ namespace Flux {
|
|||
y,
|
||||
guidance,
|
||||
pe,
|
||||
precompute_arange,
|
||||
skip_layers);
|
||||
|
||||
ggml_build_forward_expand(gf, out);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue