Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions tools/server/server-common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -757,11 +757,6 @@ json oaicompat_completion_params_parse(const json & body) {
llama_params["stop"] = json_value(body, "stop", json::array());
}

// Handle "echo" field
if (json_value(body, "echo", false)) {
throw std::runtime_error("Only no echo is supported");
}

// Params supported by OAI but unsupported by llama.cpp
static const std::vector<std::string> unsupported_params { "best_of", "suffix" };
for (const auto & param : unsupported_params) {
Expand Down
50 changes: 41 additions & 9 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ static bool server_task_type_need_logits(server_task_type task_type) {
struct server_slot {
int id;

common_params params_base;
llama_batch batch_spec = {};

// TODO: change to unique_ptrs for consistency:
Expand Down Expand Up @@ -107,6 +108,9 @@ struct server_slot {
// ref: https://github.com/ggml-org/llama.cpp/pull/17808
std::vector<int32_t> i_batch_dft;

// idx of prompt tokens to get logits (for echo=true)
std::vector<std::pair<int32_t, llama_token>> i_batch_prompt;

std::vector<completion_token_output> generated_token_probs;

bool has_next_token = true;
Expand Down Expand Up @@ -209,6 +213,12 @@ struct server_slot {
return server_task_type_need_embd(task->type);
}

bool need_prompt_logits() const {
GGML_ASSERT(task);

return task->params.echo && task->params.sampling.n_probs > 0;
}

bool need_logits() const {
GGML_ASSERT(task);

Expand Down Expand Up @@ -255,6 +265,12 @@ struct server_slot {
return ctx_dft;
}

std::string token_to_piece(const llama_token & token) const {
bool is_special = params_base.special
|| task->params.sampling.preserved_tokens.find(token) != task->params.sampling.preserved_tokens.end();
return common_token_to_piece(ctx, token, is_special);
}

void add_token(const completion_token_output & token) {
if (!is_processing()) {
SLT_WRN(*this, "%s", "slot is not processing\n");
Expand Down Expand Up @@ -724,6 +740,7 @@ struct server_context_impl {
slot.ctx = ctx;
slot.n_ctx = n_ctx_slot;
slot.mctx = mctx;
slot.params_base = params_base;
slot.prompt.tokens.has_mtmd = mctx != nullptr;

if (model_dft) {
Expand Down Expand Up @@ -1829,11 +1846,6 @@ struct server_context_impl {
// track if given slot can be batched with slots already in the batch
server_slot * slot_batched = nullptr;

auto accept_special_token = [&](server_slot & slot, llama_token token) {
return params_base.special ||
slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end();
};

// first, add sampled tokens from any ongoing sequences
for (auto & slot : slots) {
if (slot.state != SLOT_STATE_GENERATING) {
Expand Down Expand Up @@ -1919,6 +1931,8 @@ struct server_context_impl {
continue;
}

slot.i_batch_prompt.clear();

// this slot still has a prompt to be processed
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
const auto & input_tokens = slot.task->tokens;
Expand Down Expand Up @@ -2280,9 +2294,14 @@ struct server_context_impl {
cur_tok,
slot.prompt.tokens.pos_next(),
{ slot.id },
slot.need_embd());
slot.need_embd() || slot.need_prompt_logits());
slot.prompt.tokens.push_back(cur_tok);

// track prompt tokens that need logits output
if (slot.need_prompt_logits()) {
slot.i_batch_prompt.push_back({batch.n_tokens - 1, cur_tok});
}

slot.n_prompt_tokens_processed++;

// process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
Expand Down Expand Up @@ -2486,11 +2505,24 @@ struct server_context_impl {
}
}

// optionally send prompt processing progress
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) {
// optionally send prompt processing progress
if (slot.task->params.stream && slot.task->params.return_progress) {
send_partial_response(slot, {}, true);
}

// optinally get prompt logits (echo=true)
if (!slot.i_batch_prompt.empty()) {
GGML_ASSERT(slot.task->params.stream); // TODO: support non-streaming if needed
for (auto & [tok_idx, id] : slot.i_batch_prompt) {
completion_token_output result;
result.tok = id;
result.text_to_send = slot.token_to_piece(id);
result.prob = 1.0f;
populate_token_probs(slot, result, slot.task->params.post_sampling_probs, params_base.special, tok_idx);
send_partial_response(slot, result, false);
}
}
}

if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
Expand Down Expand Up @@ -2543,7 +2575,7 @@ struct server_context_impl {

completion_token_output result;
result.tok = id;
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
result.text_to_send = slot.token_to_piece(result.tok);
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs

if (slot.task->params.sampling.n_probs > 0) {
Expand Down Expand Up @@ -2594,7 +2626,7 @@ struct server_context_impl {
completion_token_output result;

result.tok = ids[i];
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
result.text_to_send = slot.token_to_piece(result.tok);
result.prob = 1.0f; // set later

// TODO: set result.probs
Expand Down
9 changes: 9 additions & 0 deletions tools/server/server-task.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ task_params server_task::params_from_json_cmpl(
params.timings_per_token = json_value(data, "timings_per_token", false);

params.stream = json_value(data, "stream", false);
params.echo = json_value(data, "echo", false);
auto stream_opt = json_value(data, "stream_options", json::object());
params.include_usage = json_value(stream_opt, "include_usage", false);
params.cache_prompt = json_value(data, "cache_prompt", true);
Expand Down Expand Up @@ -221,6 +222,14 @@ task_params server_task::params_from_json_cmpl(
params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs);
}

if (params.echo && params.sampling.n_probs == 0) {
throw std::runtime_error("Error: echo without logprobs is not yet supported");
}

if (params.echo && params.sampling.n_probs != 0 && !params.stream) {
throw std::runtime_error("Error: echo with logprobs requires streaming to be enabled");
}

if (data.contains("lora")) {
if (data.at("lora").is_array()) {
params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora"));
Expand Down
3 changes: 2 additions & 1 deletion tools/server/server-task.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ enum stop_type {
struct task_params {
bool stream = true;
bool include_usage = false;
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
bool echo = false; // echo the prompt in the output, useful for eval use cases
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
bool return_tokens = false;
bool return_progress = false;

Expand Down
Loading