Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
377 changes: 214 additions & 163 deletions conditioner.hpp

Large diffs are not rendered by default.

8 changes: 6 additions & 2 deletions diffusion_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ struct DiffusionParams {
struct ggml_tensor* vace_context = nullptr;
float vace_strength = 1.f;
std::vector<int> skip_layers = {};
std::vector<struct ggml_tensor*> extra_contexts; // for z-image-omni
std::vector<struct ggml_tensor*> ref_clip_feats; // for z-image-omni
};

struct DiffusionModel {
Expand Down Expand Up @@ -436,12 +438,14 @@ struct ZImageModel : public DiffusionModel {
DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr,
struct ggml_context* output_ctx = nullptr) override {
std::vector<ggml_tensor*> contexts = {diffusion_params.context};
contexts.insert(contexts.end(), diffusion_params.extra_contexts.begin(), diffusion_params.extra_contexts.end());
return z_image.compute(n_threads,
diffusion_params.x,
diffusion_params.timesteps,
diffusion_params.context,
contexts,
diffusion_params.ref_latents,
true, // increase_ref_index
diffusion_params.ref_clip_feats,
output,
output_ctx);
}
Expand Down
14 changes: 13 additions & 1 deletion model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1039,6 +1039,8 @@ SDVersion ModelLoader::get_sd_version() {
bool is_xl = false;
bool is_flux = false;
bool is_flux2 = false;
bool is_z_image = false;
bool is_z_image_omni = false;
bool has_single_block_47 = false;
bool is_wan = false;
int64_t patch_embedding_channels = 0;
Expand Down Expand Up @@ -1071,7 +1073,10 @@ SDVersion ModelLoader::get_sd_version() {
return VERSION_OVIS_IMAGE;
}
if (tensor_storage.name.find("model.diffusion_model.cap_embedder.0.weight") != std::string::npos) {
return VERSION_Z_IMAGE;
is_z_image = true;
}
if (tensor_storage.name.find("model.diffusion_model.siglip_embedder.0.weight") != std::string::npos) {
is_z_image_omni = true;
}
if (tensor_storage.name.find("model.diffusion_model.blocks.0.cross_attn.norm_k.weight") != std::string::npos) {
is_wan = true;
Expand Down Expand Up @@ -1174,6 +1179,13 @@ SDVersion ModelLoader::get_sd_version() {
return VERSION_FLUX2_KLEIN;
}

if (is_z_image) {
if (is_z_image_omni) {
return VERSION_Z_IMAGE_OMNI;
}
return VERSION_Z_IMAGE;
}

if (token_embedding_weight.ne[0] == 768) {
if (is_inpaint) {
return VERSION_SD1_INPAINT;
Expand Down
3 changes: 2 additions & 1 deletion model.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ enum SDVersion {
VERSION_FLUX2,
VERSION_FLUX2_KLEIN,
VERSION_Z_IMAGE,
VERSION_Z_IMAGE_OMNI,
VERSION_OVIS_IMAGE,
VERSION_COUNT,
};
Expand Down Expand Up @@ -123,7 +124,7 @@ static inline bool sd_version_is_qwen_image(SDVersion version) {
}

static inline bool sd_version_is_z_image(SDVersion version) {
if (version == VERSION_Z_IMAGE) {
if (version == VERSION_Z_IMAGE || version == VERSION_Z_IMAGE_OMNI) {
return true;
}
return false;
Expand Down
111 changes: 84 additions & 27 deletions rope.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -518,60 +518,117 @@ namespace Rope {
return (m - (a % m)) % m;
}

__STATIC_INLINE__ std::vector<std::vector<float>> gen_z_image_ids(int h,
int w,
__STATIC_INLINE__ std::vector<std::vector<float>> gen_z_image_ids(ggml_tensor* x,
const std::vector<ggml_tensor*>& contexts,
const std::vector<ggml_tensor*>& ref_latents,
const std::vector<ggml_tensor*>& siglip_feats,
int patch_size,
int bs,
int context_len,
int seq_multi_of,
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index) {
int padded_context_len = context_len + bound_mod(context_len, seq_multi_of);
auto txt_ids = std::vector<std::vector<float>>(bs * padded_context_len, std::vector<float>(3, 0.0f));
for (int i = 0; i < bs * padded_context_len; i++) {
txt_ids[i][0] = (i % padded_context_len) + 1.f;
int bs) {
GGML_ASSERT(contexts.size() > ref_latents.size());
GGML_ASSERT(contexts.size() >= siglip_feats.size());
int context_cu_len = 1;
std::vector<int> context_end_pos;
std::vector<std::vector<float>> txt_ids;
for (auto context : contexts) {
int padded_context_len = static_cast<int>(context->ne[1]) + bound_mod(static_cast<int>(context->ne[1]), seq_multi_of);
auto curr_txt_ids = std::vector<std::vector<float>>(bs * padded_context_len, std::vector<float>(3, 0.0f));
for (int i = 0; i < bs * padded_context_len; i++) {
curr_txt_ids[i][0] = static_cast<float>((i % padded_context_len) + context_cu_len);
}
context_cu_len += padded_context_len;
context_end_pos.push_back(context_cu_len);
context_cu_len += 2; // for image and siglip tokens
txt_ids = concat_ids(txt_ids, curr_txt_ids, bs);
}

int axes_dim_num = 3;
int index = padded_context_len + 1;
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, index);
std::vector<std::vector<float>> img_ids;
std::vector<ggml_tensor*> all_img = ref_latents;
all_img.push_back(x);
for (int i = 0; i < all_img.size(); i++) {
int axes_dim_num = 3;
int index = context_end_pos[i];
auto curr_img_ids = gen_flux_img_ids(static_cast<int>(all_img[i]->ne[1]), static_cast<int>(all_img[i]->ne[0]), patch_size, bs, axes_dim_num, index);

int img_pad_len = bound_mod(static_cast<int>(curr_img_ids.size() / bs), seq_multi_of);
if (img_pad_len > 0) {
std::vector<std::vector<float>> img_pad_ids(bs * img_pad_len, std::vector<float>(3, 0.f));
curr_img_ids = concat_ids(curr_img_ids, img_pad_ids, bs);
}
img_ids = concat_ids(img_ids, curr_img_ids, bs);
}

std::vector<std::vector<float>> sig_ids;
for (int i = 0; i < siglip_feats.size(); i++) {
int axes_dim_num = 3;
int index = context_end_pos[i] + 1;
int h_len = static_cast<int>(siglip_feats[i]->ne[1]);
int w_len = static_cast<int>(siglip_feats[i]->ne[0]);

std::vector<std::vector<float>> curr_sig_ids(bs * h_len * w_len, std::vector<float>(axes_dim_num, 0.0));

// scale position IDs to match img resolution
std::vector<float> row_ids = linspace<float>(0, static_cast<float>(all_img[i]->ne[1]) - 1.f, h_len);
std::vector<float> col_ids = linspace<float>(0, static_cast<float>(all_img[i]->ne[0]) - 1.f, w_len);

for (int ib = 0; ib < bs; ++ib) {
for (int ih = 0; ih < h_len; ++ih) {
for (int iw = 0; iw < w_len; ++iw) {
curr_sig_ids[ib * h_len * w_len + ih * w_len + iw][0] = static_cast<float>(index);
curr_sig_ids[ib * h_len * w_len + ih * w_len + iw][1] = row_ids[ih];
curr_sig_ids[ib * h_len * w_len + ih * w_len + iw][2] = col_ids[iw];
}
}
}

int img_pad_len = bound_mod(static_cast<int>(img_ids.size() / bs), seq_multi_of);
if (img_pad_len > 0) {
std::vector<std::vector<float>> img_pad_ids(bs * img_pad_len, std::vector<float>(3, 0.f));
img_ids = concat_ids(img_ids, img_pad_ids, bs);
int sig_pad_len = bound_mod(static_cast<int>(curr_sig_ids.size() / bs), seq_multi_of);
if (sig_pad_len > 0) {
std::vector<std::vector<float>> sig_pad_ids(bs * sig_pad_len, std::vector<float>(3, 0.f));
curr_sig_ids = concat_ids(curr_sig_ids, sig_pad_ids, bs);
}
sig_ids = concat_ids(sig_ids, curr_sig_ids, bs);
}

auto ids = concat_ids(txt_ids, img_ids, bs);

// ignore ref_latents for now
if (!sig_ids.empty()) {
ids = concat_ids(ids, sig_ids, bs);
}

return ids;
}

// Generate z_image positional embeddings
__STATIC_INLINE__ std::vector<float> gen_z_image_pe(int h,
int w,
__STATIC_INLINE__ std::vector<float> gen_z_image_pe(ggml_tensor* x,
const std::vector<ggml_tensor*>& contexts,
const std::vector<ggml_tensor*>& ref_latents,
const std::vector<ggml_tensor*>& siglip_feats,
int patch_size,
int bs,
int context_len,
int seq_multi_of,
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index,
int theta,
const std::vector<int>& axes_dim,
bool circular_h,
bool circular_w,
const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_z_image_ids(h, w, patch_size, bs, context_len, seq_multi_of, ref_latents, increase_ref_index);
int bs) {
std::vector<std::vector<float>> ids = gen_z_image_ids(x, contexts, ref_latents, siglip_feats, patch_size, seq_multi_of, bs);
std::vector<std::vector<int>> wrap_dims;
if ((circular_h || circular_w) && bs > 0 && axes_dim.size() >= 3) {
int context_len = 0;
for (auto context : contexts) {
int padded_context_len = static_cast<int>(context->ne[1]) + bound_mod(static_cast<int>(context->ne[1]), seq_multi_of);
context_len += padded_context_len;
}
int h = static_cast<int>(x->ne[1]);
int w = static_cast<int>(x->ne[0]);
int pad_h = (patch_size - (h % patch_size)) % patch_size;
int pad_w = (patch_size - (w % patch_size)) % patch_size;
int h_len = (h + pad_h) / patch_size;
int w_len = (w + pad_w) / patch_size;

if (h_len > 0 && w_len > 0) {
size_t pos_len = ids.size() / bs;
wrap_dims.assign(axes_dim.size(), std::vector<int>(pos_len, 0));
size_t cursor = context_len + bound_mod(context_len, seq_multi_of); // skip text (and its padding)
size_t cursor = context_len; // skip text (and its padding)
size_t img_tokens = static_cast<size_t>(h_len) * static_cast<size_t>(w_len);
for (size_t token_i = 0; token_i < img_tokens; ++token_i) {
if (circular_h) {
Expand Down
Loading
Loading