diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 5a67f508dfb..c110bff81ec 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -35,8 +35,8 @@ constexpr int HTTP_POLLING_SECONDS = 1; // state diagram: https://github.com/ggml-org/llama.cpp/pull/9283 enum slot_state { SLOT_STATE_IDLE, - SLOT_STATE_WAIT_OTHER, // after assigning a task, but waiting for parent slot to process prompt - SLOT_STATE_STARTED, // after assigning a task and about to process prompt + SLOT_STATE_WAIT_OTHER, // after assigning a task, but waiting for parent slot to process prompt + SLOT_STATE_STARTED, // after assigning a task and about to process prompt SLOT_STATE_PROCESSING_PROMPT, SLOT_STATE_DONE_PROMPT, SLOT_STATE_GENERATING, @@ -529,6 +529,7 @@ struct server_context_impl { llama_batch batch {}; bool add_bos_token = true; + bool has_encoder = false; // true if model is encoder-decoder (e.g., T5, BART) int32_t n_ctx; // total context for all clients / slots @@ -593,6 +594,24 @@ struct server_context_impl { n_ctx = llama_n_ctx(ctx); add_bos_token = llama_vocab_get_add_bos(vocab); + has_encoder = llama_model_has_encoder(model); + + if (has_encoder) { + SRV_INF("model has encoder - encoder-decoder mode enabled (e.g., T5, BART)%s\n", ""); + + // warn about incompatible features + if (params_base.ctx_shift) { + SRV_WRN("encoder-decoder models do not support context shift - disabling%s\n", ""); + params_base.ctx_shift = false; + } + if (params_base.cache_ram_mib != 0) { + SRV_WRN("encoder-decoder models: prompt caching works differently - encoder outputs are not cached%s\n", ""); + } + if (params_base.has_speculative()) { + SRV_WRN("encoder-decoder models do not support speculative decoding - ignoring draft model%s\n", ""); + // Note: speculative setup continues below but won't be used for enc-dec slots + } + } if (params_base.has_speculative()) { SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str()); @@ -726,7 +745,8 @@ struct server_context_impl { slot.mctx = mctx; slot.prompt.tokens.has_mtmd = mctx != nullptr; - if (model_dft) { + // speculative decoding is not supported for encoder-decoder models + if (model_dft && !has_encoder) { slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); // TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK] @@ -1928,11 +1948,101 @@ struct server_context_impl { slot.t_start_process_prompt = ggml_time_us(); slot.t_start_generation = 0; - slot.state = SLOT_STATE_PROCESSING_PROMPT; - SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, task.n_tokens = %d\n", slot.n_ctx, slot.task->params.n_keep, slot.task->n_tokens()); + // encoder-decoder model handling (e.g., T5, BART, MADLAD) + if (has_encoder) { + SLT_INF(slot, "encoder-decoder model: encoding %d tokens\n", slot.task->n_tokens()); + + // clear the decoder KV cache for this slot - encoder-decoder models + // don't support prefix caching, so we always start fresh + llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1); + slot.prompt.tokens.clear(); + + // empty prompt check + if (input_tokens.empty()) { + SLT_WRN(slot, "%s", "empty prompt - releasing slot\n"); + slot.print_timings(); + send_final_response(slot); + slot.release(); + continue; + } + + // get the text tokens for encoding + const llama_tokens & text_tokens = input_tokens.get_text_tokens(); + + // check for empty text tokens (could happen with multimodal-only input) + if (text_tokens.empty()) { + SLT_ERR(slot, "%s", "encoder-decoder models require text tokens\n"); + send_error(slot, "encoder-decoder models require text input", ERROR_TYPE_INVALID_REQUEST); + slot.release(); + continue; + } + + // build encoder batch with all prompt tokens + // Note: we need to allocate a proper batch with seq_id support + llama_batch batch_enc = llama_batch_init(text_tokens.size(), 0, 1); + batch_enc.n_tokens = text_tokens.size(); + + for (size_t i = 0; i < text_tokens.size(); i++) { + batch_enc.token[i] = text_tokens[i]; + batch_enc.pos[i] = i; + batch_enc.n_seq_id[i] = 1; + batch_enc.seq_id[i][0] = slot.id; + batch_enc.logits[i] = false; + } + + // encode the entire prompt + const int ret = llama_encode(ctx, batch_enc); + + // free the encoder batch + llama_batch_free(batch_enc); + if (ret != 0) { + SLT_ERR(slot, "llama_encode() failed with error %d\n", ret); + send_error(slot, "encoder failed", ERROR_TYPE_SERVER); + slot.release(); + continue; + } + + SLT_INF(slot, "encoder completed, %d tokens encoded\n", slot.task->n_tokens()); + + // get decoder start token + llama_token decoder_start_token = llama_model_decoder_start_token(model); + if (decoder_start_token == LLAMA_TOKEN_NULL) { + decoder_start_token = llama_vocab_bos(vocab); + } + + SLT_DBG(slot, "decoder start token: %d '%s'\n", + decoder_start_token, common_token_to_piece(ctx, decoder_start_token).c_str()); + + // add decoder start token to the batch + common_batch_add(batch, decoder_start_token, 0, { slot.id }, true); + + // update slot state - we've processed all prompt tokens (via encoder) + // and the decoder is ready to generate + slot.prompt.tokens.clear(); + slot.prompt.tokens.push_back(decoder_start_token); + slot.n_prompt_tokens_processed = slot.task->n_tokens(); + + common_sampler_reset(slot.smpl); + + slot.n_decoded = 0; + slot.i_batch = batch.n_tokens - 1; + + slot.state = SLOT_STATE_DONE_PROMPT; + + SLT_INF(slot, "encoder-decoder: prompt encoded, decoder ready%s\n", ""); + + if (!slot_batched) { + slot_batched = &slot; + } + + continue; // skip normal prompt processing + } + + slot.state = SLOT_STATE_PROCESSING_PROMPT; + // print prompt tokens (for debugging) /*if (1) { // first 16 tokens (avoid flooding logs)