Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 115 additions & 5 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ constexpr int HTTP_POLLING_SECONDS = 1;
// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
enum slot_state {
SLOT_STATE_IDLE,
SLOT_STATE_WAIT_OTHER, // after assigning a task, but waiting for parent slot to process prompt
SLOT_STATE_STARTED, // after assigning a task and about to process prompt
SLOT_STATE_WAIT_OTHER, // after assigning a task, but waiting for parent slot to process prompt
SLOT_STATE_STARTED, // after assigning a task and about to process prompt
SLOT_STATE_PROCESSING_PROMPT,
SLOT_STATE_DONE_PROMPT,
SLOT_STATE_GENERATING,
Expand Down Expand Up @@ -529,6 +529,7 @@ struct server_context_impl {
llama_batch batch {};

bool add_bos_token = true;
bool has_encoder = false; // true if model is encoder-decoder (e.g., T5, BART)

int32_t n_ctx; // total context for all clients / slots

Expand Down Expand Up @@ -593,6 +594,24 @@ struct server_context_impl {
n_ctx = llama_n_ctx(ctx);

add_bos_token = llama_vocab_get_add_bos(vocab);
has_encoder = llama_model_has_encoder(model);

if (has_encoder) {
SRV_INF("model has encoder - encoder-decoder mode enabled (e.g., T5, BART)%s\n", "");

// warn about incompatible features
if (params_base.ctx_shift) {
SRV_WRN("encoder-decoder models do not support context shift - disabling%s\n", "");
params_base.ctx_shift = false;
}
if (params_base.cache_ram_mib != 0) {
SRV_WRN("encoder-decoder models: prompt caching works differently - encoder outputs are not cached%s\n", "");
}
if (params_base.has_speculative()) {
SRV_WRN("encoder-decoder models do not support speculative decoding - ignoring draft model%s\n", "");
// Note: speculative setup continues below but won't be used for enc-dec slots
}
}

if (params_base.has_speculative()) {
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str());
Expand Down Expand Up @@ -726,7 +745,8 @@ struct server_context_impl {
slot.mctx = mctx;
slot.prompt.tokens.has_mtmd = mctx != nullptr;

if (model_dft) {
// speculative decoding is not supported for encoder-decoder models
if (model_dft && !has_encoder) {
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);

// TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK]
Expand Down Expand Up @@ -1928,11 +1948,101 @@ struct server_context_impl {
slot.t_start_process_prompt = ggml_time_us();
slot.t_start_generation = 0;

slot.state = SLOT_STATE_PROCESSING_PROMPT;

SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, task.n_tokens = %d\n",
slot.n_ctx, slot.task->params.n_keep, slot.task->n_tokens());

// encoder-decoder model handling (e.g., T5, BART, MADLAD)
if (has_encoder) {
SLT_INF(slot, "encoder-decoder model: encoding %d tokens\n", slot.task->n_tokens());

// clear the decoder KV cache for this slot - encoder-decoder models
// don't support prefix caching, so we always start fresh
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
slot.prompt.tokens.clear();

// empty prompt check
if (input_tokens.empty()) {
SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
slot.print_timings();
send_final_response(slot);
slot.release();
continue;
}

// get the text tokens for encoding
const llama_tokens & text_tokens = input_tokens.get_text_tokens();

// check for empty text tokens (could happen with multimodal-only input)
if (text_tokens.empty()) {
SLT_ERR(slot, "%s", "encoder-decoder models require text tokens\n");
send_error(slot, "encoder-decoder models require text input", ERROR_TYPE_INVALID_REQUEST);
slot.release();
continue;
}

// build encoder batch with all prompt tokens
// Note: we need to allocate a proper batch with seq_id support
llama_batch batch_enc = llama_batch_init(text_tokens.size(), 0, 1);
batch_enc.n_tokens = text_tokens.size();

for (size_t i = 0; i < text_tokens.size(); i++) {
batch_enc.token[i] = text_tokens[i];
batch_enc.pos[i] = i;
batch_enc.n_seq_id[i] = 1;
batch_enc.seq_id[i][0] = slot.id;
batch_enc.logits[i] = false;
}

// encode the entire prompt
const int ret = llama_encode(ctx, batch_enc);

// free the encoder batch
llama_batch_free(batch_enc);
if (ret != 0) {
SLT_ERR(slot, "llama_encode() failed with error %d\n", ret);
send_error(slot, "encoder failed", ERROR_TYPE_SERVER);
slot.release();
continue;
}

SLT_INF(slot, "encoder completed, %d tokens encoded\n", slot.task->n_tokens());

// get decoder start token
llama_token decoder_start_token = llama_model_decoder_start_token(model);
if (decoder_start_token == LLAMA_TOKEN_NULL) {
decoder_start_token = llama_vocab_bos(vocab);
}

SLT_DBG(slot, "decoder start token: %d '%s'\n",
decoder_start_token, common_token_to_piece(ctx, decoder_start_token).c_str());

// add decoder start token to the batch
common_batch_add(batch, decoder_start_token, 0, { slot.id }, true);

// update slot state - we've processed all prompt tokens (via encoder)
// and the decoder is ready to generate
slot.prompt.tokens.clear();
slot.prompt.tokens.push_back(decoder_start_token);
slot.n_prompt_tokens_processed = slot.task->n_tokens();

common_sampler_reset(slot.smpl);

slot.n_decoded = 0;
slot.i_batch = batch.n_tokens - 1;

slot.state = SLOT_STATE_DONE_PROMPT;

SLT_INF(slot, "encoder-decoder: prompt encoded, decoder ready%s\n", "");

if (!slot_batched) {
slot_batched = &slot;
}

continue; // skip normal prompt processing
}

slot.state = SLOT_STATE_PROCESSING_PROMPT;

// print prompt tokens (for debugging)
/*if (1) {
// first 16 tokens (avoid flooding logs)
Expand Down