-
Notifications
You must be signed in to change notification settings - Fork 14.1k
mtmd: Add DeepSeekOCR Support #17400
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
sfallah
wants to merge
104
commits into
ggml-org:master
Choose a base branch
from
sfallah:sf/deepseek-ocr
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,844
−40
Open
Changes from all commits
Commits
Show all changes
104 commits
Select commit
Hold shift + click to select a range
43a130b
mtmd: llama.cpp DeepSeekOCR support
sfallah b6b9f02
loading sam tensors
sfallah 85c7cda
mtmd: fix vision model processing
bluebread 578c8d7
Merge pull request #1 from bluebread/sf/deepseek-ocr
sfallah 2aab52e
deepseek-ocr clip-vit model impl
sfallah eab28ed
mtmd: add DeepSeek-OCR LM support with standard attention
bluebread 7630587
mtmd: successfully runs DeepSeek-OCR LM in llama-cli
bluebread 2de3436
mtmd: Fix RoPE type for DeepSeek-OCR LM.
bluebread e8b2610
Merge branch 'sf/deepseek-ocr' of github.com:sfallah/llama.cpp into s…
bluebread 97e0907
loading LM
sfallah 13dc6fb
Merge branch 'sf/deepseek-ocr' into sf/deepseek-ocr
sfallah b32bb5e
Merge pull request #2 from bluebread/sf/deepseek-ocr
sfallah 790bbb9
sam warmup working
sfallah cec9a5c
sam erroneous return corrected
sfallah 8b3d319
clip-vit: corrected cls_embd concat
sfallah 1e08157
clip-vit: model convert qkv_proj split
sfallah 331cea8
corrected combining of image encoders' results
sfallah 6c0715b
fix: update callback for ffn_moe_weighted and add callback for attn_o…
bluebread a65ddf5
Merge branch 'sf/deepseek-ocr' of github.com:sfallah/llama.cpp into s…
bluebread 63a042f
concat image_newline and image_seperator tokens
sfallah 89afda8
visual_model warmup (technically) works
sfallah 88032f4
window partitioning using standard ggml ops
sfallah 1268dc3
Merge branch 'sf/deepseek-ocr' of github.com:sfallah/llama.cpp into s…
bluebread 68b206b
sam implementation without using CPU only ops
sfallah 8bce66d
clip: fixed warnings
bluebread 5e6cf3c
Merge branch 'sf/deepseek-ocr' of github.com:sfallah/llama.cpp into s…
bluebread 7e9fbec
mtmd: fix get_rel_pos
bluebread 0f5587d
Merge branch 'sf/deepseek-ocr' of github.com:sfallah/llama.cpp into s…
bluebread 7b8d735
mtmd: fixed the wrong scaler for get_rel_pos
bluebread 86f111f
image encoding technically works but the output can't be checked sing…
sfallah effe669
mtmd: minor changed
bluebread f8f66a1
Merge branch 'sf/deepseek-ocr' of github.com:sfallah/llama.cpp into s…
bluebread 3fcfc3a
Merge pull request #3 from bluebread/sf/deepseek-ocr
sfallah ee8a148
mtmd: add native resolution support
bluebread 4cfa15f
- image encoding debugged
sfallah 3f71188
mtmd: correct token order
bluebread a594990
Merge pull request #5 from bluebread/dsocr-debug
sfallah 6dfda99
Merge branch 'sf/deepseek-ocr' into sf/deepseek-ocr
sfallah 7941f5d
Merge pull request #4 from bluebread/sf/deepseek-ocr
sfallah 206f8ab
- dynamic resizing
sfallah 40e7e6e
mtmd: quick fix token order
bluebread 81533e4
mtmd: fix danling pointer
bluebread 8810940
Merge pull request #6 from bluebread/sf/deepseek-ocr
sfallah a488b49
mtmd: SAM numerically works
bluebread ccb2f23
mtmd: debug CLIP-L (vit_pre_ln)
bluebread 841a4a8
mtmd: debug CLIP-L & first working DeepSeek-OCR model
bluebread ed3b7f1
Merge remote-tracking branch 'sfallah/master' into sf/deepseek-ocr
sfallah 5543094
Merge branch 'sf/deepseek-ocr' of github.com:sfallah/llama.cpp into s…
bluebread c5f4c64
mtmd : add --dsocr-mode CLI argument for DeepSeek-OCR resolution cont…
bluebread 95239f9
mtmd: simplify SAM patch embedding
bluebread 6b0e7cd
Merge pull request #7 from bluebread/sf/deepseek-ocr
sfallah 6634166
Merge branch 'master' into sf/deepseek-ocr
sfallah c914e05
mtmd: adapt Pillow image resizing function
bluebread e20857b
mtmd: simplify DeepSeek-OCR dynamic resolution preprocessing
bluebread 43dfc0c
Merge branch 'sf/deepseek-ocr' of github.com:sfallah/llama.cpp into s…
bluebread b696c54
mtmd: remove --dsocr-mode argument
bluebread b26b507
mtmd: refactor code & remove unused helper functions
bluebread 7451b84
mtmd: fix tensor names for image newlines and view separator
bluebread 386ba47
clean up
sfallah c73748a
Merge branch 'sf/deepseek-ocr' into sf/deepseek-ocr-cleanup
sfallah a661c52
reverting automatically removed spaces
sfallah 0399ddf
reverting automatically removed spaces
sfallah c89171c
mtmd: fixed bad ocr check in Deepseek2 (LM)
bluebread 2dd9924
Merge branch 'sf/deepseek-ocr-cleanup' of github.com:sfallah/llama.cp…
bluebread fc3f625
mtmd: support combined QKV projection in buid_vit
bluebread 4d7d994
Merge pull request #8 from sfallah/sf/deepseek-ocr-cleanup
sfallah 5381b9c
using common build_attn in sam
sfallah 076138a
corrected code-branch when flash-attn disabled
sfallah d0c08e3
mtmd: minor fix
bluebread f5bd310
minor formatting and style
sfallah 6687b4e
Merge pull request #9 from sfallah/sf/deepseek-ocr-attn
sfallah 5f2ee1a
Merge branch 'ggml-org:master' into sf/deepseek-ocr
sfallah 1c88647
fixed flake8 lint issues
sfallah d981f19
minor editorconfig-check fixes
sfallah 705394c
minor editorconfig-check fixes
sfallah 15f2ada
mtmd: simplify get_rel_pos
bluebread 2d918b3
mtmd: make sam hparams configurable
bluebread 5dfcc5a
mtmd: add detailed comments for resize_bicubic_pillow
bluebread 53273f8
mtmd: fixed wrong input setting
bluebread 48c6cf2
mtmd: convert model in FP16
bluebread 5174a1e
mtmd: minor fix
bluebread 0161406
mtmd: remove tweak to llama-mtmd-cli & deepseek-ocr template
bluebread ed944cd
fix: test-1.jpg ORC issue with small (640) resolution
sfallah aaf2fd1
minor: editconfig-check fix
sfallah 33fabf0
Merge branch 'master' into sf/deepseek-ocr-merge-test
sfallah d70f171
merge with changes from https://github.com/ggml-org/llama.cpp/pull/17909
sfallah 4cbbe8a
minor: editconfig-check fix
sfallah 47f0fee
testing deepseek-ocr
sfallah e0e69fd
Merge remote-tracking branch 'sfallah/master' into sf/deepseek-ocr-me…
sfallah f95a6fe
quick and (potential) dirty merge with https://github.com/ggml-org/ll…
sfallah f7736f2
refactoring, one single builder function and static helpers
sfallah fb3bb6a
added deepseek-ocr test to tests.sh
sfallah 1b38ccf
Merge pull request #11 from sfallah/sf/deepseek-ocr-merge_#17965
sfallah 6c36c03
minor formatting fixes
sfallah dc2066e
check with fixed expected resutls
sfallah 3fc61d4
Merge pull request #10 from sfallah/sf/deepseek-ocr-test-script
sfallah 7f8621c
minor formatting
sfallah b3bf8cb
Merge remote-tracking branch 'sfallah/master' into sf/deepseek-ocr
sfallah 8ad98ee
editorconfig-check fix
sfallah 4a4f829
Merge branch 'ggml-org:master' into sf/deepseek-ocr
sfallah 51c3de6
Merge remote-tracking branch 'sfallah/master' into sf/deepseek-ocr
sfallah 512b2c8
merge with changes from https://github.com/ggml-org/llama.cpp/pull/18042
sfallah 00d2357
Merge remote-tracking branch 'sfallah/master' into sf/deepseek-ocr
sfallah 87e4a00
minor
sfallah File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -711,6 +711,9 @@ def load_hparams(dir_model: Path, is_mistral_format: bool): | |||
| if "thinker_config" in config: | ||||
| # rename for Qwen2.5-Omni | ||||
| config["text_config"] = config["thinker_config"]["text_config"] | ||||
| if "language_config" in config: | ||||
| # rename for DeepSeekOCR | ||||
| config["text_config"] = config["language_config"] | ||||
| return config | ||||
|
|
||||
| @classmethod | ||||
|
|
@@ -1688,7 +1691,7 @@ class MmprojModel(ModelBase): | |||
| preprocessor_config: dict[str, Any] | ||||
| global_config: dict[str, Any] | ||||
|
|
||||
| n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers"] | ||||
| n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "layers", "encoder_layers"] | ||||
|
|
||||
| has_vision_encoder: bool = True # by default | ||||
| has_audio_encoder: bool = False | ||||
|
|
@@ -5956,6 +5959,68 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter | |||
| return [] # skip other tensors | ||||
|
|
||||
|
|
||||
| @ModelBase.register("DeepseekOCRForCausalLM") | ||||
| class DeepseekOCRVisionModel(MmprojModel): | ||||
| def set_gguf_parameters(self): | ||||
| super().set_gguf_parameters() | ||||
| hparams = self.hparams | ||||
| self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.DEEPSEEKOCR) | ||||
| # default values below are taken from HF tranformers code | ||||
| self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("layer_norm_eps", 1e-6)) | ||||
| self.gguf_writer.add_vision_use_gelu(True) | ||||
| # calculate proj_scale_factor (used by tinygemma3 test model) | ||||
| image_seq_length = self.preprocessor_config.get("image_seq_length", 256) | ||||
| n_per_side = int(image_seq_length ** 0.5) | ||||
| image_size = self.hparams["image_size"] | ||||
| patch_size = self.hparams["patch_size"] | ||||
| proj_scale_factor = (image_size // patch_size) // n_per_side | ||||
| if proj_scale_factor > 0 and proj_scale_factor != 4: | ||||
| # we only need to write this if it's not the default value | ||||
| # in this case, we are converting a test model | ||||
| self.gguf_writer.add_vision_projector_scale_factor(proj_scale_factor) | ||||
| # @bluebread: there's no window_size in config but just add it here anyway | ||||
| self.gguf_writer.add_vision_window_size(self.hparams.get("window_size", 14)) | ||||
|
|
||||
| # SAM configuration | ||||
| sam_hparams = hparams['sam'] | ||||
| self.gguf_writer.add_vision_sam_layers_count(sam_hparams['layers']) | ||||
| self.gguf_writer.add_vision_sam_embedding_length(sam_hparams['width']) | ||||
| self.gguf_writer.add_vision_sam_head_count(sam_hparams['heads']) | ||||
|
|
||||
| def get_vision_config(self) -> dict[str, Any]: | ||||
| vision_config: dict[str, Any] | None = self.global_config.get("vision_config") | ||||
|
|
||||
| if not vision_config: | ||||
| raise ValueError("DeepseekOCR model requires 'vision_config' in the model configuration, but it was not found") | ||||
|
|
||||
| vision_config['sam'] = vision_config['width']['sam_vit_b'] | ||||
| vision_config.update(vision_config['width']['clip-l-14-224']) | ||||
| vision_config['hidden_size'] = vision_config['width'] | ||||
| vision_config['num_heads'] = vision_config['heads'] | ||||
| vision_config['intermediate_size'] = vision_config['heads'] * 4 | ||||
|
|
||||
| return vision_config | ||||
|
|
||||
| def tensor_force_quant(self, name, new_name, bid, n_dims): | ||||
| if ".embeddings." in name or 'pos_embed' in name: | ||||
| return gguf.GGMLQuantizationType.F32 | ||||
| if ".rel_pos_h" in name or '.rel_pos_w' in name: | ||||
| return gguf.GGMLQuantizationType.F32 | ||||
| return gguf.GGMLQuantizationType.F16 | ||||
|
|
||||
| def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: | ||||
| # Only process vision-related tensors, skip language model tensors | ||||
| # Vision components: sam_model, vision_model, projector, image_newline, view_seperator | ||||
| # Language model components to skip: lm_head, embed_tokens, layers, norm | ||||
| if name.startswith(("lm_head.", "model.embed_tokens.", "model.layers.", "model.norm.")): | ||||
| return [] | ||||
|
|
||||
| if ".attn.rel_pos_h" in name or ".attn.rel_pos_w" in name: | ||||
| return [(self.map_tensor_name(name, try_suffixes=("",)), data_torch)] | ||||
|
|
||||
| return [(self.map_tensor_name(name), data_torch)] | ||||
|
|
||||
|
|
||||
| @ModelBase.register("Gemma3nForConditionalGeneration") | ||||
| class Gemma3NModel(Gemma3Model): | ||||
| model_arch = gguf.MODEL_ARCH.GEMMA3N | ||||
|
|
@@ -7122,6 +7187,15 @@ def prepare_tensors(self): | |||
| class DeepseekV2Model(TextModel): | ||||
| model_arch = gguf.MODEL_ARCH.DEEPSEEK2 | ||||
|
|
||||
| def __init__(self, *args, **kwargs): | ||||
| super().__init__(*args, **kwargs) | ||||
| vision_config = self.hparams.get('vision_config', {}).get('width', {}) | ||||
|
|
||||
| if 'clip-l-14-224' in vision_config and 'sam_vit_b' in vision_config: | ||||
| self.model_arch = gguf.MODEL_ARCH.DEEPSEEK2OCR | ||||
| self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch] | ||||
| self.gguf_writer.add_architecture() | ||||
|
|
||||
| def set_vocab(self): | ||||
| try: | ||||
| self._set_vocab_gpt2() | ||||
|
|
@@ -7177,30 +7251,40 @@ def set_vocab(self): | |||
| raise NotImplementedError(f"Deepseek pre-tokenizer {tokpre!r} is not supported yet!") | ||||
|
|
||||
| def set_gguf_parameters(self): | ||||
| is_ocr = (self.model_arch == gguf.MODEL_ARCH.DEEPSEEK2OCR) | ||||
|
|
||||
| # note: deepseek2 using MLA converts into MQA (ie: GQA with 1 group) | ||||
| self.hparams["num_key_value_heads"] = 1 | ||||
| if is_ocr: | ||||
| self.hparams['rope_theta'] = self.hparams.get('rope_theta', 10000.0) | ||||
| self.hparams['rms_norm_eps'] = self.hparams.get('rms_norm_eps', 1e-6) | ||||
| else: | ||||
| # note: deepseek2 using MLA converts into MQA (ie: GQA with 1 group) | ||||
| self.hparams["num_key_value_heads"] = 1 | ||||
|
|
||||
| super().set_gguf_parameters() | ||||
| hparams = self.hparams | ||||
|
|
||||
| kv_lora_rank = hparams["kv_lora_rank"] if hparams.get("kv_lora_rank") is not None else 512 | ||||
| routed_scaling_factor = hparams.get("routed_scaling_factor", 1.0) | ||||
| norm_topk_prob = hparams.get("norm_topk_prob", False) | ||||
| self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) | ||||
| self.gguf_writer.add_vocab_size(hparams["vocab_size"]) | ||||
| if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None: | ||||
| self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"]) | ||||
| self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"]) | ||||
| if "kv_lora_rank" in hparams and hparams["kv_lora_rank"] is not None: | ||||
| self.gguf_writer.add_kv_lora_rank(kv_lora_rank) | ||||
|
|
||||
| # note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA | ||||
| self.gguf_writer.add_key_length(hparams["kv_lora_rank"] + hparams["qk_rope_head_dim"]) | ||||
| self.gguf_writer.add_value_length(hparams["kv_lora_rank"]) | ||||
| self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) | ||||
| self.gguf_writer.add_value_length_mla(hparams["v_head_dim"]) | ||||
| if not is_ocr: | ||||
| self.gguf_writer.add_key_length(kv_lora_rank + hparams["qk_rope_head_dim"]) | ||||
| self.gguf_writer.add_value_length(kv_lora_rank) | ||||
| self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) | ||||
| self.gguf_writer.add_value_length_mla(hparams["v_head_dim"]) | ||||
| self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) | ||||
|
|
||||
| self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) | ||||
| self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) | ||||
| self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"]) | ||||
| self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"]) | ||||
| self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"]) | ||||
| self.gguf_writer.add_expert_weights_scale(routed_scaling_factor) | ||||
| self.gguf_writer.add_expert_weights_norm(norm_topk_prob) | ||||
|
|
||||
| self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) | ||||
|
|
||||
|
|
@@ -7209,12 +7293,18 @@ def set_gguf_parameters(self): | |||
| # note: for legacy reasons, this is not consistent with the other usages of self.gguf_writer.add_rope_scaling_yarn_log_mul | ||||
| # ref https://github.com/ggml-org/llama.cpp/pull/17945 | ||||
| self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * rope_mscale_all) | ||||
| self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-6)) | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
This was surely already added by |
||||
|
|
||||
| _experts: list[dict[str, Tensor]] | None = None | ||||
|
|
||||
| def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: | ||||
| # skip vision tensors and remove "language_model." for Kimi-VL | ||||
| if "vision_tower" in name or "multi_modal_projector" in name: | ||||
| if ("vision_" in name | ||||
| or "multi_modal_projector" in name | ||||
| or "image_newline" in name | ||||
| or "model.projector" in name | ||||
| or "sam_model" in name | ||||
| or "view_seperator" in name): | ||||
| return [] | ||||
|
|
||||
| if name.startswith("language_model."): | ||||
|
|
||||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discover original arch instead, like here:
llama.cpp/convert_hf_to_gguf.py
Lines 2409 to 2413 in 87e4a00