sd 3.5 medium

This commit is contained in:
Concedo 2024-11-03 23:27:06 +08:00
parent f32a874966
commit 5233e8ed1d
5 changed files with 204 additions and 49 deletions

View file

@ -252,6 +252,7 @@ struct DismantledBlock : public GGMLBlock {
public:
int64_t num_heads;
bool pre_only;
bool self_attn;
public:
DismantledBlock(int64_t hidden_size,
@ -259,14 +260,19 @@ public:
float mlp_ratio = 4.0,
std::string qk_norm = "",
bool qkv_bias = false,
bool pre_only = false)
: num_heads(num_heads), pre_only(pre_only) {
bool pre_only = false,
bool self_attn = false)
: num_heads(num_heads), pre_only(pre_only), self_attn(self_attn) {
// rmsnorm is always Flase
// scale_mod_only is always Flase
// swiglu is always Flase
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false));
blocks["attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, pre_only));
if (self_attn) {
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, false));
}
if (!pre_only) {
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false));
int64_t mlp_hidden_dim = (int64_t)(hidden_size * mlp_ratio);
@ -277,9 +283,52 @@ public:
if (pre_only) {
n_mods = 2;
}
if (self_attn) {
n_mods = 9;
}
blocks["adaLN_modulation.1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, n_mods * hidden_size));
}
std::tuple<std::vector<struct ggml_tensor*>, std::vector<struct ggml_tensor*>, std::vector<struct ggml_tensor*>> pre_attention_x(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* c) {
GGML_ASSERT(self_attn);
// x: [N, n_token, hidden_size]
// c: [N, hidden_size]
auto norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm1"]);
auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks["attn"]);
auto attn2 = std::dynamic_pointer_cast<SelfAttention>(blocks["attn2"]);
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
int64_t n_mods = 9;
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, n_mods * hidden_size]
m = ggml_reshape_3d(ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size]
m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size]
int64_t offset = m->nb[1] * m->ne[1];
auto shift_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
auto scale_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
auto gate_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size]
auto shift_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size]
auto scale_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size]
auto gate_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size]
auto shift_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 6); // [N, hidden_size]
auto scale_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 7); // [N, hidden_size]
auto gate_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 8); // [N, hidden_size]
auto x_norm = norm1->forward(ctx, x);
auto attn_in = modulate(ctx, x_norm, shift_msa, scale_msa);
auto qkv = attn->pre_attention(ctx, attn_in);
auto attn2_in = modulate(ctx, x_norm, shift_msa2, scale_msa2);
auto qkv2 = attn2->pre_attention(ctx, attn2_in);
return {qkv, qkv2, {x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2}};
}
std::pair<std::vector<struct ggml_tensor*>, std::vector<struct ggml_tensor*>> pre_attention(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* c) {
@ -319,6 +368,44 @@ public:
}
}
struct ggml_tensor* post_attention_x(struct ggml_context* ctx,
struct ggml_tensor* attn_out,
struct ggml_tensor* attn2_out,
struct ggml_tensor* x,
struct ggml_tensor* gate_msa,
struct ggml_tensor* shift_mlp,
struct ggml_tensor* scale_mlp,
struct ggml_tensor* gate_mlp,
struct ggml_tensor* gate_msa2) {
// attn_out: [N, n_token, hidden_size]
// x: [N, n_token, hidden_size]
// gate_msa: [N, hidden_size]
// shift_mlp: [N, hidden_size]
// scale_mlp: [N, hidden_size]
// gate_mlp: [N, hidden_size]
// return: [N, n_token, hidden_size]
GGML_ASSERT(!pre_only);
auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks["attn"]);
auto attn2 = std::dynamic_pointer_cast<SelfAttention>(blocks["attn2"]);
auto norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm2"]);
auto mlp = std::dynamic_pointer_cast<Mlp>(blocks["mlp"]);
gate_msa = ggml_reshape_3d(ctx, gate_msa, gate_msa->ne[0], 1, gate_msa->ne[1]); // [N, 1, hidden_size]
gate_mlp = ggml_reshape_3d(ctx, gate_mlp, gate_mlp->ne[0], 1, gate_mlp->ne[1]); // [N, 1, hidden_size]
gate_msa2 = ggml_reshape_3d(ctx, gate_msa2, gate_msa2->ne[0], 1, gate_msa2->ne[1]); // [N, 1, hidden_size]
attn_out = attn->post_attention(ctx, attn_out);
attn2_out = attn2->post_attention(ctx, attn2_out);
x = ggml_add(ctx, x, ggml_mul(ctx, attn_out, gate_msa));
x = ggml_add(ctx, x, ggml_mul(ctx, attn2_out, gate_msa2));
auto mlp_out = mlp->forward(ctx, modulate(ctx, norm2->forward(ctx, x), shift_mlp, scale_mlp));
x = ggml_add(ctx, x, ggml_mul(ctx, mlp_out, gate_mlp));
return x;
}
struct ggml_tensor* post_attention(struct ggml_context* ctx,
struct ggml_tensor* attn_out,
struct ggml_tensor* x,
@ -357,29 +444,52 @@ public:
// return: [N, n_token, hidden_size]
auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks["attn"]);
if (self_attn) {
auto qkv_intermediates = pre_attention_x(ctx, x, c);
// auto qkv = qkv_intermediates.first;
// auto intermediates = qkv_intermediates.second;
// no longer a pair, but a tuple
auto qkv = std::get<0>(qkv_intermediates);
auto qkv2 = std::get<1>(qkv_intermediates);
auto intermediates = std::get<2>(qkv_intermediates);
auto qkv_intermediates = pre_attention(ctx, x, c);
auto qkv = qkv_intermediates.first;
auto intermediates = qkv_intermediates.second;
auto attn_out = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim]
auto attn2_out = ggml_nn_attention_ext(ctx, qkv2[0], qkv2[1], qkv2[2], num_heads); // [N, n_token, dim]
x = post_attention_x(ctx,
attn_out,
attn2_out,
intermediates[0],
intermediates[1],
intermediates[2],
intermediates[3],
intermediates[4],
intermediates[5]);
return x; // [N, n_token, dim]
} else {
auto qkv_intermediates = pre_attention(ctx, x, c);
auto qkv = qkv_intermediates.first;
auto intermediates = qkv_intermediates.second;
auto attn_out = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim]
x = post_attention(ctx,
attn_out,
intermediates[0],
intermediates[1],
intermediates[2],
intermediates[3],
intermediates[4]);
return x; // [N, n_token, dim]
auto attn_out = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim]
x = post_attention(ctx,
attn_out,
intermediates[0],
intermediates[1],
intermediates[2],
intermediates[3],
intermediates[4]);
return x; // [N, n_token, dim]
}
}
};
__STATIC_INLINE__ std::pair<struct ggml_tensor*, struct ggml_tensor*> block_mixing(struct ggml_context* ctx,
struct ggml_tensor* context,
struct ggml_tensor* x,
struct ggml_tensor* c,
std::shared_ptr<DismantledBlock> context_block,
std::shared_ptr<DismantledBlock> x_block) {
__STATIC_INLINE__ std::pair<struct ggml_tensor*, struct ggml_tensor*>
block_mixing(struct ggml_context* ctx,
struct ggml_tensor* context,
struct ggml_tensor* x,
struct ggml_tensor* c,
std::shared_ptr<DismantledBlock> context_block,
std::shared_ptr<DismantledBlock> x_block) {
// context: [N, n_context, hidden_size]
// x: [N, n_token, hidden_size]
// c: [N, hidden_size]
@ -387,10 +497,18 @@ __STATIC_INLINE__ std::pair<struct ggml_tensor*, struct ggml_tensor*> block_mixi
auto context_qkv = context_qkv_intermediates.first;
auto context_intermediates = context_qkv_intermediates.second;
auto x_qkv_intermediates = x_block->pre_attention(ctx, x, c);
auto x_qkv = x_qkv_intermediates.first;
auto x_intermediates = x_qkv_intermediates.second;
std::vector<ggml_tensor*> x_qkv, x_qkv2, x_intermediates;
if (x_block->self_attn) {
auto x_qkv_intermediates = x_block->pre_attention_x(ctx, x, c);
x_qkv = std::get<0>(x_qkv_intermediates);
x_qkv2 = std::get<1>(x_qkv_intermediates);
x_intermediates = std::get<2>(x_qkv_intermediates);
} else {
auto x_qkv_intermediates = x_block->pre_attention(ctx, x, c);
x_qkv = x_qkv_intermediates.first;
x_intermediates = x_qkv_intermediates.second;
}
std::vector<struct ggml_tensor*> qkv;
for (int i = 0; i < 3; i++) {
qkv.push_back(ggml_concat(ctx, context_qkv[i], x_qkv[i], 1));
@ -429,13 +547,27 @@ __STATIC_INLINE__ std::pair<struct ggml_tensor*, struct ggml_tensor*> block_mixi
context = NULL;
}
x = x_block->post_attention(ctx,
x_attn,
x_intermediates[0],
x_intermediates[1],
x_intermediates[2],
x_intermediates[3],
x_intermediates[4]);
if (x_block->self_attn) {
auto attn2 = ggml_nn_attention_ext(ctx, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads); // [N, n_token, hidden_size]
x = x_block->post_attention_x(ctx,
x_attn,
attn2,
x_intermediates[0],
x_intermediates[1],
x_intermediates[2],
x_intermediates[3],
x_intermediates[4],
x_intermediates[5]);
} else {
x = x_block->post_attention(ctx,
x_attn,
x_intermediates[0],
x_intermediates[1],
x_intermediates[2],
x_intermediates[3],
x_intermediates[4]);
}
return {context, x};
}
@ -447,9 +579,10 @@ public:
float mlp_ratio = 4.0,
std::string qk_norm = "",
bool qkv_bias = false,
bool pre_only = false) {
bool pre_only = false,
bool self_attn_x = false) {
blocks["context_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only));
blocks["x_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false));
blocks["x_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false, self_attn_x));
}
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx,
@ -507,6 +640,7 @@ protected:
int64_t input_size = -1;
int64_t patch_size = 2;
int64_t in_channels = 16;
int64_t d_self = -1; // >=0 for MMdiT-X
int64_t depth = 24;
float mlp_ratio = 4.0f;
int64_t adm_in_channels = 2048;
@ -561,6 +695,20 @@ public:
context_size = 4096;
context_embedder_out_dim = 2432;
qk_norm = "rms";
} else if (version == VERSION_SD3_5_2B) {
input_size = -1;
patch_size = 2;
in_channels = 16;
depth = 24;
d_self = 12;
mlp_ratio = 4.0f;
adm_in_channels = 2048;
out_channels = 16;
pos_embed_max_size = 384;
num_patchs = 147456;
context_size = 4096;
context_embedder_out_dim = 1536;
qk_norm = "rms";
}
int64_t default_out_channels = in_channels;
hidden_size = 64 * depth;
@ -581,15 +729,17 @@ public:
mlp_ratio,
qk_norm,
true,
i == depth - 1));
i == depth - 1,
i <= d_self));
}
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new FinalLayer(hidden_size, patch_size, out_channels));
}
struct ggml_tensor* cropped_pos_embed(struct ggml_context* ctx,
int64_t h,
int64_t w) {
struct ggml_tensor*
cropped_pos_embed(struct ggml_context* ctx,
int64_t h,
int64_t w) {
auto pos_embed = params["pos_embed"];
h = (h + 1) / patch_size;