Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8039,6 +8039,160 @@ def prepare_tensors(self):
raise ValueError(f"Unprocessed experts: {experts}")


@ModelBase.register("Glm4vMoeForConditionalGeneration", "Glm4vForConditionalGeneration")
class GLM4VisionModel(MmprojModel):
"""Multimodal projector from:
- [zai-org/GLM-4.1V-9B-Thinking](https://huggingface.co/zai-org/GLM-4.1V-9B-Thinking)
- [zai-org/GLM-4.5V](https://huggingface.co/zai-org/GLM-4.5V)
- [zai-org/GLM-4.6V-Flash](https://huggingface.co/zai-org/GLM-4.6V-Flash)

ref: [#16600](https://github.com/ggml-org/llama.cpp/pull/16600)"""

def set_gguf_parameters(self):
super().set_gguf_parameters()
assert self.hparams_vision is not None
vparams = self.hparams_vision
ln_eps = vparams.get("layer_norm_eps", 1e-5)

self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GLM4V)
self.gguf_writer.add_vision_attention_layernorm_eps(ln_eps)
self.gguf_writer.add_vision_use_silu(True)
# GLM4V uses 2x2 spatial downsampling (kernel size matches downsample.weight shape)
self.gguf_writer.add_vision_spatial_merge_size(2)

def tensor_force_quant(self, name, new_name, bid, n_dims):
if ".position_embd." in new_name:
return gguf.GGMLQuantizationType.F32
return super().tensor_force_quant(name, new_name, bid, n_dims)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused

if not name.startswith("model.visual."):
return [] # skip non-vision tensors

# CRITICAL FIX: Force eager evaluation of lazy tensors
# LazyTorchTensor wraps ALL operations (including .float(), .clone(), etc.)
# in new lazy tensors, so we MUST call to_eager() explicitly to materialize
from gguf.lazy import LazyBase

# Check if this is a lazy tensor and force eager evaluation
if isinstance(data_torch, LazyBase) or isinstance(data_torch, LazyTorchTensor):
data_torch = LazyBase.to_eager(data_torch)
# Verify it's now a real tensor
if hasattr(data_torch, 'is_meta') and data_torch.is_meta:
raise RuntimeError(f"ERROR: {name} is still a meta tensor after to_eager()!")

# GLM4V tensor name mappings: HuggingFace -> GGUF
# Handle patch embedding Conv3D -> two Conv2D kernels
if "patch_embed.proj.weight" in name:
# Split Conv3D [c_out, c_in, kt, kh, kw] into two Conv2D [c_out, c_in, kh, kw]
c1, c2, kt, kh, kw = data_torch.shape
del c1, c2, kh, kw # unused
assert kt == 2, "GLM4V expects temporal_patch_size of 2"

# Slice the tensor (already materialized at start of modify_tensors)
slice0 = data_torch[:, :, 0, ...]
slice1 = data_torch[:, :, 1, ...]

return [
(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight", slice0),
(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight.1", slice1),
]

if "patch_embed.proj.bias" in name:
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".bias", data_torch)]

# Position embedding
if "embeddings.position_embedding" in name:
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_POS] + ".weight", data_torch)]

# Post-convolution layernorm (GLM4V-specific)
if "post_conv_layernorm.weight" in name:
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MM_POST_CONV_LN] + ".weight", data_torch)]
if "post_conv_layernorm.bias" in name:
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MM_POST_CONV_LN] + ".bias", data_torch)]

# Post layernorm
if "post_layernorm.weight" in name:
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_POST_NORM] + ".weight", data_torch)]
if "post_layernorm.bias" in name:
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_POST_NORM] + ".bias", data_torch)]

# Downsample (GLM4V-specific)
if "downsample.weight" in name:
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MM_DOWNSAMPLE] + ".weight", data_torch)]
if "downsample.bias" in name:
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MM_DOWNSAMPLE] + ".bias", data_torch)]

# Merger (GLM4V-specific)
if "merger.proj.weight" in name:
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MM_MERGER_PROJ] + ".weight", data_torch)]
if "merger.proj.bias" in name:
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MM_MERGER_PROJ] + ".bias", data_torch)]
if "merger.post_projection_norm.weight" in name:
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MM_MERGER_NORM] + ".weight", data_torch)]
if "merger.post_projection_norm.bias" in name:
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MM_MERGER_NORM] + ".bias", data_torch)]
if "merger.gate_proj.weight" in name:
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MM_MERGER_GATE] + ".weight", data_torch)]
if "merger.gate_proj.bias" in name:
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MM_MERGER_GATE] + ".bias", data_torch)]
if "merger.up_proj.weight" in name:
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MM_MERGER_UP] + ".weight", data_torch)]
if "merger.up_proj.bias" in name:
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MM_MERGER_UP] + ".bias", data_torch)]
if "merger.down_proj.weight" in name:
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MM_MERGER_DOWN] + ".weight", data_torch)]
if "merger.down_proj.bias" in name:
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MM_MERGER_DOWN] + ".bias", data_torch)]

# Vision transformer blocks (model.visual.blocks.{N}.*)
import re
block_match = re.match(r"model\.visual\.blocks\.(\d+)\.(.*)", name)
if block_match:
block_id = int(block_match.group(1))
rest = block_match.group(2)

# Attention
if rest == "attn.qkv.weight":
return [(f"v.blk.{block_id}.attn_qkv.weight", data_torch)]
if rest == "attn.qkv.bias":
return [(f"v.blk.{block_id}.attn_qkv.bias", data_torch)]
if rest == "attn.proj.weight":
return [(f"v.blk.{block_id}.attn_out.weight", data_torch)]
if rest == "attn.proj.bias":
return [(f"v.blk.{block_id}.attn_out.bias", data_torch)]

# Layer norms
if rest == "norm1.weight":
return [(f"v.blk.{block_id}.ln1.weight", data_torch)]
if rest == "norm1.bias":
return [(f"v.blk.{block_id}.ln1.bias", data_torch)]
if rest == "norm2.weight":
return [(f"v.blk.{block_id}.ln2.weight", data_torch)]
if rest == "norm2.bias":
return [(f"v.blk.{block_id}.ln2.bias", data_torch)]

# MLP (SwiGLU)
if rest == "mlp.gate_proj.weight":
return [(f"v.blk.{block_id}.ffn_gate.weight", data_torch)]
if rest == "mlp.gate_proj.bias":
return [(f"v.blk.{block_id}.ffn_gate.bias", data_torch)]
if rest == "mlp.up_proj.weight":
return [(f"v.blk.{block_id}.ffn_up.weight", data_torch)]
if rest == "mlp.up_proj.bias":
return [(f"v.blk.{block_id}.ffn_up.bias", data_torch)]
if rest == "mlp.down_proj.weight":
return [(f"v.blk.{block_id}.ffn_down.weight", data_torch)]
if rest == "mlp.down_proj.bias":
return [(f"v.blk.{block_id}.ffn_down.bias", data_torch)]

# If we get here, tensor wasn't handled - log warning and skip
logger.warning(f"GLM4V: Unhandled vision tensor: {name}")
return []


@ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration")
class ChatGLMModel(TextModel):
model_arch = gguf.MODEL_ARCH.CHATGLM
Expand Down
22 changes: 22 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,13 @@ class MODEL_TENSOR(IntEnum):
V_MM_UP = auto() # cogvlm
V_MM_DOWN = auto() # cogvlm
V_MM_GATE = auto() # cogvlm
V_MM_POST_CONV_LN = auto() # glm4v
V_MM_DOWNSAMPLE = auto() # glm4v
V_MM_MERGER_PROJ = auto() # glm4v
V_MM_MERGER_NORM = auto() # glm4v
V_MM_MERGER_GATE = auto() # glm4v
V_MM_MERGER_UP = auto() # glm4v
V_MM_MERGER_DOWN = auto() # glm4v
V_TOK_BOI = auto() # cogvlm
V_TOK_EOI = auto() # cogvlm
# audio (mtmd)
Expand Down Expand Up @@ -1055,6 +1062,13 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.V_MM_UP: "mm.up",
MODEL_TENSOR.V_MM_DOWN: "mm.down",
MODEL_TENSOR.V_MM_GATE: "mm.gate",
MODEL_TENSOR.V_MM_POST_CONV_LN: "mm.post_conv_ln", # glm4v
MODEL_TENSOR.V_MM_DOWNSAMPLE: "mm.downsample", # glm4v
MODEL_TENSOR.V_MM_MERGER_PROJ: "mm.merger.proj", # glm4v
MODEL_TENSOR.V_MM_MERGER_NORM: "mm.merger.norm", # glm4v
MODEL_TENSOR.V_MM_MERGER_GATE: "mm.merger.gate", # glm4v
MODEL_TENSOR.V_MM_MERGER_UP: "mm.merger.up", # glm4v
MODEL_TENSOR.V_MM_MERGER_DOWN: "mm.merger.down", # glm4v
MODEL_TENSOR.V_TOK_BOI: "v.boi",
MODEL_TENSOR.V_TOK_EOI: "v.eoi",
# audio (mtmd)
Expand Down Expand Up @@ -1133,6 +1147,13 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.V_MM_UP,
MODEL_TENSOR.V_MM_DOWN,
MODEL_TENSOR.V_MM_GATE,
MODEL_TENSOR.V_MM_POST_CONV_LN,
MODEL_TENSOR.V_MM_DOWNSAMPLE,
MODEL_TENSOR.V_MM_MERGER_PROJ,
MODEL_TENSOR.V_MM_MERGER_NORM,
MODEL_TENSOR.V_MM_MERGER_GATE,
MODEL_TENSOR.V_MM_MERGER_UP,
MODEL_TENSOR.V_MM_MERGER_DOWN,
MODEL_TENSOR.V_TOK_BOI,
MODEL_TENSOR.V_TOK_EOI,
# audio
Expand Down Expand Up @@ -3324,6 +3345,7 @@ class VisionProjectorType:
VOXTRAL = "voxtral"
LFM2 = "lfm2"
KIMIVL = "kimivl"
GLM4V = "glm4v"
LIGHTONOCR = "lightonocr"
COGVLM = "cogvlm"
JANUS_PRO = "janus_pro"
Expand Down
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ add_library(llama
models/gemma3n-iswa.cpp
models/glm4-moe.cpp
models/glm4.cpp
models/glm4v.cpp
models/gpt2.cpp
models/gptneox.cpp
models/granite-hybrid.cpp
Expand Down
19 changes: 19 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_CHATGLM, "chatglm" },
{ LLM_ARCH_GLM4, "glm4" },
{ LLM_ARCH_GLM4_MOE, "glm4moe" },
{ LLM_ARCH_GLM4V, "glm4v" },
{ LLM_ARCH_BITNET, "bitnet" },
{ LLM_ARCH_T5, "t5" },
{ LLM_ARCH_T5ENCODER, "t5encoder" },
Expand Down Expand Up @@ -1634,6 +1635,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.nextn.shared_head_norm" },
},
},
{
LLM_ARCH_GLM4V,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
{
LLM_ARCH_BITNET,
{
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ enum llm_arch {
LLM_ARCH_CHATGLM,
LLM_ARCH_GLM4,
LLM_ARCH_GLM4_MOE,
LLM_ARCH_GLM4V,
LLM_ARCH_BITNET,
LLM_ARCH_T5,
LLM_ARCH_T5ENCODER,
Expand Down
53 changes: 53 additions & 0 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1723,6 +1723,15 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_GLM4V:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true);
switch (hparams.n_layer) {
case 40: type = LLM_TYPE_9B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_BITNET:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
Expand Down Expand Up @@ -5142,6 +5151,45 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
}
}
break;
case LLM_ARCH_GLM4V:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);

// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
// if output is NULL, init from the input tok embed
if (output == NULL) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
}

for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];

layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);

if (layer.wqkv == nullptr) {
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
}

layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);

layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);

layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0);

layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
}
} break;
case LLM_ARCH_NEMOTRON:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
Expand Down Expand Up @@ -7423,6 +7471,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{
llm = std::make_unique<llm_build_glm4_moe>(*this, params);
} break;
case LLM_ARCH_GLM4V:
{
llm = std::make_unique<llm_build_glm4v>(*this, params);
} break;
case LLM_ARCH_BITNET:
{
llm = std::make_unique<llm_build_bitnet>(*this, params);
Expand Down Expand Up @@ -7832,6 +7884,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
return LLAMA_ROPE_TYPE_NEOX;

case LLM_ARCH_QWEN2VL:
case LLM_ARCH_GLM4V:
return LLAMA_ROPE_TYPE_MROPE;
case LLM_ARCH_QWEN3VL:
case LLM_ARCH_QWEN3VLMOE:
Expand Down
Loading