Skip to content

Commit 254098a

Browse files
authored
common : refactor common_sampler + grammar logic changes (#17937)
* common : refactor common_sampler + grammar logic changes * tests : increase max_tokens to get needed response * batched : fix uninitialized samplers
1 parent 3238b14 commit 254098a

File tree

27 files changed

+370
-291
lines changed

27 files changed

+370
-291
lines changed

common/arg.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1415,7 +1415,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
14151415
params.sampling.top_k = value;
14161416
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K;
14171417
}
1418-
).set_sparam());
1418+
).set_sparam().set_env("LLAMA_ARG_TOP_K"));
14191419
add_opt(common_arg(
14201420
{"--top-p"}, "N",
14211421
string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p),

common/common.cpp

Lines changed: 125 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,31 +1013,40 @@ bool tty_can_use_colors() {
10131013
// Model utils
10141014
//
10151015

1016-
static inline void common_init_sampler_from_model(
1016+
// TODO: move to common/sampling
1017+
static void common_init_sampler_from_model(
10171018
const llama_model * model,
10181019
common_params_sampling & sparams) {
10191020

10201021
const uint64_t config = sparams.user_sampling_config;
10211022

10221023
auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) {
1023-
if (config & user_config) return;
1024+
if (config & user_config) {
1025+
return;
1026+
}
10241027

10251028
char buf[64] = {0};
10261029
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
10271030
char * end = nullptr;
10281031
int32_t v = strtol(buf, &end, 10);
1029-
if (end && end != buf) dst = v;
1032+
if (end && end != buf) {
1033+
dst = v;
1034+
}
10301035
}
10311036
};
10321037

10331038
auto get_float = [&](const char * key, float & dst, uint64_t user_config) {
1034-
if (config & user_config) return;
1039+
if (config & user_config) {
1040+
return;
1041+
}
10351042

10361043
char buf[128] = {0};
10371044
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
10381045
char * end = nullptr;
10391046
float v = strtof(buf, &end);
1040-
if (end && end != buf) dst = v;
1047+
if (end && end != buf) {
1048+
dst = v;
1049+
}
10411050
}
10421051
};
10431052

@@ -1065,31 +1074,122 @@ static inline void common_init_sampler_from_model(
10651074
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA);
10661075
}
10671076

1068-
struct common_init_result common_init_from_params(common_params & params) {
1069-
common_init_result iparams;
1070-
auto mparams = common_model_params_to_llama(params);
1077+
struct common_init_result::impl {
1078+
impl() = default;
1079+
~impl() = default;
1080+
1081+
llama_model_ptr model;
1082+
llama_context_ptr context;
1083+
1084+
std::vector<llama_adapter_lora_ptr> lora;
1085+
1086+
std::vector<common_sampler_ptr> samplers;
1087+
};
1088+
1089+
common_init_result::common_init_result(common_params & params) :
1090+
pimpl(new impl{}) {
1091+
const auto mparams = common_model_params_to_llama(params);
10711092

10721093
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
10731094
if (model == NULL) {
1074-
LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
1075-
__func__, params.model.path.c_str());
1076-
return iparams;
1095+
return;
10771096
}
10781097

1079-
common_init_sampler_from_model(model, params.sampling);
1098+
pimpl->model.reset(model);
10801099

10811100
const llama_vocab * vocab = llama_model_get_vocab(model);
10821101

1102+
// updates params.sampling
1103+
// TODO: fix naming
1104+
common_init_sampler_from_model(model, params.sampling);
1105+
10831106
auto cparams = common_context_params_to_llama(params);
10841107

1108+
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
1109+
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
1110+
params.sampling.ignore_eos = false;
1111+
}
1112+
1113+
// initialize once
1114+
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
1115+
if (llama_vocab_is_eog(vocab, i)) {
1116+
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(vocab, i).c_str(), -INFINITY);
1117+
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
1118+
}
1119+
}
1120+
1121+
if (params.sampling.ignore_eos) {
1122+
// add EOG biases to the active set of logit biases
1123+
params.sampling.logit_bias.insert(
1124+
params.sampling.logit_bias.end(),
1125+
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
1126+
}
1127+
1128+
//if (params.sampling.penalty_last_n == -1) {
1129+
// LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
1130+
// params.sampling.penalty_last_n = llama_n_ctx(lctx);
1131+
//}
1132+
1133+
//if (params.sampling.dry_penalty_last_n == -1) {
1134+
// LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
1135+
// params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
1136+
//}
1137+
1138+
pimpl->samplers.resize(cparams.n_seq_max);
1139+
1140+
for (int i = 0; i < (int) cparams.n_seq_max; ++i) {
1141+
pimpl->samplers[i].reset(common_sampler_init(model, params.sampling));
1142+
}
1143+
10851144
llama_context * lctx = llama_init_from_model(model, cparams);
1145+
if (lctx == NULL) {
1146+
LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
1147+
__func__, params.model.path.c_str());
1148+
return;
1149+
}
1150+
1151+
pimpl->context.reset(lctx);
1152+
}
1153+
1154+
llama_model * common_init_result::model() {
1155+
return pimpl->model.get();
1156+
}
1157+
1158+
llama_context * common_init_result::context() {
1159+
return pimpl->context.get();
1160+
}
1161+
1162+
common_sampler * common_init_result::sampler(llama_seq_id seq_id) {
1163+
return pimpl->samplers[seq_id].get();
1164+
}
1165+
1166+
std::vector<llama_adapter_lora_ptr> & common_init_result::lora() {
1167+
return pimpl->lora;
1168+
}
1169+
1170+
void common_init_result::free_context() {
1171+
pimpl->context.reset();
1172+
}
1173+
1174+
common_init_result_ptr common_init_from_params(common_params & params) {
1175+
common_init_result_ptr res(new common_init_result(params));
1176+
1177+
llama_model * model = res->model();
1178+
if (model == NULL) {
1179+
LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
1180+
__func__, params.model.path.c_str());
1181+
return res;
1182+
}
1183+
1184+
llama_context * lctx = res->context();
10861185
if (lctx == NULL) {
10871186
LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
10881187
__func__, params.model.path.c_str());
1089-
llama_model_free(model);
1090-
return iparams;
1188+
return res;
10911189
}
10921190

1191+
const llama_vocab * vocab = llama_model_get_vocab(model);
1192+
10931193
if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) {
10941194
LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
10951195
params.ctx_shift = false;
@@ -1101,10 +1201,7 @@ struct common_init_result common_init_from_params(common_params & params) {
11011201

11021202
const auto cvec = common_control_vector_load(params.control_vectors);
11031203
if (cvec.n_embd == -1) {
1104-
llama_free(lctx);
1105-
llama_model_free(model);
1106-
1107-
return iparams;
1204+
return res;
11081205
}
11091206

11101207
int err = llama_apply_adapter_cvec(
@@ -1115,10 +1212,7 @@ struct common_init_result common_init_from_params(common_params & params) {
11151212
params.control_vector_layer_start,
11161213
params.control_vector_layer_end);
11171214
if (err) {
1118-
llama_free(lctx);
1119-
llama_model_free(model);
1120-
1121-
return iparams;
1215+
return res;
11221216
}
11231217
}
11241218

@@ -1142,10 +1236,7 @@ struct common_init_result common_init_from_params(common_params & params) {
11421236
}
11431237

11441238
if (!ok) {
1145-
llama_free(lctx);
1146-
llama_model_free(model);
1147-
1148-
return iparams;
1239+
return res;
11491240
}
11501241
}
11511242

@@ -1155,9 +1246,7 @@ struct common_init_result common_init_from_params(common_params & params) {
11551246
lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
11561247
if (lora == nullptr) {
11571248
LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
1158-
llama_free(lctx);
1159-
llama_model_free(model);
1160-
return iparams;
1249+
return res;
11611250
}
11621251

11631252
char buf[1024];
@@ -1166,43 +1255,13 @@ struct common_init_result common_init_from_params(common_params & params) {
11661255
la.task_name = buf;
11671256
llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf));
11681257
la.prompt_prefix = buf;
1169-
iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
1258+
res->lora().emplace_back(std::move(lora)); // copy to list of loaded adapters
11701259
}
11711260

11721261
if (!params.lora_init_without_apply) {
11731262
common_set_adapter_lora(lctx, params.lora_adapters);
11741263
}
11751264

1176-
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
1177-
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
1178-
params.sampling.ignore_eos = false;
1179-
}
1180-
1181-
// initialize once
1182-
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
1183-
if (llama_vocab_is_eog(vocab, i)) {
1184-
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
1185-
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
1186-
}
1187-
}
1188-
1189-
if (params.sampling.ignore_eos) {
1190-
// add EOG biases to the active set of logit biases
1191-
params.sampling.logit_bias.insert(
1192-
params.sampling.logit_bias.end(),
1193-
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
1194-
}
1195-
1196-
if (params.sampling.penalty_last_n == -1) {
1197-
LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
1198-
params.sampling.penalty_last_n = llama_n_ctx(lctx);
1199-
}
1200-
1201-
if (params.sampling.dry_penalty_last_n == -1) {
1202-
LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
1203-
params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
1204-
}
1205-
12061265
if (params.warmup) {
12071266
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
12081267

@@ -1241,12 +1300,11 @@ struct common_init_result common_init_from_params(common_params & params) {
12411300
llama_set_warmup(lctx, false);
12421301
}
12431302

1244-
iparams.model.reset(model);
1245-
iparams.context.reset(lctx);
1246-
1247-
return iparams;
1303+
return res;
12481304
}
12491305

1306+
common_init_result::~common_init_result() = default;
1307+
12501308
std::string get_model_endpoint() {
12511309
const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
12521310
// We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
@@ -1255,7 +1313,9 @@ std::string get_model_endpoint() {
12551313
std::string model_endpoint = "https://huggingface.co/";
12561314
if (endpoint_env) {
12571315
model_endpoint = endpoint_env;
1258-
if (model_endpoint.back() != '/') model_endpoint += '/';
1316+
if (model_endpoint.back() != '/') {
1317+
model_endpoint += '/';
1318+
}
12591319
}
12601320
return model_endpoint;
12611321
}

common/common.h

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,6 @@ struct common_params_sampling {
195195

196196
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
197197

198-
199198
std::vector<enum common_sampler_type> samplers = {
200199
COMMON_SAMPLER_TYPE_PENALTIES,
201200
COMMON_SAMPLER_TYPE_DRY,
@@ -216,6 +215,10 @@ struct common_params_sampling {
216215
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
217216
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
218217

218+
bool has_logit_bias() const {
219+
return !logit_bias.empty();
220+
}
221+
219222
// print the parameters into a string
220223
std::string print() const;
221224
};
@@ -669,15 +672,29 @@ bool tty_can_use_colors();
669672
// Model utils
670673
//
671674

672-
// note: defines object's lifetime
675+
struct common_sampler;
676+
677+
// note: defines the model, context, samplers, ets. lifetimes
673678
struct common_init_result {
674-
llama_model_ptr model;
675-
llama_context_ptr context;
679+
common_init_result(common_params & params);
680+
~common_init_result();
676681

677-
std::vector<llama_adapter_lora_ptr> lora;
682+
llama_model * model();
683+
llama_context * context();
684+
common_sampler * sampler(llama_seq_id seq_id);
685+
686+
std::vector<llama_adapter_lora_ptr> & lora();
687+
688+
void free_context();
689+
690+
private:
691+
struct impl;
692+
std::unique_ptr<impl> pimpl;
678693
};
679694

680-
struct common_init_result common_init_from_params(common_params & params);
695+
using common_init_result_ptr = std::unique_ptr<common_init_result>;
696+
697+
common_init_result_ptr common_init_from_params(common_params & params);
681698

682699
struct llama_model_params common_model_params_to_llama ( common_params & params);
683700
struct llama_context_params common_context_params_to_llama(const common_params & params);

0 commit comments

Comments
 (0)