Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1359,7 +1359,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.sampling.top_k = value;
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K;
}
).set_sparam());
).set_sparam().set_env("LLAMA_ARG_TOP_K"));
add_opt(common_arg(
{"--top-p"}, "N",
string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p),
Expand Down
190 changes: 125 additions & 65 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1013,31 +1013,40 @@ bool tty_can_use_colors() {
// Model utils
//

static inline void common_init_sampler_from_model(
// TODO: move to common/sampling
static void common_init_sampler_from_model(
const llama_model * model,
common_params_sampling & sparams) {

const uint64_t config = sparams.user_sampling_config;

auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) {
if (config & user_config) return;
if (config & user_config) {
return;
}

char buf[64] = {0};
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
char * end = nullptr;
int32_t v = strtol(buf, &end, 10);
if (end && end != buf) dst = v;
if (end && end != buf) {
dst = v;
}
}
};

auto get_float = [&](const char * key, float & dst, uint64_t user_config) {
if (config & user_config) return;
if (config & user_config) {
return;
}

char buf[128] = {0};
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
char * end = nullptr;
float v = strtof(buf, &end);
if (end && end != buf) dst = v;
if (end && end != buf) {
dst = v;
}
}
};

Expand Down Expand Up @@ -1065,31 +1074,122 @@ static inline void common_init_sampler_from_model(
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA);
}

struct common_init_result common_init_from_params(common_params & params) {
common_init_result iparams;
auto mparams = common_model_params_to_llama(params);
struct common_init_result::impl {
impl() = default;
~impl() = default;

llama_model_ptr model;
llama_context_ptr context;

std::vector<llama_adapter_lora_ptr> lora;

std::vector<common_sampler_ptr> samplers;
};

common_init_result::common_init_result(common_params & params) :
pimpl(new impl{}) {
const auto mparams = common_model_params_to_llama(params);

llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
if (model == NULL) {
LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
__func__, params.model.path.c_str());
return iparams;
return;
}

common_init_sampler_from_model(model, params.sampling);
pimpl->model.reset(model);

const llama_vocab * vocab = llama_model_get_vocab(model);

// updates params.sampling
// TODO: fix naming
common_init_sampler_from_model(model, params.sampling);

auto cparams = common_context_params_to_llama(params);

if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
params.sampling.ignore_eos = false;
}

// initialize once
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
if (llama_vocab_is_eog(vocab, i)) {
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(vocab, i).c_str(), -INFINITY);
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
}
}

if (params.sampling.ignore_eos) {
// add EOG biases to the active set of logit biases
params.sampling.logit_bias.insert(
params.sampling.logit_bias.end(),
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
}

//if (params.sampling.penalty_last_n == -1) {
// LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
// params.sampling.penalty_last_n = llama_n_ctx(lctx);
//}

//if (params.sampling.dry_penalty_last_n == -1) {
// LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
// params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
//}

pimpl->samplers.resize(cparams.n_seq_max);

for (int i = 0; i < (int) cparams.n_seq_max; ++i) {
pimpl->samplers[i].reset(common_sampler_init(model, params.sampling));
}

llama_context * lctx = llama_init_from_model(model, cparams);
if (lctx == NULL) {
LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
__func__, params.model.path.c_str());
return;
}

pimpl->context.reset(lctx);
}

llama_model * common_init_result::model() {
return pimpl->model.get();
}

llama_context * common_init_result::context() {
return pimpl->context.get();
}

common_sampler * common_init_result::sampler(llama_seq_id seq_id) {
return pimpl->samplers[seq_id].get();
}

std::vector<llama_adapter_lora_ptr> & common_init_result::lora() {
return pimpl->lora;
}

void common_init_result::free_context() {
pimpl->context.reset();
}

common_init_result_ptr common_init_from_params(common_params & params) {
common_init_result_ptr res(new common_init_result(params));

llama_model * model = res->model();
if (model == NULL) {
LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
__func__, params.model.path.c_str());
return res;
}

llama_context * lctx = res->context();
if (lctx == NULL) {
LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
__func__, params.model.path.c_str());
llama_model_free(model);
return iparams;
return res;
}

const llama_vocab * vocab = llama_model_get_vocab(model);

if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) {
LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
params.ctx_shift = false;
Expand All @@ -1101,10 +1201,7 @@ struct common_init_result common_init_from_params(common_params & params) {

const auto cvec = common_control_vector_load(params.control_vectors);
if (cvec.n_embd == -1) {
llama_free(lctx);
llama_model_free(model);

return iparams;
return res;
}

int err = llama_apply_adapter_cvec(
Expand All @@ -1115,10 +1212,7 @@ struct common_init_result common_init_from_params(common_params & params) {
params.control_vector_layer_start,
params.control_vector_layer_end);
if (err) {
llama_free(lctx);
llama_model_free(model);

return iparams;
return res;
}
}

Expand All @@ -1142,10 +1236,7 @@ struct common_init_result common_init_from_params(common_params & params) {
}

if (!ok) {
llama_free(lctx);
llama_model_free(model);

return iparams;
return res;
}
}

Expand All @@ -1155,9 +1246,7 @@ struct common_init_result common_init_from_params(common_params & params) {
lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
if (lora == nullptr) {
LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
llama_free(lctx);
llama_model_free(model);
return iparams;
return res;
}

char buf[1024];
Expand All @@ -1166,43 +1255,13 @@ struct common_init_result common_init_from_params(common_params & params) {
la.task_name = buf;
llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf));
la.prompt_prefix = buf;
iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
res->lora().emplace_back(std::move(lora)); // copy to list of loaded adapters
}

if (!params.lora_init_without_apply) {
common_set_adapter_lora(lctx, params.lora_adapters);
}

if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
params.sampling.ignore_eos = false;
}

// initialize once
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
if (llama_vocab_is_eog(vocab, i)) {
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
}
}

if (params.sampling.ignore_eos) {
// add EOG biases to the active set of logit biases
params.sampling.logit_bias.insert(
params.sampling.logit_bias.end(),
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
}

if (params.sampling.penalty_last_n == -1) {
LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
params.sampling.penalty_last_n = llama_n_ctx(lctx);
}

if (params.sampling.dry_penalty_last_n == -1) {
LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
}

if (params.warmup) {
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);

Expand Down Expand Up @@ -1241,12 +1300,11 @@ struct common_init_result common_init_from_params(common_params & params) {
llama_set_warmup(lctx, false);
}

iparams.model.reset(model);
iparams.context.reset(lctx);

return iparams;
return res;
}

common_init_result::~common_init_result() = default;

std::string get_model_endpoint() {
const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
// We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
Expand All @@ -1255,7 +1313,9 @@ std::string get_model_endpoint() {
std::string model_endpoint = "https://huggingface.co/";
if (endpoint_env) {
model_endpoint = endpoint_env;
if (model_endpoint.back() != '/') model_endpoint += '/';
if (model_endpoint.back() != '/') {
model_endpoint += '/';
}
}
return model_endpoint;
}
Expand Down
29 changes: 23 additions & 6 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ struct common_params_sampling {

std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY


std::vector<enum common_sampler_type> samplers = {
COMMON_SAMPLER_TYPE_PENALTIES,
COMMON_SAMPLER_TYPE_DRY,
Expand All @@ -216,6 +215,10 @@ struct common_params_sampling {
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens

bool has_logit_bias() const {
return !logit_bias.empty();
}

// print the parameters into a string
std::string print() const;
};
Expand Down Expand Up @@ -669,15 +672,29 @@ bool tty_can_use_colors();
// Model utils
//

// note: defines object's lifetime
struct common_sampler;

// note: defines the model, context, samplers, ets. lifetimes
struct common_init_result {
llama_model_ptr model;
llama_context_ptr context;
common_init_result(common_params & params);
~common_init_result();

std::vector<llama_adapter_lora_ptr> lora;
llama_model * model();
llama_context * context();
common_sampler * sampler(llama_seq_id seq_id);

std::vector<llama_adapter_lora_ptr> & lora();

void free_context();

private:
struct impl;
std::unique_ptr<impl> pimpl;
};

struct common_init_result common_init_from_params(common_params & params);
using common_init_result_ptr = std::unique_ptr<common_init_result>;

common_init_result_ptr common_init_from_params(common_params & params);

struct llama_model_params common_model_params_to_llama ( common_params & params);
struct llama_context_params common_context_params_to_llama(const common_params & params);
Expand Down
Loading
Loading