Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ add_library(llama
models/gemma3n-iswa.cpp
models/glm4-moe.cpp
models/glm4.cpp
models/glm4v.cpp
models/gpt2.cpp
models/gptneox.cpp
models/granite-hybrid.cpp
Expand Down
19 changes: 19 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_CHATGLM, "chatglm" },
{ LLM_ARCH_GLM4, "glm4" },
{ LLM_ARCH_GLM4_MOE, "glm4moe" },
{ LLM_ARCH_GLM4V, "glm4v" },
{ LLM_ARCH_BITNET, "bitnet" },
{ LLM_ARCH_T5, "t5" },
{ LLM_ARCH_T5ENCODER, "t5encoder" },
Expand Down Expand Up @@ -1634,6 +1635,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.nextn.shared_head_norm" },
},
},
{
LLM_ARCH_GLM4V,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
{
LLM_ARCH_BITNET,
{
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ enum llm_arch {
LLM_ARCH_CHATGLM,
LLM_ARCH_GLM4,
LLM_ARCH_GLM4_MOE,
LLM_ARCH_GLM4V,
LLM_ARCH_BITNET,
LLM_ARCH_T5,
LLM_ARCH_T5ENCODER,
Expand Down
53 changes: 53 additions & 0 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1723,6 +1723,15 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_GLM4V:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true);
switch (hparams.n_layer) {
case 40: type = LLM_TYPE_9B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_BITNET:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
Expand Down Expand Up @@ -5142,6 +5151,45 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
}
}
break;
case LLM_ARCH_GLM4V:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);

// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
// if output is NULL, init from the input tok embed
if (output == NULL) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
}

for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];

layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);

if (layer.wqkv == nullptr) {
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
}

layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);

layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);

layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0);

layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
}
} break;
case LLM_ARCH_NEMOTRON:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
Expand Down Expand Up @@ -7423,6 +7471,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{
llm = std::make_unique<llm_build_glm4_moe>(*this, params);
} break;
case LLM_ARCH_GLM4V:
{
llm = std::make_unique<llm_build_glm4v>(*this, params);
} break;
case LLM_ARCH_BITNET:
{
llm = std::make_unique<llm_build_bitnet>(*this, params);
Expand Down Expand Up @@ -7832,6 +7884,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
return LLAMA_ROPE_TYPE_NEOX;

case LLM_ARCH_QWEN2VL:
case LLM_ARCH_GLM4V:
return LLAMA_ROPE_TYPE_MROPE;
case LLM_ARCH_QWEN3VL:
case LLM_ARCH_QWEN3VLMOE:
Expand Down
133 changes: 133 additions & 0 deletions src/models/glm4v.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
#include "models.h"

llm_build_glm4v::llm_build_glm4v(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();

GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);

ggml_tensor * cur;
ggml_tensor * inpL;

inpL = build_inp_embd(model.tok_embd);

// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();

auto * inp_attn = build_attn_inp_kv();

// M-RoPE sections from hparams
int sections[4];
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);

ggml_tensor * inp_out_ids = build_inp_out_ids();

for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;

// Pre-attention norm
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);

// self-attention
{
ggml_tensor * Qcur = nullptr;
ggml_tensor * Kcur = nullptr;
ggml_tensor * Vcur = nullptr;

if (model.layers[il].wqkv == nullptr) {
Qcur = build_lora_mm(model.layers[il].wq, cur);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
}
Kcur = build_lora_mm(model.layers[il].wk, cur);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
}
Vcur = build_lora_mm(model.layers[il].wv, cur);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
}
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
} else {
cur = build_lora_mm(model.layers[il].wqkv, cur);
cb(cur, "wqkv", il);
if (model.layers[il].bqkv) {
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
cb(cur, "bqkv", il);
}
Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), cur->nb[1],
0 * sizeof(float) * (n_embd));
Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float),
cur->nb[1], 1 * sizeof(float) * (n_embd));
Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float),
cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa));
}

// GLM4V uses M-RoPE (multi-dimensional rotary position embeddings)
Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr,
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);

Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr,
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);

cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);

cur = build_attn(inp_attn,
model.layers[il].wo, NULL,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
}
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
// Post-attention norm
cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "post_attn_norm", il);

// Add the input (residual connection after post-attention norm)
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);

// FF
{
// Pre-MLP norm
cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);

// MLP
cur = build_ffn(cur,
model.layers[il].ffn_up, NULL, NULL,
NULL, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL, LLM_FFN_SWIGLU, LLM_FFN_SEQ, il);
cb(cur, "ffn_out", il);

// Post-MLP norm
cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "post_mlp_norm", il);
}
// Add residual connection after post-MLP norm
inpL = ggml_add(ctx0, cur, ffn_inp);
cb(inpL, "l_out", il);
}
// Final norm
cur = build_norm(inpL, model.output_norm, NULL, LLM_NORM_RMS, -1);

cb(cur, "result_norm", -1);
res->t_embd = cur;

// Output projection
cur = build_lora_mm(model.output, cur);

cb(cur, "result_output", -1);
res->t_logits = cur;

ggml_build_forward_expand(gf, cur);
}
4 changes: 4 additions & 0 deletions src/models/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,10 @@ struct llm_build_glm4_moe : public llm_graph_context {
llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params);
};

struct llm_build_glm4v : public llm_graph_context {
llm_build_glm4v(const llama_model & model, const llm_graph_params & params);
};

struct llm_build_gpt2 : public llm_graph_context {
llm_build_gpt2(const llama_model & model, const llm_graph_params & params);
};
Expand Down
1 change: 1 addition & 0 deletions tools/mtmd/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ add_library(mtmd
clip-graph.h
models/models.h
models/cogvlm.cpp
models/glm4v.cpp
models/internvl.cpp
models/kimivl.cpp
models/llama4.cpp
Expand Down
11 changes: 11 additions & 0 deletions tools/mtmd/clip-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,15 @@
#define TN_TOK_BOI "v.boi"
#define TN_TOK_EOI "v.eoi"

// glm4v
#define TN_GLM4V_POST_CONV_LN "mm.post_conv_ln.%s"
#define TN_GLM4V_DOWNSAMPLE "mm.downsample.%s"
#define TN_GLM4V_MERGER_PROJ "mm.merger.proj.%s"
#define TN_GLM4V_MERGER_NORM "mm.merger.norm.%s"
#define TN_GLM4V_MERGER_GATE "mm.merger.gate.%s"
#define TN_GLM4V_MERGER_UP "mm.merger.up.%s"
#define TN_GLM4V_MERGER_DOWN "mm.merger.down.%s"

// align x to upper multiple of n
#define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))

Expand Down Expand Up @@ -164,6 +173,7 @@ enum projector_type {
PROJECTOR_TYPE_LIGHTONOCR,
PROJECTOR_TYPE_COGVLM,
PROJECTOR_TYPE_JANUS_PRO,
PROJECTOR_TYPE_GLM4V,
PROJECTOR_TYPE_UNKNOWN,
};

Expand All @@ -190,6 +200,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"},
{ PROJECTOR_TYPE_COGVLM, "cogvlm"},
{ PROJECTOR_TYPE_JANUS_PRO, "janus_pro"},
{ PROJECTOR_TYPE_GLM4V, "glm4v"},
};

static projector_type clip_projector_type_from_string(const std::string & str) {
Expand Down
16 changes: 16 additions & 0 deletions tools/mtmd/clip-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,22 @@ struct clip_model {
ggml_tensor * mm_boi = nullptr;
ggml_tensor * mm_eoi = nullptr;

// glm4v
ggml_tensor * mm_post_conv_ln_w = nullptr;
ggml_tensor * mm_post_conv_ln_b = nullptr;
ggml_tensor * mm_downsample_w = nullptr;
ggml_tensor * mm_downsample_b = nullptr;
ggml_tensor * mm_merger_proj_w = nullptr;
ggml_tensor * mm_merger_proj_b = nullptr;
ggml_tensor * mm_merger_norm_w = nullptr;
ggml_tensor * mm_merger_norm_b = nullptr;
ggml_tensor * mm_merger_gate_w = nullptr;
ggml_tensor * mm_merger_gate_b = nullptr;
ggml_tensor * mm_merger_up_w = nullptr;
ggml_tensor * mm_merger_up_b = nullptr;
ggml_tensor * mm_merger_down_w = nullptr;
ggml_tensor * mm_merger_down_b = nullptr;

bool audio_has_avgpool() const {
return proj_type == PROJECTOR_TYPE_QWEN2A
|| proj_type == PROJECTOR_TYPE_VOXTRAL;
Expand Down
Loading