Skip to content

Commit ec98e20

Browse files
llama: fix early stop in params_fit if ctx is set (#18070)
1 parent 59977eb commit ec98e20

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

src/llama.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,13 @@ static void llama_params_fit_impl(
241241
global_surplus += memory_reduction;
242242
LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n",
243243
__func__, hp_nct, cparams->n_ctx, memory_reduction/MiB);
244+
if (global_surplus >= 0) {
245+
if (nd == 1) {
246+
LLAMA_LOG_INFO("%s: entire model can be fit by reducing context\n", __func__);
247+
return;
248+
}
249+
LLAMA_LOG_INFO("%s: entire model should be fit across devices by reducing context\n", __func__);
250+
}
244251
} else {
245252
LLAMA_LOG_INFO("%s: default model context size is %" PRIu32 " which is <= the min. context size of %" PRIu32 " -> no change\n",
246253
__func__, hp_nct, n_ctx_min);
@@ -249,10 +256,6 @@ static void llama_params_fit_impl(
249256
LLAMA_LOG_INFO("%s: context size set by user to %" PRIu32 " -> no change\n", __func__, cparams->n_ctx);
250257
}
251258
}
252-
if (global_surplus >= 0) {
253-
LLAMA_LOG_INFO("%s: entire model can be fit across devices by reducing context\n", __func__);
254-
return;
255-
}
256259
}
257260

258261
if (mparams->n_gpu_layers != default_mparams.n_gpu_layers) {

0 commit comments

Comments
 (0)