diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 151608d56b8..a9fd83daa16 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -8039,6 +8039,160 @@ def prepare_tensors(self): raise ValueError(f"Unprocessed experts: {experts}") +@ModelBase.register("Glm4vMoeForConditionalGeneration", "Glm4vForConditionalGeneration") +class GLM4VisionModel(MmprojModel): + """Multimodal projector from: + - [zai-org/GLM-4.1V-9B-Thinking](https://huggingface.co/zai-org/GLM-4.1V-9B-Thinking) + - [zai-org/GLM-4.5V](https://huggingface.co/zai-org/GLM-4.5V) + - [zai-org/GLM-4.6V-Flash](https://huggingface.co/zai-org/GLM-4.6V-Flash) + + ref: [#16600](https://github.com/ggml-org/llama.cpp/pull/16600)""" + + def set_gguf_parameters(self): + super().set_gguf_parameters() + assert self.hparams_vision is not None + vparams = self.hparams_vision + ln_eps = vparams.get("layer_norm_eps", 1e-5) + + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GLM4V) + self.gguf_writer.add_vision_attention_layernorm_eps(ln_eps) + self.gguf_writer.add_vision_use_silu(True) + # GLM4V uses 2x2 spatial downsampling (kernel size matches downsample.weight shape) + self.gguf_writer.add_vision_spatial_merge_size(2) + + def tensor_force_quant(self, name, new_name, bid, n_dims): + if ".position_embd." in new_name: + return gguf.GGMLQuantizationType.F32 + return super().tensor_force_quant(name, new_name, bid, n_dims) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + if not name.startswith("model.visual."): + return [] # skip non-vision tensors + + # CRITICAL FIX: Force eager evaluation of lazy tensors + # LazyTorchTensor wraps ALL operations (including .float(), .clone(), etc.) + # in new lazy tensors, so we MUST call to_eager() explicitly to materialize + from gguf.lazy import LazyBase + + # Check if this is a lazy tensor and force eager evaluation + if isinstance(data_torch, LazyBase) or isinstance(data_torch, LazyTorchTensor): + data_torch = LazyBase.to_eager(data_torch) + # Verify it's now a real tensor + if hasattr(data_torch, 'is_meta') and data_torch.is_meta: + raise RuntimeError(f"ERROR: {name} is still a meta tensor after to_eager()!") + + # GLM4V tensor name mappings: HuggingFace -> GGUF + # Handle patch embedding Conv3D -> two Conv2D kernels + if "patch_embed.proj.weight" in name: + # Split Conv3D [c_out, c_in, kt, kh, kw] into two Conv2D [c_out, c_in, kh, kw] + c1, c2, kt, kh, kw = data_torch.shape + del c1, c2, kh, kw # unused + assert kt == 2, "GLM4V expects temporal_patch_size of 2" + + # Slice the tensor (already materialized at start of modify_tensors) + slice0 = data_torch[:, :, 0, ...] + slice1 = data_torch[:, :, 1, ...] + + return [ + (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight", slice0), + (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight.1", slice1), + ] + + if "patch_embed.proj.bias" in name: + return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".bias", data_torch)] + + # Position embedding + if "embeddings.position_embedding" in name: + return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_POS] + ".weight", data_torch)] + + # Post-convolution layernorm (GLM4V-specific) + if "post_conv_layernorm.weight" in name: + return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MM_POST_CONV_LN] + ".weight", data_torch)] + if "post_conv_layernorm.bias" in name: + return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MM_POST_CONV_LN] + ".bias", data_torch)] + + # Post layernorm + if "post_layernorm.weight" in name: + return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_POST_NORM] + ".weight", data_torch)] + if "post_layernorm.bias" in name: + return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_POST_NORM] + ".bias", data_torch)] + + # Downsample (GLM4V-specific) + if "downsample.weight" in name: + return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MM_DOWNSAMPLE] + ".weight", data_torch)] + if "downsample.bias" in name: + return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MM_DOWNSAMPLE] + ".bias", data_torch)] + + # Merger (GLM4V-specific) + if "merger.proj.weight" in name: + return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MM_MERGER_PROJ] + ".weight", data_torch)] + if "merger.proj.bias" in name: + return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MM_MERGER_PROJ] + ".bias", data_torch)] + if "merger.post_projection_norm.weight" in name: + return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MM_MERGER_NORM] + ".weight", data_torch)] + if "merger.post_projection_norm.bias" in name: + return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MM_MERGER_NORM] + ".bias", data_torch)] + if "merger.gate_proj.weight" in name: + return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MM_MERGER_GATE] + ".weight", data_torch)] + if "merger.gate_proj.bias" in name: + return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MM_MERGER_GATE] + ".bias", data_torch)] + if "merger.up_proj.weight" in name: + return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MM_MERGER_UP] + ".weight", data_torch)] + if "merger.up_proj.bias" in name: + return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MM_MERGER_UP] + ".bias", data_torch)] + if "merger.down_proj.weight" in name: + return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MM_MERGER_DOWN] + ".weight", data_torch)] + if "merger.down_proj.bias" in name: + return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MM_MERGER_DOWN] + ".bias", data_torch)] + + # Vision transformer blocks (model.visual.blocks.{N}.*) + import re + block_match = re.match(r"model\.visual\.blocks\.(\d+)\.(.*)", name) + if block_match: + block_id = int(block_match.group(1)) + rest = block_match.group(2) + + # Attention + if rest == "attn.qkv.weight": + return [(f"v.blk.{block_id}.attn_qkv.weight", data_torch)] + if rest == "attn.qkv.bias": + return [(f"v.blk.{block_id}.attn_qkv.bias", data_torch)] + if rest == "attn.proj.weight": + return [(f"v.blk.{block_id}.attn_out.weight", data_torch)] + if rest == "attn.proj.bias": + return [(f"v.blk.{block_id}.attn_out.bias", data_torch)] + + # Layer norms + if rest == "norm1.weight": + return [(f"v.blk.{block_id}.ln1.weight", data_torch)] + if rest == "norm1.bias": + return [(f"v.blk.{block_id}.ln1.bias", data_torch)] + if rest == "norm2.weight": + return [(f"v.blk.{block_id}.ln2.weight", data_torch)] + if rest == "norm2.bias": + return [(f"v.blk.{block_id}.ln2.bias", data_torch)] + + # MLP (SwiGLU) + if rest == "mlp.gate_proj.weight": + return [(f"v.blk.{block_id}.ffn_gate.weight", data_torch)] + if rest == "mlp.gate_proj.bias": + return [(f"v.blk.{block_id}.ffn_gate.bias", data_torch)] + if rest == "mlp.up_proj.weight": + return [(f"v.blk.{block_id}.ffn_up.weight", data_torch)] + if rest == "mlp.up_proj.bias": + return [(f"v.blk.{block_id}.ffn_up.bias", data_torch)] + if rest == "mlp.down_proj.weight": + return [(f"v.blk.{block_id}.ffn_down.weight", data_torch)] + if rest == "mlp.down_proj.bias": + return [(f"v.blk.{block_id}.ffn_down.bias", data_torch)] + + # If we get here, tensor wasn't handled - log warning and skip + logger.warning(f"GLM4V: Unhandled vision tensor: {name}") + return [] + + @ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration") class ChatGLMModel(TextModel): model_arch = gguf.MODEL_ARCH.CHATGLM diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 2b8489c591b..907a714e27e 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -683,6 +683,13 @@ class MODEL_TENSOR(IntEnum): V_MM_UP = auto() # cogvlm V_MM_DOWN = auto() # cogvlm V_MM_GATE = auto() # cogvlm + V_MM_POST_CONV_LN = auto() # glm4v + V_MM_DOWNSAMPLE = auto() # glm4v + V_MM_MERGER_PROJ = auto() # glm4v + V_MM_MERGER_NORM = auto() # glm4v + V_MM_MERGER_GATE = auto() # glm4v + V_MM_MERGER_UP = auto() # glm4v + V_MM_MERGER_DOWN = auto() # glm4v V_TOK_BOI = auto() # cogvlm V_TOK_EOI = auto() # cogvlm # audio (mtmd) @@ -1055,6 +1062,13 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.V_MM_UP: "mm.up", MODEL_TENSOR.V_MM_DOWN: "mm.down", MODEL_TENSOR.V_MM_GATE: "mm.gate", + MODEL_TENSOR.V_MM_POST_CONV_LN: "mm.post_conv_ln", # glm4v + MODEL_TENSOR.V_MM_DOWNSAMPLE: "mm.downsample", # glm4v + MODEL_TENSOR.V_MM_MERGER_PROJ: "mm.merger.proj", # glm4v + MODEL_TENSOR.V_MM_MERGER_NORM: "mm.merger.norm", # glm4v + MODEL_TENSOR.V_MM_MERGER_GATE: "mm.merger.gate", # glm4v + MODEL_TENSOR.V_MM_MERGER_UP: "mm.merger.up", # glm4v + MODEL_TENSOR.V_MM_MERGER_DOWN: "mm.merger.down", # glm4v MODEL_TENSOR.V_TOK_BOI: "v.boi", MODEL_TENSOR.V_TOK_EOI: "v.eoi", # audio (mtmd) @@ -1133,6 +1147,13 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.V_MM_UP, MODEL_TENSOR.V_MM_DOWN, MODEL_TENSOR.V_MM_GATE, + MODEL_TENSOR.V_MM_POST_CONV_LN, + MODEL_TENSOR.V_MM_DOWNSAMPLE, + MODEL_TENSOR.V_MM_MERGER_PROJ, + MODEL_TENSOR.V_MM_MERGER_NORM, + MODEL_TENSOR.V_MM_MERGER_GATE, + MODEL_TENSOR.V_MM_MERGER_UP, + MODEL_TENSOR.V_MM_MERGER_DOWN, MODEL_TENSOR.V_TOK_BOI, MODEL_TENSOR.V_TOK_EOI, # audio @@ -3324,6 +3345,7 @@ class VisionProjectorType: VOXTRAL = "voxtral" LFM2 = "lfm2" KIMIVL = "kimivl" + GLM4V = "glm4v" LIGHTONOCR = "lightonocr" COGVLM = "cogvlm" JANUS_PRO = "janus_pro" diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 4192af7c0c3..1671706fa56 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -71,6 +71,7 @@ add_library(llama models/gemma3n-iswa.cpp models/glm4-moe.cpp models/glm4.cpp + models/glm4v.cpp models/gpt2.cpp models/gptneox.cpp models/granite-hybrid.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 64ad1b77690..88241228167 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -69,6 +69,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_GLM4, "glm4" }, { LLM_ARCH_GLM4_MOE, "glm4moe" }, + { LLM_ARCH_GLM4V, "glm4v" }, { LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_T5, "t5" }, { LLM_ARCH_T5ENCODER, "t5encoder" }, @@ -1634,6 +1635,24 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.nextn.shared_head_norm" }, }, }, + { + LLM_ARCH_GLM4V, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + }, + }, { LLM_ARCH_BITNET, { diff --git a/src/llama-arch.h b/src/llama-arch.h index e113180024d..88ce2bafbda 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -73,6 +73,7 @@ enum llm_arch { LLM_ARCH_CHATGLM, LLM_ARCH_GLM4, LLM_ARCH_GLM4_MOE, + LLM_ARCH_GLM4V, LLM_ARCH_BITNET, LLM_ARCH_T5, LLM_ARCH_T5ENCODER, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e4808b1e1eb..b91fb27da29 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1723,6 +1723,15 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_GLM4V: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + switch (hparams.n_layer) { + case 40: type = LLM_TYPE_9B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_BITNET: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -5142,6 +5151,45 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } break; + case LLM_ARCH_GLM4V: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + + if (layer.wqkv == nullptr) { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + } + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0); + + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } + } break; case LLM_ARCH_NEMOTRON: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -7423,6 +7471,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_GLM4V: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_BITNET: { llm = std::make_unique(*this, params); @@ -7832,6 +7884,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: + case LLM_ARCH_GLM4V: return LLAMA_ROPE_TYPE_MROPE; case LLM_ARCH_QWEN3VL: case LLM_ARCH_QWEN3VLMOE: diff --git a/src/models/glm4v.cpp b/src/models/glm4v.cpp new file mode 100644 index 00000000000..b3630bd01fe --- /dev/null +++ b/src/models/glm4v.cpp @@ -0,0 +1,133 @@ +#include "models.h" + +llm_build_glm4v::llm_build_glm4v(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + // M-RoPE sections from hparams + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // Pre-attention norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = nullptr; + ggml_tensor * Kcur = nullptr; + ggml_tensor * Vcur = nullptr; + + if (model.layers[il].wqkv == nullptr) { + Qcur = build_lora_mm(model.layers[il].wq, cur); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + } + Kcur = build_lora_mm(model.layers[il].wk, cur); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + } + Vcur = build_lora_mm(model.layers[il].wv, cur); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + } + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + } else { + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + if (model.layers[il].bqkv) { + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + } + Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), cur->nb[1], + 0 * sizeof(float) * (n_embd)); + Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), + cur->nb[1], 1 * sizeof(float) * (n_embd)); + Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), + cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)); + } + + // GLM4V uses M-RoPE (multi-dimensional rotary position embeddings) + Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); + } + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + // Post-attention norm + cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "post_attn_norm", il); + + // Add the input (residual connection after post-attention norm) + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // FF + { + // Pre-MLP norm + cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // MLP + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, LLM_FFN_SWIGLU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + + // Post-MLP norm + cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "post_mlp_norm", il); + } + // Add residual connection after post-MLP norm + inpL = ggml_add(ctx0, cur, ffn_inp); + cb(inpL, "l_out", il); + } + // Final norm + cur = build_norm(inpL, model.output_norm, NULL, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // Output projection + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/src/models/models.h b/src/models/models.h index 6494f545018..b09e101d172 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -222,6 +222,10 @@ struct llm_build_glm4_moe : public llm_graph_context { llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_glm4v : public llm_graph_context { + llm_build_glm4v(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_gpt2 : public llm_graph_context { llm_build_gpt2(const llama_model & model, const llm_graph_params & params); }; diff --git a/tools/mtmd/CMakeLists.txt b/tools/mtmd/CMakeLists.txt index 3ee42036fda..e7f3067a163 100644 --- a/tools/mtmd/CMakeLists.txt +++ b/tools/mtmd/CMakeLists.txt @@ -15,6 +15,7 @@ add_library(mtmd clip-graph.h models/models.h models/cogvlm.cpp + models/glm4v.cpp models/internvl.cpp models/kimivl.cpp models/llama4.cpp diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 1726823ec69..a7556cac051 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -133,6 +133,15 @@ #define TN_TOK_BOI "v.boi" #define TN_TOK_EOI "v.eoi" +// glm4v +#define TN_GLM4V_POST_CONV_LN "mm.post_conv_ln.%s" +#define TN_GLM4V_DOWNSAMPLE "mm.downsample.%s" +#define TN_GLM4V_MERGER_PROJ "mm.merger.proj.%s" +#define TN_GLM4V_MERGER_NORM "mm.merger.norm.%s" +#define TN_GLM4V_MERGER_GATE "mm.merger.gate.%s" +#define TN_GLM4V_MERGER_UP "mm.merger.up.%s" +#define TN_GLM4V_MERGER_DOWN "mm.merger.down.%s" + // align x to upper multiple of n #define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n)) @@ -164,6 +173,7 @@ enum projector_type { PROJECTOR_TYPE_LIGHTONOCR, PROJECTOR_TYPE_COGVLM, PROJECTOR_TYPE_JANUS_PRO, + PROJECTOR_TYPE_GLM4V, PROJECTOR_TYPE_UNKNOWN, }; @@ -190,6 +200,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"}, { PROJECTOR_TYPE_COGVLM, "cogvlm"}, { PROJECTOR_TYPE_JANUS_PRO, "janus_pro"}, + { PROJECTOR_TYPE_GLM4V, "glm4v"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip-model.h b/tools/mtmd/clip-model.h index 51bcce1ebb0..72e3dc6e0cb 100644 --- a/tools/mtmd/clip-model.h +++ b/tools/mtmd/clip-model.h @@ -267,6 +267,22 @@ struct clip_model { ggml_tensor * mm_boi = nullptr; ggml_tensor * mm_eoi = nullptr; + // glm4v + ggml_tensor * mm_post_conv_ln_w = nullptr; + ggml_tensor * mm_post_conv_ln_b = nullptr; + ggml_tensor * mm_downsample_w = nullptr; + ggml_tensor * mm_downsample_b = nullptr; + ggml_tensor * mm_merger_proj_w = nullptr; + ggml_tensor * mm_merger_proj_b = nullptr; + ggml_tensor * mm_merger_norm_w = nullptr; + ggml_tensor * mm_merger_norm_b = nullptr; + ggml_tensor * mm_merger_gate_w = nullptr; + ggml_tensor * mm_merger_gate_b = nullptr; + ggml_tensor * mm_merger_up_w = nullptr; + ggml_tensor * mm_merger_up_b = nullptr; + ggml_tensor * mm_merger_down_w = nullptr; + ggml_tensor * mm_merger_down_b = nullptr; + bool audio_has_avgpool() const { return proj_type == PROJECTOR_TYPE_QWEN2A || proj_type == PROJECTOR_TYPE_VOXTRAL; diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index bb922e30b43..a21842538b7 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -781,6 +781,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { builder = std::make_unique(ctx, img); } break; + case PROJECTOR_TYPE_GLM4V: + { + builder = std::make_unique(ctx, img); + } break; case PROJECTOR_TYPE_MINICPMV: { builder = std::make_unique(ctx, img); @@ -1128,6 +1132,13 @@ struct clip_model_loader { LOG_WRN("%s: more info: https://github.com/ggml-org/llama.cpp/issues/16842\n\n", __func__); } } break; + case PROJECTOR_TYPE_GLM4V: + { + // GLM4V uses spatial_merge_size = 2 from HuggingFace config + hparams.n_merge = 2; + // GLM4V Vision RoPE: Glm4vVisionRotaryEmbedding uses theta=10000.0 (default) + hparams.rope_theta = 10000.0f; + } break; case PROJECTOR_TYPE_LLAMA4: { hparams.rope_theta = 10000.0f; @@ -1432,6 +1443,20 @@ struct clip_model_loader { model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); } break; + case PROJECTOR_TYPE_GLM4V: + { + // GLM4V merger/projector tensors + model.mm_post_conv_ln_w = get_tensor(string_format(TN_GLM4V_POST_CONV_LN, "weight")); + model.mm_post_conv_ln_b = get_tensor(string_format(TN_GLM4V_POST_CONV_LN, "bias"), false); + model.mm_downsample_w = get_tensor(string_format(TN_GLM4V_DOWNSAMPLE, "weight")); + model.mm_downsample_b = get_tensor(string_format(TN_GLM4V_DOWNSAMPLE, "bias"), false); + model.mm_merger_proj_w = get_tensor(string_format(TN_GLM4V_MERGER_PROJ, "weight")); + model.mm_merger_norm_w = get_tensor(string_format(TN_GLM4V_MERGER_NORM, "weight")); + model.mm_merger_norm_b = get_tensor(string_format(TN_GLM4V_MERGER_NORM, "bias"), false); + model.mm_merger_gate_w = get_tensor(string_format(TN_GLM4V_MERGER_GATE, "weight")); + model.mm_merger_up_w = get_tensor(string_format(TN_GLM4V_MERGER_UP, "weight")); + model.mm_merger_down_w = get_tensor(string_format(TN_GLM4V_MERGER_DOWN, "weight")); + } break; case PROJECTOR_TYPE_GEMMA3: { model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ); @@ -2604,6 +2629,17 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str res_imgs->entries.push_back(std::move(img_f32)); } break; + case PROJECTOR_TYPE_GLM4V: + { + // GLM4V uses fixed image_size (336) with bicubic interpolation + clip_image_u8 resized_image; + int sz = params.image_size; + img_tool::resize(*img, resized_image, {sz, sz}, img_tool::RESIZE_ALGO_BICUBIC); + clip_image_f32_ptr img_f32(clip_image_f32_init()); + normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std); + res_imgs->entries.push_back(std::move(img_f32)); + } break; + case PROJECTOR_TYPE_JANUS_PRO: { // Janus Pro preprocessing: pad to square with gray(127), resize to 384x384 @@ -2839,6 +2875,11 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im int y_patch = img->ny / (params.patch_size * 2); n_patches = x_patch * y_patch; } break; + case PROJECTOR_TYPE_GLM4V: + { + // GLM4V uses spatial_merge_size=2, reducing patches by 4x + n_patches /= 4; + } break; case PROJECTOR_TYPE_GEMMA3: case PROJECTOR_TYPE_IDEFICS3: case PROJECTOR_TYPE_INTERNVL: @@ -3171,6 +3212,23 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } } + set_input_i32("positions", positions); + } break; + case PROJECTOR_TYPE_GLM4V: + { + // GLM4V uses M-RoPE positions (same format as Qwen2VL) + // RoPE is applied BEFORE patch merger, so positions are for all pre-merge patches + const int pw = image_size_width / patch_size; // patches per row + + std::vector positions(n_pos * 4); + for (int i = 0; i < n_pos; i++) { + int y = i / pw; // h position + int x = i % pw; // w position + positions[0 * n_pos + i] = y; // chunk 0: h + positions[1 * n_pos + i] = x; // chunk 1: w + positions[2 * n_pos + i] = y; // chunk 2: h (repeat) + positions[3 * n_pos + i] = x; // chunk 3: w (repeat) + } set_input_i32("positions", positions); } break; case PROJECTOR_TYPE_PIXTRAL: @@ -3341,6 +3399,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_COGVLM: return ctx->model.mm_4h_to_h_w->ne[1]; + case PROJECTOR_TYPE_GLM4V: + return ctx->model.mm_merger_down_w->ne[1]; default: GGML_ABORT("Unknown projector type"); } diff --git a/tools/mtmd/models/glm4v.cpp b/tools/mtmd/models/glm4v.cpp new file mode 100644 index 00000000000..8f624a84ec2 --- /dev/null +++ b/tools/mtmd/models/glm4v.cpp @@ -0,0 +1,149 @@ +#include "models.h" + +ggml_cgraph * clip_graph_glm4v::build() { + GGML_ASSERT(model.patch_embeddings_0 != nullptr); + GGML_ASSERT(model.patch_embeddings_1 != nullptr); + GGML_ASSERT(model.position_embeddings != nullptr); + GGML_ASSERT(model.class_embedding == nullptr); + + // M-RoPE input positions (same pattern as Qwen2VL) + // Format: [h0,h1,...,hN, w0,w1,...,wN, h0,h1,...,hN, w0,w1,...,wN] + ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches * 4); + ggml_set_name(positions, "positions"); + ggml_set_input(positions); + + // GLM4V Patch Embedding using Conv2D on raw image + // Reference: modeling_glm4v.py Glm4vVisionPatchEmbed.forward() + // + // HF uses Conv3d with temporal_patch_size=2 for video support. + // For single images, HF duplicates the frame to create [2, C, H, W] input. + // The Conv3d kernel is split into two temporal slices and summed. + // + // Since frame0 == frame1 for single images, this is equivalent to: + // conv2d(kernel_t0, img) + conv2d(kernel_t1, img) + // which is the same pattern as Qwen2VL dual conv2d. + + ggml_tensor * inp_raw = build_inp_raw(); + + // Apply both temporal kernel slices to the same image + ggml_tensor * out0 = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + ggml_tensor * out1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + cb(out0, "conv_out0", -1); + cb(out1, "conv_out1", -1); + + // Sum temporal frames (simulates Conv3d temporal reduction) + ggml_tensor * inp = ggml_add(ctx0, out0, out1); + + // Reshape from conv2d output [1, n_patches_y, n_patches_x, n_embd] to [n_embd, n_patches] + inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd); + inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); + + // Add patch embedding bias (Conv3d bias) + // ref: self.proj.bias in Glm4vVisionPatchEmbed + if (model.patch_bias != nullptr) { + inp = ggml_add(ctx0, inp, model.patch_bias); + cb(inp, "patch_bias", -1); + } + cb(inp, "patch_embed", -1); + + // post-convolution layernorm + // ref: self.post_conv_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + inp = build_norm(inp, model.mm_post_conv_ln_w, model.mm_post_conv_ln_b, NORM_TYPE_RMS, eps, -1); + cb(inp, "post_conv_ln", -1); + + // absolute position embeddings (interpolated) + // ref: self.embeddings + ggml_tensor * learned_pos_embd = resize_position_embeddings(); + inp = ggml_add(ctx0, inp, learned_pos_embd); + cb(inp, "abs_pos_embed", -1); + + // RoPE to be applied inside ViT blocks + // Uses M-RoPE (same as Qwen2VL) with [h, w, h, w] pattern at 32-dim chunks + // ref: self.rotary_pos_emb, apply_rotary_pos_emb_vision (identical to Qwen2VL) + int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; + + auto add_pos = [&](ggml_tensor * cur, const clip_layer &) { + return ggml_rope_multi(ctx0, cur, positions, nullptr, + d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, + 32768, hparams.rope_theta, 1.0f, 0.0f, 1.0f, 32.0f, 1.0f); + }; + + // ViT blocks + ggml_tensor * cur = build_vit( + inp, n_patches, + NORM_TYPE_RMS, + FFN_SILU, // hidden_act is "silu" + nullptr, // absolute embeddings already added + add_pos); + + // post-ViT layernorm + cur = build_norm(cur, model.post_ln_w, model.post_ln_b, NORM_TYPE_RMS, eps, -1); + cb(cur, "post_vit_ln", -1); + + // Patch merger downsampling - EXACT HF implementation + // HF: hidden_states [576, 1536] -> view(-1, 2, 2, 1536) -> [144, 2, 2, 1536] + // -> permute(0, 3, 1, 2) -> [144, 1536, 2, 2] + // -> Conv2d(1536, 4096, kernel=2, stride=2) -> [144, 4096, 1, 1] + // -> flatten -> [144, 4096] + // + // In ggml (dimensions reversed): + // Input: [1536, 576] + // reshape_4d: [1536, 2, 2, 144] (reversed from HF [144, 2, 2, 1536]) + // Need permute to: [2, 2, 1536, 144] (reversed from HF [144, 1536, 2, 2]) + // Conv2d output: [1, 1, 4096, 144] + // Final: [4096, 144] + const int merge_size = 2; + const int num_merge_blocks = n_patches / (merge_size * merge_size); // 576 / 4 = 144 + + // Reshape to 4D: [1536, 2, 2, 144] + cur = ggml_reshape_4d(ctx0, cur, n_embd, merge_size, merge_size, num_merge_blocks); + + // Permute to [2, 2, 1536, 144] for conv2d + // ggml_permute(a,b,c,d): axis0->a, axis1->b, axis2->c, axis3->d + // From [1536, 2, 2, 144] to [2, 2, 1536, 144]: + // axis0(1536)->2, axis1(2)->0, axis2(2)->1, axis3(144)->3 + cur = ggml_permute(ctx0, cur, 2, 0, 1, 3); + cur = ggml_cont(ctx0, cur); + cb(cur, "pre_downsample_permute", -1); + + // downsample conv2d - each 2x2 block -> 1 token with 4096 features + // Output: [1, 1, 4096, 144] + cur = ggml_conv_2d(ctx0, model.mm_downsample_w, cur, merge_size, merge_size, 0, 0, 1, 1); + cb(cur, "downsample_conv", -1); + + // Reshape to [4096, 144] for ggml_mul_mat + cur = ggml_reshape_2d(ctx0, cur, cur->ne[2], cur->ne[3]); + cb(cur, "post_downsample_reshape", -1); + + // patch merger FFN + // ref: class Glm4vVisionPatchMerger(nn.Module): + { + // input projection + cur = ggml_mul_mat(ctx0, model.mm_merger_proj_w, cur); + + // apply norm + GELU + cur = build_norm(cur, model.mm_merger_norm_w, model.mm_merger_norm_b, NORM_TYPE_NORMAL, 1e-5f, -1); + cur = ggml_gelu(ctx0, cur); + ggml_tensor * ffn_input = cur; + cb(cur, "merger_ffn_inp", -1); + + // gate projection + ggml_tensor * gate = ggml_mul_mat(ctx0, model.mm_merger_gate_w, ffn_input); + cb(gate, "merger_gate", -1); + + // up projection + ggml_tensor * up = ggml_mul_mat(ctx0, model.mm_merger_up_w, ffn_input); + cb(up, "merger_up", -1); + + // activation + down projection + cur = ggml_silu(ctx0, gate); + cur = ggml_mul(ctx0, cur, up); + cur = ggml_mul_mat(ctx0, model.mm_merger_down_w, cur); + cb(cur, "merger_ffn_out", -1); + } + + // build the graph + ggml_build_forward_expand(gf, cur); + + return gf; +} diff --git a/tools/mtmd/models/models.h b/tools/mtmd/models/models.h index 4b35da259ce..53aae07f3a1 100644 --- a/tools/mtmd/models/models.h +++ b/tools/mtmd/models/models.h @@ -17,6 +17,11 @@ struct clip_graph_qwen2vl : clip_graph { ggml_cgraph * build() override; }; +struct clip_graph_glm4v : clip_graph { + clip_graph_glm4v(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; +}; + struct clip_graph_qwen3vl : clip_graph { clip_graph_qwen3vl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} ggml_cgraph * build() override;