Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
freq_base (cparams.rope_freq_base),
freq_scale (cparams.rope_freq_scale),
ext_factor (cparams.yarn_ext_factor),
attn_factor (llama_hparams::yarn_attn_factor_adjust(cparams.yarn_attn_factor, cparams.rope_freq_scale, cparams.yarn_ext_factor)),
attn_factor (cparams.yarn_attn_factor),
beta_fast (cparams.yarn_beta_fast),
beta_slow (cparams.yarn_beta_slow),
norm_eps (hparams.f_norm_eps),
Expand Down
10 changes: 0 additions & 10 deletions src/llama-hparams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,13 +231,3 @@ bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama

return false;
}

float llama_hparams::yarn_attn_factor_adjust(float attn_factor, float freq_scale, float ext_factor) {
GGML_ASSERT(ext_factor >= 0.0f);

if (ext_factor != 0.0f) {
attn_factor *= 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
}

return attn_factor;
}
7 changes: 0 additions & 7 deletions src/llama-hparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -268,13 +268,6 @@ struct llama_hparams {
// TODO: think of a better place for this function
// TODO: pack the SWA params in a struct?
static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1);

// when YARN is applied with yarn_ext_factor != 0.0f, we need to cancel this factor:
// https://github.com/ggml-org/llama.cpp/blob/a81a569577cc38b32558958b048228150be63eae/ggml/src/ggml-cpu/ops.cpp#L5541-L5544
//
// ref: https://github.com/ggml-org/llama.cpp/discussions/7416
// https://github.com/ggml-org/llama.cpp/pull/17945
static float yarn_attn_factor_adjust(float attn_factor, float freq_scale, float ext_factor);
};

static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
2 changes: 1 addition & 1 deletion src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1372,7 +1372,7 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
const auto & yarn_beta_fast = cparams.yarn_beta_fast;
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
const auto & yarn_attn_factor = llama_hparams::yarn_attn_factor_adjust(cparams.yarn_attn_factor, cparams.rope_freq_scale, cparams.yarn_ext_factor);
const auto & yarn_attn_factor = cparams.yarn_attn_factor;

const auto & n_rot = hparams.n_rot;
const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE
Expand Down
50 changes: 31 additions & 19 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2294,30 +2294,42 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: throw std::runtime_error("unsupported model architecture");
}

// ref: https://github.com/huggingface/transformers/blob/6d00f6b0a5679c36510f203e4226e36f517c3032/src/transformers/modeling_rope_utils.py#L336-L348
if (hparams.rope_yarn_log_mul != 0.0f) {
const float factor = 1.0f / hparams.rope_freq_scale_train;

// note: here we assume `mscale == 1.0f`
// TODO: start reading the actual value of mscale and handle the case where it is not 1.0f
float mscale = 1.0f;
const float mscale_all_dims = hparams.rope_yarn_log_mul;

// [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
// special-case DEEPSEEK v2:
// https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/blob/main/config.json#L42-L43
if (arch == LLM_ARCH_DEEPSEEK2 && mscale_all_dims != 1.0f) {
mscale = mscale_all_dims;
}

// YaRN
if (hparams.yarn_ext_factor != 0) {
static auto get_mscale = [](float scale, float mscale) {
return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f);
};

hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims);
const float factor = 1.0f / hparams.rope_freq_scale_train;

// ref: https://github.com/huggingface/transformers/blob/6d00f6b0a5679c36510f203e4226e36f517c3032/src/transformers/modeling_rope_utils.py#L336-L348
if (hparams.rope_yarn_log_mul != 0.0f) {
// note: here we assume `mscale == 1.0f`
// TODO: start reading the actual value of mscale and handle the case where it is not 1.0f
float mscale = 1.0f;
const float mscale_all_dims = hparams.rope_yarn_log_mul;

// [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
// special-case DEEPSEEK v2:
// https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/blob/main/config.json#L42-L43
if (arch == LLM_ARCH_DEEPSEEK2 && mscale_all_dims != 1.0f) {
mscale = mscale_all_dims;
}

hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims);

LLAMA_LOG_WARN("%s: setting new yarn_attn_factor = %.4f (mscale == %.1f, mscale_all_dim = %.1f)\n",
__func__, hparams.yarn_attn_factor, mscale, mscale_all_dims);
} else {
hparams.yarn_attn_factor = get_mscale(factor, 1.0f);
}

LLAMA_LOG_WARN("%s: setting new yarn_attn_factor = %.4f (mscale == %.1f, mscale_all_dim = %.1f)\n",
__func__, hparams.yarn_attn_factor, mscale, mscale_all_dims);
// when YARN is applied with yarn_ext_factor != 0.0f, we need to cancel this factor:
// https://github.com/ggml-org/llama.cpp/blob/a81a569577cc38b32558958b048228150be63eae/ggml/src/ggml-cpu/ops.cpp#L5541-L5544
//
// ref: https://github.com/ggml-org/llama.cpp/discussions/7416
// https://github.com/ggml-org/llama.cpp/pull/17945
hparams.yarn_attn_factor *= 1.0f / (1.0f + 0.1f * logf(factor));
}

pimpl->n_bytes = ml.n_bytes;
Expand Down
Loading