diff --git a/clip.hpp b/clip.hpp index 4eec0241c..3fc656650 100644 --- a/clip.hpp +++ b/clip.hpp @@ -510,7 +510,7 @@ struct CLIPLayer : public GGMLBlock { blocks["mlp"] = std::shared_ptr(new CLIPMLP(d_model, intermediate_size)); } - struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, bool mask = true) { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* mask = nullptr) { // x: [N, n_token, d_model] auto self_attn = std::dynamic_pointer_cast(blocks["self_attn"]); auto layer_norm1 = std::dynamic_pointer_cast(blocks["layer_norm1"]); @@ -542,8 +542,8 @@ struct CLIPEncoder : public GGMLBlock { struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, - int clip_skip = -1, - bool mask = true) { + struct ggml_tensor* mask = nullptr, + int clip_skip = -1) { // x: [N, n_token, d_model] int layer_idx = n_layer - 1; // LOG_DEBUG("clip_skip %d", clip_skip); @@ -741,16 +741,17 @@ class CLIPTextModel : public GGMLBlock { struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* input_ids, struct ggml_tensor* tkn_embeddings, - size_t max_token_idx = 0, - bool return_pooled = false, - int clip_skip = -1) { + struct ggml_tensor* mask = nullptr, + size_t max_token_idx = 0, + bool return_pooled = false, + int clip_skip = -1) { // input_ids: [N, n_token] auto embeddings = std::dynamic_pointer_cast(blocks["embeddings"]); auto encoder = std::dynamic_pointer_cast(blocks["encoder"]); auto final_layer_norm = std::dynamic_pointer_cast(blocks["final_layer_norm"]); auto x = embeddings->forward(ctx, input_ids, tkn_embeddings); // [N, n_token, hidden_size] - x = encoder->forward(ctx, x, return_pooled ? -1 : clip_skip, true); + x = encoder->forward(ctx, x, mask, return_pooled ? -1 : clip_skip); if (return_pooled || with_final_ln) { x = final_layer_norm->forward(ctx, x); } @@ -814,10 +815,11 @@ class CLIPVisionModel : public GGMLBlock { auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim] x = pre_layernorm->forward(ctx, x); - x = encoder->forward(ctx, x, clip_skip, false); - // print_ggml_tensor(x, true, "ClipVisionModel x: "); + x = encoder->forward(ctx, x, nullptr, clip_skip); + auto last_hidden_state = x; - x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size] + + x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size] GGML_ASSERT(x->ne[3] == 1); if (return_pooled) { @@ -905,6 +907,8 @@ class CLIPVisionModelProjection : public GGMLBlock { struct CLIPTextModelRunner : public GGMLRunner { CLIPTextModel model; + std::vector attention_mask_vec; + CLIPTextModelRunner(ggml_backend_t backend, bool offload_params_to_cpu, const String2TensorStorage& tensor_storage_map, @@ -938,6 +942,7 @@ struct CLIPTextModelRunner : public GGMLRunner { struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* input_ids, struct ggml_tensor* embeddings, + struct ggml_tensor* mask, size_t max_token_idx = 0, bool return_pooled = false, int clip_skip = -1) { @@ -948,7 +953,7 @@ struct CLIPTextModelRunner : public GGMLRunner { input_ids = ggml_reshape_2d(ctx->ggml_ctx, input_ids, model.n_token, input_ids->ne[0] / model.n_token); } - return model.forward(ctx, input_ids, embeddings, max_token_idx, return_pooled, clip_skip); + return model.forward(ctx, input_ids, embeddings, mask, max_token_idx, return_pooled, clip_skip); } struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids, @@ -975,9 +980,23 @@ struct CLIPTextModelRunner : public GGMLRunner { embeddings = ggml_concat(compute_ctx, token_embed_weight, custom_embeddings, 1); } + int n_tokens = static_cast(input_ids->ne[0]); + attention_mask_vec.resize(n_tokens * n_tokens); + for (int i0 = 0; i0 < n_tokens; i0++) { + for (int i1 = 0; i1 < n_tokens; i1++) { + float value = 0.f; + if (i0 > i1) { + value = -INFINITY; + } + attention_mask_vec[i1 * n_tokens + i0] = value; + } + } + auto attention_mask = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, n_tokens, n_tokens); + set_backend_tensor_data(attention_mask, attention_mask_vec.data()); + auto runner_ctx = get_context(); - struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, embeddings, max_token_idx, return_pooled, clip_skip); + struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, embeddings, attention_mask, max_token_idx, return_pooled, clip_skip); ggml_build_forward_expand(gf, hidden_states); diff --git a/common.hpp b/common.hpp index 7183eb82e..d9c823df0 100644 --- a/common.hpp +++ b/common.hpp @@ -317,7 +317,7 @@ class CrossAttention : public GGMLBlock { auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim] auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim] - x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, inner_dim] + x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, inner_dim] x = to_out_0->forward(ctx, x); // [N, n_token, query_dim] return x; diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 0fb656dca..ab58ab5f0 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -604,10 +604,10 @@ int main(int argc, const char* argv[]) { if (gen_params.mask_image_path.size() > 0) { if (!load_sd_image_from_file(&mask_image, - gen_params.mask_image_path.c_str(), - gen_params.get_resolved_width(), - gen_params.get_resolved_height(), - 1)) { + gen_params.mask_image_path.c_str(), + gen_params.get_resolved_width(), + gen_params.get_resolved_height(), + 1)) { LOG_ERROR("load image from '%s' failed", gen_params.mask_image_path.c_str()); release_all_resources(); return 1; @@ -626,9 +626,9 @@ int main(int argc, const char* argv[]) { if (gen_params.control_image_path.size() > 0) { if (!load_sd_image_from_file(&control_image, - gen_params.control_image_path.c_str(), - gen_params.get_resolved_width(), - gen_params.get_resolved_height())) { + gen_params.control_image_path.c_str(), + gen_params.get_resolved_width(), + gen_params.get_resolved_height())) { LOG_ERROR("load image from '%s' failed", gen_params.control_image_path.c_str()); release_all_resources(); return 1; diff --git a/ggml_extend.hpp b/ggml_extend.hpp index fedab3809..7dac03738 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -1257,7 +1257,6 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context struct ggml_tensor* v, int64_t n_head, struct ggml_tensor* mask = nullptr, - bool diag_mask_inf = false, bool skip_reshape = false, bool flash_attn = false, float kv_scale = 1.0f) { // avoid overflow @@ -1385,9 +1384,6 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context if (mask) { kq = ggml_add_inplace(ctx, kq, mask); } - if (diag_mask_inf) { - kq = ggml_diag_mask_inf_inplace(ctx, kq, 0); - } kq = ggml_soft_max_inplace(ctx, kq); kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, L_q, d_head] @@ -2604,7 +2600,7 @@ class MultiheadAttention : public GGMLBlock { // x: [N, n_token, embed_dim] struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, - bool mask = false) { + struct ggml_tensor* mask = nullptr) { auto out_proj = std::dynamic_pointer_cast(blocks[out_proj_name]); ggml_tensor* q; @@ -2627,7 +2623,7 @@ class MultiheadAttention : public GGMLBlock { v = v_proj->forward(ctx, x); } - x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, mask); // [N, n_token, embed_dim] + x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, mask); // [N, n_token, embed_dim] x = out_proj->forward(ctx, x); // [N, n_token, embed_dim] return x; diff --git a/llm.hpp b/llm.hpp index 7feb8d3c8..315557510 100644 --- a/llm.hpp +++ b/llm.hpp @@ -881,7 +881,7 @@ namespace LLM { k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 2, 1, 3)); // [N, num_kv_heads, n_token, head_dim] k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]); // [N*num_kv_heads, n_token, head_dim] - x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, attention_mask, false, true, false); // [N, n_token, hidden_size] + x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, attention_mask, true, false); // [N, n_token, hidden_size] x = out_proj->forward(ctx, x); // [N, n_token, hidden_size] return x; diff --git a/mmdit.hpp b/mmdit.hpp index 086b444dc..726f60c2f 100644 --- a/mmdit.hpp +++ b/mmdit.hpp @@ -211,8 +211,8 @@ class SelfAttention : public GGMLBlock { struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { auto qkv = pre_attention(ctx, x); - x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim] - x = post_attention(ctx, x); // [N, n_token, dim] + x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim] + x = post_attention(ctx, x); // [N, n_token, dim] return x; } }; @@ -433,8 +433,8 @@ struct DismantledBlock : public GGMLBlock { auto qkv2 = std::get<1>(qkv_intermediates); auto intermediates = std::get<2>(qkv_intermediates); - auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim] - auto attn2_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv2[0], qkv2[1], qkv2[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim] + auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim] + auto attn2_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv2[0], qkv2[1], qkv2[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim] x = post_attention_x(ctx, attn_out, attn2_out, @@ -450,7 +450,7 @@ struct DismantledBlock : public GGMLBlock { auto qkv = qkv_intermediates.first; auto intermediates = qkv_intermediates.second; - auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim] + auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim] x = post_attention(ctx, attn_out, intermediates[0], @@ -494,7 +494,7 @@ block_mixing(GGMLRunnerContext* ctx, qkv.push_back(ggml_concat(ctx->ggml_ctx, context_qkv[i], x_qkv[i], 1)); } - auto attn = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_context + n_token, hidden_size] + auto attn = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_context + n_token, hidden_size] auto context_attn = ggml_view_3d(ctx->ggml_ctx, attn, @@ -526,7 +526,7 @@ block_mixing(GGMLRunnerContext* ctx, } if (x_block->self_attn) { - auto attn2 = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, hidden_size] + auto attn2 = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, hidden_size] x = x_block->post_attention_x(ctx, x_attn, diff --git a/rope.hpp b/rope.hpp index 2d123b3cc..45e88c831 100644 --- a/rope.hpp +++ b/rope.hpp @@ -642,7 +642,7 @@ namespace Rope { q = apply_rope(ctx->ggml_ctx, q, pe, rope_interleaved); // [N*n_head, L, d_head] k = apply_rope(ctx->ggml_ctx, k, pe, rope_interleaved); // [N*n_head, L, d_head] - auto x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, v->ne[1], mask, false, true, ctx->flash_attn_enabled, kv_scale); // [N, L, n_head*d_head] + auto x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, v->ne[1], mask, true, ctx->flash_attn_enabled, kv_scale); // [N, L, n_head*d_head] return x; } }; // namespace Rope diff --git a/vae.hpp b/vae.hpp index fdddc8ae5..01b99e89b 100644 --- a/vae.hpp +++ b/vae.hpp @@ -141,7 +141,7 @@ class AttnBlock : public UnaryBlock { v = ggml_reshape_3d(ctx->ggml_ctx, v, c, h * w, n); // [N, h * w, in_channels] } - h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, true, false); + h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, false); if (use_linear) { h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels] diff --git a/wan.hpp b/wan.hpp index c56e1f926..81959efcf 100644 --- a/wan.hpp +++ b/wan.hpp @@ -572,8 +572,8 @@ namespace WAN { auto v = qkv_vec[2]; v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [t, c, h * w] - v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [t, h * w, c] - x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, true, false); // [t, h * w, c] + v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [t, h * w, c] + x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, false); // [t, h * w, c] x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [t, c, h * w] x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, c, n); // [t, c, h, w] @@ -1393,7 +1393,7 @@ namespace WAN { k = norm_k->forward(ctx, k); auto v = v_proj->forward(ctx, context); // [N, n_context, dim] - x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim] + x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim] x = o_proj->forward(ctx, x); // [N, n_token, dim] return x; @@ -1455,8 +1455,8 @@ namespace WAN { k_img = norm_k_img->forward(ctx, k_img); auto v_img = v_img_proj->forward(ctx, context_img); // [N, context_img_len, dim] - auto img_x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k_img, v_img, num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim] - x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim] + auto img_x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k_img, v_img, num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim] + x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim] x = ggml_add(ctx->ggml_ctx, x, img_x);