Skip to content

Commit ccf0bb5

Browse files
eelbazclaude
andcommitted
mtmd, llama: add GLM4V vision-language model support
Add complete GLM4V (GLM-4.6V-Flash) support including: **Vision Encoder (mtmd):** - Dual Conv2D patch embedding (simulating Conv3D temporal reduction) - M-RoPE using ggml_rope_multi() with [h,w,h,w] position pattern - 2x2 patch merger with downsample convolution - SwiGLU-based merger FFN **LLM Architecture (libllama):** - New LLM_ARCH_GLM4V based on GLM4 with M-RoPE - Uses LLAMA_ROPE_TYPE_MROPE with rope_sections from model config - Reuses ggml_rope_multi() (same as Qwen2VL) for position encoding Key design decisions: - Vision encoder uses ggml_rope_multi() instead of custom RoPE - LLM follows GLM4 structure with M-RoPE (not Qwen2VL structure) - Minimal code: ~300 lines total across all files Tested with GLM-4.6V-Flash producing correct image descriptions. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Sonnet 4.5 <[email protected]>
1 parent 380b4c9 commit ccf0bb5

File tree

12 files changed

+453
-0
lines changed

12 files changed

+453
-0
lines changed

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ add_library(llama
7171
models/gemma3n-iswa.cpp
7272
models/glm4-moe.cpp
7373
models/glm4.cpp
74+
models/glm4v.cpp
7475
models/gpt2.cpp
7576
models/gptneox.cpp
7677
models/granite-hybrid.cpp

src/llama-arch.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
6969
{ LLM_ARCH_CHATGLM, "chatglm" },
7070
{ LLM_ARCH_GLM4, "glm4" },
7171
{ LLM_ARCH_GLM4_MOE, "glm4moe" },
72+
{ LLM_ARCH_GLM4V, "glm4v" },
7273
{ LLM_ARCH_BITNET, "bitnet" },
7374
{ LLM_ARCH_T5, "t5" },
7475
{ LLM_ARCH_T5ENCODER, "t5encoder" },
@@ -1634,6 +1635,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
16341635
{ LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.nextn.shared_head_norm" },
16351636
},
16361637
},
1638+
{
1639+
LLM_ARCH_GLM4V,
1640+
{
1641+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1642+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1643+
{ LLM_TENSOR_OUTPUT, "output" },
1644+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1645+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1646+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1647+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1648+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1649+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1650+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1651+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1652+
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
1653+
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
1654+
},
1655+
},
16371656
{
16381657
LLM_ARCH_BITNET,
16391658
{

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ enum llm_arch {
7373
LLM_ARCH_CHATGLM,
7474
LLM_ARCH_GLM4,
7575
LLM_ARCH_GLM4_MOE,
76+
LLM_ARCH_GLM4V,
7677
LLM_ARCH_BITNET,
7778
LLM_ARCH_T5,
7879
LLM_ARCH_T5ENCODER,

src/llama-model.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1723,6 +1723,15 @@ void llama_model::load_hparams(llama_model_loader & ml) {
17231723
default: type = LLM_TYPE_UNKNOWN;
17241724
}
17251725
} break;
1726+
case LLM_ARCH_GLM4V:
1727+
{
1728+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1729+
ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true);
1730+
switch (hparams.n_layer) {
1731+
case 40: type = LLM_TYPE_9B; break;
1732+
default: type = LLM_TYPE_UNKNOWN;
1733+
}
1734+
} break;
17261735
case LLM_ARCH_BITNET:
17271736
{
17281737
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -5142,6 +5151,45 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
51425151
}
51435152
}
51445153
break;
5154+
case LLM_ARCH_GLM4V:
5155+
{
5156+
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
5157+
5158+
// output
5159+
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
5160+
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
5161+
// if output is NULL, init from the input tok embed
5162+
if (output == NULL) {
5163+
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
5164+
}
5165+
5166+
for (int i = 0; i < n_layer; ++i) {
5167+
auto & layer = layers[i];
5168+
5169+
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
5170+
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
5171+
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
5172+
5173+
if (layer.wqkv == nullptr) {
5174+
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
5175+
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
5176+
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
5177+
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
5178+
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
5179+
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
5180+
}
5181+
5182+
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
5183+
5184+
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
5185+
5186+
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
5187+
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
5188+
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0);
5189+
5190+
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
5191+
}
5192+
} break;
51455193
case LLM_ARCH_NEMOTRON:
51465194
{
51475195
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -7423,6 +7471,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
74237471
{
74247472
llm = std::make_unique<llm_build_glm4_moe>(*this, params);
74257473
} break;
7474+
case LLM_ARCH_GLM4V:
7475+
{
7476+
llm = std::make_unique<llm_build_glm4v>(*this, params);
7477+
} break;
74267478
case LLM_ARCH_BITNET:
74277479
{
74287480
llm = std::make_unique<llm_build_bitnet>(*this, params);
@@ -7832,6 +7884,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
78327884
return LLAMA_ROPE_TYPE_NEOX;
78337885

78347886
case LLM_ARCH_QWEN2VL:
7887+
case LLM_ARCH_GLM4V:
78357888
return LLAMA_ROPE_TYPE_MROPE;
78367889
case LLM_ARCH_QWEN3VL:
78377890
case LLM_ARCH_QWEN3VLMOE:

src/models/glm4v.cpp

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
#include "models.h"
2+
3+
llm_build_glm4v::llm_build_glm4v(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
4+
const int64_t n_embd_head = hparams.n_embd_head_v;
5+
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
6+
7+
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
8+
9+
ggml_tensor * cur;
10+
ggml_tensor * inpL;
11+
12+
inpL = build_inp_embd(model.tok_embd);
13+
14+
// inp_pos - contains the positions
15+
ggml_tensor * inp_pos = build_inp_pos();
16+
17+
auto * inp_attn = build_attn_inp_kv();
18+
19+
// M-RoPE sections from hparams
20+
int sections[4];
21+
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
22+
23+
ggml_tensor * inp_out_ids = build_inp_out_ids();
24+
25+
for (int il = 0; il < n_layer; ++il) {
26+
ggml_tensor * inpSA = inpL;
27+
28+
// Pre-attention norm
29+
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
30+
cb(cur, "attn_norm", il);
31+
32+
// self-attention
33+
{
34+
ggml_tensor * Qcur = nullptr;
35+
ggml_tensor * Kcur = nullptr;
36+
ggml_tensor * Vcur = nullptr;
37+
38+
if (model.layers[il].wqkv == nullptr) {
39+
Qcur = build_lora_mm(model.layers[il].wq, cur);
40+
if (model.layers[il].bq) {
41+
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
42+
}
43+
Kcur = build_lora_mm(model.layers[il].wk, cur);
44+
if (model.layers[il].bk) {
45+
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
46+
}
47+
Vcur = build_lora_mm(model.layers[il].wv, cur);
48+
if (model.layers[il].bv) {
49+
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
50+
}
51+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
52+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
53+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
54+
} else {
55+
cur = build_lora_mm(model.layers[il].wqkv, cur);
56+
cb(cur, "wqkv", il);
57+
if (model.layers[il].bqkv) {
58+
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
59+
cb(cur, "bqkv", il);
60+
}
61+
Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), cur->nb[1],
62+
0 * sizeof(float) * (n_embd));
63+
Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float),
64+
cur->nb[1], 1 * sizeof(float) * (n_embd));
65+
Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float),
66+
cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa));
67+
}
68+
69+
// GLM4V uses M-RoPE (multi-dimensional rotary position embeddings)
70+
Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr,
71+
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
72+
ext_factor, attn_factor, beta_fast, beta_slow);
73+
74+
Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr,
75+
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
76+
ext_factor, attn_factor, beta_fast, beta_slow);
77+
78+
cb(Qcur, "Qcur", il);
79+
cb(Kcur, "Kcur", il);
80+
cb(Vcur, "Vcur", il);
81+
82+
cur = build_attn(inp_attn,
83+
model.layers[il].wo, NULL,
84+
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
85+
}
86+
if (il == n_layer - 1 && inp_out_ids) {
87+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
88+
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
89+
}
90+
// Post-attention norm
91+
cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il);
92+
cb(cur, "post_attn_norm", il);
93+
94+
// Add the input (residual connection after post-attention norm)
95+
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
96+
cb(ffn_inp, "ffn_inp", il);
97+
98+
// FF
99+
{
100+
// Pre-MLP norm
101+
cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
102+
cb(cur, "ffn_norm", il);
103+
104+
// MLP
105+
cur = build_ffn(cur,
106+
model.layers[il].ffn_up, NULL, NULL,
107+
NULL, NULL, NULL,
108+
model.layers[il].ffn_down, NULL, NULL,
109+
NULL, LLM_FFN_SWIGLU, LLM_FFN_SEQ, il);
110+
cb(cur, "ffn_out", il);
111+
112+
// Post-MLP norm
113+
cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, il);
114+
cb(cur, "post_mlp_norm", il);
115+
}
116+
// Add residual connection after post-MLP norm
117+
inpL = ggml_add(ctx0, cur, ffn_inp);
118+
cb(inpL, "l_out", il);
119+
}
120+
// Final norm
121+
cur = build_norm(inpL, model.output_norm, NULL, LLM_NORM_RMS, -1);
122+
123+
cb(cur, "result_norm", -1);
124+
res->t_embd = cur;
125+
126+
// Output projection
127+
cur = build_lora_mm(model.output, cur);
128+
129+
cb(cur, "result_output", -1);
130+
res->t_logits = cur;
131+
132+
ggml_build_forward_expand(gf, cur);
133+
}

src/models/models.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,10 @@ struct llm_build_glm4_moe : public llm_graph_context {
222222
llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params);
223223
};
224224

225+
struct llm_build_glm4v : public llm_graph_context {
226+
llm_build_glm4v(const llama_model & model, const llm_graph_params & params);
227+
};
228+
225229
struct llm_build_gpt2 : public llm_graph_context {
226230
llm_build_gpt2(const llama_model & model, const llm_graph_params & params);
227231
};

tools/mtmd/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ add_library(mtmd
1515
clip-graph.h
1616
models/models.h
1717
models/cogvlm.cpp
18+
models/glm4v.cpp
1819
models/internvl.cpp
1920
models/kimivl.cpp
2021
models/llama4.cpp

tools/mtmd/clip-impl.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,15 @@
133133
#define TN_TOK_BOI "v.boi"
134134
#define TN_TOK_EOI "v.eoi"
135135

136+
// glm4v
137+
#define TN_GLM4V_POST_CONV_LN "mm.post_conv_ln.%s"
138+
#define TN_GLM4V_DOWNSAMPLE "mm.downsample.%s"
139+
#define TN_GLM4V_MERGER_PROJ "mm.merger.proj.%s"
140+
#define TN_GLM4V_MERGER_NORM "mm.merger.norm.%s"
141+
#define TN_GLM4V_MERGER_GATE "mm.merger.gate.%s"
142+
#define TN_GLM4V_MERGER_UP "mm.merger.up.%s"
143+
#define TN_GLM4V_MERGER_DOWN "mm.merger.down.%s"
144+
136145
// align x to upper multiple of n
137146
#define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))
138147

@@ -164,6 +173,7 @@ enum projector_type {
164173
PROJECTOR_TYPE_LIGHTONOCR,
165174
PROJECTOR_TYPE_COGVLM,
166175
PROJECTOR_TYPE_JANUS_PRO,
176+
PROJECTOR_TYPE_GLM4V,
167177
PROJECTOR_TYPE_UNKNOWN,
168178
};
169179

@@ -190,6 +200,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
190200
{ PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"},
191201
{ PROJECTOR_TYPE_COGVLM, "cogvlm"},
192202
{ PROJECTOR_TYPE_JANUS_PRO, "janus_pro"},
203+
{ PROJECTOR_TYPE_GLM4V, "glm4v"},
193204
};
194205

195206
static projector_type clip_projector_type_from_string(const std::string & str) {

tools/mtmd/clip-model.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,22 @@ struct clip_model {
267267
ggml_tensor * mm_boi = nullptr;
268268
ggml_tensor * mm_eoi = nullptr;
269269

270+
// glm4v
271+
ggml_tensor * mm_post_conv_ln_w = nullptr;
272+
ggml_tensor * mm_post_conv_ln_b = nullptr;
273+
ggml_tensor * mm_downsample_w = nullptr;
274+
ggml_tensor * mm_downsample_b = nullptr;
275+
ggml_tensor * mm_merger_proj_w = nullptr;
276+
ggml_tensor * mm_merger_proj_b = nullptr;
277+
ggml_tensor * mm_merger_norm_w = nullptr;
278+
ggml_tensor * mm_merger_norm_b = nullptr;
279+
ggml_tensor * mm_merger_gate_w = nullptr;
280+
ggml_tensor * mm_merger_gate_b = nullptr;
281+
ggml_tensor * mm_merger_up_w = nullptr;
282+
ggml_tensor * mm_merger_up_b = nullptr;
283+
ggml_tensor * mm_merger_down_w = nullptr;
284+
ggml_tensor * mm_merger_down_b = nullptr;
285+
270286
bool audio_has_avgpool() const {
271287
return proj_type == PROJECTOR_TYPE_QWEN2A
272288
|| proj_type == PROJECTOR_TYPE_VOXTRAL;

0 commit comments

Comments
 (0)