Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 137 additions & 5 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,10 @@ def _create_vocab_sentencepiece(self):
elif tokenizer.IsByte(token_id):
toktype = SentencePieceTokenTypes.BYTE

if token_id >= vocab_size:
logger.warning(f'ignore tokens from {token_id}: id is out of range, max={vocab_size - 1}')
break

tokens[token_id] = text
scores[token_id] = score
toktypes[token_id] = toktype
Expand Down Expand Up @@ -1131,6 +1135,7 @@ class MmprojModel(ModelBase):
preprocessor_config: dict[str, Any]
global_config: dict[str, Any]

block_count: Any = None # will be set in __init__
n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"]

has_vision_encoder: bool = True # by default
Expand Down Expand Up @@ -1169,7 +1174,8 @@ def __init__(self, *args, **kwargs):

# TODO @ngxson : this is a hack to support both vision and audio encoders
have_multiple_encoders = self.has_audio_encoder and self.has_vision_encoder
self.block_count = 128 if have_multiple_encoders else self.find_hparam(self.n_block_keys, True)
if self.block_count is None:
self.block_count = 128 if have_multiple_encoders else self.find_hparam(self.n_block_keys, True)
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count)

# load preprocessor config
Expand Down Expand Up @@ -4223,6 +4229,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
@ModelBase.register("Gemma3ForCausalLM", "Gemma3ForConditionalGeneration")
class Gemma3Model(TextModel):
model_arch = gguf.MODEL_ARCH.GEMMA3
norm_shift = 1.0 # Gemma3RMSNorm adds 1.0 to the norm value

def set_vocab(self):
self._set_vocab_sentencepiece()
Expand All @@ -4244,9 +4251,8 @@ def set_gguf_parameters(self):
self.gguf_writer.add_value_length(hparams.get("head_dim", 256))
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1_000_000.0)) # for global layers
# both attn_logit_softcapping and final_logit_softcapping are removed in Gemma3
# attn_logit_softcapping is removed in Gemma3
assert hparams.get("attn_logit_softcapping") is None
assert hparams.get("final_logit_softcapping") is None
self.gguf_writer.add_sliding_window(hparams["sliding_window"])
self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4))
if hparams.get("rope_scaling") is not None:
Expand All @@ -4258,7 +4264,7 @@ def set_gguf_parameters(self):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused

if name.startswith("language_model."):
if "language_model." in name:
name = name.replace("language_model.", "")

elif name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \
Expand All @@ -4273,8 +4279,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter

# ref code in Gemma3RMSNorm
# output = output * (1.0 + self.weight.float())
# note: this is not the case on gemma3n
if name.endswith("norm.weight"):
data_torch = data_torch + 1
data_torch = data_torch + self.norm_shift

return [(self.map_tensor_name(name), data_torch)]

Expand Down Expand Up @@ -4331,6 +4338,131 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return [] # skip other tensors


@ModelBase.register("Gemma3nForConditionalGeneration")
class Gemma3NModel(Gemma3Model):
model_arch = gguf.MODEL_ARCH.GEMMA3N
norm_shift = 0.0 # same value with Gemma3p5RMSNorm scale_shift on python code

_altup_proj: list[Tensor] = []
_altup_unembd: list[Tensor] = []

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams["altup_num_inputs"] == 4, "Current conversion only supports 4 altup inputs"
self._altup_proj = [
torch.Tensor(), # to be replaced
torch.Tensor(), # to be replaced
torch.Tensor(), # to be replaced
]
self._altup_unembd = [
torch.Tensor(), # to be replaced
torch.Tensor(), # to be replaced
torch.Tensor(), # to be replaced
]

def set_gguf_parameters(self):
super().set_gguf_parameters()

def _stack_matrices(self, matrices: list[Tensor]) -> Tensor | None:
has_all = all(m.numel() > 0 for m in matrices)
if not has_all:
return None
else:
return torch.stack(matrices, dim=0)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name.endswith("_scale"):
name = name + ".weight"

# TODO: implement self.prediction_coefs.weight.clamp_(...)

if "language_model." not in name:
return [] # skip non-language model tensors

if "embed_tokens_per_layer.weight" in name:
hidden_size_per_layer_input = 256
data_torch = data_torch * (hidden_size_per_layer_input**0.5)

if "altup_unembed_projections" in name:
data_torch = data_torch.to(device="cpu")
if ".0." in name:
self._altup_unembd[0] = data_torch
elif ".1." in name:
self._altup_unembd[1] = data_torch
elif ".2." in name:
self._altup_unembd[2] = data_torch
else:
raise ValueError(f"Unknown name: {name}")
out = self._stack_matrices(self._altup_unembd)
if out is not None:
return [(self.map_tensor_name("model.altup_unembed_projections.weight"), out)]
else:
return []

if "altup_projections" in name:
data_torch = data_torch.to(device="cpu")
if ".0." in name:
self._altup_proj[0] = data_torch
elif ".1." in name:
self._altup_proj[1] = data_torch
elif ".2." in name:
self._altup_proj[2] = data_torch
else:
raise ValueError(f"Unknown name: {name}")
out = self._stack_matrices(self._altup_proj)
if out is not None:
return [(self.map_tensor_name("model.altup_projections.weight"), out)]
else:
return []

return super().modify_tensors(data_torch, name, bid)


@ModelBase.register("Gemma3nForConditionalGeneration")
class Gemma3NMmprojModel(MmprojModel):
has_audio_encoder = False # TODO
has_vision_encoder = True
block_count = 128 # dummy, unused

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hparams["image_size"] = 768
self.hparams["patch_size"] = 3
# below are dummy values, unused
self.hparams["intermediate_size"] = 1
self.hparams["n_layers"] = 0
self.hparams["num_attention_heads"] = 0

def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GEMMA3NV)
self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("layer_norm_eps", 1e-6))
self.gguf_writer.add_vision_use_gelu(True)

# def tensor_force_quant(self, name, new_name, bid, n_dims):
# del bid, new_name, n_dims # unused
# return False

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused

if name.startswith("model.vision_tower.timm_model"):
# process vision tensors
name = name.replace(".gamma", ".weight")
name = name.replace("model.vision_tower.timm_model.", "v.mobilenet.")

# if "conv" in name or "attn" in name:
# # check if we have 1x1 kernel (last 2 dims are 1x1)
# if data_torch.dim() == 4 and data_torch.shape[-2:] == (1, 1):
# # convert 4D conv with 1x1 kernel to 2D matrix for matmul operation
# data_torch = data_torch.squeeze(-1).squeeze(-1)

return [(name, data_torch)] # not using map_tensor_name here

return [] # skip other tensors


@ModelBase.register("Starcoder2ForCausalLM")
class StarCoder2Model(TextModel):
model_arch = gguf.MODEL_ARCH.STARCODER2
Expand Down
70 changes: 70 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ class MODEL_ARCH(IntEnum):
GEMMA = auto()
GEMMA2 = auto()
GEMMA3 = auto()
GEMMA3N = auto()
STARCODER2 = auto()
RWKV6 = auto()
RWKV6QWEN2 = auto()
Expand Down Expand Up @@ -398,6 +399,22 @@ class MODEL_TENSOR(IntEnum):
ATTN_Q_NORM = auto()
ATTN_K_NORM = auto()
LAYER_OUT_NORM = auto()
PER_LAYER_TOKEN_EMBD = auto() # gemma3n
PER_LAYER_MODEL_PROJ = auto() # gemma3n
PER_LAYER_INP_GATE = auto() # gemma3n
PER_LAYER_PROJ = auto() # gemma3n
PER_LAYER_PROJ_NORM = auto() # gemma3n
PER_LAYER_POST_NORM = auto() # gemma3n
ALTUP_PROJ = auto() # gemma3n
ALTUP_UNEMBD_PROJ = auto() # gemma3n
ALTUP_CORRECT_COEF = auto() # gemma3n
ALTUP_CORRECT_SCALE = auto() # gemma3n
ALTUP_PREDICT_COEF = auto() # gemma3n
ALTUP_ROUTER = auto() # gemma3n
ALTUP_ROUTER_NORM = auto() # gemma3n
LAUREL_L = auto() # gemma3n
LAUREL_R = auto() # gemma3n
LAUREL_POST_NORM = auto() # gemma3n
SSM_IN = auto()
SSM_CONV1D = auto()
SSM_X = auto()
Expand Down Expand Up @@ -596,6 +613,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.GEMMA: "gemma",
MODEL_ARCH.GEMMA2: "gemma2",
MODEL_ARCH.GEMMA3: "gemma3",
MODEL_ARCH.GEMMA3N: "gemma3n",
MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.RWKV6: "rwkv6",
MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2",
Expand Down Expand Up @@ -681,6 +699,22 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b",
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: "per_layer_token_embd", # gemma3n
MODEL_TENSOR.PER_LAYER_MODEL_PROJ: "per_layer_model_proj", # gemma3n
MODEL_TENSOR.PER_LAYER_PROJ_NORM: "per_layer_proj_norm", # gemma3n
MODEL_TENSOR.ALTUP_UNEMBD_PROJ: "altup_unembd_proj", # gemma3n
MODEL_TENSOR.ALTUP_PROJ: "altup_proj", # gemma3n
MODEL_TENSOR.PER_LAYER_INP_GATE: "blk.{bid}.inp_gate", # gemma3n
MODEL_TENSOR.PER_LAYER_PROJ: "blk.{bid}.proj", # gemma3n
MODEL_TENSOR.PER_LAYER_POST_NORM: "blk.{bid}.post_norm", # gemma3n
MODEL_TENSOR.ALTUP_CORRECT_COEF: "blk.{bid}.altup_correct_coef", # gemma3n
MODEL_TENSOR.ALTUP_CORRECT_SCALE: "blk.{bid}.altup_correct_scale", # gemma3n
MODEL_TENSOR.ALTUP_PREDICT_COEF: "blk.{bid}.altup_predict_coef", # gemma3n
MODEL_TENSOR.ALTUP_ROUTER: "blk.{bid}.altup_router", # gemma3n
MODEL_TENSOR.ALTUP_ROUTER_NORM: "blk.{bid}.altup_router_norm", # gemma3n
MODEL_TENSOR.LAUREL_L: "blk.{bid}.laurel_l", # gemma3n
MODEL_TENSOR.LAUREL_R: "blk.{bid}.laurel_r", # gemma3n
MODEL_TENSOR.LAUREL_POST_NORM: "blk.{bid}.laurel_post_norm", # gemma3n
MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in",
MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d",
MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x",
Expand Down Expand Up @@ -1485,6 +1519,41 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_PRE_NORM,
MODEL_TENSOR.FFN_POST_NORM,
],
MODEL_ARCH.GEMMA3N: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.FFN_PRE_NORM,
MODEL_TENSOR.FFN_POST_NORM,
# altup / laurel
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD,
MODEL_TENSOR.PER_LAYER_MODEL_PROJ,
MODEL_TENSOR.PER_LAYER_INP_GATE,
MODEL_TENSOR.PER_LAYER_PROJ,
MODEL_TENSOR.PER_LAYER_PROJ_NORM,
MODEL_TENSOR.PER_LAYER_POST_NORM,
MODEL_TENSOR.ALTUP_PROJ,
MODEL_TENSOR.ALTUP_UNEMBD_PROJ,
MODEL_TENSOR.ALTUP_CORRECT_COEF,
MODEL_TENSOR.ALTUP_CORRECT_SCALE,
MODEL_TENSOR.ALTUP_PREDICT_COEF,
MODEL_TENSOR.ALTUP_ROUTER,
MODEL_TENSOR.ALTUP_ROUTER_NORM,
MODEL_TENSOR.LAUREL_L,
MODEL_TENSOR.LAUREL_R,
MODEL_TENSOR.LAUREL_POST_NORM,
],
MODEL_ARCH.STARCODER2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
Expand Down Expand Up @@ -2313,6 +2382,7 @@ def get_type(val: Any) -> GGUFValueType:

class VisionProjectorType:
GEMMA3 = "gemma3"
GEMMA3NV = "gemma3nv" # vision
IDEFICS3 = "idefics3"
PIXTRAL = "pixtral"
LLAMA4 = "llama4"
Expand Down
64 changes: 64 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,70 @@ class TensorNameMap:
"encoder.layer.{bid}.layer_norm_2" # jina-v2-code
),

MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: (
"model.embed_tokens_per_layer", # gemma3n
),

MODEL_TENSOR.PER_LAYER_MODEL_PROJ: (
"model.per_layer_model_projection", # gemma3n
),

MODEL_TENSOR.PER_LAYER_PROJ_NORM: (
"model.per_layer_projection_norm", # gemma3n
),

MODEL_TENSOR.ALTUP_PROJ: (
"model.altup_projections", # gemma3n
),

MODEL_TENSOR.ALTUP_UNEMBD_PROJ: (
"model.altup_unembed_projections", # gemma3n
),

MODEL_TENSOR.PER_LAYER_INP_GATE: (
"model.layers.{bid}.per_layer_input_gate", # gemma3n
),

MODEL_TENSOR.PER_LAYER_PROJ: (
"model.layers.{bid}.per_layer_projection", # gemma3n
),

MODEL_TENSOR.PER_LAYER_POST_NORM: (
"model.layers.{bid}.post_per_layer_input_norm", # gemma3n
),

MODEL_TENSOR.ALTUP_CORRECT_COEF: (
"model.layers.{bid}.altup.correction_coefs", # gemma3n
),

MODEL_TENSOR.ALTUP_CORRECT_SCALE: (
"model.layers.{bid}.altup.correct_output_scale", # gemma3n
),

MODEL_TENSOR.ALTUP_PREDICT_COEF: (
"model.layers.{bid}.altup.prediction_coefs", # gemma3n
),

MODEL_TENSOR.ALTUP_ROUTER: (
"model.layers.{bid}.altup.modality_router", # gemma3n
),

MODEL_TENSOR.ALTUP_ROUTER_NORM: (
"model.layers.{bid}.altup.router_norm", # gemma3n
),

MODEL_TENSOR.LAUREL_L: (
"model.layers.{bid}.laurel.linear_left", # gemma3n
),

MODEL_TENSOR.LAUREL_R: (
"model.layers.{bid}.laurel.linear_right", # gemma3n
),

MODEL_TENSOR.LAUREL_POST_NORM: (
"model.layers.{bid}.laurel.post_laurel_norm", # gemma3n
),

MODEL_TENSOR.SSM_IN: (
"model.layers.{bid}.in_proj",
"backbone.layers.{bid}.mixer.in_proj",
Expand Down
Loading