Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 31 additions & 12 deletions clip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ struct CLIPLayer : public GGMLBlock {
blocks["mlp"] = std::shared_ptr<GGMLBlock>(new CLIPMLP(d_model, intermediate_size));
}

struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, bool mask = true) {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* mask = nullptr) {
// x: [N, n_token, d_model]
auto self_attn = std::dynamic_pointer_cast<MultiheadAttention>(blocks["self_attn"]);
auto layer_norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["layer_norm1"]);
Expand Down Expand Up @@ -542,8 +542,8 @@ struct CLIPEncoder : public GGMLBlock {

struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
int clip_skip = -1,
bool mask = true) {
struct ggml_tensor* mask = nullptr,
int clip_skip = -1) {
// x: [N, n_token, d_model]
int layer_idx = n_layer - 1;
// LOG_DEBUG("clip_skip %d", clip_skip);
Expand Down Expand Up @@ -741,16 +741,17 @@ class CLIPTextModel : public GGMLBlock {
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* input_ids,
struct ggml_tensor* tkn_embeddings,
size_t max_token_idx = 0,
bool return_pooled = false,
int clip_skip = -1) {
struct ggml_tensor* mask = nullptr,
size_t max_token_idx = 0,
bool return_pooled = false,
int clip_skip = -1) {
// input_ids: [N, n_token]
auto embeddings = std::dynamic_pointer_cast<CLIPEmbeddings>(blocks["embeddings"]);
auto encoder = std::dynamic_pointer_cast<CLIPEncoder>(blocks["encoder"]);
auto final_layer_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["final_layer_norm"]);

auto x = embeddings->forward(ctx, input_ids, tkn_embeddings); // [N, n_token, hidden_size]
x = encoder->forward(ctx, x, return_pooled ? -1 : clip_skip, true);
x = encoder->forward(ctx, x, mask, return_pooled ? -1 : clip_skip);
if (return_pooled || with_final_ln) {
x = final_layer_norm->forward(ctx, x);
}
Expand Down Expand Up @@ -814,10 +815,11 @@ class CLIPVisionModel : public GGMLBlock {

auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim]
x = pre_layernorm->forward(ctx, x);
x = encoder->forward(ctx, x, clip_skip, false);
// print_ggml_tensor(x, true, "ClipVisionModel x: ");
x = encoder->forward(ctx, x, nullptr, clip_skip);

auto last_hidden_state = x;
x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size]

x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size]

GGML_ASSERT(x->ne[3] == 1);
if (return_pooled) {
Expand Down Expand Up @@ -905,6 +907,8 @@ class CLIPVisionModelProjection : public GGMLBlock {
struct CLIPTextModelRunner : public GGMLRunner {
CLIPTextModel model;

std::vector<float> attention_mask_vec;

CLIPTextModelRunner(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2TensorStorage& tensor_storage_map,
Expand Down Expand Up @@ -938,6 +942,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* input_ids,
struct ggml_tensor* embeddings,
struct ggml_tensor* mask,
size_t max_token_idx = 0,
bool return_pooled = false,
int clip_skip = -1) {
Expand All @@ -948,7 +953,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
input_ids = ggml_reshape_2d(ctx->ggml_ctx, input_ids, model.n_token, input_ids->ne[0] / model.n_token);
}

return model.forward(ctx, input_ids, embeddings, max_token_idx, return_pooled, clip_skip);
return model.forward(ctx, input_ids, embeddings, mask, max_token_idx, return_pooled, clip_skip);
}

struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids,
Expand All @@ -975,9 +980,23 @@ struct CLIPTextModelRunner : public GGMLRunner {
embeddings = ggml_concat(compute_ctx, token_embed_weight, custom_embeddings, 1);
}

int n_tokens = static_cast<int>(input_ids->ne[0]);
attention_mask_vec.resize(n_tokens * n_tokens);
for (int i0 = 0; i0 < n_tokens; i0++) {
for (int i1 = 0; i1 < n_tokens; i1++) {
float value = 0.f;
if (i0 > i1) {
value = -INFINITY;
}
attention_mask_vec[i1 * n_tokens + i0] = value;
}
}
auto attention_mask = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, n_tokens, n_tokens);
set_backend_tensor_data(attention_mask, attention_mask_vec.data());

auto runner_ctx = get_context();

struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, embeddings, max_token_idx, return_pooled, clip_skip);
struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, embeddings, attention_mask, max_token_idx, return_pooled, clip_skip);

ggml_build_forward_expand(gf, hidden_states);

Expand Down
2 changes: 1 addition & 1 deletion common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ class CrossAttention : public GGMLBlock {
auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim]
auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim]

x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, inner_dim]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, inner_dim]

x = to_out_0->forward(ctx, x); // [N, n_token, query_dim]
return x;
Expand Down
14 changes: 7 additions & 7 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -604,10 +604,10 @@ int main(int argc, const char* argv[]) {

if (gen_params.mask_image_path.size() > 0) {
if (!load_sd_image_from_file(&mask_image,
gen_params.mask_image_path.c_str(),
gen_params.get_resolved_width(),
gen_params.get_resolved_height(),
1)) {
gen_params.mask_image_path.c_str(),
gen_params.get_resolved_width(),
gen_params.get_resolved_height(),
1)) {
LOG_ERROR("load image from '%s' failed", gen_params.mask_image_path.c_str());
release_all_resources();
return 1;
Expand All @@ -626,9 +626,9 @@ int main(int argc, const char* argv[]) {

if (gen_params.control_image_path.size() > 0) {
if (!load_sd_image_from_file(&control_image,
gen_params.control_image_path.c_str(),
gen_params.get_resolved_width(),
gen_params.get_resolved_height())) {
gen_params.control_image_path.c_str(),
gen_params.get_resolved_width(),
gen_params.get_resolved_height())) {
LOG_ERROR("load image from '%s' failed", gen_params.control_image_path.c_str());
release_all_resources();
return 1;
Expand Down
8 changes: 2 additions & 6 deletions ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1257,7 +1257,6 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
struct ggml_tensor* v,
int64_t n_head,
struct ggml_tensor* mask = nullptr,
bool diag_mask_inf = false,
bool skip_reshape = false,
bool flash_attn = false,
float kv_scale = 1.0f) { // avoid overflow
Expand Down Expand Up @@ -1385,9 +1384,6 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
if (mask) {
kq = ggml_add_inplace(ctx, kq, mask);
}
if (diag_mask_inf) {
kq = ggml_diag_mask_inf_inplace(ctx, kq, 0);
}
kq = ggml_soft_max_inplace(ctx, kq);

kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, L_q, d_head]
Expand Down Expand Up @@ -2604,7 +2600,7 @@ class MultiheadAttention : public GGMLBlock {
// x: [N, n_token, embed_dim]
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
bool mask = false) {
struct ggml_tensor* mask = nullptr) {
auto out_proj = std::dynamic_pointer_cast<Linear>(blocks[out_proj_name]);

ggml_tensor* q;
Expand All @@ -2627,7 +2623,7 @@ class MultiheadAttention : public GGMLBlock {
v = v_proj->forward(ctx, x);
}

x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, mask); // [N, n_token, embed_dim]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, mask); // [N, n_token, embed_dim]

x = out_proj->forward(ctx, x); // [N, n_token, embed_dim]
return x;
Expand Down
2 changes: 1 addition & 1 deletion llm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,7 @@ namespace LLM {
k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 2, 1, 3)); // [N, num_kv_heads, n_token, head_dim]
k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]); // [N*num_kv_heads, n_token, head_dim]

x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, attention_mask, false, true, false); // [N, n_token, hidden_size]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, attention_mask, true, false); // [N, n_token, hidden_size]

x = out_proj->forward(ctx, x); // [N, n_token, hidden_size]
return x;
Expand Down
14 changes: 7 additions & 7 deletions mmdit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ class SelfAttention : public GGMLBlock {
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x) {
auto qkv = pre_attention(ctx, x);
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
x = post_attention(ctx, x); // [N, n_token, dim]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
x = post_attention(ctx, x); // [N, n_token, dim]
return x;
}
};
Expand Down Expand Up @@ -433,8 +433,8 @@ struct DismantledBlock : public GGMLBlock {
auto qkv2 = std::get<1>(qkv_intermediates);
auto intermediates = std::get<2>(qkv_intermediates);

auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
auto attn2_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv2[0], qkv2[1], qkv2[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
auto attn2_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv2[0], qkv2[1], qkv2[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
x = post_attention_x(ctx,
attn_out,
attn2_out,
Expand All @@ -450,7 +450,7 @@ struct DismantledBlock : public GGMLBlock {
auto qkv = qkv_intermediates.first;
auto intermediates = qkv_intermediates.second;

auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
x = post_attention(ctx,
attn_out,
intermediates[0],
Expand Down Expand Up @@ -494,7 +494,7 @@ block_mixing(GGMLRunnerContext* ctx,
qkv.push_back(ggml_concat(ctx->ggml_ctx, context_qkv[i], x_qkv[i], 1));
}

auto attn = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_context + n_token, hidden_size]
auto attn = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_context + n_token, hidden_size]

auto context_attn = ggml_view_3d(ctx->ggml_ctx,
attn,
Expand Down Expand Up @@ -526,7 +526,7 @@ block_mixing(GGMLRunnerContext* ctx,
}

if (x_block->self_attn) {
auto attn2 = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, hidden_size]
auto attn2 = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, hidden_size]

x = x_block->post_attention_x(ctx,
x_attn,
Expand Down
2 changes: 1 addition & 1 deletion rope.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ namespace Rope {
q = apply_rope(ctx->ggml_ctx, q, pe, rope_interleaved); // [N*n_head, L, d_head]
k = apply_rope(ctx->ggml_ctx, k, pe, rope_interleaved); // [N*n_head, L, d_head]

auto x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, v->ne[1], mask, false, true, ctx->flash_attn_enabled, kv_scale); // [N, L, n_head*d_head]
auto x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, v->ne[1], mask, true, ctx->flash_attn_enabled, kv_scale); // [N, L, n_head*d_head]
return x;
}
}; // namespace Rope
Expand Down
2 changes: 1 addition & 1 deletion vae.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class AttnBlock : public UnaryBlock {
v = ggml_reshape_3d(ctx->ggml_ctx, v, c, h * w, n); // [N, h * w, in_channels]
}

h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, true, false);
h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, false);

if (use_linear) {
h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels]
Expand Down
10 changes: 5 additions & 5 deletions wan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -572,8 +572,8 @@ namespace WAN {
auto v = qkv_vec[2];
v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [t, c, h * w]

v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [t, h * w, c]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, true, false); // [t, h * w, c]
v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [t, h * w, c]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, false); // [t, h * w, c]

x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, c, n); // [t, c, h, w]
Expand Down Expand Up @@ -1393,7 +1393,7 @@ namespace WAN {
k = norm_k->forward(ctx, k);
auto v = v_proj->forward(ctx, context); // [N, n_context, dim]

x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]

x = o_proj->forward(ctx, x); // [N, n_token, dim]
return x;
Expand Down Expand Up @@ -1455,8 +1455,8 @@ namespace WAN {
k_img = norm_k_img->forward(ctx, k_img);
auto v_img = v_img_proj->forward(ctx, context_img); // [N, context_img_len, dim]

auto img_x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k_img, v_img, num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
auto img_x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k_img, v_img, num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]

x = ggml_add(ctx->ggml_ctx, x, img_x);

Expand Down
Loading