diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 9914b3276b7..2a17e44ecdf 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -9,6 +9,7 @@ #include "llama-model.h" #include +#include #include #include #include @@ -72,6 +73,43 @@ llama_context::llama_context( cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f; } + if (cparams.yarn_ext_factor != 0) { + static auto get_mscale = [](float scale, float mscale) { + return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f); + }; + + const float factor = 1.0f / cparams.rope_freq_scale; + + // ref: https://github.com/huggingface/transformers/blob/6d00f6b0a5679c36510f203e4226e36f517c3032/src/transformers/modeling_rope_utils.py#L336-L348 + if (hparams.rope_yarn_log_mul != 0.0f) { + // note: here we assume `mscale == 1.0f` + // TODO: start reading the actual value of mscale and handle the case where it is not 1.0f + float mscale = 1.0f; + const float mscale_all_dims = hparams.rope_yarn_log_mul; + + // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] + // special-case DEEPSEEK v2: + // https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/blob/main/config.json#L42-L43 + if (model.arch == LLM_ARCH_DEEPSEEK2 && mscale_all_dims != 1.0f) { + mscale = mscale_all_dims; + } + + cparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims); + + LLAMA_LOG_WARN("%s: setting new yarn_attn_factor = %.4f (mscale == %.1f, mscale_all_dim = %.1f)\n", + __func__, cparams.yarn_attn_factor, mscale, mscale_all_dims); + } else { + cparams.yarn_attn_factor = get_mscale(factor, 1.0f); + } + + // when YARN is applied with yarn_ext_factor != 0.0f, we need to cancel this factor: + // https://github.com/ggml-org/llama.cpp/blob/a81a569577cc38b32558958b048228150be63eae/ggml/src/ggml-cpu/ops.cpp#L5541-L5544 + // + // ref: https://github.com/ggml-org/llama.cpp/discussions/7416 + // https://github.com/ggml-org/llama.cpp/pull/17945 + cparams.yarn_attn_factor *= 1.0f / (1.0f + 0.1f * logf(factor)); + } + cparams.yarn_attn_factor *= hparams.rope_attn_factor; if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index a1a32494b75..6cf9a883a6e 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -574,7 +574,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : freq_base (cparams.rope_freq_base), freq_scale (cparams.rope_freq_scale), ext_factor (cparams.yarn_ext_factor), - attn_factor (llama_hparams::yarn_attn_factor_adjust(cparams.yarn_attn_factor, cparams.rope_freq_scale, cparams.yarn_ext_factor)), + attn_factor (cparams.yarn_attn_factor), beta_fast (cparams.yarn_beta_fast), beta_slow (cparams.yarn_beta_slow), norm_eps (hparams.f_norm_eps), diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 277d0bcfd3c..96c9598c24c 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -3,7 +3,6 @@ #include "ggml.h" #include -#include void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) { if (dense_first) { @@ -231,13 +230,3 @@ bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama return false; } - -float llama_hparams::yarn_attn_factor_adjust(float attn_factor, float freq_scale, float ext_factor) { - GGML_ASSERT(ext_factor >= 0.0f); - - if (ext_factor != 0.0f) { - attn_factor *= 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)); - } - - return attn_factor; -} diff --git a/src/llama-hparams.h b/src/llama-hparams.h index c9960e91697..aab319754e5 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -268,13 +268,6 @@ struct llama_hparams { // TODO: think of a better place for this function // TODO: pack the SWA params in a struct? static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1); - - // when YARN is applied with yarn_ext_factor != 0.0f, we need to cancel this factor: - // https://github.com/ggml-org/llama.cpp/blob/a81a569577cc38b32558958b048228150be63eae/ggml/src/ggml-cpu/ops.cpp#L5541-L5544 - // - // ref: https://github.com/ggml-org/llama.cpp/discussions/7416 - // https://github.com/ggml-org/llama.cpp/pull/17945 - static float yarn_attn_factor_adjust(float attn_factor, float freq_scale, float ext_factor); }; static_assert(std::is_trivially_copyable::value, "llama_hparams must be trivially copyable"); diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 8f94c8820ce..bf3de2f2efd 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1372,7 +1372,7 @@ ggml_tensor * llama_kv_cache::build_rope_shift( const auto & yarn_ext_factor = cparams.yarn_ext_factor; const auto & yarn_beta_fast = cparams.yarn_beta_fast; const auto & yarn_beta_slow = cparams.yarn_beta_slow; - const auto & yarn_attn_factor = llama_hparams::yarn_attn_factor_adjust(cparams.yarn_attn_factor, cparams.rope_freq_scale, cparams.yarn_ext_factor); + const auto & yarn_attn_factor = cparams.yarn_attn_factor; const auto & n_rot = hparams.n_rot; const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e4808b1e1eb..5da1dd6dbb8 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2294,32 +2294,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: throw std::runtime_error("unsupported model architecture"); } - // ref: https://github.com/huggingface/transformers/blob/6d00f6b0a5679c36510f203e4226e36f517c3032/src/transformers/modeling_rope_utils.py#L336-L348 - if (hparams.rope_yarn_log_mul != 0.0f) { - const float factor = 1.0f / hparams.rope_freq_scale_train; - - // note: here we assume `mscale == 1.0f` - // TODO: start reading the actual value of mscale and handle the case where it is not 1.0f - float mscale = 1.0f; - const float mscale_all_dims = hparams.rope_yarn_log_mul; - - // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] - // special-case DEEPSEEK v2: - // https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/blob/main/config.json#L42-L43 - if (arch == LLM_ARCH_DEEPSEEK2 && mscale_all_dims != 1.0f) { - mscale = mscale_all_dims; - } - - static auto get_mscale = [](float scale, float mscale) { - return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f); - }; - - hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims); - - LLAMA_LOG_WARN("%s: setting new yarn_attn_factor = %.4f (mscale == %.1f, mscale_all_dim = %.1f)\n", - __func__, hparams.yarn_attn_factor, mscale, mscale_all_dims); - } - pimpl->n_bytes = ml.n_bytes; pimpl->desc_str = arch_name() + " " + type_name() + " " + ml.ftype_name();