@@ -1013,31 +1013,40 @@ bool tty_can_use_colors() {
10131013// Model utils
10141014//
10151015
1016- static inline void common_init_sampler_from_model (
1016+ // TODO: move to common/sampling
1017+ static void common_init_sampler_from_model (
10171018 const llama_model * model,
10181019 common_params_sampling & sparams) {
10191020
10201021 const uint64_t config = sparams.user_sampling_config ;
10211022
10221023 auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) {
1023- if (config & user_config) return ;
1024+ if (config & user_config) {
1025+ return ;
1026+ }
10241027
10251028 char buf[64 ] = {0 };
10261029 if (llama_model_meta_val_str (model, key, buf, sizeof (buf)) > 0 ) {
10271030 char * end = nullptr ;
10281031 int32_t v = strtol (buf, &end, 10 );
1029- if (end && end != buf) dst = v;
1032+ if (end && end != buf) {
1033+ dst = v;
1034+ }
10301035 }
10311036 };
10321037
10331038 auto get_float = [&](const char * key, float & dst, uint64_t user_config) {
1034- if (config & user_config) return ;
1039+ if (config & user_config) {
1040+ return ;
1041+ }
10351042
10361043 char buf[128 ] = {0 };
10371044 if (llama_model_meta_val_str (model, key, buf, sizeof (buf)) > 0 ) {
10381045 char * end = nullptr ;
10391046 float v = strtof (buf, &end);
1040- if (end && end != buf) dst = v;
1047+ if (end && end != buf) {
1048+ dst = v;
1049+ }
10411050 }
10421051 };
10431052
@@ -1065,31 +1074,122 @@ static inline void common_init_sampler_from_model(
10651074 get_float (llama_model_meta_key_str (LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta , common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA);
10661075}
10671076
1068- struct common_init_result common_init_from_params (common_params & params) {
1069- common_init_result iparams;
1070- auto mparams = common_model_params_to_llama (params);
1077+ struct common_init_result ::impl {
1078+ impl () = default ;
1079+ ~impl () = default ;
1080+
1081+ llama_model_ptr model;
1082+ llama_context_ptr context;
1083+
1084+ std::vector<llama_adapter_lora_ptr> lora;
1085+
1086+ std::vector<common_sampler_ptr> samplers;
1087+ };
1088+
1089+ common_init_result::common_init_result (common_params & params) :
1090+ pimpl(new impl{}) {
1091+ const auto mparams = common_model_params_to_llama (params);
10711092
10721093 llama_model * model = llama_model_load_from_file (params.model .path .c_str (), mparams);
10731094 if (model == NULL ) {
1074- LOG_ERR (" %s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n " ,
1075- __func__, params.model .path .c_str ());
1076- return iparams;
1095+ return ;
10771096 }
10781097
1079- common_init_sampler_from_model ( model, params. sampling );
1098+ pimpl-> model . reset (model );
10801099
10811100 const llama_vocab * vocab = llama_model_get_vocab (model);
10821101
1102+ // updates params.sampling
1103+ // TODO: fix naming
1104+ common_init_sampler_from_model (model, params.sampling );
1105+
10831106 auto cparams = common_context_params_to_llama (params);
10841107
1108+ if (params.sampling .ignore_eos && llama_vocab_eos (vocab) == LLAMA_TOKEN_NULL) {
1109+ LOG_WRN (" %s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n " , __func__);
1110+ params.sampling .ignore_eos = false ;
1111+ }
1112+
1113+ // initialize once
1114+ for (llama_token i = 0 ; i < llama_vocab_n_tokens (vocab); i++) {
1115+ if (llama_vocab_is_eog (vocab, i)) {
1116+ LOG_INF (" %s: added %s logit bias = %f\n " , __func__, common_token_to_piece (vocab, i).c_str (), -INFINITY);
1117+ params.sampling .logit_bias_eog .push_back ({i, -INFINITY});
1118+ }
1119+ }
1120+
1121+ if (params.sampling .ignore_eos ) {
1122+ // add EOG biases to the active set of logit biases
1123+ params.sampling .logit_bias .insert (
1124+ params.sampling .logit_bias .end (),
1125+ params.sampling .logit_bias_eog .begin (), params.sampling .logit_bias_eog .end ());
1126+ }
1127+
1128+ // if (params.sampling.penalty_last_n == -1) {
1129+ // LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
1130+ // params.sampling.penalty_last_n = llama_n_ctx(lctx);
1131+ // }
1132+
1133+ // if (params.sampling.dry_penalty_last_n == -1) {
1134+ // LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
1135+ // params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
1136+ // }
1137+
1138+ pimpl->samplers .resize (cparams.n_seq_max );
1139+
1140+ for (int i = 0 ; i < (int ) cparams.n_seq_max ; ++i) {
1141+ pimpl->samplers [i].reset (common_sampler_init (model, params.sampling ));
1142+ }
1143+
10851144 llama_context * lctx = llama_init_from_model (model, cparams);
1145+ if (lctx == NULL ) {
1146+ LOG_ERR (" %s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n " ,
1147+ __func__, params.model .path .c_str ());
1148+ return ;
1149+ }
1150+
1151+ pimpl->context .reset (lctx);
1152+ }
1153+
1154+ llama_model * common_init_result::model () {
1155+ return pimpl->model .get ();
1156+ }
1157+
1158+ llama_context * common_init_result::context () {
1159+ return pimpl->context .get ();
1160+ }
1161+
1162+ common_sampler * common_init_result::sampler (llama_seq_id seq_id) {
1163+ return pimpl->samplers [seq_id].get ();
1164+ }
1165+
1166+ std::vector<llama_adapter_lora_ptr> & common_init_result::lora () {
1167+ return pimpl->lora ;
1168+ }
1169+
1170+ void common_init_result::free_context () {
1171+ pimpl->context .reset ();
1172+ }
1173+
1174+ common_init_result_ptr common_init_from_params (common_params & params) {
1175+ common_init_result_ptr res (new common_init_result (params));
1176+
1177+ llama_model * model = res->model ();
1178+ if (model == NULL ) {
1179+ LOG_ERR (" %s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n " ,
1180+ __func__, params.model .path .c_str ());
1181+ return res;
1182+ }
1183+
1184+ llama_context * lctx = res->context ();
10861185 if (lctx == NULL ) {
10871186 LOG_ERR (" %s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n " ,
10881187 __func__, params.model .path .c_str ());
1089- llama_model_free (model);
1090- return iparams;
1188+ return res;
10911189 }
10921190
1191+ const llama_vocab * vocab = llama_model_get_vocab (model);
1192+
10931193 if (params.ctx_shift && !llama_memory_can_shift (llama_get_memory (lctx))) {
10941194 LOG_WRN (" %s: KV cache shifting is not supported for this context, disabling KV cache shifting\n " , __func__);
10951195 params.ctx_shift = false ;
@@ -1101,10 +1201,7 @@ struct common_init_result common_init_from_params(common_params & params) {
11011201
11021202 const auto cvec = common_control_vector_load (params.control_vectors );
11031203 if (cvec.n_embd == -1 ) {
1104- llama_free (lctx);
1105- llama_model_free (model);
1106-
1107- return iparams;
1204+ return res;
11081205 }
11091206
11101207 int err = llama_apply_adapter_cvec (
@@ -1115,10 +1212,7 @@ struct common_init_result common_init_from_params(common_params & params) {
11151212 params.control_vector_layer_start ,
11161213 params.control_vector_layer_end );
11171214 if (err) {
1118- llama_free (lctx);
1119- llama_model_free (model);
1120-
1121- return iparams;
1215+ return res;
11221216 }
11231217 }
11241218
@@ -1142,10 +1236,7 @@ struct common_init_result common_init_from_params(common_params & params) {
11421236 }
11431237
11441238 if (!ok) {
1145- llama_free (lctx);
1146- llama_model_free (model);
1147-
1148- return iparams;
1239+ return res;
11491240 }
11501241 }
11511242
@@ -1155,9 +1246,7 @@ struct common_init_result common_init_from_params(common_params & params) {
11551246 lora.reset (llama_adapter_lora_init (model, la.path .c_str ()));
11561247 if (lora == nullptr ) {
11571248 LOG_ERR (" %s: failed to apply lora adapter '%s'\n " , __func__, la.path .c_str ());
1158- llama_free (lctx);
1159- llama_model_free (model);
1160- return iparams;
1249+ return res;
11611250 }
11621251
11631252 char buf[1024 ];
@@ -1166,43 +1255,13 @@ struct common_init_result common_init_from_params(common_params & params) {
11661255 la.task_name = buf;
11671256 llama_adapter_meta_val_str (la.ptr , " adapter.lora.prompt_prefix" , buf, sizeof (buf));
11681257 la.prompt_prefix = buf;
1169- iparams. lora .emplace_back (std::move (lora)); // copy to list of loaded adapters
1258+ res-> lora () .emplace_back (std::move (lora)); // copy to list of loaded adapters
11701259 }
11711260
11721261 if (!params.lora_init_without_apply ) {
11731262 common_set_adapter_lora (lctx, params.lora_adapters );
11741263 }
11751264
1176- if (params.sampling .ignore_eos && llama_vocab_eos (vocab) == LLAMA_TOKEN_NULL) {
1177- LOG_WRN (" %s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n " , __func__);
1178- params.sampling .ignore_eos = false ;
1179- }
1180-
1181- // initialize once
1182- for (llama_token i = 0 ; i < llama_vocab_n_tokens (vocab); i++) {
1183- if (llama_vocab_is_eog (vocab, i)) {
1184- LOG_INF (" %s: added %s logit bias = %f\n " , __func__, common_token_to_piece (lctx, i).c_str (), -INFINITY);
1185- params.sampling .logit_bias_eog .push_back ({i, -INFINITY});
1186- }
1187- }
1188-
1189- if (params.sampling .ignore_eos ) {
1190- // add EOG biases to the active set of logit biases
1191- params.sampling .logit_bias .insert (
1192- params.sampling .logit_bias .end (),
1193- params.sampling .logit_bias_eog .begin (), params.sampling .logit_bias_eog .end ());
1194- }
1195-
1196- if (params.sampling .penalty_last_n == -1 ) {
1197- LOG_INF (" %s: setting penalty_last_n to ctx_size = %d\n " , __func__, llama_n_ctx (lctx));
1198- params.sampling .penalty_last_n = llama_n_ctx (lctx);
1199- }
1200-
1201- if (params.sampling .dry_penalty_last_n == -1 ) {
1202- LOG_INF (" %s: setting dry_penalty_last_n to ctx_size = %d\n " , __func__, llama_n_ctx (lctx));
1203- params.sampling .dry_penalty_last_n = llama_n_ctx (lctx);
1204- }
1205-
12061265 if (params.warmup ) {
12071266 LOG_WRN (" %s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n " , __func__);
12081267
@@ -1241,12 +1300,11 @@ struct common_init_result common_init_from_params(common_params & params) {
12411300 llama_set_warmup (lctx, false );
12421301 }
12431302
1244- iparams.model .reset (model);
1245- iparams.context .reset (lctx);
1246-
1247- return iparams;
1303+ return res;
12481304}
12491305
1306+ common_init_result::~common_init_result () = default ;
1307+
12501308std::string get_model_endpoint () {
12511309 const char * model_endpoint_env = getenv (" MODEL_ENDPOINT" );
12521310 // We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
@@ -1255,7 +1313,9 @@ std::string get_model_endpoint() {
12551313 std::string model_endpoint = " https://huggingface.co/" ;
12561314 if (endpoint_env) {
12571315 model_endpoint = endpoint_env;
1258- if (model_endpoint.back () != ' /' ) model_endpoint += ' /' ;
1316+ if (model_endpoint.back () != ' /' ) {
1317+ model_endpoint += ' /' ;
1318+ }
12591319 }
12601320 return model_endpoint;
12611321}
0 commit comments