From c5e90a164525e95c455e07a0b7aefde1b8cc9226 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 24 May 2025 09:17:16 +0200 Subject: [PATCH 01/21] convertible to gguf --- convert_hf_to_gguf.py | 20 +++++++++- gguf-py/gguf/constants.py | 67 ++++++++++++++++++++++++++++++++++ gguf-py/gguf/tensor_mapping.py | 64 ++++++++++++++++++++++++++++++++ 3 files changed, 149 insertions(+), 2 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 123083b9154..553d94f3a38 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -936,6 +936,10 @@ def _create_vocab_sentencepiece(self): elif tokenizer.IsByte(token_id): toktype = SentencePieceTokenTypes.BYTE + if token_id >= vocab_size: + logger.warning(f'ignore tokens from {token_id}: id is out of range, max={vocab_size - 1}') + break + tokens[token_id] = text scores[token_id] = score toktypes[token_id] = toktype @@ -4000,9 +4004,8 @@ def set_gguf_parameters(self): self.gguf_writer.add_value_length(hparams.get("head_dim", 256)) self.gguf_writer.add_file_type(self.ftype) self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1_000_000.0)) # for global layers - # both attn_logit_softcapping and final_logit_softcapping are removed in Gemma3 + # attn_logit_softcapping is removed in Gemma3 assert hparams.get("attn_logit_softcapping") is None - assert hparams.get("final_logit_softcapping") is None self.gguf_writer.add_sliding_window(hparams["sliding_window"]) self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4)) if hparams.get("rope_scaling") is not None: @@ -4087,6 +4090,19 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [] # skip other tensors +@ModelBase.register("Gemma3p5ForCausalLM") +class Gemma3NModel(Gemma3Model): + model_arch = gguf.MODEL_ARCH.GEMMA3N + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name.endswith("_scale"): + name = name + ".weight" + return super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("Starcoder2ForCausalLM") class StarCoder2Model(TextModel): model_arch = gguf.MODEL_ARCH.STARCODER2 diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 58de45dfddb..afceddd8d83 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -309,6 +309,7 @@ class MODEL_ARCH(IntEnum): GEMMA = auto() GEMMA2 = auto() GEMMA3 = auto() + GEMMA3N = auto() STARCODER2 = auto() RWKV6 = auto() RWKV6QWEN2 = auto() @@ -392,6 +393,22 @@ class MODEL_TENSOR(IntEnum): ATTN_Q_NORM = auto() ATTN_K_NORM = auto() LAYER_OUT_NORM = auto() + PER_LAYER_TOKEN_EMBD = auto() # gemma3n + PER_LAYER_MODEL_PROJ = auto() # gemma3n + PER_LAYER_INP_GATE = auto() # gemma3n + PER_LAYER_PROJ = auto() # gemma3n + PER_LAYER_PROJ_NORM = auto() # gemma3n + PER_LAYER_POST_NORM = auto() # gemma3n + ALTUP_PROJ = auto() # gemma3n + ALTUP_UNEMBD_PROJ = auto() # gemma3n + ALTUP_CORRECT_COEF = auto() # gemma3n + ALTUP_CORRECT_SCALE = auto() # gemma3n + ALTUP_PREDICT_COEF = auto() # gemma3n + ALTUP_ROUTER = auto() # gemma3n + ALTUP_ROUTER_NORM = auto() # gemma3n + LAUREL_L = auto() # gemma3n + LAUREL_R = auto() # gemma3n + LAUREL_POST_NORM = auto() # gemma3n SSM_IN = auto() SSM_CONV1D = auto() SSM_X = auto() @@ -588,6 +605,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.GEMMA: "gemma", MODEL_ARCH.GEMMA2: "gemma2", MODEL_ARCH.GEMMA3: "gemma3", + MODEL_ARCH.GEMMA3N: "gemma3n", MODEL_ARCH.STARCODER2: "starcoder2", MODEL_ARCH.RWKV6: "rwkv6", MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2", @@ -671,6 +689,22 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps", MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b", MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", + MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: "per_layer_token_embd", # gemma3n + MODEL_TENSOR.PER_LAYER_MODEL_PROJ: "per_layer_model_proj", # gemma3n + MODEL_TENSOR.PER_LAYER_PROJ_NORM: "per_layer_proj_norm", # gemma3n + MODEL_TENSOR.ALTUP_UNEMBD_PROJ: "altup_unembd_proj.{bid}", # gemma3n + MODEL_TENSOR.ALTUP_PROJ: "altup_proj.{bid}", # gemma3n + MODEL_TENSOR.PER_LAYER_INP_GATE: "blk.{bid}.inp_gate", # gemma3n + MODEL_TENSOR.PER_LAYER_PROJ: "blk.{bid}.proj", # gemma3n + MODEL_TENSOR.PER_LAYER_POST_NORM: "blk.{bid}.post_norm", # gemma3n + MODEL_TENSOR.ALTUP_CORRECT_COEF: "blk.{bid}.altup_correct_coef", # gemma3n + MODEL_TENSOR.ALTUP_CORRECT_SCALE: "blk.{bid}.altup_correct_scale", # gemma3n + MODEL_TENSOR.ALTUP_PREDICT_COEF: "blk.{bid}.altup_predict_coef", # gemma3n + MODEL_TENSOR.ALTUP_ROUTER: "blk.{bid}.altup_router", # gemma3n + MODEL_TENSOR.ALTUP_ROUTER_NORM: "blk.{bid}.altup_router_norm", # gemma3n + MODEL_TENSOR.LAUREL_L: "blk.{bid}.laurel_l", # gemma3n + MODEL_TENSOR.LAUREL_R: "blk.{bid}.laurel_r", # gemma3n + MODEL_TENSOR.LAUREL_POST_NORM: "blk.{bid}.laurel_post_norm", # gemma3n MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in", MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d", MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x", @@ -1460,6 +1494,39 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_PRE_NORM, MODEL_TENSOR.FFN_POST_NORM, ], + MODEL_ARCH.GEMMA3N: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.FFN_PRE_NORM, + MODEL_TENSOR.FFN_POST_NORM, + # altup / laurel + MODEL_TENSOR.PER_LAYER_TOKEN_EMBD, + MODEL_TENSOR.PER_LAYER_MODEL_PROJ, + MODEL_TENSOR.PER_LAYER_INP_GATE, + MODEL_TENSOR.PER_LAYER_PROJ, + MODEL_TENSOR.PER_LAYER_PROJ_NORM, + MODEL_TENSOR.PER_LAYER_POST_NORM, + MODEL_TENSOR.ALTUP_PROJ, + MODEL_TENSOR.ALTUP_UNEMBD_PROJ, + MODEL_TENSOR.ALTUP_CORRECT_COEF, + MODEL_TENSOR.ALTUP_CORRECT_SCALE, + MODEL_TENSOR.ALTUP_PREDICT_COEF, + MODEL_TENSOR.ALTUP_ROUTER, + MODEL_TENSOR.ALTUP_ROUTER_NORM, + MODEL_TENSOR.LAUREL_L, + MODEL_TENSOR.LAUREL_R, + MODEL_TENSOR.LAUREL_POST_NORM, + ], MODEL_ARCH.STARCODER2: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 91a95ea48b4..b1b1cc097c0 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -461,6 +461,70 @@ class TensorNameMap: "encoder.layer.{bid}.layer_norm_2" # jina-v2-code ), + MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: ( + "model.embed_tokens_per_layer", # gemma3n + ), + + MODEL_TENSOR.PER_LAYER_MODEL_PROJ: ( + "model.per_layer_model_projection", # gemma3n + ), + + MODEL_TENSOR.PER_LAYER_PROJ_NORM: ( + "model.per_layer_projection_norm", # gemma3n + ), + + MODEL_TENSOR.ALTUP_PROJ: ( + "model.altup_projections.{bid}", # gemma3n + ), + + MODEL_TENSOR.ALTUP_UNEMBD_PROJ: ( + "model.altup_unembed_projections.{bid}", # gemma3n + ), + + MODEL_TENSOR.PER_LAYER_INP_GATE: ( + "model.layers.{bid}.per_layer_input_gate", # gemma3n + ), + + MODEL_TENSOR.PER_LAYER_PROJ: ( + "model.layers.{bid}.per_layer_projection", # gemma3n + ), + + MODEL_TENSOR.PER_LAYER_POST_NORM: ( + "model.layers.{bid}.post_per_layer_input_norm", # gemma3n + ), + + MODEL_TENSOR.ALTUP_CORRECT_COEF: ( + "model.layers.{bid}.altup.correction_coefs", # gemma3n + ), + + MODEL_TENSOR.ALTUP_CORRECT_SCALE: ( + "model.layers.{bid}.altup.correct_output_scale", # gemma3n + ), + + MODEL_TENSOR.ALTUP_PREDICT_COEF: ( + "model.layers.{bid}.altup.prediction_coefs", # gemma3n + ), + + MODEL_TENSOR.ALTUP_ROUTER: ( + "model.layers.{bid}.altup.modality_router", # gemma3n + ), + + MODEL_TENSOR.ALTUP_ROUTER_NORM: ( + "model.layers.{bid}.altup.router_norm", # gemma3n + ), + + MODEL_TENSOR.LAUREL_L: ( + "model.layers.{bid}.laurel.linear_left", # gemma3n + ), + + MODEL_TENSOR.LAUREL_R: ( + "model.layers.{bid}.laurel.linear_right", # gemma3n + ), + + MODEL_TENSOR.LAUREL_POST_NORM: ( + "model.layers.{bid}.laurel.post_laurel_norm", # gemma3n + ), + MODEL_TENSOR.SSM_IN: ( "model.layers.{bid}.in_proj", "backbone.layers.{bid}.mixer.in_proj", From 2e0a9dfb17b1dc6b609e55e2bd564a7866be69b2 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 24 May 2025 12:03:51 +0200 Subject: [PATCH 02/21] weight loaded ok, missing cgraph --- convert_hf_to_gguf.py | 57 ++++++++++++++++++++++++++++++ gguf-py/gguf/constants.py | 4 +-- gguf-py/gguf/tensor_mapping.py | 4 +-- src/llama-arch.cpp | 52 ++++++++++++++++++++++++++++ src/llama-arch.h | 17 +++++++++ src/llama-hparams.h | 5 +++ src/llama-model.cpp | 63 ++++++++++++++++++++++++++++++++++ src/llama-model.h | 20 +++++++++++ 8 files changed, 218 insertions(+), 4 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 553d94f3a38..83655a3439f 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4094,12 +4094,69 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter class Gemma3NModel(Gemma3Model): model_arch = gguf.MODEL_ARCH.GEMMA3N + _altup_proj: list[Tensor] = [] + _altup_unembd: list[Tensor] = [] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.hparams["altup_num_inputs"] == 4, "Current conversion only supports 4 altup inputs" + self._altup_proj = [ + torch.Tensor(), # to be replaced + torch.Tensor(), # to be replaced + torch.Tensor(), # to be replaced + ] + self._altup_unembd = [ + torch.Tensor(), # to be replaced + torch.Tensor(), # to be replaced + torch.Tensor(), # to be replaced + ] + def set_gguf_parameters(self): super().set_gguf_parameters() + def _stack_matrices(self, matrices: list[Tensor]) -> Tensor | None: + has_all = all(m.numel() > 0 for m in matrices) + if not has_all: + return None + else: + return torch.stack(matrices, dim=0) + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: if name.endswith("_scale"): name = name + ".weight" + + if "altup_unembed_projections" in name: + data_torch = data_torch.to(device="cpu") + if ".0." in name: + self._altup_unembd[0] = data_torch + elif ".1." in name: + self._altup_unembd[1] = data_torch + elif ".2." in name: + self._altup_unembd[2] = data_torch + else: + raise ValueError(f"Unknown name: {name}") + out = self._stack_matrices(self._altup_unembd) + if out is not None: + return [(self.map_tensor_name("model.altup_unembed_projections.weight"), out)] + else: + return [] + + if "altup_projections" in name: + data_torch = data_torch.to(device="cpu") + if ".0." in name: + self._altup_proj[0] = data_torch + elif ".1." in name: + self._altup_proj[1] = data_torch + elif ".2." in name: + self._altup_proj[2] = data_torch + else: + raise ValueError(f"Unknown name: {name}") + out = self._stack_matrices(self._altup_proj) + if out is not None: + return [(self.map_tensor_name("model.altup_projections.weight"), out)] + else: + return [] + return super().modify_tensors(data_torch, name, bid) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index afceddd8d83..eff01e81956 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -692,8 +692,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: "per_layer_token_embd", # gemma3n MODEL_TENSOR.PER_LAYER_MODEL_PROJ: "per_layer_model_proj", # gemma3n MODEL_TENSOR.PER_LAYER_PROJ_NORM: "per_layer_proj_norm", # gemma3n - MODEL_TENSOR.ALTUP_UNEMBD_PROJ: "altup_unembd_proj.{bid}", # gemma3n - MODEL_TENSOR.ALTUP_PROJ: "altup_proj.{bid}", # gemma3n + MODEL_TENSOR.ALTUP_UNEMBD_PROJ: "altup_unembd_proj", # gemma3n + MODEL_TENSOR.ALTUP_PROJ: "altup_proj", # gemma3n MODEL_TENSOR.PER_LAYER_INP_GATE: "blk.{bid}.inp_gate", # gemma3n MODEL_TENSOR.PER_LAYER_PROJ: "blk.{bid}.proj", # gemma3n MODEL_TENSOR.PER_LAYER_POST_NORM: "blk.{bid}.post_norm", # gemma3n diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index b1b1cc097c0..48166feaa13 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -474,11 +474,11 @@ class TensorNameMap: ), MODEL_TENSOR.ALTUP_PROJ: ( - "model.altup_projections.{bid}", # gemma3n + "model.altup_projections", # gemma3n ), MODEL_TENSOR.ALTUP_UNEMBD_PROJ: ( - "model.altup_unembed_projections.{bid}", # gemma3n + "model.altup_unembed_projections", # gemma3n ), MODEL_TENSOR.PER_LAYER_INP_GATE: ( diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index abf436adac4..dbcdcf75096 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -41,6 +41,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GEMMA, "gemma" }, { LLM_ARCH_GEMMA2, "gemma2" }, { LLM_ARCH_GEMMA3, "gemma3" }, + { LLM_ARCH_GEMMA3N, "gemma3n" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, { LLM_ARCH_XVERSE, "xverse" }, @@ -892,6 +893,40 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, }, }, + { + LLM_ARCH_GEMMA3N, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + { LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "per_layer_token_embd" }, + { LLM_TENSOR_PER_LAYER_MODEL_PROJ, "per_layer_model_proj" }, + { LLM_TENSOR_PER_LAYER_PROJ_NORM, "per_layer_proj_norm" }, + { LLM_TENSOR_ALTUP_UNEMBD_PROJ, "altup_unembd_proj" }, + { LLM_TENSOR_ALTUP_PROJ, "altup_proj" }, + { LLM_TENSOR_PER_LAYER_INP_GATE, "blk.%d.inp_gate" }, + { LLM_TENSOR_PER_LAYER_PROJ, "blk.%d.proj" }, + { LLM_TENSOR_PER_LAYER_POST_NORM, "blk.%d.post_norm" }, + { LLM_TENSOR_ALTUP_CORRECT_COEF, "blk.%d.altup_correct_coef" }, + { LLM_TENSOR_ALTUP_CORRECT_SCALE, "blk.%d.altup_correct_scale" }, + { LLM_TENSOR_ALTUP_PREDICT_COEF, "blk.%d.altup_predict_coef" }, + { LLM_TENSOR_ALTUP_ROUTER, "blk.%d.altup_router" }, + { LLM_TENSOR_ALTUP_ROUTER_NORM, "blk.%d.altup_router_norm" }, + { LLM_TENSOR_LAUREL_L, "blk.%d.laurel_l" }, + { LLM_TENSOR_LAUREL_R, "blk.%d.laurel_r" }, + { LLM_TENSOR_LAUREL_POST_NORM, "blk.%d.laurel_post_norm" }, + }, + }, { LLM_ARCH_STARCODER2, { @@ -1681,6 +1716,23 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + // altup / laurel + {LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_PER_LAYER_MODEL_PROJ, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_PER_LAYER_PROJ_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL}}, + {LLM_TENSOR_ALTUP_UNEMBD_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ALTUP_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_PER_LAYER_INP_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_PER_LAYER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_PER_LAYER_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ALTUP_CORRECT_COEF, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ALTUP_CORRECT_SCALE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ALTUP_PREDICT_COEF, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ALTUP_ROUTER, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ALTUP_ROUTER_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_LAUREL_L, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_LAUREL_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_LAUREL_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // this tensor is loaded for T5, but never used {LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}}, {LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index 41a023da3da..7abb6a81a8d 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -45,6 +45,7 @@ enum llm_arch { LLM_ARCH_GEMMA, LLM_ARCH_GEMMA2, LLM_ARCH_GEMMA3, + LLM_ARCH_GEMMA3N, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, LLM_ARCH_XVERSE, @@ -263,6 +264,22 @@ enum llm_tensor { LLM_TENSOR_LAYER_OUT_NORM, LLM_TENSOR_POST_ATTN_NORM, LLM_TENSOR_POST_MLP_NORM, + LLM_TENSOR_PER_LAYER_TOKEN_EMBD, // gemma3n + LLM_TENSOR_PER_LAYER_MODEL_PROJ, // gemma3n + LLM_TENSOR_PER_LAYER_INP_GATE, // gemma3n + LLM_TENSOR_PER_LAYER_PROJ, // gemma3n + LLM_TENSOR_PER_LAYER_PROJ_NORM, // gemma3n + LLM_TENSOR_PER_LAYER_POST_NORM, // gemma3n + LLM_TENSOR_ALTUP_PROJ, // gemma3n + LLM_TENSOR_ALTUP_UNEMBD_PROJ, // gemma3n + LLM_TENSOR_ALTUP_CORRECT_COEF, // gemma3n + LLM_TENSOR_ALTUP_CORRECT_SCALE, // gemma3n + LLM_TENSOR_ALTUP_PREDICT_COEF, // gemma3n + LLM_TENSOR_ALTUP_ROUTER, // gemma3n + LLM_TENSOR_ALTUP_ROUTER_NORM, // gemma3n + LLM_TENSOR_LAUREL_L, // gemma3n + LLM_TENSOR_LAUREL_R, // gemma3n + LLM_TENSOR_LAUREL_POST_NORM, // gemma3n LLM_TENSOR_SSM_IN, LLM_TENSOR_SSM_CONV1D, LLM_TENSOR_SSM_X, diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 5222eedcfb0..4f3cd275dd0 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -145,6 +145,11 @@ struct llama_hparams { uint32_t n_attn_temp_floor_scale = 8192; float f_attn_temp_scale = 0.1; + // gemma3n altup + uint32_t n_altup = 4; // altup_num_inputs + uint32_t laurel_rank = 64; + uint32_t n_embd_altup = 256; + // needed by encoder-decoder models (e.g. T5, FLAN-T5) // ref: https://github.com/ggerganov/llama.cpp/pull/8141 llama_token dec_start_token_id = LLAMA_TOKEN_NULL; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 3735e3c16f0..2d9dc4749f6 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -973,6 +973,20 @@ void llama_model::load_hparams(llama_model_loader & ml) { ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0))) : 1.0f / std::sqrt(float(hparams.n_embd_head_k)); } break; + case LLM_ARCH_GEMMA3N: + { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.n_swa_pattern = 5; + + hparams.rope_freq_base_train_swa = 10000.0f; + hparams.rope_freq_scale_train_swa = 1.0f; + hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k)); + + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // TODO: switch (hparams.n_layer) + } break; case LLM_ARCH_STARCODER2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -2865,6 +2879,54 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); } } break; + case LLM_ARCH_GEMMA3N: + { + const int64_t n_altup = hparams.n_altup; + const int64_t laurel_rank = hparams.laurel_rank; + const int64_t n_embd_altup = hparams.n_embd_altup; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + tok_embd_per_layer = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0); + + altup_proj = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0); + altup_unembd_proj = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0); + per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight"), {n_embd, n_embd_altup * n_layer}, 0); + per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight"), {n_embd_altup}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + + // altup & laurel + layer.per_layer_inp_gate = create_tensor(tn(LLM_TENSOR_PER_LAYER_INP_GATE, "weight", i), {n_embd, n_embd_altup}, 0); + layer.per_layer_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ, "weight", i), {n_embd_altup, n_embd}, 0); + layer.per_layer_post_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_POST_NORM, "weight", i), {n_embd}, 0); + layer.altup_correct_coef = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_COEF, "weight", i), {n_altup, n_altup}, 0); + layer.altup_correct_scale = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_SCALE, "weight", i), {n_embd}, 0); + layer.altup_predict_coef = create_tensor(tn(LLM_TENSOR_ALTUP_PREDICT_COEF, "weight", i), {n_altup, n_altup * n_altup}, 0); + layer.altup_router = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER, "weight", i), {n_embd, n_altup}, 0); + layer.altup_router_norm = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER_NORM, "weight", i), {n_embd}, 0); + layer.laurel_l = create_tensor(tn(LLM_TENSOR_LAUREL_L, "weight", i), {n_embd, laurel_rank}, 0); + layer.laurel_r = create_tensor(tn(LLM_TENSOR_LAUREL_R, "weight", i), {laurel_rank, n_embd}, 0); + layer.laurel_post_norm = create_tensor(tn(LLM_TENSOR_LAUREL_POST_NORM, "weight", i), {n_embd}, 0); + } + } break; case LLM_ARCH_STARCODER2: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -13677,6 +13739,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GEMMA: case LLM_ARCH_GEMMA2: case LLM_ARCH_GEMMA3: + case LLM_ARCH_GEMMA3N: case LLM_ARCH_STARCODER2: case LLM_ARCH_OPENELM: case LLM_ARCH_GPTNEOX: diff --git a/src/llama-model.h b/src/llama-model.h index cbea2cb331b..e3afa04b5ed 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -315,6 +315,19 @@ struct llama_layer { struct ggml_tensor * ffn_up_scale = nullptr; struct ggml_tensor * ffn_down_scale = nullptr; + // altup & laurel + struct ggml_tensor * per_layer_inp_gate = nullptr; + struct ggml_tensor * per_layer_proj = nullptr; + struct ggml_tensor * per_layer_post_norm = nullptr; + struct ggml_tensor * altup_correct_coef = nullptr; + struct ggml_tensor * altup_correct_scale = nullptr; + struct ggml_tensor * altup_predict_coef = nullptr; + struct ggml_tensor * altup_router = nullptr; + struct ggml_tensor * altup_router_norm = nullptr; + struct ggml_tensor * laurel_l = nullptr; + struct ggml_tensor * laurel_r = nullptr; + struct ggml_tensor * laurel_post_norm = nullptr; + struct llama_layer_posnet posnet; struct llama_layer_convnext convnext; @@ -350,6 +363,13 @@ struct llama_model { struct ggml_tensor * conv1d = nullptr; struct ggml_tensor * conv1d_b = nullptr; + // gemma3n altup + struct ggml_tensor * tok_embd_per_layer = nullptr; + struct ggml_tensor * altup_proj = nullptr; + struct ggml_tensor * altup_unembd_proj = nullptr; + struct ggml_tensor * per_layer_model_proj = nullptr; + struct ggml_tensor * per_layer_proj_norm = nullptr; + std::vector layers; llama_model_params params; From fd3b181cff3efc43f407e865a13963593e778651 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 24 May 2025 16:30:04 +0200 Subject: [PATCH 03/21] wip --- convert_hf_to_gguf.py | 7 +- src/llama-hparams.h | 1 + src/llama-model.cpp | 262 ++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 258 insertions(+), 12 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 83655a3439f..3bf52e8c894 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3983,6 +3983,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter @ModelBase.register("Gemma3ForCausalLM", "Gemma3ForConditionalGeneration") class Gemma3Model(TextModel): model_arch = gguf.MODEL_ARCH.GEMMA3 + norm_shift = 1.0 # Gemma3RMSNorm adds 1.0 to the norm value def set_vocab(self): self._set_vocab_sentencepiece() @@ -4032,8 +4033,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter # ref code in Gemma3RMSNorm # output = output * (1.0 + self.weight.float()) + # note: this is not the case on gemma3n if name.endswith("norm.weight"): - data_torch = data_torch + 1 + data_torch = data_torch + self.norm_shift return [(self.map_tensor_name(name), data_torch)] @@ -4093,6 +4095,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter @ModelBase.register("Gemma3p5ForCausalLM") class Gemma3NModel(Gemma3Model): model_arch = gguf.MODEL_ARCH.GEMMA3N + norm_shift = 0.0 # same value with Gemma3p5RMSNorm scale_shift on python code _altup_proj: list[Tensor] = [] _altup_unembd: list[Tensor] = [] @@ -4125,6 +4128,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if name.endswith("_scale"): name = name + ".weight" + # TODO: implement self.prediction_coefs.weight.clamp_(...) + if "altup_unembed_projections" in name: data_torch = data_torch.to(device="cpu") if ".0." in name: diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 4f3cd275dd0..25108fc4abf 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -147,6 +147,7 @@ struct llama_hparams { // gemma3n altup uint32_t n_altup = 4; // altup_num_inputs + uint32_t i_altup_act = 0; // altup_active_idx uint32_t laurel_rank = 64; uint32_t n_embd_altup = 256; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 2d9dc4749f6..44f9cca6796 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2914,17 +2914,17 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); // altup & laurel - layer.per_layer_inp_gate = create_tensor(tn(LLM_TENSOR_PER_LAYER_INP_GATE, "weight", i), {n_embd, n_embd_altup}, 0); - layer.per_layer_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ, "weight", i), {n_embd_altup, n_embd}, 0); - layer.per_layer_post_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_POST_NORM, "weight", i), {n_embd}, 0); - layer.altup_correct_coef = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_COEF, "weight", i), {n_altup, n_altup}, 0); - layer.altup_correct_scale = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_SCALE, "weight", i), {n_embd}, 0); - layer.altup_predict_coef = create_tensor(tn(LLM_TENSOR_ALTUP_PREDICT_COEF, "weight", i), {n_altup, n_altup * n_altup}, 0); - layer.altup_router = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER, "weight", i), {n_embd, n_altup}, 0); - layer.altup_router_norm = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER_NORM, "weight", i), {n_embd}, 0); - layer.laurel_l = create_tensor(tn(LLM_TENSOR_LAUREL_L, "weight", i), {n_embd, laurel_rank}, 0); - layer.laurel_r = create_tensor(tn(LLM_TENSOR_LAUREL_R, "weight", i), {laurel_rank, n_embd}, 0); - layer.laurel_post_norm = create_tensor(tn(LLM_TENSOR_LAUREL_POST_NORM, "weight", i), {n_embd}, 0); + layer.per_layer_inp_gate = create_tensor(tn(LLM_TENSOR_PER_LAYER_INP_GATE, "weight", i), {n_embd, n_embd_altup}, 0); + layer.per_layer_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ, "weight", i), {n_embd_altup, n_embd}, 0); + layer.per_layer_post_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_POST_NORM, "weight", i), {n_embd}, 0); + layer.altup_correct_coef = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_COEF, "weight", i), {n_altup, n_altup}, 0); + layer.altup_correct_scale = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_SCALE, "weight", i), {n_embd}, 0); + layer.altup_predict_coef = create_tensor(tn(LLM_TENSOR_ALTUP_PREDICT_COEF, "weight", i), {n_altup, n_altup * n_altup}, 0); + layer.altup_router = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER, "weight", i), {n_embd, n_altup}, 0); + layer.altup_router_norm = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER_NORM, "weight", i), {n_embd}, 0); + layer.laurel_l = create_tensor(tn(LLM_TENSOR_LAUREL_L, "weight", i), {n_embd, laurel_rank}, 0); + layer.laurel_r = create_tensor(tn(LLM_TENSOR_LAUREL_R, "weight", i), {laurel_rank, n_embd}, 0); + layer.laurel_post_norm = create_tensor(tn(LLM_TENSOR_LAUREL_POST_NORM, "weight", i), {n_embd}, 0); } } break; case LLM_ARCH_STARCODER2: @@ -8757,6 +8757,242 @@ struct llm_build_gemma3_iswa : public llm_graph_context { } }; +struct llm_build_gemma3n_iswa : public llm_graph_context { + llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_k; + const int64_t n_embd_altup = hparams.n_embd_altup; + const int64_t n_altup = hparams.n_altup; + const int i_altup_act = hparams.i_altup_act; + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) + if (ubatch.token) { + inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd)); + cb(inpL, "inp_scaled", -1); + } + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + // TODO: is causal == true correct? might need some changes + auto * inp_attn = build_attn_inp_kv_unified_iswa(); + + ggml_tensor * inp_per_layer; + + // equivalent to get_per_layer_inputs() in python code + { + auto inp = std::make_unique(); + if (ubatch.token) { + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); + ggml_set_input(inp->tokens); + res->t_tokens = inp->tokens; + inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens); + inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens); + } else { + GGML_ABORT("TODO: support embd input"); + } + res->add_input(std::move(inp)); + } + + // equivalent to project_per_layer_inputs() in python code + // this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim + { + ggml_tensor * per_layer_proj = ggml_mul_mat(ctx0, model.per_layer_model_proj, inpL); + // shape: [n_embd, n_tokens] + per_layer_proj = ggml_scale(ctx0, inp_per_layer, 1.0f / sqrtf(n_embd)); // per_layer_projection_scale + per_layer_proj = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens); + inp_per_layer = ggml_add(ctx0, inp_per_layer, per_layer_proj); + inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf(2.0)); // per_layer_input_scale + // permute to shape: [n_embd_altup, n_tokens, n_layer] + inp_per_layer = ggml_cont(ctx0, ggml_permute(ctx0, inp_per_layer, 0, 2, 1, 3)); + printf("shape of inp_per_layer: %lld %lld %lld\n", inp_per_layer->ne[0], inp_per_layer->ne[1], inp_per_layer->ne[2]); + cb(inp_per_layer, "inp_per_layer", -1); + } + + // inpL now has only 1 altup, project it to the rest of the altups + // these "added" altups will be concat to the last dim of inpL + { + ggml_tensor * altup_added = ggml_mul_mat(ctx0, model.altup_proj, inpL); // shape: [n_embd, n_tokens, n_altup - 1] + // TODO: missing new_magnitude / target_magnitude stuff + inpL = ggml_concat(ctx0, inpL, altup_added, 2); // shape: [n_embd, n_tokens, n_altup] + } + + // inpL now has shape: [n_embd, n_tokens, n_altup] + // inp_per_layer now hasshape: [n_embd_altup, n_tokens, n_layer] + + // equivalent to compute_router_modalities() in python code + // output shape: [n_altup, n_tokens] + auto compute_router_modalities = [&](ggml_tensor * x, int il) { + ggml_tensor * router_inputs = build_norm(router_inputs, + model.layers[il].altup_router_norm, NULL, + LLM_NORM_RMS, il); + + // router_input_scale + router_inputs = ggml_scale(ctx0, router_inputs, 1.0f / (float)n_embd); + + return ggml_mul_mat(ctx0, model.layers[il].altup_router, router_inputs); + }; + + // get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim + auto view_2d_slice = [&](ggml_tensor * x, int idx) { + return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], + ggml_row_size(x->type, x->ne[0]), + idx * x->ne[0] * x->ne[1] * ggml_element_size(x)); + }; + + for (int il = 0; il < n_layer; ++il) { + const float freq_base_l = model.get_rope_freq_base (cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + + ggml_tensor * cur = inpL; // [n_embd, n_tokens, n_altup] + ggml_tensor * activated = view_2d_slice(cur, i_altup_act); + + // altup predict + ggml_tensor * predictions; + { + ggml_tensor * modalities = compute_router_modalities(activated, il); // [n_altup, n_tokens] + cb(modalities, "modalities", il); + + ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_predict_coef, modalities); + // first dim now having n_altup^2 elements, we reshape it to 2D (so we end up with 3D tensor) + all_coefs = ggml_reshape_3d(ctx0, all_coefs, n_altup, n_altup, n_tokens); + + // permute to [n_altup, n_embd, n_tokens] + ggml_tensor * cur_permuted = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3)); + predictions = ggml_mul_mat(ctx0, cur_permuted, all_coefs); // [n_altup, n_embd, n_tokens] + + // final shape must be the same as cur: [n_embd, n_tokens, n_altup] + predictions = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 2, 1, 3)); + predictions = ggml_add(ctx0, predictions, cur); + cb(predictions, "predictions", il); + } + + // predicted value will go through self-attention and laurel + cur = view_2d_slice(predictions, i_altup_act); + cb(cur, "active_prediction", il); + + // norm + cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // laurel + ggml_tensor * laurel_out; + { + laurel_out = build_lora_mm(model.layers[il].laurel_l, cur); + laurel_out = build_lora_mm(model.layers[il].laurel_r, cur); + laurel_out = build_norm(laurel_out, model.layers[il].laurel_post_norm, NULL, LLM_NORM_RMS, il); + laurel_out = ggml_add(ctx0, laurel_out, cur); + cb(laurel_out, "laurel_out", il); + } + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // ⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️ + // ⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️ + // ⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️ + // CURRENTLY STUCKED HERE + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il); + } + + cur = build_norm(cur, + model.layers[il].attn_post_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + // if (il == n_layer - 1) { + // // skip computing output for unused tokens + // ggml_tensor * inp_out_ids = build_inp_out_ids(); + // cur = ggml_get_rows(ctx0, cur, inp_out_ids); + // inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + // } + + ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL); + cb(sa_out, "sa_out", il); + + cur = build_norm(sa_out, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // feed-forward network + { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + cur = build_norm(cur, + model.layers[il].ffn_post_norm, NULL, + LLM_NORM_RMS, -1); + cb(cur, "ffn_post_norm", -1); + + cur = ggml_add(ctx0, cur, sa_out); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + // TODO: move up next to build_starcoder struct llm_build_starcoder2 : public llm_graph_context { llm_build_starcoder2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { @@ -13445,6 +13681,10 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_GEMMA3N: + { + llm = std::make_unique(*this, params, gf); + } break; case LLM_ARCH_STARCODER2: { llm = std::make_unique(*this, params, gf); From 122a54f5245972155bfe2fd150c7e930b0d0079e Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 27 May 2025 10:54:36 +0200 Subject: [PATCH 04/21] wip --- src/llama-model.cpp | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 44f9cca6796..193ccd07333 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -8811,14 +8811,25 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { inp_per_layer = ggml_cont(ctx0, ggml_permute(ctx0, inp_per_layer, 0, 2, 1, 3)); printf("shape of inp_per_layer: %lld %lld %lld\n", inp_per_layer->ne[0], inp_per_layer->ne[1], inp_per_layer->ne[2]); cb(inp_per_layer, "inp_per_layer", -1); + ggml_build_forward_expand(gf, inp_per_layer); } + auto calc_magnitude = [&](ggml_tensor * x) { + return ggml_sqrt(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, x))); + }; + // inpL now has only 1 altup, project it to the rest of the altups // these "added" altups will be concat to the last dim of inpL { - ggml_tensor * altup_added = ggml_mul_mat(ctx0, model.altup_proj, inpL); // shape: [n_embd, n_tokens, n_altup - 1] - // TODO: missing new_magnitude / target_magnitude stuff + ggml_tensor * target_magnitude = calc_magnitude(inpL); + ggml_tensor * altup_added = ggml_mul_mat(ctx0, inpL, model.altup_proj); // because ggml only support broadcasting A, we do (B*A)^T instead of the normal A*B + altup_added = ggml_cont(ctx0, ggml_permute(ctx0, altup_added, 1, 0, 2, 3)); // shape: [n_embd, n_tokens, n_altup - 1] + ggml_tensor * new_magnitude = calc_magnitude(altup_added); + altup_added = ggml_div(ctx0, + ggml_mul(ctx0, altup_added, target_magnitude), + new_magnitude); inpL = ggml_concat(ctx0, inpL, altup_added, 2); // shape: [n_embd, n_tokens, n_altup] + cb(inpL, "inp_stacked", -1); } // inpL now has shape: [n_embd, n_tokens, n_altup] @@ -8827,7 +8838,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { // equivalent to compute_router_modalities() in python code // output shape: [n_altup, n_tokens] auto compute_router_modalities = [&](ggml_tensor * x, int il) { - ggml_tensor * router_inputs = build_norm(router_inputs, + ggml_tensor * router_inputs = build_norm(x, model.layers[il].altup_router_norm, NULL, LLM_NORM_RMS, il); @@ -8866,7 +8877,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { predictions = ggml_mul_mat(ctx0, cur_permuted, all_coefs); // [n_altup, n_embd, n_tokens] // final shape must be the same as cur: [n_embd, n_tokens, n_altup] - predictions = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 2, 1, 3)); + predictions = ggml_cont(ctx0, ggml_permute(ctx0, predictions, 0, 2, 1, 3)); predictions = ggml_add(ctx0, predictions, cur); cb(predictions, "predictions", il); } @@ -8882,8 +8893,9 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { // laurel ggml_tensor * laurel_out; { - laurel_out = build_lora_mm(model.layers[il].laurel_l, cur); - laurel_out = build_lora_mm(model.layers[il].laurel_r, cur); + laurel_out = cur; + laurel_out = build_lora_mm(model.layers[il].laurel_l, laurel_out); + laurel_out = build_lora_mm(model.layers[il].laurel_r, laurel_out); laurel_out = build_norm(laurel_out, model.layers[il].laurel_post_norm, NULL, LLM_NORM_RMS, il); laurel_out = ggml_add(ctx0, laurel_out, cur); cb(laurel_out, "laurel_out", il); @@ -8941,6 +8953,8 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { // inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); // } + cur = ggml_repeat(ctx0, cur, inpL); // DUMMY, REMOVE IT LATER + ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL); cb(sa_out, "sa_out", il); @@ -8983,8 +8997,9 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { cb(cur, "result_norm", -1); res->t_embd = cur; - // lm_head - cur = build_lora_mm(model.output, cur); + // DUMMY + cur = view_2d_slice(cur, 0); + cur = build_lora_mm(model.tok_embd, cur); cb(cur, "result_output", -1); res->t_logits = cur; From ba8dbcc1aca75e520600a8a18df4f2f6bad9c0eb Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 27 May 2025 18:30:05 +0200 Subject: [PATCH 05/21] activations matched until attn_post_norm --- convert_hf_to_gguf.py | 4 ++++ src/llama-model.cpp | 42 +++++++++++++++++++++++++++++------------- 2 files changed, 33 insertions(+), 13 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 3bf52e8c894..c5285a3a3ff 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4130,6 +4130,10 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter # TODO: implement self.prediction_coefs.weight.clamp_(...) + if "embed_tokens_per_layer.weight" in name: + hidden_size_per_layer_input = 256 + data_torch = data_torch * (hidden_size_per_layer_input**0.5) + if "altup_unembed_projections" in name: data_torch = data_torch.to(device="cpu") if ".0." in name: diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 193ccd07333..9c753be125d 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -980,7 +980,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.rope_freq_base_train_swa = 10000.0f; hparams.rope_freq_scale_train_swa = 1.0f; - hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k)); + hparams.f_attention_scale = 32.0f / 256.0f; // == query_rescale_scalar / query_pre_attn_scalar ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -8802,16 +8802,19 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { // this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim { ggml_tensor * per_layer_proj = ggml_mul_mat(ctx0, model.per_layer_model_proj, inpL); + // shape: [n_embd, n_tokens] per_layer_proj = ggml_scale(ctx0, inp_per_layer, 1.0f / sqrtf(n_embd)); // per_layer_projection_scale per_layer_proj = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens); inp_per_layer = ggml_add(ctx0, inp_per_layer, per_layer_proj); inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf(2.0)); // per_layer_input_scale + cb(inp_per_layer, "inp_per_layer", -1); + // permute to shape: [n_embd_altup, n_tokens, n_layer] inp_per_layer = ggml_cont(ctx0, ggml_permute(ctx0, inp_per_layer, 0, 2, 1, 3)); - printf("shape of inp_per_layer: %lld %lld %lld\n", inp_per_layer->ne[0], inp_per_layer->ne[1], inp_per_layer->ne[2]); - cb(inp_per_layer, "inp_per_layer", -1); ggml_build_forward_expand(gf, inp_per_layer); + + // @ngxson: matched activations ✅ } auto calc_magnitude = [&](ggml_tensor * x) { @@ -8822,14 +8825,17 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { // these "added" altups will be concat to the last dim of inpL { ggml_tensor * target_magnitude = calc_magnitude(inpL); - ggml_tensor * altup_added = ggml_mul_mat(ctx0, inpL, model.altup_proj); // because ggml only support broadcasting A, we do (B*A)^T instead of the normal A*B - altup_added = ggml_cont(ctx0, ggml_permute(ctx0, altup_added, 1, 0, 2, 3)); // shape: [n_embd, n_tokens, n_altup - 1] + // TODO: use ggml_repeat_4d for this ⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️ + ggml_tensor * inp_repeated = ggml_repeat(ctx0, inpL, ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd, n_tokens, n_altup - 1)); + ggml_tensor * altup_added = ggml_mul_mat(ctx0, model.altup_proj, inp_repeated); // shape: [n_embd, n_tokens, n_altup - 1] ggml_tensor * new_magnitude = calc_magnitude(altup_added); altup_added = ggml_div(ctx0, ggml_mul(ctx0, altup_added, target_magnitude), new_magnitude); inpL = ggml_concat(ctx0, inpL, altup_added, 2); // shape: [n_embd, n_tokens, n_altup] cb(inpL, "inp_stacked", -1); + + // @ngxson: matched activations ✅ } // inpL now has shape: [n_embd, n_tokens, n_altup] @@ -8845,7 +8851,9 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { // router_input_scale router_inputs = ggml_scale(ctx0, router_inputs, 1.0f / (float)n_embd); - return ggml_mul_mat(ctx0, model.layers[il].altup_router, router_inputs); + ggml_tensor * output = ggml_mul_mat(ctx0, model.layers[il].altup_router, router_inputs); + return ggml_tanh(ctx0, output); // [n_altup, n_tokens] + // @ngxson: matched activations ✅ }; // get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim @@ -8880,6 +8888,8 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { predictions = ggml_cont(ctx0, ggml_permute(ctx0, predictions, 0, 2, 1, 3)); predictions = ggml_add(ctx0, predictions, cur); cb(predictions, "predictions", il); + + // @ngxson: matched activations ✅ } // predicted value will go through self-attention and laurel @@ -8917,6 +8927,14 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Qcur = ggml_rms_norm(ctx0, Qcur, hparams.f_norm_rms_eps); + Kcur = ggml_rms_norm(ctx0, Kcur, hparams.f_norm_rms_eps); + Vcur = ggml_rms_norm(ctx0, Vcur, hparams.f_norm_rms_eps); + + cb(Qcur, "Qcur_normed", il); + cb(Kcur, "Kcur_normed", il); + cb(Vcur, "Vcur_normed", il); + Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, @@ -8927,14 +8945,10 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow); - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + cb(Qcur, "Qcur_pos", il); + cb(Kcur, "Kcur_pos", il); - // ⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️ - // ⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️ - // ⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️ - // CURRENTLY STUCKED HERE + // SOME LAYERS DOES NOT HAVE KV ⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️ cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, @@ -8946,6 +8960,8 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { LLM_NORM_RMS, il); cb(cur, "attn_post_norm", il); + // @ngxson: matched activations ✅ (layer 0) + // if (il == n_layer - 1) { // // skip computing output for unused tokens // ggml_tensor * inp_out_ids = build_inp_out_ids(); From c533872e9396cd0c941308af4217e158c7e8f884 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 31 May 2025 23:26:03 +0200 Subject: [PATCH 06/21] matched activations until `l_out-19` --- src/llama-model.cpp | 104 ++++++++++++++++++++++++++++++++++---------- 1 file changed, 81 insertions(+), 23 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 9c753be125d..b979d68f9c8 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -8781,7 +8781,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { // TODO: is causal == true correct? might need some changes auto * inp_attn = build_attn_inp_kv_unified_iswa(); - ggml_tensor * inp_per_layer; + ggml_tensor * inp_per_layer; // [n_embd_altup, n_tokens, n_layer] // equivalent to get_per_layer_inputs() in python code { @@ -8792,6 +8792,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { res->t_tokens = inp->tokens; inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens); inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens); + cb(inp_per_layer, "inp_per_layer_selected", -1); } else { GGML_ABORT("TODO: support embd input"); } @@ -8802,19 +8803,19 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { // this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim { ggml_tensor * per_layer_proj = ggml_mul_mat(ctx0, model.per_layer_model_proj, inpL); + per_layer_proj = ggml_scale(ctx0, per_layer_proj, 1.0f / sqrtf((float)n_embd)); // per_layer_projection_scale + per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens); + per_layer_proj = build_norm(per_layer_proj, + model.per_layer_proj_norm, NULL, + LLM_NORM_RMS, -1); // [n_embd_altup, n_layer, n_tokens] + cb(per_layer_proj, "per_layer_proj", -1); - // shape: [n_embd, n_tokens] - per_layer_proj = ggml_scale(ctx0, inp_per_layer, 1.0f / sqrtf(n_embd)); // per_layer_projection_scale - per_layer_proj = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens); inp_per_layer = ggml_add(ctx0, inp_per_layer, per_layer_proj); - inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf(2.0)); // per_layer_input_scale + inp_per_layer = ggml_scale(ctx0, inp_per_layer, 1.0f / sqrtf(2.0)); // per_layer_input_scale cb(inp_per_layer, "inp_per_layer", -1); // permute to shape: [n_embd_altup, n_tokens, n_layer] inp_per_layer = ggml_cont(ctx0, ggml_permute(ctx0, inp_per_layer, 0, 2, 1, 3)); - ggml_build_forward_expand(gf, inp_per_layer); - - // @ngxson: matched activations ✅ } auto calc_magnitude = [&](ggml_tensor * x) { @@ -8834,8 +8835,6 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { new_magnitude); inpL = ggml_concat(ctx0, inpL, altup_added, 2); // shape: [n_embd, n_tokens, n_altup] cb(inpL, "inp_stacked", -1); - - // @ngxson: matched activations ✅ } // inpL now has shape: [n_embd, n_tokens, n_altup] @@ -8853,7 +8852,6 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { ggml_tensor * output = ggml_mul_mat(ctx0, model.layers[il].altup_router, router_inputs); return ggml_tanh(ctx0, output); // [n_altup, n_tokens] - // @ngxson: matched activations ✅ }; // get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim @@ -8863,6 +8861,13 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { idx * x->ne[0] * x->ne[1] * ggml_element_size(x)); }; + ggml_tensor * one; // containing single element 1.0f + { + one = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + one = ggml_scale(ctx0, one, 0.0f); + one = ggml_cos(ctx0, one); + } + for (int il = 0; il < n_layer; ++il) { const float freq_base_l = model.get_rope_freq_base (cparams, il); const float freq_scale_l = model.get_rope_freq_scale(cparams, il); @@ -8877,6 +8882,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { cb(modalities, "modalities", il); ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_predict_coef, modalities); + cb(all_coefs, "all_coefs", il); // first dim now having n_altup^2 elements, we reshape it to 2D (so we end up with 3D tensor) all_coefs = ggml_reshape_3d(ctx0, all_coefs, n_altup, n_altup, n_tokens); @@ -8888,12 +8894,11 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { predictions = ggml_cont(ctx0, ggml_permute(ctx0, predictions, 0, 2, 1, 3)); predictions = ggml_add(ctx0, predictions, cur); cb(predictions, "predictions", il); - - // @ngxson: matched activations ✅ } // predicted value will go through self-attention and laurel - cur = view_2d_slice(predictions, i_altup_act); + ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act); // [n_embd, n_tokens] + cur = active_prediction; cb(cur, "active_prediction", il); // norm @@ -8948,7 +8953,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { cb(Qcur, "Qcur_pos", il); cb(Kcur, "Kcur_pos", il); - // SOME LAYERS DOES NOT HAVE KV ⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️ + // SOME LAYERS DOES NOT HAVE KV ⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️ cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, @@ -8960,8 +8965,6 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { LLM_NORM_RMS, il); cb(cur, "attn_post_norm", il); - // @ngxson: matched activations ✅ (layer 0) - // if (il == n_layer - 1) { // // skip computing output for unused tokens // ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -8969,18 +8972,22 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { // inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); // } - cur = ggml_repeat(ctx0, cur, inpL); // DUMMY, REMOVE IT LATER + cur = ggml_add(ctx0, cur, active_prediction); // [n_embd, n_tokens] + cb(cur, "attn_gated", il); - ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL); - cb(sa_out, "sa_out", il); + ggml_tensor * attn_laurel = ggml_scale(ctx0, + ggml_add(ctx0, cur, laurel_out), + 1.0f / sqrtf(2.0f)); // [n_embd, n_tokens] + cb(attn_laurel, "attn_laurel", il); - cur = build_norm(sa_out, + cur = build_norm(attn_laurel, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "ffn_norm", il); // feed-forward network { + // missing icdf ⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️ cur = build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, @@ -8993,10 +9000,61 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, -1); - cb(cur, "ffn_post_norm", -1); + cb(cur, "ffn_post_norm", il); - cur = ggml_add(ctx0, cur, sa_out); + ggml_tensor * attn_ffw_laurel_gated = ggml_add(ctx0, cur, attn_laurel); // [n_embd, n_tokens] + cb(attn_ffw_laurel_gated, "attn_ffw_laurel_gated", il); + + ggml_tensor * corrected; // [n_embd, n_tokens, n_altup] + { + ggml_tensor * activated = attn_ffw_laurel_gated; + ggml_tensor * modalities = compute_router_modalities(activated, il); // [n_altup, n_tokens] + cb(modalities, "modalities", il); + + ggml_tensor * innovation = ggml_sub(ctx0, activated, active_prediction); // [n_embd, n_tokens] + cb(innovation, "innovation", il); + + ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_correct_coef, modalities); // [n_altup, n_tokens] + all_coefs = ggml_add(ctx0, all_coefs, one); + cb(all_coefs, "all_coefs", il); + all_coefs = ggml_cont(ctx0, ggml_transpose(ctx0, all_coefs)); // [n_tokens, n_altup] + all_coefs = ggml_reshape_3d(ctx0, all_coefs, 1, n_tokens, n_altup); // [1, n_tokens, n_altup] + + // TODO: use ggml_repeat_4d for this ⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️ + innovation = ggml_repeat(ctx0, innovation, ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd, n_tokens, n_altup)); + corrected = ggml_mul(ctx0, innovation, all_coefs); // [n_embd, n_tokens, n_altup] + corrected = ggml_add(ctx0, corrected, predictions); // [n_embd, n_tokens, n_altup] + cb(corrected, "corrected", il); + } + + ggml_tensor * first_prediction = view_2d_slice(corrected, i_altup_act); // [n_embd, n_tokens] + { + first_prediction = ggml_mul(ctx0, first_prediction, model.layers[il].altup_correct_scale); + first_prediction = build_lora_mm(model.layers[il].per_layer_inp_gate, first_prediction); + first_prediction = ggml_gelu(ctx0, first_prediction); // [n_embd_altup, n_tokens] + cb(first_prediction, "first_prediction_gated", il); + ggml_tensor * inp_this_layer = view_2d_slice(inp_per_layer, il); // [n_embd_altup, n_tokens] + first_prediction = ggml_mul(ctx0, first_prediction, inp_this_layer); // [n_embd_altup, n_tokens] + cb(first_prediction, "first_prediction_scaled", il); + + first_prediction = build_lora_mm(model.layers[il].per_layer_proj, first_prediction); // [n_embd, n_tokens] + first_prediction = build_norm(first_prediction, + model.layers[il].per_layer_post_norm, NULL, + LLM_NORM_RMS, il); + cb(first_prediction, "first_prediction_out", il); + } + + // equivalent to python code: corrected_predictions[1:] += first_prediction + for (int i_alt = 1; i_alt < n_altup; ++i_alt) { + ggml_tensor * view = view_2d_slice(corrected, i_alt); // [n_embd, n_tokens] + ggml_tensor * tmp = ggml_add(ctx0, view, first_prediction); // [n_embd, n_tokens] + size_t offset = i_alt * corrected->ne[0] * corrected->ne[1] * ggml_element_size(view); + corrected = ggml_set(ctx0, corrected, tmp, + corrected->nb[1], corrected->nb[2], corrected->nb[3], + offset); // [n_embd, n_tokens, n_altup] + } + cur = corrected; // [n_embd, n_tokens, n_altup] cur = build_cvec(cur, il); cb(cur, "l_out", il); From cc80c3b87733f7e1c6b2e55932ff45ca02523abd Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 31 May 2025 23:28:26 +0200 Subject: [PATCH 07/21] small clean up --- src/llama-model.cpp | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index ddec0f685dd..4584bd8de7e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -8829,8 +8829,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { // these "added" altups will be concat to the last dim of inpL { ggml_tensor * target_magnitude = calc_magnitude(inpL); - // TODO: use ggml_repeat_4d for this ⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️ - ggml_tensor * inp_repeated = ggml_repeat(ctx0, inpL, ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd, n_tokens, n_altup - 1)); + ggml_tensor * inp_repeated = ggml_repeat_4d(ctx0, inpL, n_embd, n_tokens, n_altup - 1, 1); ggml_tensor * altup_added = ggml_mul_mat(ctx0, model.altup_proj, inp_repeated); // shape: [n_embd, n_tokens, n_altup - 1] ggml_tensor * new_magnitude = calc_magnitude(altup_added); altup_added = ggml_div(ctx0, @@ -9023,8 +9022,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { all_coefs = ggml_cont(ctx0, ggml_transpose(ctx0, all_coefs)); // [n_tokens, n_altup] all_coefs = ggml_reshape_3d(ctx0, all_coefs, 1, n_tokens, n_altup); // [1, n_tokens, n_altup] - // TODO: use ggml_repeat_4d for this ⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️ - innovation = ggml_repeat(ctx0, innovation, ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd, n_tokens, n_altup)); + innovation = ggml_repeat_4d(ctx0, innovation, n_embd, n_tokens, n_altup, 1); corrected = ggml_mul(ctx0, innovation, all_coefs); // [n_embd, n_tokens, n_altup] corrected = ggml_add(ctx0, corrected, predictions); // [n_embd, n_tokens, n_altup] cb(corrected, "corrected", il); @@ -9074,7 +9072,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { cb(cur, "result_norm", -1); res->t_embd = cur; - // DUMMY + // DUMMY ⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️ cur = view_2d_slice(cur, 0); cur = build_lora_mm(model.tok_embd, cur); From a66ac3f62321c713ff81c55ff216d27ff70c1693 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 1 Jun 2025 14:26:35 +0200 Subject: [PATCH 08/21] clean up, break into smaller fn --- src/llama-model.cpp | 296 +++++++++++++++++++++++++------------------- 1 file changed, 167 insertions(+), 129 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 4584bd8de7e..56403e76233 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -980,7 +980,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_GEMMA3N: { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.n_swa_pattern = 5; + hparams.set_swa_pattern(5); hparams.rope_freq_base_train_swa = 10000.0f; hparams.rope_freq_scale_train_swa = 1.0f; @@ -8761,15 +8761,30 @@ struct llm_build_gemma3_iswa : public llm_graph_context { }; struct llm_build_gemma3n_iswa : public llm_graph_context { - llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_k; - const int64_t n_embd_altup = hparams.n_embd_altup; - const int64_t n_altup = hparams.n_altup; - const int i_altup_act = hparams.i_altup_act; - + const llama_model & model; + ggml_cgraph * gf; + + const int64_t n_embd_head; + const int64_t n_embd_altup; + const int64_t n_altup; + const int i_altup_act; + + ggml_tensor * one; // containing single element 1.0f + + llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) + : llm_graph_context(params), + model(model), + gf(gf), + n_embd_head(model.hparams.n_embd_head_k), + n_embd_altup(model.hparams.n_embd_altup), + n_altup(model.hparams.n_altup), + i_altup_act(model.hparams.i_altup_act) { ggml_tensor * cur; ggml_tensor * inpL; + one = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + one = ggml_cos(ctx0, ggml_scale(ctx0, one, 0.0f)); // cos(0.0f) = 1.0f + inpL = build_inp_embd(model.tok_embd); // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) @@ -8784,46 +8799,8 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { // TODO: is causal == true correct? might need some changes auto * inp_attn = build_attn_inp_kv_unified_iswa(); - ggml_tensor * inp_per_layer; // [n_embd_altup, n_tokens, n_layer] - - // equivalent to get_per_layer_inputs() in python code - { - auto inp = std::make_unique(); - if (ubatch.token) { - inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); - ggml_set_input(inp->tokens); - res->t_tokens = inp->tokens; - inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens); - inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens); - cb(inp_per_layer, "inp_per_layer_selected", -1); - } else { - GGML_ABORT("TODO: support embd input"); - } - res->add_input(std::move(inp)); - } - - // equivalent to project_per_layer_inputs() in python code - // this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim - { - ggml_tensor * per_layer_proj = ggml_mul_mat(ctx0, model.per_layer_model_proj, inpL); - per_layer_proj = ggml_scale(ctx0, per_layer_proj, 1.0f / sqrtf((float)n_embd)); // per_layer_projection_scale - per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens); - per_layer_proj = build_norm(per_layer_proj, - model.per_layer_proj_norm, NULL, - LLM_NORM_RMS, -1); // [n_embd_altup, n_layer, n_tokens] - cb(per_layer_proj, "per_layer_proj", -1); - - inp_per_layer = ggml_add(ctx0, inp_per_layer, per_layer_proj); - inp_per_layer = ggml_scale(ctx0, inp_per_layer, 1.0f / sqrtf(2.0)); // per_layer_input_scale - cb(inp_per_layer, "inp_per_layer", -1); - - // permute to shape: [n_embd_altup, n_tokens, n_layer] - inp_per_layer = ggml_cont(ctx0, ggml_permute(ctx0, inp_per_layer, 0, 2, 1, 3)); - } - - auto calc_magnitude = [&](ggml_tensor * x) { - return ggml_sqrt(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, x))); - }; + // inp_per_layer shape: [n_embd_altup, n_tokens, n_layer] + ggml_tensor * inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs()); // inpL now has only 1 altup, project it to the rest of the altups // these "added" altups will be concat to the last dim of inpL @@ -8842,61 +8819,14 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { // inpL now has shape: [n_embd, n_tokens, n_altup] // inp_per_layer now hasshape: [n_embd_altup, n_tokens, n_layer] - // equivalent to compute_router_modalities() in python code - // output shape: [n_altup, n_tokens] - auto compute_router_modalities = [&](ggml_tensor * x, int il) { - ggml_tensor * router_inputs = build_norm(x, - model.layers[il].altup_router_norm, NULL, - LLM_NORM_RMS, il); - - // router_input_scale - router_inputs = ggml_scale(ctx0, router_inputs, 1.0f / (float)n_embd); - - ggml_tensor * output = ggml_mul_mat(ctx0, model.layers[il].altup_router, router_inputs); - return ggml_tanh(ctx0, output); // [n_altup, n_tokens] - }; - - // get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim - auto view_2d_slice = [&](ggml_tensor * x, int idx) { - return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], - ggml_row_size(x->type, x->ne[0]), - idx * x->ne[0] * x->ne[1] * ggml_element_size(x)); - }; - - ggml_tensor * one; // containing single element 1.0f - { - one = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - one = ggml_scale(ctx0, one, 0.0f); - one = ggml_cos(ctx0, one); - } - for (int il = 0; il < n_layer; ++il) { + // this block is made to be closely resemble Gemma3p5DecoderLayer on python code + const float freq_base_l = model.get_rope_freq_base (cparams, il); const float freq_scale_l = model.get_rope_freq_scale(cparams, il); ggml_tensor * cur = inpL; // [n_embd, n_tokens, n_altup] - ggml_tensor * activated = view_2d_slice(cur, i_altup_act); - - // altup predict - ggml_tensor * predictions; - { - ggml_tensor * modalities = compute_router_modalities(activated, il); // [n_altup, n_tokens] - cb(modalities, "modalities", il); - - ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_predict_coef, modalities); - cb(all_coefs, "all_coefs", il); - // first dim now having n_altup^2 elements, we reshape it to 2D (so we end up with 3D tensor) - all_coefs = ggml_reshape_3d(ctx0, all_coefs, n_altup, n_altup, n_tokens); - - // permute to [n_altup, n_embd, n_tokens] - ggml_tensor * cur_permuted = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3)); - predictions = ggml_mul_mat(ctx0, cur_permuted, all_coefs); // [n_altup, n_embd, n_tokens] - - // final shape must be the same as cur: [n_embd, n_tokens, n_altup] - predictions = ggml_cont(ctx0, ggml_permute(ctx0, predictions, 0, 2, 1, 3)); - predictions = ggml_add(ctx0, predictions, cur); - cb(predictions, "predictions", il); - } + ggml_tensor * predictions = altup_predict(cur, il); // [n_embd, n_tokens, n_altup] // predicted value will go through self-attention and laurel ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act); // [n_embd, n_tokens] @@ -8908,15 +8838,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { cb(cur, "attn_norm", il); // laurel - ggml_tensor * laurel_out; - { - laurel_out = cur; - laurel_out = build_lora_mm(model.layers[il].laurel_l, laurel_out); - laurel_out = build_lora_mm(model.layers[il].laurel_r, laurel_out); - laurel_out = build_norm(laurel_out, model.layers[il].laurel_post_norm, NULL, LLM_NORM_RMS, il); - laurel_out = ggml_add(ctx0, laurel_out, cur); - cb(laurel_out, "laurel_out", il); - } + ggml_tensor * laurel_out = laurel(cur, il); // [n_embd, n_tokens] // self-attention { @@ -8978,8 +8900,8 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { cb(cur, "attn_gated", il); ggml_tensor * attn_laurel = ggml_scale(ctx0, - ggml_add(ctx0, cur, laurel_out), - 1.0f / sqrtf(2.0f)); // [n_embd, n_tokens] + ggml_add(ctx0, cur, laurel_out), + 1.0f / sqrtf(2.0f)); // [n_embd, n_tokens] cb(attn_laurel, "attn_laurel", il); cur = build_norm(attn_laurel, @@ -9007,29 +8929,11 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { ggml_tensor * attn_ffw_laurel_gated = ggml_add(ctx0, cur, attn_laurel); // [n_embd, n_tokens] cb(attn_ffw_laurel_gated, "attn_ffw_laurel_gated", il); - ggml_tensor * corrected; // [n_embd, n_tokens, n_altup] - { - ggml_tensor * activated = attn_ffw_laurel_gated; - ggml_tensor * modalities = compute_router_modalities(activated, il); // [n_altup, n_tokens] - cb(modalities, "modalities", il); - - ggml_tensor * innovation = ggml_sub(ctx0, activated, active_prediction); // [n_embd, n_tokens] - cb(innovation, "innovation", il); + ggml_tensor * corrected = altup_corrent(predictions, attn_ffw_laurel_gated, il); // [n_embd, n_tokens, n_altup] - ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_correct_coef, modalities); // [n_altup, n_tokens] - all_coefs = ggml_add(ctx0, all_coefs, one); - cb(all_coefs, "all_coefs", il); - all_coefs = ggml_cont(ctx0, ggml_transpose(ctx0, all_coefs)); // [n_tokens, n_altup] - all_coefs = ggml_reshape_3d(ctx0, all_coefs, 1, n_tokens, n_altup); // [1, n_tokens, n_altup] - - innovation = ggml_repeat_4d(ctx0, innovation, n_embd, n_tokens, n_altup, 1); - corrected = ggml_mul(ctx0, innovation, all_coefs); // [n_embd, n_tokens, n_altup] - corrected = ggml_add(ctx0, corrected, predictions); // [n_embd, n_tokens, n_altup] - cb(corrected, "corrected", il); - } - - ggml_tensor * first_prediction = view_2d_slice(corrected, i_altup_act); // [n_embd, n_tokens] + ggml_tensor * first_prediction; // [n_embd, n_tokens] { + first_prediction = view_2d_slice(corrected, i_altup_act); // [n_embd, n_tokens] first_prediction = ggml_mul(ctx0, first_prediction, model.layers[il].altup_correct_scale); first_prediction = build_lora_mm(model.layers[il].per_layer_inp_gate, first_prediction); first_prediction = ggml_gelu(ctx0, first_prediction); // [n_embd_altup, n_tokens] @@ -9081,6 +8985,140 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { ggml_build_forward_expand(gf, cur); } + + ggml_tensor * calc_magnitude(ggml_tensor * x) { + return ggml_sqrt(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, x))); + } + + // get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim + ggml_tensor * view_2d_slice(ggml_tensor * x, int idx) { + return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], + ggml_row_size(x->type, x->ne[0]), + idx * x->ne[0] * x->ne[1] * ggml_element_size(x)); + } + + // equivalent to get_per_layer_inputs() in python code + // output shape: [n_embd_altup, n_layer, n_tokens] + ggml_tensor * get_per_layer_inputs() { + auto inp = std::make_unique(); + ggml_tensor * inp_per_layer; + if (ubatch.token) { + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); + ggml_set_input(inp->tokens); + res->t_tokens = inp->tokens; + inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens); + inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens); + cb(inp_per_layer, "inp_per_layer_selected", -1); + } else { + GGML_ABORT("TODO: support embd input"); + } + res->add_input(std::move(inp)); + return inp_per_layer; + } + + // equivalent to project_per_layer_inputs() in python code + // this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim + // output shape: [n_embd_altup, n_tokens, n_layer] + ggml_tensor * project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) { + const float per_layer_projection_scale = 1.0f / sqrtf((float)n_embd); + const float per_layer_input_scale = 1.0f / sqrtf(2.0f); + + ggml_tensor * per_layer_proj = ggml_mul_mat(ctx0, model.per_layer_model_proj, inputs_embeds); + per_layer_proj = ggml_scale(ctx0, per_layer_proj, per_layer_projection_scale); + per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens); + per_layer_proj = build_norm(per_layer_proj, + model.per_layer_proj_norm, NULL, + LLM_NORM_RMS, -1); // [n_embd_altup, n_layer, n_tokens] + cb(per_layer_proj, "per_layer_proj", -1); + + inp_per_layer = ggml_add(ctx0, inp_per_layer, per_layer_proj); + inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale); + cb(inp_per_layer, "inp_per_layer", -1); + + // permute to shape: [n_embd_altup, n_tokens, n_layer] + inp_per_layer = ggml_cont(ctx0, ggml_permute(ctx0, inp_per_layer, 0, 2, 1, 3)); + return inp_per_layer; + } + + // input cur shape: [n_altup, n_tokens] + // output shape: [n_altup, n_tokens] + ggml_tensor * laurel(ggml_tensor * cur, int il) { + ggml_tensor * tmp = cur; + tmp = build_lora_mm(model.layers[il].laurel_l, tmp); + tmp = build_lora_mm(model.layers[il].laurel_r, tmp); + tmp = build_norm(tmp, model.layers[il].laurel_post_norm, NULL, LLM_NORM_RMS, il); + tmp = ggml_add(ctx0, tmp, cur); + cb(tmp, "laurel_out", il); + return tmp; + } + + // + // altup functions + // + + // equivalent to compute_router_modalities() in python code + // input x shape: [n_embd, n_tokens] + // output shape: [n_altup, n_tokens] + ggml_tensor * altup_compute_router_modalities(ggml_tensor * x, int il) { + ggml_tensor * router_inputs = build_norm(x, + model.layers[il].altup_router_norm, NULL, + LLM_NORM_RMS, il); + + // router_input_scale + router_inputs = ggml_scale(ctx0, router_inputs, 1.0f / (float)n_embd); + + ggml_tensor * output = ggml_mul_mat(ctx0, model.layers[il].altup_router, router_inputs); + return ggml_tanh(ctx0, output); // [n_altup, n_tokens] + } + + // input cur shape: [n_embd, n_tokens, n_altup] + // output shape: [n_embd, n_tokens, n_altup] + ggml_tensor * altup_predict(ggml_tensor * cur, int il) { + ggml_tensor * activated = view_2d_slice(cur, i_altup_act); // [n_embd, n_tokens] + ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens] + cb(modalities, "modalities", il); + + ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_predict_coef, modalities); + cb(all_coefs, "all_coefs", il); + // first dim now having n_altup^2 elements, we reshape it to 2D (so we end up with 3D tensor) + all_coefs = ggml_reshape_3d(ctx0, all_coefs, n_altup, n_altup, n_tokens); + + // permute to [n_altup, n_embd, n_tokens] + ggml_tensor * cur_permuted = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3)); + ggml_tensor * predictions = ggml_mul_mat(ctx0, cur_permuted, all_coefs); // [n_altup, n_embd, n_tokens] + + // final shape must be the same as cur: [n_embd, n_tokens, n_altup] + predictions = ggml_cont(ctx0, ggml_permute(ctx0, predictions, 0, 2, 1, 3)); + predictions = ggml_add(ctx0, predictions, cur); + cb(predictions, "predictions", il); + + return predictions; + } + + // input predictions shape: [n_embd, n_tokens, n_altup] + // input activated shape: [n_embd, n_tokens] + // output shape: [n_embd, n_tokens, n_altup] + ggml_tensor * altup_corrent(ggml_tensor * predictions, ggml_tensor * activated, int il) { + ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens] + cb(modalities, "modalities", il); + + ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act); + ggml_tensor * innovation = ggml_sub(ctx0, activated, active_prediction); // [n_embd, n_tokens] + cb(innovation, "innovation", il); + + ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_correct_coef, modalities); // [n_altup, n_tokens] + all_coefs = ggml_add(ctx0, all_coefs, one); + cb(all_coefs, "all_coefs", il); + all_coefs = ggml_cont(ctx0, ggml_transpose(ctx0, all_coefs)); // [n_tokens, n_altup] + all_coefs = ggml_reshape_3d(ctx0, all_coefs, 1, n_tokens, n_altup); // [1, n_tokens, n_altup] + + innovation = ggml_repeat_4d(ctx0, innovation, n_embd, n_tokens, n_altup, 1); + ggml_tensor * corrected = ggml_mul(ctx0, innovation, all_coefs); // [n_embd, n_tokens, n_altup] + corrected = ggml_add(ctx0, corrected, predictions); // [n_embd, n_tokens, n_altup] + cb(corrected, "corrected", il); + + return corrected; + } }; // TODO: move up next to build_starcoder From 787f73fe3cc96a3785bcdbc97505ec748da10ad7 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 1 Jun 2025 16:47:17 +0200 Subject: [PATCH 09/21] cpu-only is ok, still missing icdf --- src/llama-arch.cpp | 8 ++--- src/llama-graph.cpp | 32 +++++++++++++++++ src/llama-graph.h | 11 ++++++ src/llama-model.cpp | 86 ++++++++++++++++++++++++++++++++++++++------- 4 files changed, 120 insertions(+), 17 deletions(-) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index cb9410610b0..ea096a87530 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1719,18 +1719,18 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, - // altup / laurel + // altup / laurel (gemma 3n) {LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, {LLM_TENSOR_PER_LAYER_MODEL_PROJ, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_PER_LAYER_PROJ_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL}}, + {LLM_TENSOR_ALTUP_PROJ, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ALTUP_UNEMBD_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ALTUP_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_PER_LAYER_INP_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_PER_LAYER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_PER_LAYER_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, - {LLM_TENSOR_ALTUP_CORRECT_COEF, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ALTUP_CORRECT_COEF, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ALTUP_CORRECT_SCALE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, - {LLM_TENSOR_ALTUP_PREDICT_COEF, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ALTUP_PREDICT_COEF, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ALTUP_ROUTER, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ALTUP_ROUTER_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_LAUREL_L, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index b30f6fb4f41..bafa6c563e2 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1384,6 +1384,38 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } +ggml_tensor * llm_graph_context::build_attn_reuse_cache( + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * kq_mask, + float kq_scale, + int il_reuse, + int il) const { + const auto * kv_state_iswa = static_cast(mstate); + + // TODO @ngxson : this could be wrong + const auto * kv_state = hparams.is_swa(il_reuse) ? kv_state_iswa->get_swa() : kv_state_iswa->get_base(); + + ggml_tensor * q = q_cur; + ggml_tensor * k = kv_state->get_k(ctx0, il_reuse); + ggml_tensor * v = kv_state->get_v(ctx0, il_reuse); + + ggml_tensor * cur = build_attn_mha(gf, q, k, v, nullptr, kq_mask, nullptr, kq_scale); + cb(cur, "kqv_out", il); + + if (wo) { + cur = build_lora_mm(wo, cur); + } + + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + + return cur; +} + llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const { auto inp = std::make_unique(cross); diff --git a/src/llama-graph.h b/src/llama-graph.h index d1c5dd1bf03..0aa00f79751 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -601,6 +601,17 @@ struct llm_graph_context { float kq_scale, int il) const; + // reuse cache from a previous layer, leaving no modifications to the cache + ggml_tensor * build_attn_reuse_cache( + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * kq_mask, + float kq_scale, + int il_reuse, + int il) const; + // // recurrent // diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 56403e76233..a75c8567f7c 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -8768,6 +8768,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { const int64_t n_embd_altup; const int64_t n_altup; const int i_altup_act; + const int n_layer_kv = 20; // number of layers having KV ggml_tensor * one; // containing single element 1.0f @@ -8821,6 +8822,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { for (int il = 0; il < n_layer; ++il) { // this block is made to be closely resemble Gemma3p5DecoderLayer on python code + const bool has_kv = (il < n_layer_kv); const float freq_base_l = model.get_rope_freq_base (cparams, il); const float freq_scale_l = model.get_rope_freq_scale(cparams, il); @@ -8841,7 +8843,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { ggml_tensor * laurel_out = laurel(cur, il); // [n_embd, n_tokens] // self-attention - { + if (has_kv) { // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); cb(Qcur, "Qcur", il); @@ -8877,11 +8879,36 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { cb(Qcur, "Qcur_pos", il); cb(Kcur, "Kcur_pos", il); - // SOME LAYERS DOES NOT HAVE KV ⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️ - cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il); + } else { + // no KV layers + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + + Qcur = ggml_rms_norm(ctx0, Qcur, hparams.f_norm_rms_eps); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(Qcur, "Qcur_pos", il); + + // TODO: slice the KQ mask to get only output tokens + const bool is_swa = hparams.is_swa(il); + const int il_reuse = n_layer_kv - (is_swa ? 2 : 1); + const auto & kq_mask = is_swa ? inp_attn->get_kq_mask_swa() : inp_attn->get_kq_mask(); + // make sure the reused layer has the same SWA status as the current layer + GGML_ASSERT( + (is_swa && hparams.is_swa(il_reuse)) || + (!is_swa && !hparams.is_swa(il_reuse)) + ); + cur = build_attn_reuse_cache(gf, + model.layers[il].wo, NULL, + Qcur, kq_mask, hparams.f_attention_scale, il_reuse, il); } cur = build_norm(cur, @@ -8889,13 +8916,6 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { LLM_NORM_RMS, il); cb(cur, "attn_post_norm", il); - // if (il == n_layer - 1) { - // // skip computing output for unused tokens - // ggml_tensor * inp_out_ids = build_inp_out_ids(); - // cur = ggml_get_rows(ctx0, cur, inp_out_ids); - // inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); - // } - cur = ggml_add(ctx0, cur, active_prediction); // [n_embd, n_tokens] cb(cur, "attn_gated", il); @@ -8967,7 +8987,41 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { inpL = cur; } - cur = inpL; + cur = inpL; // [n_embd, n_tokens, n_altup] + + // cur now has multiple altup(s), we want to merge them back to 1 altup + { + ggml_tensor * target_magnitude = calc_magnitude(view_2d_slice(cur, i_altup_act)); // [n_embd, n_tokens] + // do a view to skip the first slice (active altup) + ggml_tensor * alt_slice = ggml_view_3d(ctx0, cur, n_embd, n_tokens, n_altup - 1, + ggml_row_size(cur->type, n_embd), + ggml_row_size(cur->type, n_embd*n_tokens), + n_embd*n_tokens*ggml_element_size(cur)); + ggml_tensor * altup_unembd = ggml_mul_mat(ctx0, model.altup_unembd_proj, alt_slice); // shape: [n_embd, n_tokens, n_altup - 1] + ggml_tensor * new_magnitude = calc_magnitude(altup_unembd); + altup_unembd = ggml_div(ctx0, + ggml_mul(ctx0, altup_unembd, target_magnitude), + new_magnitude); + cb(altup_unembd, "altup_unembd", -1); + + // equivalent to torch.mean(hidden_states, dim=0) + cur = view_2d_slice(cur, 0); // [n_embd, n_tokens] + for (int i = 0; i < n_altup - 1; ++i) { + cur = ggml_add(ctx0, cur, view_2d_slice(altup_unembd, i)); + } + cur = ggml_scale(ctx0, cur, 1.0f / float(n_altup)); // [n_embd, n_tokens] + cb(cur, "unembd_merged", -1); + } + + // cur now has shape: [n_embd, n_tokens] + + // TODO @ngxson : move this to right after the last KV layer ⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️ + { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + //inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } cur = build_norm(cur, model.output_norm, NULL, @@ -8976,10 +9030,15 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { cb(cur, "result_norm", -1); res->t_embd = cur; - // DUMMY ⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️ - cur = view_2d_slice(cur, 0); cur = build_lora_mm(model.tok_embd, cur); + { + // final logit soft-capping + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); + cur = ggml_tanh(ctx0, cur); + cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping); + } + cb(cur, "result_output", -1); res->t_logits = cur; @@ -8992,6 +9051,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { // get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim ggml_tensor * view_2d_slice(ggml_tensor * x, int idx) { + GGML_ASSERT(idx < (int)x->ne[2]); return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]), idx * x->ne[0] * x->ne[1] * ggml_element_size(x)); From 80bee4ed460faf5fc9b03bc1c86804a93355f2cf Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 1 Jun 2025 17:45:14 +0200 Subject: [PATCH 10/21] replace ggml_set with ggml_concat --- src/llama-model.cpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a75c8567f7c..10c1fda6343 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -8970,13 +8970,14 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { } // equivalent to python code: corrected_predictions[1:] += first_prediction - for (int i_alt = 1; i_alt < n_altup; ++i_alt) { - ggml_tensor * view = view_2d_slice(corrected, i_alt); // [n_embd, n_tokens] - ggml_tensor * tmp = ggml_add(ctx0, view, first_prediction); // [n_embd, n_tokens] - size_t offset = i_alt * corrected->ne[0] * corrected->ne[1] * ggml_element_size(view); - corrected = ggml_set(ctx0, corrected, tmp, - corrected->nb[1], corrected->nb[2], corrected->nb[3], - offset); // [n_embd, n_tokens, n_altup] + { + ggml_tensor * slice_first = view_2d_slice(corrected, 0); + ggml_tensor * slice_rest = ggml_view_3d(ctx0, corrected, n_embd, n_tokens, n_altup - 1, + ggml_row_size(corrected->type, n_embd), + ggml_row_size(corrected->type, n_embd*n_tokens), + n_embd*n_tokens*ggml_element_size(corrected)); + ggml_tensor * tmp = ggml_add(ctx0, slice_rest, first_prediction); // [n_embd, n_tokens, n_altup - 1] + corrected = ggml_concat(ctx0, slice_first, tmp, 2); // [n_embd, n_tokens, n_altup] } cur = corrected; // [n_embd, n_tokens, n_altup] From ee6703aef8a33052b37912ecf492c43b588a9c8b Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 1 Jun 2025 18:51:05 +0200 Subject: [PATCH 11/21] matched text output --- src/llama-model.cpp | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 10c1fda6343..e84843272b6 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -8769,6 +8769,8 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { const int64_t n_altup; const int i_altup_act; const int n_layer_kv = 20; // number of layers having KV + const int n_layer_sparsity = 10; // number of layers using activation sparsity + const float f_sparsity_std_mul = 1.6448533535003662f; // std_multiplier = normal_dist.icdf(0.95) ggml_tensor * one; // containing single element 1.0f @@ -8931,13 +8933,17 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { // feed-forward network { - // missing icdf ⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️ - cur = build_ffn(cur, - model.layers[il].ffn_up, NULL, NULL, - model.layers[il].ffn_gate, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, - NULL, - LLM_FFN_GELU, LLM_FFN_PAR, il); + ggml_tensor * up_proj = build_lora_mm(model.layers[il].ffn_up, cur); + ggml_tensor * gate_proj = build_lora_mm(model.layers[il].ffn_gate, cur); + + if (il < n_layer_sparsity) { + // apply activation sparsity + gate_proj = gaussian_topk(gate_proj); + } + gate_proj = ggml_gelu(ctx0, gate_proj); + + cur = ggml_mul(ctx0, up_proj, gate_proj); + cur = build_lora_mm(model.layers[il].ffn_down, cur); cb(cur, "ffn_out", il); } @@ -9113,6 +9119,18 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { return tmp; } + // input x shape: [n_embd, n_tokens] + // output shape: [n_embd, n_tokens] + ggml_tensor * gaussian_topk(ggml_tensor * x) { + ggml_tensor * mean = ggml_mean(ctx0, x); + ggml_tensor * std = ggml_sqrt(ctx0, ggml_scale(ctx0, + ggml_sum_rows(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x, mean))), + 1.0f / (float)(x->ne[0] - 1) + )); + ggml_tensor * cutoff_x = ggml_add(ctx0, mean, ggml_scale(ctx0, std, f_sparsity_std_mul)); + return ggml_relu(ctx0, ggml_sub(ctx0, x, cutoff_x)); + } + // // altup functions // From 4d3d0aeb78d58b775024d7b63d4ac0bec8fbd443 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 19 Jun 2025 15:55:36 +0200 Subject: [PATCH 12/21] fix merge conflict --- src/llama-graph.h | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/llama-graph.h b/src/llama-graph.h index 451fb3aadf5..f2a408157bb 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -606,7 +606,7 @@ struct llm_graph_context { llm_graph_input_attn_cross * build_attn_inp_cross() const; ggml_tensor * build_attn( - llm_graph_input_mem_hybrid * inp, + llm_graph_input_attn_cross * inp, ggml_cgraph * gf, ggml_tensor * wo, ggml_tensor * wo_b, @@ -629,6 +629,18 @@ struct llm_graph_context { int il_reuse, int il) const; + ggml_tensor * build_attn( + llm_graph_input_mem_hybrid * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] + ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] + ggml_tensor * kq_b, + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] + float kq_scale, + int il) const; // // recurrent // From d8589f8bb45bbcaef049dc780b8721f8113ebc7c Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 19 Jun 2025 22:34:12 +0200 Subject: [PATCH 13/21] update text conversion --- convert_hf_to_gguf.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 7db7abe3a05..0e81f46716c 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4262,7 +4262,7 @@ def set_gguf_parameters(self): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unused - if name.startswith("language_model."): + if "language_model." in name: name = name.replace("language_model.", "") elif name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \ @@ -4336,7 +4336,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [] # skip other tensors -@ModelBase.register("Gemma3p5ForCausalLM") +@ModelBase.register("Gemma3nForConditionalGeneration") class Gemma3NModel(Gemma3Model): model_arch = gguf.MODEL_ARCH.GEMMA3N norm_shift = 0.0 # same value with Gemma3p5RMSNorm scale_shift on python code @@ -4374,6 +4374,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter # TODO: implement self.prediction_coefs.weight.clamp_(...) + if "language_model." not in name: + return [] # skip non-language model tensors + if "embed_tokens_per_layer.weight" in name: hidden_size_per_layer_input = 256 data_torch = data_torch * (hidden_size_per_layer_input**0.5) From e5fe215414bc4f3d0ef88d770f75d66bfd9498ad Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 22 Jun 2025 00:29:14 +0200 Subject: [PATCH 14/21] add kq weighted norm --- gguf-py/gguf/constants.py | 2 ++ src/llama-arch.cpp | 2 ++ src/llama-model.cpp | 12 +++++++----- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 0ee66911eb6..bef090d31b5 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -1524,7 +1524,9 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, MODEL_TENSOR.ATTN_V, MODEL_TENSOR.ATTN_OUT, MODEL_TENSOR.FFN_GATE, diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index e3265ab88b0..3a7cfece342 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -939,7 +939,9 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c1d683d7d59..7fb4326bba7 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2990,6 +2990,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -9139,8 +9141,8 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rms_norm(ctx0, Qcur, hparams.f_norm_rms_eps); - Kcur = ggml_rms_norm(ctx0, Kcur, hparams.f_norm_rms_eps); + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); Vcur = ggml_rms_norm(ctx0, Vcur, hparams.f_norm_rms_eps); cb(Qcur, "Qcur_normed", il); @@ -9162,14 +9164,14 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0, il); } else { // no KV layers ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); cb(Qcur, "Qcur", il); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Qcur = ggml_rms_norm(ctx0, Qcur, hparams.f_norm_rms_eps); + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); Qcur = ggml_rope_ext( @@ -9189,7 +9191,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { ); cur = build_attn_reuse_cache(gf, model.layers[il].wo, NULL, - Qcur, kq_mask, hparams.f_attention_scale, il_reuse, il); + Qcur, kq_mask, 1.0, il_reuse, il); } cur = build_norm(cur, From d42071e25c7dbd60c01b4d75d49bea38fa792c18 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 22 Jun 2025 00:54:03 +0200 Subject: [PATCH 15/21] fix f_attention_scale --- src/llama-model.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 7fb4326bba7..995b76a9e52 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1024,7 +1024,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.rope_freq_base_train_swa = 10000.0f; hparams.rope_freq_scale_train_swa = 1.0f; - hparams.f_attention_scale = 32.0f / 256.0f; // == query_rescale_scalar / query_pre_attn_scalar + hparams.f_attention_scale = 1.0f; ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -9164,7 +9164,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0, il); + Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il); } else { // no KV layers ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -9191,7 +9191,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { ); cur = build_attn_reuse_cache(gf, model.layers[il].wo, NULL, - Qcur, kq_mask, 1.0, il_reuse, il); + Qcur, kq_mask, hparams.f_attention_scale, il_reuse, il); } cur = build_norm(cur, From f4e70342c00ab967cc5f3c8ce5d5aa5cfc2373c0 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 22 Jun 2025 16:11:01 +0200 Subject: [PATCH 16/21] init --- tools/mtmd/clip-mobilenet.h | 285 ++++++++++++++++++++++++++++++++++++ 1 file changed, 285 insertions(+) create mode 100644 tools/mtmd/clip-mobilenet.h diff --git a/tools/mtmd/clip-mobilenet.h b/tools/mtmd/clip-mobilenet.h new file mode 100644 index 00000000000..3c95ba14e29 --- /dev/null +++ b/tools/mtmd/clip-mobilenet.h @@ -0,0 +1,285 @@ +#pragma once + +#include "clip.h" +#include "clip-impl.h" +#include "ggml.h" + +#include +#include +#include + +using get_tensor_fn = std::function; + +static ggml_tensor * conv2d(ggml_context * ctx, ggml_tensor * kernel, ggml_tensor * inp, int strides, bool depthwise = false) { + int p0 = 0; + int p1 = 0; + + { + const int kernel_size = kernel->ne[0]; + + auto compute_padding_length = [](int input_length, int kernel_length, int stride) { + int total_padding_length = (kernel_length - 1) - (input_length - 1) % stride; + int left_padding = total_padding_length / 2; + int right_padding = (total_padding_length + 1) / 2; + return std::make_pair(left_padding, right_padding); + }; + + auto [left, right] = compute_padding_length(inp->ne[0], kernel_size, strides); + auto [top, bottom] = compute_padding_length(inp->ne[1], kernel_size, strides); + + if (left > 0 && right > 0) { + p0 = std::min(left, right); + left -= p0; + right -= p0; + } + + if (top > 0 && bottom > 0) { + p1 = std::min(top, bottom); + top -= p1; + bottom -= p1; + } + + GGML_ASSERT(left == 0 && top == 0); + + if (right != 0 || bottom != 0) { + inp = ggml_pad(ctx, inp, right, bottom, 0, 0); + } + } + + ggml_tensor * cur; + + if (depthwise) { + cur = ggml_conv_2d_dw(ctx, + kernel, inp, + strides, strides, + p0, p1, 1, 1); + } else { + cur = ggml_conv_2d(ctx, + kernel, inp, + strides, strides, + p0, p1, 1, 1); + } + + return cur; +} + +struct mobilenet_g3n_blk { + virtual ggml_tensor * build(ggml_context * ctx, ggml_tensor * cur) = 0; + virtual void load_tensors(const std::string & prefix, get_tensor_fn & get_tensor) = 0; + virtual ~mobilenet_g3n_blk() = default; +}; + +// ConvNormAct +struct mobilenet_g3n_cna : mobilenet_g3n_blk { + int kernel_size = 0; + int stride = 1; + int dilation = 1; + int filters = 0; + float expand_ratio = 1.0f; + + ggml_tensor * norm = nullptr; + ggml_tensor * conv = nullptr; + + virtual void load_tensors(const std::string & prefix, get_tensor_fn & get_tensor) { + std::string tmp; + tmp = prefix + "bn.weight"; + norm = get_tensor(tmp); + tmp = prefix + "conv.weight"; + conv = get_tensor(tmp); + } + + virtual ggml_tensor * build(ggml_context * ctx, ggml_tensor * cur) { + cur = conv2d(ctx, conv, cur, stride, true); + cur = ggml_group_norm(ctx, cur, std::min(32, filters * expand_ratio / 4), 1e-6f); + cur = ggml_mul(ctx, cur, ggml_reshape_3d(ctx, norm, 1, 1, norm->ne[0])); + return cur; + } +}; + +// EdgeResidual +struct mobilenet_g3n_er : mobilenet_g3n_blk { + int kernel_size = 0; + int stride = 1; + int filters = 0; + + ggml_tensor * norm1 = nullptr; + ggml_tensor * norm2 = nullptr; + ggml_tensor * conv_exp = nullptr; + ggml_tensor * conv_pwl = nullptr; + + mobilenet_g3n_er(int kernel_size, int filters, int stride) : + kernel_size(kernel_size), stride(stride), filters(filters) {} + + virtual void load_tensors(const std::string & prefix, get_tensor_fn & get_tensor) { + std::string tmp; + tmp = prefix + "bn1.weight"; + norm1 = get_tensor(tmp); + tmp = prefix + "bn2.weight"; + norm2 = get_tensor(tmp); + tmp = prefix + "conv_exp.weight"; + conv_exp = get_tensor(tmp); + tmp = prefix + "conv_pwl.weight"; + conv_pwl = get_tensor(tmp); + } + + virtual ggml_tensor * build(ggml_context * ctx, ggml_tensor * cur) { + return cur; + } +}; + +// UniversalInvertedResidual +struct mobilenet_g3n_uir : mobilenet_g3n_blk { + int start_dw_kernel_size = 0; + int mid_dw_kernel_size = 0; + bool multiscale = false; + ggml_tensor * layer_scale = nullptr; + + mobilenet_g3n_cna dw_start; + mobilenet_g3n_cna dw_mid; + mobilenet_g3n_cna dw_end; + mobilenet_g3n_cna dw_proj; + + mobilenet_g3n_uir(int start_dw_kernel_size, int mid_dw_kernel_size, int filters, int stride = 1, float expand_ratio = 4.0f, bool multiscale = false) : + start_dw_kernel_size(start_dw_kernel_size), + mid_dw_kernel_size(mid_dw_kernel_size), + multiscale(multiscale) { + dw_start.stride = stride; + dw_start.filters = filters; + dw_start.expand_ratio = expand_ratio; + + dw_mid.stride = 1; + dw_mid.filters = filters; + dw_mid.expand_ratio = expand_ratio; + + dw_end.stride = 1; + dw_end.filters = filters; + dw_end.expand_ratio = expand_ratio; + + dw_proj.stride = 1; + dw_proj.filters = filters; + dw_proj.expand_ratio = 1.0f; // projection does not expand + } + + virtual void load_tensors(const std::string & prefix, get_tensor_fn & get_tensor) { + dw_start.load_tensors(prefix + "dw_start.", get_tensor); + dw_mid.load_tensors(prefix + "dw_mid.", get_tensor); + dw_end.load_tensors(prefix + "dw_end.", get_tensor); + dw_proj.load_tensors(prefix + "dw_proj.", get_tensor); + layer_scale = get_tensor(prefix + "layer_scale.weight"); + } + + virtual ggml_tensor * build(ggml_context * ctx, ggml_tensor * cur) { + if (dw_start.conv) { + cur = dw_start.build(ctx, cur); + } + if (dw_mid.conv) { + cur = dw_mid.build(ctx, cur); + } + if (dw_end.conv) { + cur = dw_end.build(ctx, cur); + } + if (dw_proj.conv) { + cur = dw_proj.build(ctx, cur); + } + + cur = ggml_mul(ctx, cur, ggml_reshape_3d(ctx, layer_scale, 1, 1, layer_scale->ne[0])); + + return cur; + } +}; + +// MultiQueryAttentionBlock +struct mobilenet_g3n_mmqa : mobilenet_g3n_blk { + int num_heads = 0; + int kv_strides = 0; + int kv_dim = 0; + bool mmqa_avg_pool_kv = false; + bool multiscale = false; + + mobilenet_g3n_mmqa(int num_heads, int kv_dim, int kv_strides, + bool mmqa_avg_pool_kv = false, bool multiscale = false) : + num_heads(num_heads), kv_dim(kv_dim), kv_strides(kv_strides), + mmqa_avg_pool_kv(mmqa_avg_pool_kv), multiscale(multiscale) {} + + virtual void load_tensors(const std::string & prefix, get_tensor_fn & get_tensor) { + } + + virtual ggml_tensor * build(ggml_context * ctx, ggml_tensor * cur) { + return cur; + } +}; + +struct mobilenet_g3n { + // mapping prefix to block, order is important + std::map blocks; + + // temporary variables + int stg_idx = 0; + int blk_idx = 0; + + mobilenet_g3n() { + // Stage 1: Edge Residuals + stg_idx = 0; blk_idx = 0; + add( new mobilenet_g3n_er(3, 128, 2)); + add(2, new mobilenet_g3n_er(3, 128, 1)); + + // Stage 2: Universal Inverted Residuals + stg_idx = 1; blk_idx = 0; + add( new mobilenet_g3n_uir(3, 5, 256, 2, 6.0f)); + add( new mobilenet_g3n_uir(5, 0, 256)); + add( new mobilenet_g3n_uir(3, 0, 256)); + add( new mobilenet_g3n_uir(5, 0, 256)); + add( new mobilenet_g3n_uir(3, 0, 256)); + + // Stage 3: Universal Inverted Residuals with Multi-Query Attention + stg_idx = 2; blk_idx = 0; + add( new mobilenet_g3n_uir(5, 5, 640, 2, 6.0f)); + add(7, new mobilenet_g3n_uir(5, 0, 640)); + add( new mobilenet_g3n_uir(0, 0, 640, 1, 1.0f)); + add(13, new mobilenet_g3n_mmqa(12, 64, 2), new mobilenet_g3n_uir(0, 0, 640, 1, 2.0f)); + add( new mobilenet_g3n_mmqa(12, 64, 2), new mobilenet_g3n_uir(0, 0, 640, 1, 2.0f, true)); + + // Stage 4: Universal Inverted Residuals with Multi-Query Attention + stg_idx = 3; blk_idx = 0; + add( new mobilenet_g3n_uir(5, 5, 1280, 2, 6.0f)); + add(18, new mobilenet_g3n_mmqa(16, 96, 1), new mobilenet_g3n_uir(0, 0, 1280, 1, 2.0f)); + add( new mobilenet_g3n_mmqa(16, 96, 1), new mobilenet_g3n_uir(0, 0, 1280, 1, 2.0f, true)); + } + + ~mobilenet_g3n() { + for (auto & blk : blocks) { + delete blk.first; + } + } + + void load_tensors(get_tensor_fn & get_tensor) { + for (auto blk : blocks) { + blk.first->load_tensors(blk.second, get_tensor); + } + } + + ggml_tensor * build(ggml_context * ctx, ggml_tensor * cur) { + return cur; + } + + void add(mobilenet_g3n_blk * blk) { + blocks.insert({ blk, string_format("v.mobilenet.%d.%d.", stg_idx, blk_idx++) }); + } + + void add(mobilenet_g3n_blk * blk0, mobilenet_g3n_blk * blk1) { + add(blk0); + add(blk1); + } + + void add(int count, mobilenet_g3n_blk * blk) { + for (int i = 0; i < count; ++i) { + add(blk); + } + } + + void add(int count, mobilenet_g3n_blk * blk0, mobilenet_g3n_blk * blk1) { + for (int i = 0; i < count; ++i) { + add(blk0, blk1); + } + } +}; From f8f73e8626ac84ee558aca5c5970e238a1bd93b4 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 22 Jun 2025 19:43:23 +0200 Subject: [PATCH 17/21] conversion script --- convert_hf_to_gguf.py | 49 ++++++++++++++++++++++++++++++++++++++- gguf-py/gguf/constants.py | 1 + 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 0e81f46716c..9de5db157d1 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1135,6 +1135,7 @@ class MmprojModel(ModelBase): preprocessor_config: dict[str, Any] global_config: dict[str, Any] + block_count: Any = None # will be set in __init__ n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"] has_vision_encoder: bool = True # by default @@ -1173,7 +1174,8 @@ def __init__(self, *args, **kwargs): # TODO @ngxson : this is a hack to support both vision and audio encoders have_multiple_encoders = self.has_audio_encoder and self.has_vision_encoder - self.block_count = 128 if have_multiple_encoders else self.find_hparam(self.n_block_keys, True) + if self.block_count is None: + self.block_count = 128 if have_multiple_encoders else self.find_hparam(self.n_block_keys, True) self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count) # load preprocessor config @@ -4416,6 +4418,51 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return super().modify_tensors(data_torch, name, bid) +@ModelBase.register("Gemma3nForConditionalGeneration") +class Gemma3NMmprojModel(MmprojModel): + has_audio_encoder = False # TODO + has_vision_encoder = True + block_count = 128 # dummy, unused + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.hparams["image_size"] = 768 + self.hparams["patch_size"] = 3 + # below are dummy values, unused + self.hparams["intermediate_size"] = 1 + self.hparams["n_layers"] = self.block_count + self.hparams["num_attention_heads"] = 0 + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GEMMA3N) + self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("layer_norm_eps", 1e-6)) + self.gguf_writer.add_vision_use_gelu(True) + + # def tensor_force_quant(self, name, new_name, bid, n_dims): + # del bid, new_name, n_dims # unused + # return False + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + if name.startswith("model.vision_tower.timm_model"): + # process vision tensors + name = name.replace(".gamma", ".weight") + name = name.replace("model.vision_tower.timm_model.", "v.mobilenet.") + + if "conv" in name or "attn" in name: + # check if we have 1x1 kernel (last 2 dims are 1x1) + if data_torch.dim() == 4 and data_torch.shape[-2:] == (1, 1): + # convert 4D conv with 1x1 kernel to 2D matrix for matmul operation + data_torch = data_torch.squeeze(-1).squeeze(-1) + + return [(name, data_torch)] # not using map_tensor_name here + + return [] # skip other tensors + + @ModelBase.register("Starcoder2ForCausalLM") class StarCoder2Model(TextModel): model_arch = gguf.MODEL_ARCH.STARCODER2 diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index bef090d31b5..419a8bf3338 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -2382,6 +2382,7 @@ def get_type(val: Any) -> GGUFValueType: class VisionProjectorType: GEMMA3 = "gemma3" + GEMMA3N = "gemma3n" IDEFICS3 = "idefics3" PIXTRAL = "pixtral" LLAMA4 = "llama4" From 96fd71dd6683557990088cb11bf5fc9281641fee Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 23 Jun 2025 00:39:49 +0200 Subject: [PATCH 18/21] wip --- convert_hf_to_gguf.py | 16 +- gguf-py/gguf/constants.py | 2 +- tools/mtmd/clip-impl.h | 4 + tools/mtmd/clip-mobilenet.h | 417 +++++++++++++++++++++++------------- tools/mtmd/clip.cpp | 36 ++++ 5 files changed, 312 insertions(+), 163 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 9de5db157d1..127c3e4cec5 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4429,14 +4429,14 @@ def __init__(self, *args, **kwargs): self.hparams["image_size"] = 768 self.hparams["patch_size"] = 3 # below are dummy values, unused - self.hparams["intermediate_size"] = 1 - self.hparams["n_layers"] = self.block_count + self.hparams["intermediate_size"] = 1 + self.hparams["n_layers"] = 0 self.hparams["num_attention_heads"] = 0 def set_gguf_parameters(self): super().set_gguf_parameters() hparams = self.hparams - self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GEMMA3N) + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GEMMA3NV) self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("layer_norm_eps", 1e-6)) self.gguf_writer.add_vision_use_gelu(True) @@ -4452,11 +4452,11 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter name = name.replace(".gamma", ".weight") name = name.replace("model.vision_tower.timm_model.", "v.mobilenet.") - if "conv" in name or "attn" in name: - # check if we have 1x1 kernel (last 2 dims are 1x1) - if data_torch.dim() == 4 and data_torch.shape[-2:] == (1, 1): - # convert 4D conv with 1x1 kernel to 2D matrix for matmul operation - data_torch = data_torch.squeeze(-1).squeeze(-1) + # if "conv" in name or "attn" in name: + # # check if we have 1x1 kernel (last 2 dims are 1x1) + # if data_torch.dim() == 4 and data_torch.shape[-2:] == (1, 1): + # # convert 4D conv with 1x1 kernel to 2D matrix for matmul operation + # data_torch = data_torch.squeeze(-1).squeeze(-1) return [(name, data_torch)] # not using map_tensor_name here diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 419a8bf3338..5433d08a7b6 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -2382,7 +2382,7 @@ def get_type(val: Any) -> GGUFValueType: class VisionProjectorType: GEMMA3 = "gemma3" - GEMMA3N = "gemma3n" + GEMMA3NV = "gemma3nv" # vision IDEFICS3 = "idefics3" PIXTRAL = "pixtral" LLAMA4 = "llama4" diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 62c936ed00f..4f45c22b016 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -123,6 +123,8 @@ enum projector_type { PROJECTOR_TYPE_GLM_EDGE, PROJECTOR_TYPE_QWEN2VL, PROJECTOR_TYPE_GEMMA3, + PROJECTOR_TYPE_GEMMA3NV, + //PROJECTOR_TYPE_GEMMA3NA, PROJECTOR_TYPE_IDEFICS3, PROJECTOR_TYPE_PIXTRAL, PROJECTOR_TYPE_QWEN25VL, @@ -143,6 +145,8 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_QWEN2VL, "qwen2vl_merger"}, { PROJECTOR_TYPE_QWEN25VL, "qwen2.5vl_merger"}, { PROJECTOR_TYPE_GEMMA3, "gemma3"}, + { PROJECTOR_TYPE_GEMMA3NV, "gemma3nv"}, + //{ PROJECTOR_TYPE_GEMMA3NA, "gemma3na"}, { PROJECTOR_TYPE_IDEFICS3, "idefics3"}, { PROJECTOR_TYPE_PIXTRAL, "pixtral"}, { PROJECTOR_TYPE_ULTRAVOX, "ultravox"}, diff --git a/tools/mtmd/clip-mobilenet.h b/tools/mtmd/clip-mobilenet.h index 3c95ba14e29..27cc3d5c6de 100644 --- a/tools/mtmd/clip-mobilenet.h +++ b/tools/mtmd/clip-mobilenet.h @@ -1,285 +1,394 @@ #pragma once -#include "clip.h" -#include "clip-impl.h" #include "ggml.h" +#include #include #include -#include +#include +#include -using get_tensor_fn = std::function; - -static ggml_tensor * conv2d(ggml_context * ctx, ggml_tensor * kernel, ggml_tensor * inp, int strides, bool depthwise = false) { - int p0 = 0; - int p1 = 0; - - { - const int kernel_size = kernel->ne[0]; +// mobilenet v5 implementation - auto compute_padding_length = [](int input_length, int kernel_length, int stride) { - int total_padding_length = (kernel_length - 1) - (input_length - 1) % stride; - int left_padding = total_padding_length / 2; - int right_padding = (total_padding_length + 1) / 2; - return std::make_pair(left_padding, right_padding); - }; +namespace mobilenet { - auto [left, right] = compute_padding_length(inp->ne[0], kernel_size, strides); - auto [top, bottom] = compute_padding_length(inp->ne[1], kernel_size, strides); +using get_tensor_fn = std::function; +using callback_fn = std::function; + +static std::string str_fmt(const char * fmt, ...) { + va_list ap; + va_list ap2; + va_start(ap, fmt); + va_copy(ap2, ap); + int size = vsnprintf(NULL, 0, fmt, ap); + GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT + std::vector buf(size + 1); + int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); + GGML_ASSERT(size2 == size); + va_end(ap2); + va_end(ap); + return std::string(buf.data(), buf.size()); +} - if (left > 0 && right > 0) { - p0 = std::min(left, right); - left -= p0; - right -= p0; - } +static std::string str_concat(const std::string & a, const std::string & b) { + return str_fmt("%s%s", a.c_str(), b.c_str()); // the "+" operator does not work, why? +} - if (top > 0 && bottom > 0) { - p1 = std::min(top, bottom); - top -= p1; - bottom -= p1; - } +struct v5_blk { + virtual ggml_tensor * build(ggml_context * ctx, ggml_tensor * cur, callback_fn & cb) = 0; + virtual void load_tensors(const std::string & prefix, get_tensor_fn & get_tensor) = 0; + virtual ~v5_blk() = default; +}; - GGML_ASSERT(left == 0 && top == 0); +enum conv_type { + CONV_TYPE_NORMAL, // ggml_conv_2d + CONV_TYPE_POINTWISE, // ggml_mul_mat + CONV_TYPE_DEPTHWISE, // ggml_conv_2d_dw +}; - if (right != 0 || bottom != 0) { - inp = ggml_pad(ctx, inp, right, bottom, 0, 0); - } +static ggml_tensor * rms_norm_act_2d( + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * scale, + int n_groups, + bool apply_act, + callback_fn & cb) { + cur = ggml_group_norm(ctx, cur, n_groups, 1e-6f); + cb(cur, "rms_norm_act.norm", -1); + if (scale != nullptr) { + cur = ggml_mul(ctx, cur, ggml_reshape_3d(ctx, scale, 1, 1, scale->ne[0])); + cb(cur, "rms_norm_act.norm_scaled", -1); } - - ggml_tensor * cur; - - if (depthwise) { - cur = ggml_conv_2d_dw(ctx, - kernel, inp, - strides, strides, - p0, p1, 1, 1); - } else { - cur = ggml_conv_2d(ctx, - kernel, inp, - strides, strides, - p0, p1, 1, 1); + if (apply_act) { + cur = ggml_gelu(ctx, cur); + cb(cur, "rms_norm_act.gelu", -1); } - return cur; } -struct mobilenet_g3n_blk { - virtual ggml_tensor * build(ggml_context * ctx, ggml_tensor * cur) = 0; - virtual void load_tensors(const std::string & prefix, get_tensor_fn & get_tensor) = 0; - virtual ~mobilenet_g3n_blk() = default; -}; - // ConvNormAct -struct mobilenet_g3n_cna : mobilenet_g3n_blk { +struct v5_cna : v5_blk { + conv_type type = CONV_TYPE_NORMAL; int kernel_size = 0; int stride = 1; int dilation = 1; - int filters = 0; + int padding = 0; + bool apply_act = false; float expand_ratio = 1.0f; + int in_chs = 0; + int out_chs = 0; // aka filters + ggml_tensor * norm = nullptr; ggml_tensor * conv = nullptr; virtual void load_tensors(const std::string & prefix, get_tensor_fn & get_tensor) { - std::string tmp; - tmp = prefix + "bn.weight"; - norm = get_tensor(tmp); - tmp = prefix + "conv.weight"; - conv = get_tensor(tmp); + norm = get_tensor(str_concat(prefix, ".bn.weight")); + conv = get_tensor(str_concat(prefix, ".conv.weight")); + + if (type == CONV_TYPE_POINTWISE) { + GGML_ASSERT(kernel_size == 1); + GGML_ASSERT(stride == 1); + GGML_ASSERT(padding == 0); + GGML_ASSERT(dilation == 1); + GGML_ASSERT(conv->ne[0] == 1 && conv->ne[1] == 1); + } else { + GGML_ASSERT(conv->ne[0] == kernel_size && conv->ne[1] == kernel_size); + GGML_ASSERT(conv->ne[3] == norm->ne[0]); // norm size matches + } + + in_chs = conv->ne[2]; + out_chs = conv->ne[3]; } - virtual ggml_tensor * build(ggml_context * ctx, ggml_tensor * cur) { - cur = conv2d(ctx, conv, cur, stride, true); - cur = ggml_group_norm(ctx, cur, std::min(32, filters * expand_ratio / 4), 1e-6f); - cur = ggml_mul(ctx, cur, ggml_reshape_3d(ctx, norm, 1, 1, norm->ne[0])); + virtual ggml_tensor * build(ggml_context * ctx, ggml_tensor * cur, callback_fn & cb) { + if (type == CONV_TYPE_POINTWISE) { + cur = ggml_conv_2d(ctx, conv, cur, 1, 1, 0, 0, 1, 1); + cb(cur, "conv_norm_act.pw", -1); + } else if (type == CONV_TYPE_DEPTHWISE) { + cur = ggml_conv_2d_dw(ctx, conv, cur, + stride, stride, + padding, padding, + dilation, dilation); + cb(cur, "conv_norm_act.dw", -1); + } else { + cur = ggml_conv_2d(ctx, conv, cur, + stride, stride, + padding, padding, + dilation, dilation); + cb(cur, "conv_norm_act", -1); + } + + cur = rms_norm_act_2d(ctx, cur, norm, out_chs, apply_act, cb); + return cur; } }; // EdgeResidual -struct mobilenet_g3n_er : mobilenet_g3n_blk { +struct v5_er : v5_blk { int kernel_size = 0; int stride = 1; int filters = 0; - ggml_tensor * norm1 = nullptr; - ggml_tensor * norm2 = nullptr; + ggml_tensor * norm1 = nullptr; + ggml_tensor * norm2 = nullptr; ggml_tensor * conv_exp = nullptr; ggml_tensor * conv_pwl = nullptr; - mobilenet_g3n_er(int kernel_size, int filters, int stride) : + v5_er(int kernel_size, int filters, int stride) : kernel_size(kernel_size), stride(stride), filters(filters) {} virtual void load_tensors(const std::string & prefix, get_tensor_fn & get_tensor) { - std::string tmp; - tmp = prefix + "bn1.weight"; - norm1 = get_tensor(tmp); - tmp = prefix + "bn2.weight"; - norm2 = get_tensor(tmp); - tmp = prefix + "conv_exp.weight"; - conv_exp = get_tensor(tmp); - tmp = prefix + "conv_pwl.weight"; - conv_pwl = get_tensor(tmp); + norm1 = get_tensor(str_concat(prefix, ".bn1.weight")); + norm2 = get_tensor(str_concat(prefix, ".bn2.weight")); + conv_exp = get_tensor(str_concat(prefix, ".conv_exp.weight")); + conv_pwl = get_tensor(str_concat(prefix, ".conv_pwl.weight")); + + GGML_ASSERT(ggml_n_dims(conv_exp) == 4); // expected 4D tensor + GGML_ASSERT(ggml_n_dims(conv_pwl) == 4); // expected 4D tensor + GGML_ASSERT(conv_exp->ne[0] == kernel_size && conv_exp->ne[1] == kernel_size); + GGML_ASSERT(conv_pwl->ne[0] == 1 && conv_pwl->ne[1] == 1); } - virtual ggml_tensor * build(ggml_context * ctx, ggml_tensor * cur) { + virtual ggml_tensor * build(ggml_context * ctx, ggml_tensor * cur, callback_fn & cb) { + int padding = (kernel_size - 1) / 2; + cur = ggml_conv_2d(ctx, conv_exp, cur, + stride, stride, + padding, padding, + 1, 1); + cb(cur, "edge_residual.conv_exp", -1); + + int mid_chs = conv_exp->ne[3]; + cur = rms_norm_act_2d(ctx, cur, norm1, mid_chs, true, cb); + cb(cur, "edge_residual.norm1", -1); + + cur = ggml_conv_2d(ctx, conv_pwl, cur, 1, 1, 0, 0, 1, 1); + cb(cur, "edge_residual.conv_pwl", -1); + + int out_chs = conv_pwl->ne[1]; + cur = rms_norm_act_2d(ctx, cur, norm2, out_chs, false, cb); + cb(cur, "edge_residual.norm2", -1); + return cur; } }; // UniversalInvertedResidual -struct mobilenet_g3n_uir : mobilenet_g3n_blk { - int start_dw_kernel_size = 0; - int mid_dw_kernel_size = 0; +struct v5_uir : v5_blk { + int dw_kernel_size_start = 0; + int dw_kernel_size_mid = 0; bool multiscale = false; ggml_tensor * layer_scale = nullptr; - mobilenet_g3n_cna dw_start; - mobilenet_g3n_cna dw_mid; - mobilenet_g3n_cna dw_end; - mobilenet_g3n_cna dw_proj; + v5_cna dw_start; + v5_cna pw_exp; + v5_cna dw_mid; + v5_cna pw_proj; - mobilenet_g3n_uir(int start_dw_kernel_size, int mid_dw_kernel_size, int filters, int stride = 1, float expand_ratio = 4.0f, bool multiscale = false) : - start_dw_kernel_size(start_dw_kernel_size), - mid_dw_kernel_size(mid_dw_kernel_size), + v5_uir(int dw_kernel_size_start, int dw_kernel_size_mid, int filters, int stride = 1, float expand_ratio = 4.0f, bool multiscale = false) : + dw_kernel_size_start(dw_kernel_size_start), + dw_kernel_size_mid(dw_kernel_size_mid), multiscale(multiscale) { - dw_start.stride = stride; - dw_start.filters = filters; - dw_start.expand_ratio = expand_ratio; + GGML_UNUSED(filters); + GGML_UNUSED(expand_ratio); - dw_mid.stride = 1; - dw_mid.filters = filters; - dw_mid.expand_ratio = expand_ratio; + dw_start.type = CONV_TYPE_DEPTHWISE; + dw_start.kernel_size = dw_kernel_size_start; + dw_start.stride = !dw_kernel_size_mid ? stride : 1; + dw_start.padding = (dw_kernel_size_start - 1) / 2; - dw_end.stride = 1; - dw_end.filters = filters; - dw_end.expand_ratio = expand_ratio; + pw_exp.type = CONV_TYPE_POINTWISE; + pw_exp.kernel_size = 1; - dw_proj.stride = 1; - dw_proj.filters = filters; - dw_proj.expand_ratio = 1.0f; // projection does not expand - } + dw_mid.type = CONV_TYPE_DEPTHWISE; + dw_mid.kernel_size = dw_kernel_size_mid; + dw_mid.stride = 1; + dw_mid.padding = (dw_kernel_size_mid - 1) / 2; - virtual void load_tensors(const std::string & prefix, get_tensor_fn & get_tensor) { - dw_start.load_tensors(prefix + "dw_start.", get_tensor); - dw_mid.load_tensors(prefix + "dw_mid.", get_tensor); - dw_end.load_tensors(prefix + "dw_end.", get_tensor); - dw_proj.load_tensors(prefix + "dw_proj.", get_tensor); - layer_scale = get_tensor(prefix + "layer_scale.weight"); + pw_proj.type = CONV_TYPE_POINTWISE; + pw_proj.kernel_size = 1; } - virtual ggml_tensor * build(ggml_context * ctx, ggml_tensor * cur) { - if (dw_start.conv) { - cur = dw_start.build(ctx, cur); + virtual void load_tensors(const std::string & prefix, get_tensor_fn & get_tensor) { + if (dw_kernel_size_start) { + dw_start.load_tensors(str_concat(prefix, ".dw_start"), get_tensor); } - if (dw_mid.conv) { - cur = dw_mid.build(ctx, cur); + pw_exp.load_tensors(str_concat(prefix, ".pw_exp"), get_tensor); + if (dw_kernel_size_mid) { + dw_mid.load_tensors(str_concat(prefix, ".dw_mid"), get_tensor); } - if (dw_end.conv) { - cur = dw_end.build(ctx, cur); + pw_proj.load_tensors(str_concat(prefix, ".pw_proj"), get_tensor); + layer_scale = get_tensor(str_concat(prefix, ".layer_scale.weight")); + } + + virtual ggml_tensor * build(ggml_context * ctx, ggml_tensor * cur, callback_fn & cb) { + if (dw_kernel_size_start) { + cur = dw_start.build(ctx, cur, cb); } - if (dw_proj.conv) { - cur = dw_proj.build(ctx, cur); + cur = pw_exp.build(ctx, cur, cb); + if (dw_kernel_size_mid) { + cur = dw_mid.build(ctx, cur, cb); } - + cur = pw_proj.build(ctx, cur, cb); cur = ggml_mul(ctx, cur, ggml_reshape_3d(ctx, layer_scale, 1, 1, layer_scale->ne[0])); - return cur; } }; // MultiQueryAttentionBlock -struct mobilenet_g3n_mmqa : mobilenet_g3n_blk { +struct v5_mmqa : v5_blk { int num_heads = 0; int kv_strides = 0; int kv_dim = 0; bool mmqa_avg_pool_kv = false; bool multiscale = false; - mobilenet_g3n_mmqa(int num_heads, int kv_dim, int kv_strides, + v5_mmqa(int num_heads, int kv_dim, int kv_strides, bool mmqa_avg_pool_kv = false, bool multiscale = false) : - num_heads(num_heads), kv_dim(kv_dim), kv_strides(kv_strides), + num_heads(num_heads), kv_strides(kv_strides), kv_dim(kv_dim), mmqa_avg_pool_kv(mmqa_avg_pool_kv), multiscale(multiscale) {} virtual void load_tensors(const std::string & prefix, get_tensor_fn & get_tensor) { + // TODO + } + + virtual ggml_tensor * build(ggml_context * ctx, ggml_tensor * cur, callback_fn & cb) { + // TODO + return cur; + } +}; + +// MobileNetV5MultiScaleFusionAdapter +struct v5_msfa : v5_blk { + v5_cna pw_exp; + v5_cna pw_proj; + + ggml_tensor * norm; + + v5_msfa() { + pw_exp.type = CONV_TYPE_POINTWISE; + pw_exp.kernel_size = 1; + + pw_proj.type = CONV_TYPE_POINTWISE; + pw_proj.kernel_size = 1; + } + + virtual void load_tensors(const std::string & prefix, get_tensor_fn & get_tensor) { + pw_exp .load_tensors(str_concat(prefix, ".ffn.pw_exp"), get_tensor); + pw_proj.load_tensors(str_concat(prefix, ".ffn.pw_proj"), get_tensor); + norm = get_tensor(str_concat(prefix, ".norm.weight")); } - virtual ggml_tensor * build(ggml_context * ctx, ggml_tensor * cur) { + virtual ggml_tensor * build(ggml_context * ctx, ggml_tensor * cur, callback_fn & cb) { + cur = pw_exp .build(ctx, cur, cb); + cb(cur, "msfa.ffn.pw_exp.output", -1); + cur = pw_proj.build(ctx, cur, cb); + cb(cur, "msfa.ffn.pw_proj.output", -1); + cur = ggml_mul(ctx, cur, ggml_reshape_3d(ctx, norm, 1, 1, norm->ne[0])); + cb(cur, "msfa.norm", -1); return cur; } }; -struct mobilenet_g3n { +struct v5_model { + v5_cna conv_stem; // input + v5_msfa msfa; // output + // mapping prefix to block, order is important - std::map blocks; + std::vector> blocks; // temporary variables int stg_idx = 0; int blk_idx = 0; - mobilenet_g3n() { + v5_model() { + conv_stem.type = CONV_TYPE_NORMAL; + conv_stem.kernel_size = 3; + conv_stem.stride = 2; + conv_stem.padding = 1; + } + + ~v5_model() { + for (auto & blk : blocks) { + delete blk.first; + } + } + + void load(get_tensor_fn & get_tensor) { + // Convolution Stem + conv_stem.load_tensors("v.mobilenet.conv_stem", get_tensor); + // Stage 1: Edge Residuals stg_idx = 0; blk_idx = 0; - add( new mobilenet_g3n_er(3, 128, 2)); - add(2, new mobilenet_g3n_er(3, 128, 1)); + add( new v5_er(3, 128, 2)); + add(2, new v5_er(3, 128, 1)); // Stage 2: Universal Inverted Residuals stg_idx = 1; blk_idx = 0; - add( new mobilenet_g3n_uir(3, 5, 256, 2, 6.0f)); - add( new mobilenet_g3n_uir(5, 0, 256)); - add( new mobilenet_g3n_uir(3, 0, 256)); - add( new mobilenet_g3n_uir(5, 0, 256)); - add( new mobilenet_g3n_uir(3, 0, 256)); + add( new v5_uir(3, 5, 256, 2, 6.0f)); + add( new v5_uir(5, 0, 256)); + add( new v5_uir(3, 0, 256)); + add( new v5_uir(5, 0, 256)); + add( new v5_uir(3, 0, 256)); // Stage 3: Universal Inverted Residuals with Multi-Query Attention stg_idx = 2; blk_idx = 0; - add( new mobilenet_g3n_uir(5, 5, 640, 2, 6.0f)); - add(7, new mobilenet_g3n_uir(5, 0, 640)); - add( new mobilenet_g3n_uir(0, 0, 640, 1, 1.0f)); - add(13, new mobilenet_g3n_mmqa(12, 64, 2), new mobilenet_g3n_uir(0, 0, 640, 1, 2.0f)); - add( new mobilenet_g3n_mmqa(12, 64, 2), new mobilenet_g3n_uir(0, 0, 640, 1, 2.0f, true)); + add( new v5_uir(5, 5, 640, 2, 6.0f)); + add(7, new v5_uir(5, 0, 640)); + add( new v5_uir(0, 0, 640, 1, 1.0f)); + add(13, new v5_mmqa(12, 64, 2), new v5_uir(0, 0, 640, 1, 2.0f)); + add( new v5_mmqa(12, 64, 2), new v5_uir(0, 0, 640, 1, 2.0f, true)); // Stage 4: Universal Inverted Residuals with Multi-Query Attention stg_idx = 3; blk_idx = 0; - add( new mobilenet_g3n_uir(5, 5, 1280, 2, 6.0f)); - add(18, new mobilenet_g3n_mmqa(16, 96, 1), new mobilenet_g3n_uir(0, 0, 1280, 1, 2.0f)); - add( new mobilenet_g3n_mmqa(16, 96, 1), new mobilenet_g3n_uir(0, 0, 1280, 1, 2.0f, true)); - } - - ~mobilenet_g3n() { - for (auto & blk : blocks) { - delete blk.first; - } - } + add( new v5_uir(5, 5, 1280, 2, 6.0f)); + add(18, new v5_mmqa(16, 96, 1), new v5_uir(0, 0, 1280, 1, 2.0f)); + add( new v5_mmqa(16, 96, 1), new v5_uir(0, 0, 1280, 1, 2.0f, true)); - void load_tensors(get_tensor_fn & get_tensor) { for (auto blk : blocks) { blk.first->load_tensors(blk.second, get_tensor); } + + // Output + msfa.load_tensors("v.mobilenet.msfa", get_tensor); } - ggml_tensor * build(ggml_context * ctx, ggml_tensor * cur) { + ggml_tensor * build(ggml_context * ctx, ggml_tensor * cur, callback_fn & cb) { + cur = conv_stem.build(ctx, cur, cb); + cb(cur, "conv_stem.output", -1); + + for (auto & blk : blocks) { + cur = blk.first->build(ctx, cur, cb); + cb(cur, str_concat(blk.second, ".output").c_str(), -1); + } + + cur = msfa.build(ctx, cur, cb); + cb(cur, "msfa.output", -1); + return cur; } - void add(mobilenet_g3n_blk * blk) { - blocks.insert({ blk, string_format("v.mobilenet.%d.%d.", stg_idx, blk_idx++) }); + void add(v5_blk * blk) { + blocks.emplace_back(blk, str_fmt("v.mobilenet.blocks.%d.%d", stg_idx, blk_idx++)); } - void add(mobilenet_g3n_blk * blk0, mobilenet_g3n_blk * blk1) { + void add(v5_blk * blk0, v5_blk * blk1) { add(blk0); add(blk1); } - void add(int count, mobilenet_g3n_blk * blk) { + void add(int count, v5_blk * blk) { for (int i = 0; i < count; ++i) { add(blk); } } - void add(int count, mobilenet_g3n_blk * blk0, mobilenet_g3n_blk * blk1) { + void add(int count, v5_blk * blk0, v5_blk * blk1) { for (int i = 0; i < count; ++i) { add(blk0, blk1); } } }; + +} // namespace mobilenet diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 30283d6f1f0..d5c36df419d 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -4,6 +4,7 @@ // Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch #include "clip.h" #include "clip-impl.h" +#include "clip-mobilenet.h" #include "ggml.h" #include "ggml-cpp.h" #include "ggml-cpu.h" @@ -354,6 +355,9 @@ struct clip_model { ggml_tensor * conv1d_2_b = nullptr; ggml_tensor * mm_norm_pre_w = nullptr; ggml_tensor * mm_norm_mid_w = nullptr; + + // mobilenetv5 (gemma3n) + mobilenet::v5_model mobilenetv5; }; struct clip_ctx { @@ -1536,12 +1540,25 @@ struct clip_graph { return gf; } + ggml_cgraph * build_gemma3n() { + ggml_tensor * cur = build_inp_raw(); + mobilenet::callback_fn fn_cb = std::bind(&clip_graph::cb, + this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3); + + ctx->model.mobilenetv5.build(ctx0, cur, fn_cb); + ggml_build_forward_expand(gf, cur); + + return gf; + } + private: // // utility functions // void cb(ggml_tensor * cur0, const char * name, int il) const { + printf("cb: %s, shape = [%lld, %lld, %lld, %lld]\n", + name, cur0->ne[0], cur0->ne[1], cur0->ne[2], cur0->ne[3]); if (ctx->debug_graph) { ggml_tensor * cur = ggml_cpy(ctx0, cur0, ggml_dup_tensor(ctx0, cur0)); std::string cur_name = il >= 0 ? std::string(name) + "_" + std::to_string(il) : name; @@ -1981,6 +1998,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { res = graph.build_whisper_enc(); } break; + case PROJECTOR_TYPE_GEMMA3NV: + { + res = graph.build_gemma3n(); + } break; default: { res = graph.build_llava(); @@ -2497,6 +2518,14 @@ struct clip_model_loader { model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ); model.mm_soft_emb_norm_w = get_tensor(TN_MM_SOFT_EMB_N); } break; + case PROJECTOR_TYPE_GEMMA3NV: + { + mobilenet::get_tensor_fn fn_get_tensor = [&](const std::string & name) { + // printf(">> %s\n", name.c_str()); + return get_tensor(name, true); + }; + model.mobilenetv5.load(fn_get_tensor); + } break; case PROJECTOR_TYPE_IDEFICS3: { model.projection = get_tensor(TN_MM_PROJECTOR); @@ -3571,6 +3600,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im // another divide by 2 because of nn.AvgPool1d(2, stride=2) n_patches_sq = img->nx / 4; } break; + case PROJECTOR_TYPE_GEMMA3NV: + { + n_patches_sq = 256; // vision_soft_tokens_per_image + } break; default: GGML_ABORT("unsupported projector type"); } @@ -3971,6 +4004,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima set_input_i32("patches", patches); } break; case PROJECTOR_TYPE_GEMMA3: + case PROJECTOR_TYPE_GEMMA3NV: case PROJECTOR_TYPE_IDEFICS3: case PROJECTOR_TYPE_INTERNVL: case PROJECTOR_TYPE_QWEN2A: @@ -4072,6 +4106,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->model.mm_1_b->ne[0]; case PROJECTOR_TYPE_GEMMA3: return ctx->model.mm_input_proj_w->ne[0]; + case PROJECTOR_TYPE_GEMMA3NV: + return 2048; // TODO: read this from tensor shape case PROJECTOR_TYPE_IDEFICS3: return ctx->model.projection->ne[1]; case PROJECTOR_TYPE_ULTRAVOX: From 325cbe761c274c2e50f5cd96e53acb0bc2bbb25a Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 23 Jun 2025 14:03:21 +0200 Subject: [PATCH 19/21] add MultiQueryAttentionBlock --- tools/mtmd/clip-mobilenet.h | 128 ++++++++++++++++++++++++++++++++++-- tools/mtmd/clip.cpp | 3 +- 2 files changed, 125 insertions(+), 6 deletions(-) diff --git a/tools/mtmd/clip-mobilenet.h b/tools/mtmd/clip-mobilenet.h index 27cc3d5c6de..be776d3dc78 100644 --- a/tools/mtmd/clip-mobilenet.h +++ b/tools/mtmd/clip-mobilenet.h @@ -179,13 +179,14 @@ struct v5_uir : v5_blk { int dw_kernel_size_start = 0; int dw_kernel_size_mid = 0; bool multiscale = false; - ggml_tensor * layer_scale = nullptr; v5_cna dw_start; v5_cna pw_exp; v5_cna dw_mid; v5_cna pw_proj; + ggml_tensor * layer_scale = nullptr; + v5_uir(int dw_kernel_size_start, int dw_kernel_size_mid, int filters, int stride = 1, float expand_ratio = 4.0f, bool multiscale = false) : dw_kernel_size_start(dw_kernel_size_start), dw_kernel_size_mid(dw_kernel_size_mid), @@ -244,17 +245,126 @@ struct v5_mmqa : v5_blk { bool mmqa_avg_pool_kv = false; bool multiscale = false; + ggml_tensor * k_down_conv = nullptr; + ggml_tensor * k_norm = nullptr; + ggml_tensor * k_proj = nullptr; + ggml_tensor * q_proj = nullptr; + ggml_tensor * v_down_conv = nullptr; + ggml_tensor * v_norm = nullptr; + ggml_tensor * v_proj = nullptr; + ggml_tensor * o_proj = nullptr; + ggml_tensor * layer_scale = nullptr; + ggml_tensor * norm = nullptr; + v5_mmqa(int num_heads, int kv_dim, int kv_strides, bool mmqa_avg_pool_kv = false, bool multiscale = false) : num_heads(num_heads), kv_strides(kv_strides), kv_dim(kv_dim), mmqa_avg_pool_kv(mmqa_avg_pool_kv), multiscale(multiscale) {} virtual void load_tensors(const std::string & prefix, get_tensor_fn & get_tensor) { - // TODO + if (kv_strides > 1) { + k_down_conv = get_tensor(str_concat(prefix, ".attn.key.down_conv.weight")); + v_down_conv = get_tensor(str_concat(prefix, ".attn.value.down_conv.weight")); + k_norm = get_tensor(str_concat(prefix, ".attn.key.norm.weight")); + v_norm = get_tensor(str_concat(prefix, ".attn.value.norm.weight")); + } + k_proj = get_tensor(str_concat(prefix, ".attn.key.proj.weight")); + q_proj = get_tensor(str_concat(prefix, ".attn.query.proj.weight")); + v_proj = get_tensor(str_concat(prefix, ".attn.value.proj.weight")); + o_proj = get_tensor(str_concat(prefix, ".attn.output.proj.weight")); + layer_scale = get_tensor(str_concat(prefix, ".layer_scale.weight")); + norm = get_tensor(str_concat(prefix, ".norm.weight")); } virtual ggml_tensor * build(ggml_context * ctx, ggml_tensor * cur, callback_fn & cb) { - // TODO + ggml_tensor * k = nullptr; + ggml_tensor * q = nullptr; + ggml_tensor * v = nullptr; + + if (kv_strides > 1) { + k = ggml_conv_2d_dw(ctx, k_down_conv, cur, kv_strides, kv_strides, 0, 0, 1, 1); + cb(k, "mmqa.k_down_conv", -1); + k = rms_norm_act_2d(ctx, k, k_norm, kv_dim, false, cb); + k = ggml_conv_2d(ctx, k_proj, k, 1, 1, 0, 0, 1, 1); + cb(k, "mmqa.k_proj", -1); + } else { + k = ggml_conv_2d(ctx, k_proj, cur, 1, 1, 0, 0, 1, 1); + cb(k, "mmqa.k_proj", -1); + } + + if (kv_strides > 1) { + v = ggml_conv_2d_dw(ctx, v_down_conv, cur, kv_strides, kv_strides, 0, 0, 1, 1); + cb(v, "mmqa.v_down_conv", -1); + v = rms_norm_act_2d(ctx, v, v_norm, kv_dim, false, cb); + v = ggml_conv_2d(ctx, v_proj, v, 1, 1, 0, 0, 1, 1); + cb(v, "mmqa.v_proj", -1); + } else { + v = ggml_conv_2d(ctx, v_proj, cur, 1, 1, 0, 0, 1, 1); + cb(v, "mmqa.v_proj", -1); + } + + q = ggml_conv_2d(ctx, q_proj, cur, 1, 1, 0, 0, 1, 1); + cb(q, "mmqa.q_proj", -1); + + // reshape k, v, q + + q = ggml_reshape_3d(ctx, q, kv_dim, num_heads, q->ne[0] * q->ne[1]); + q = ggml_permute(ctx, q, 0, 2, 1, 3); + cb(q, "mmqa.q_reshape", -1); + + k = ggml_cont(ctx, ggml_permute(ctx, k, 1, 2, 0, 3)); + k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1] * k->ne[2]); + cb(k, "mmqa.k_reshape", -1); + + v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); + v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1] * v->ne[2]); + v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 0, 2, 3)); + cb(v, "mmqa.v_reshape", -1); + + float kq_scale = 1.0f / std::sqrt(static_cast(kv_dim)); + build_attn(ctx, o_proj, q, k, v, nullptr, kq_scale, cb); + cb(cur, "mmqa.attn_output", -1); + + return cur; + } + + ggml_tensor * build_attn( + ggml_context * ctx0, + ggml_tensor * wo, + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * kq_mask, + float kq_scale, + callback_fn & cb) const { + ggml_tensor * cur; + + { + const auto n_tokens = q->ne[1]; + const auto n_head = q->ne[2]; + + ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, 0.0f); + + ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); + cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); + } + + cb(cur, "kqv_out", -1); + + { + int h = std::sqrt(cur->ne[1]); + int w = h; + int c = cur->ne[0]; + cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 0, 2, 3)); + cur = ggml_reshape_3d(ctx0, cur, w, h, c); + cb(cur, "kqv_out_reshape", -1); + } + + // output projection + cur = ggml_conv_2d(ctx0, wo, cur, 1, 1, 0, 0, 1, 1); + return cur; } }; @@ -281,12 +391,20 @@ struct v5_msfa : v5_blk { } virtual ggml_tensor * build(ggml_context * ctx, ggml_tensor * cur, callback_fn & cb) { - cur = pw_exp .build(ctx, cur, cb); + int target_res = pw_exp.conv->ne[2]; + + cur = ggml_upscale_ext(ctx, cur, + cur->ne[0], cur->ne[1], target_res, 1, GGML_SCALE_MODE_NEAREST); + cb(cur, "msfa.ffn.pw_exp.upscale", -1); + + cur = pw_exp.build(ctx, cur, cb); cb(cur, "msfa.ffn.pw_exp.output", -1); cur = pw_proj.build(ctx, cur, cb); cb(cur, "msfa.ffn.pw_proj.output", -1); + cur = ggml_mul(ctx, cur, ggml_reshape_3d(ctx, norm, 1, 1, norm->ne[0])); cb(cur, "msfa.norm", -1); + return cur; } }; @@ -295,7 +413,7 @@ struct v5_model { v5_cna conv_stem; // input v5_msfa msfa; // output - // mapping prefix to block, order is important + // mapping block to prefix std::vector> blocks; // temporary variables diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index d5c36df419d..260304fb911 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -1545,7 +1545,7 @@ struct clip_graph { mobilenet::callback_fn fn_cb = std::bind(&clip_graph::cb, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3); - ctx->model.mobilenetv5.build(ctx0, cur, fn_cb); + cur = ctx->model.mobilenetv5.build(ctx0, cur, fn_cb); ggml_build_forward_expand(gf, cur); return gf; @@ -3380,6 +3380,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str } else if (ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE || ctx->proj_type() == PROJECTOR_TYPE_GEMMA3 + || ctx->proj_type() == PROJECTOR_TYPE_GEMMA3NV || ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3 || ctx->proj_type() == PROJECTOR_TYPE_INTERNVL // TODO @ngxson : support dynamic resolution ) { From 3161c0c8ed3ffd120dc42f9de8b1a8c367861275 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 25 Jun 2025 11:11:18 +0200 Subject: [PATCH 20/21] fix rms_norm_act_2d --- tools/mtmd/clip-mobilenet.h | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tools/mtmd/clip-mobilenet.h b/tools/mtmd/clip-mobilenet.h index be776d3dc78..a346ebad1e1 100644 --- a/tools/mtmd/clip-mobilenet.h +++ b/tools/mtmd/clip-mobilenet.h @@ -49,18 +49,20 @@ enum conv_type { static ggml_tensor * rms_norm_act_2d( ggml_context * ctx, ggml_tensor * cur, - ggml_tensor * scale, + ggml_tensor * weight, int n_groups, bool apply_act, callback_fn & cb) { - cur = ggml_group_norm(ctx, cur, n_groups, 1e-6f); - cb(cur, "rms_norm_act.norm", -1); - if (scale != nullptr) { - cur = ggml_mul(ctx, cur, ggml_reshape_3d(ctx, scale, 1, 1, scale->ne[0])); - cb(cur, "rms_norm_act.norm_scaled", -1); + // TODO @ngxson : prevent using ggml_cont here + cur = ggml_cont(ctx, ggml_permute(ctx, cur, 1, 2, 0, 3)); // first dim is now channels + cur = ggml_rms_norm(ctx, cur, 1e-6f); + cur = ggml_cont(ctx, ggml_permute(ctx, cur, 2, 0, 1, 3)); // back to original order + if (weight != nullptr) { + cur = ggml_mul(ctx, cur, ggml_reshape_3d(ctx, weight, 1, 1, weight->ne[0])); + cb(cur, "rms_norm_act.norm_w", -1); } if (apply_act) { - cur = ggml_gelu(ctx, cur); + cur = ggml_gelu_erf(ctx, cur); cb(cur, "rms_norm_act.gelu", -1); } return cur; From 59e60910ccac9ac5c797d0b082366a3d81757314 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 25 Jun 2025 11:56:25 +0200 Subject: [PATCH 21/21] wip --- tools/mtmd/clip-mobilenet.h | 35 ++++++++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/tools/mtmd/clip-mobilenet.h b/tools/mtmd/clip-mobilenet.h index a346ebad1e1..cc305cbba3f 100644 --- a/tools/mtmd/clip-mobilenet.h +++ b/tools/mtmd/clip-mobilenet.h @@ -46,6 +46,20 @@ enum conv_type { CONV_TYPE_DEPTHWISE, // ggml_conv_2d_dw }; +static ggml_tensor * conv2d_pw(ggml_context * ctx, ggml_tensor * a, ggml_tensor * b) { + GGML_ASSERT(a->ne[0] == 1 && a->ne[1] == 1); // pointwise conv expects 1x1 kernel + // return ggml_conv_2d(ctx, a, b, 1, 1, 0, 0, 1, 1); + int w = b->ne[0]; + int h = b->ne[1]; + int c = b->ne[2]; + GGML_ASSERT(b->ne[3] == 1); // not support batch size > 1 for now + ggml_tensor * cur = ggml_cont(ctx, ggml_permute(ctx, b, 1, 2, 0, 3)); // first dim is now channels + a = ggml_reshape_2d(ctx, a, a->ne[2], a->ne[3]); + cur = ggml_mul_mat(ctx, a, cur); + cur = ggml_cont(ctx, ggml_permute(ctx, cur, 2, 0, 1, 3)); // back to original order + return cur; +} + static ggml_tensor * rms_norm_act_2d( ggml_context * ctx, ggml_tensor * cur, @@ -53,6 +67,7 @@ static ggml_tensor * rms_norm_act_2d( int n_groups, bool apply_act, callback_fn & cb) { + GGML_UNUSED(n_groups); // also unused in python impl // TODO @ngxson : prevent using ggml_cont here cur = ggml_cont(ctx, ggml_permute(ctx, cur, 1, 2, 0, 3)); // first dim is now channels cur = ggml_rms_norm(ctx, cur, 1e-6f); @@ -105,7 +120,7 @@ struct v5_cna : v5_blk { virtual ggml_tensor * build(ggml_context * ctx, ggml_tensor * cur, callback_fn & cb) { if (type == CONV_TYPE_POINTWISE) { - cur = ggml_conv_2d(ctx, conv, cur, 1, 1, 0, 0, 1, 1); + cur = conv2d_pw(ctx, conv, cur); cb(cur, "conv_norm_act.pw", -1); } else if (type == CONV_TYPE_DEPTHWISE) { cur = ggml_conv_2d_dw(ctx, conv, cur, @@ -165,7 +180,7 @@ struct v5_er : v5_blk { cur = rms_norm_act_2d(ctx, cur, norm1, mid_chs, true, cb); cb(cur, "edge_residual.norm1", -1); - cur = ggml_conv_2d(ctx, conv_pwl, cur, 1, 1, 0, 0, 1, 1); + cur = conv2d_pw(ctx, conv_pwl, cur); cb(cur, "edge_residual.conv_pwl", -1); int out_chs = conv_pwl->ne[1]; @@ -287,10 +302,10 @@ struct v5_mmqa : v5_blk { k = ggml_conv_2d_dw(ctx, k_down_conv, cur, kv_strides, kv_strides, 0, 0, 1, 1); cb(k, "mmqa.k_down_conv", -1); k = rms_norm_act_2d(ctx, k, k_norm, kv_dim, false, cb); - k = ggml_conv_2d(ctx, k_proj, k, 1, 1, 0, 0, 1, 1); + k = conv2d_pw(ctx, k_proj, k); cb(k, "mmqa.k_proj", -1); } else { - k = ggml_conv_2d(ctx, k_proj, cur, 1, 1, 0, 0, 1, 1); + k = conv2d_pw(ctx, k_proj, cur); cb(k, "mmqa.k_proj", -1); } @@ -298,14 +313,14 @@ struct v5_mmqa : v5_blk { v = ggml_conv_2d_dw(ctx, v_down_conv, cur, kv_strides, kv_strides, 0, 0, 1, 1); cb(v, "mmqa.v_down_conv", -1); v = rms_norm_act_2d(ctx, v, v_norm, kv_dim, false, cb); - v = ggml_conv_2d(ctx, v_proj, v, 1, 1, 0, 0, 1, 1); + v = conv2d_pw(ctx, v_proj, v); cb(v, "mmqa.v_proj", -1); } else { - v = ggml_conv_2d(ctx, v_proj, cur, 1, 1, 0, 0, 1, 1); + v = conv2d_pw(ctx, v_proj, cur); cb(v, "mmqa.v_proj", -1); } - q = ggml_conv_2d(ctx, q_proj, cur, 1, 1, 0, 0, 1, 1); + q = conv2d_pw(ctx, q_proj, cur); cb(q, "mmqa.q_proj", -1); // reshape k, v, q @@ -365,7 +380,7 @@ struct v5_mmqa : v5_blk { } // output projection - cur = ggml_conv_2d(ctx0, wo, cur, 1, 1, 0, 0, 1, 1); + cur = conv2d_pw(ctx0, wo, cur); return cur; } @@ -404,7 +419,7 @@ struct v5_msfa : v5_blk { cur = pw_proj.build(ctx, cur, cb); cb(cur, "msfa.ffn.pw_proj.output", -1); - cur = ggml_mul(ctx, cur, ggml_reshape_3d(ctx, norm, 1, 1, norm->ne[0])); + cur = rms_norm_act_2d(ctx, cur, norm, pw_proj.out_chs, false, cb); cb(cur, "msfa.norm", -1); return cur; @@ -483,6 +498,8 @@ struct v5_model { cb(cur, str_concat(blk.second, ".output").c_str(), -1); } + // TODO (IMPORTANT): mfsa also takes some intermediate results as input + cur = msfa.build(ctx, cur, cb); cb(cur, "msfa.output", -1);