diff --git a/CMakeLists.txt b/CMakeLists.txt index 4b06d98b3..7fd8866cf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -153,6 +153,15 @@ if (LLAMA_BUILD) add_compile_definitions(GGML_USE_METAL) endif() + # Set version for mtmd (required by upstream CMakeLists.txt) + # NOTE: This is a workaround for mtmd build requirements. + # Version is set to 0.0.0 for local builds. If upstream adds version + # compatibility checks, this may need to match llama.cpp version. + if (NOT DEFINED LLAMA_BUILD_NUMBER) + set(LLAMA_BUILD_NUMBER 0) + endif() + set(LLAMA_INSTALL_VERSION 0.0.${LLAMA_BUILD_NUMBER}) + # Building llava add_subdirectory(vendor/llama.cpp/tools/mtmd) diff --git a/llama_cpp/__init__.py b/llama_cpp/__init__.py index c1dde7046..0d56603e2 100644 --- a/llama_cpp/__init__.py +++ b/llama_cpp/__init__.py @@ -1,4 +1,4 @@ from .llama_cpp import * from .llama import * -__version__ = "0.3.16" +__version__ = "0.4.0" diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 71d94ebd8..18d8bc66d 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -91,9 +91,9 @@ def __init__( logits_all: bool = False, embedding: bool = False, offload_kqv: bool = True, - flash_attn: bool = False, op_offload: Optional[bool] = None, swa_full: Optional[bool] = None, + flash_attn: Optional[bool] = None, # Sampling Params no_perf: bool = False, last_n_tokens_size: int = 64, @@ -173,7 +173,7 @@ def __init__( logits_all: Return logits for all tokens, not just the last token. Must be True for completion to return logprobs. embedding: Embedding mode only. offload_kqv: Offload K, Q, V to GPU. - flash_attn: Use flash attention. + flash_attn: Use flash attention. None = auto, True = enabled, False = disabled. op_offload: offload host tensor operations to device swa_full: use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) no_perf: Measure performance timings. @@ -341,7 +341,16 @@ def __init__( self._logits_all = logits_all if draft_model is None else True self.context_params.embeddings = embedding # TODO: Rename to embeddings self.context_params.offload_kqv = offload_kqv - self.context_params.flash_attn = flash_attn + if flash_attn is None: + self.context_params.flash_attn_type = llama_cpp.LLAMA_FLASH_ATTN_TYPE_AUTO + elif flash_attn: + self.context_params.flash_attn_type = ( + llama_cpp.LLAMA_FLASH_ATTN_TYPE_ENABLED + ) + else: + self.context_params.flash_attn_type = ( + llama_cpp.LLAMA_FLASH_ATTN_TYPE_DISABLED + ) if op_offload is not None: self.context_params.op_offload = op_offload @@ -934,7 +943,8 @@ def generate( sample_idx += 1 if stopping_criteria is not None and stopping_criteria( - self._input_ids[: sample_idx], self._scores[sample_idx - self.n_tokens, :] + self._input_ids[:sample_idx], + self._scores[sample_idx - self.n_tokens, :], ): return tokens_or_none = yield token @@ -1041,7 +1051,9 @@ def embed( data: Union[List[List[float]], List[List[List[float]]]] = [] def decode_batch(seq_sizes: List[int]): - llama_cpp.llama_kv_self_clear(self._ctx.ctx) + mem = llama_cpp.llama_get_memory(self._ctx.ctx) + if mem is not None: + llama_cpp.llama_memory_clear(mem, True) self._ctx.decode(self._batch) self._batch.reset() @@ -1112,7 +1124,9 @@ def decode_batch(seq_sizes: List[int]): output = data[0] if isinstance(input, str) else data - llama_cpp.llama_kv_self_clear(self._ctx.ctx) + mem = llama_cpp.llama_get_memory(self._ctx.ctx) + if mem is not None: + llama_cpp.llama_memory_clear(mem, True) self.reset() if return_count: @@ -1157,9 +1171,9 @@ def _create_completion( bos_token_id: int = self.token_bos() cls_token_id: int = self._model.token_cls() sep_token_id: int = self._model.token_sep() - prefix_token_id: int = 0 # self._model.token_prefix() # TODO: Fix - middle_token_id: int = 0 # self._model.token_middle() # TODO: Fix - suffix_token_id: int = 0 # self._model.token_suffix() # TODO: Fix + prefix_token_id: int = self._model.token_prefix() + middle_token_id: int = self._model.token_middle() + suffix_token_id: int = self._model.token_suffix() add_space_prefix: bool = ( self.metadata.get("tokenizer.ggml.add_space_prefix", "true") == "true" ) @@ -1315,7 +1329,7 @@ def logit_bias_processor( if seed is not None: self.set_seed(seed) else: - self.set_seed(random.Random(self._seed).randint(0, 2 ** 32)) + self.set_seed(random.Random(self._seed).randint(0, 2**32)) finish_reason = "length" multibyte_fix = 0 @@ -2056,7 +2070,10 @@ def create_chat_completion_openai_v1( stream = kwargs.get("stream", False) # type: ignore assert isinstance(stream, bool) if stream: - return (ChatCompletionChunk(**chunk) for chunk in self.create_chat_completion(*args, **kwargs)) # type: ignore + return ( + ChatCompletionChunk(**chunk) + for chunk in self.create_chat_completion(*args, **kwargs) + ) # type: ignore else: return ChatCompletion(**self.create_chat_completion(*args, **kwargs)) # type: ignore except ImportError: @@ -2096,7 +2113,7 @@ def __getstate__(self): logits_all=self._logits_all, embedding=self.context_params.embeddings, offload_kqv=self.context_params.offload_kqv, - flash_attn=self.context_params.flash_attn, + flash_attn=self.context_params.flash_attn_type, op_offload=self.context_params.op_offload, swa_full=self.context_params.swa_full, # Sampling Params @@ -2316,19 +2333,23 @@ def from_pretrained( ) if additional_files: - for additonal_file_name in additional_files: + for additional_file_name in additional_files: # find the additional shard file: - matching_additional_files = [file for file in file_list if fnmatch.fnmatch(file, additonal_file_name)] + matching_additional_files = [ + file + for file in file_list + if fnmatch.fnmatch(file, additional_file_name) + ] if len(matching_additional_files) == 0: raise ValueError( - f"No file found in {repo_id} that match {additonal_file_name}\n\n" + f"No file found in {repo_id} that match {additional_file_name}\n\n" f"Available Files:\n{json.dumps(file_list)}" ) if len(matching_additional_files) > 1: raise ValueError( - f"Multiple files found in {repo_id} matching {additonal_file_name}\n\n" + f"Multiple files found in {repo_id} matching {additional_file_name}\n\n" f"Available Files:\n{json.dumps(files)}" ) diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 711d42a6a..a4f6b0d40 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -33,7 +33,11 @@ # Specify the base name of the shared library to load _lib_base_name = "llama" _override_base_path = os.environ.get("LLAMA_CPP_LIB_PATH") -_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib" if _override_base_path is None else pathlib.Path(_override_base_path) +_base_path = ( + pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib" + if _override_base_path is None + else pathlib.Path(_override_base_path) +) # Load the library _lib = load_shared_library(_lib_base_name, _base_path) @@ -117,6 +121,14 @@ # typedef bool (*ggml_abort_callback)(void * data); ggml_abort_callback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_void_p) +# typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data); +ggml_log_callback = ctypes.CFUNCTYPE( + None, ctypes.c_int, ctypes.c_char_p, ctypes.c_void_p +) + +# typedef struct ggml_threadpool * ggml_threadpool_t; +ggml_threadpool_t = ctypes.c_void_p + # llama.h bindings _lib.llama_max_devices.argtypes = [] @@ -177,6 +189,13 @@ # typedef int32_t llama_seq_id; llama_seq_id = ctypes.c_int32 +# typedef uint32_t llama_state_seq_flags; +llama_state_seq_flags = ctypes.c_uint32 + +# State sequence flags +LLAMA_STATE_SEQ_FLAGS_SWA_ONLY = 1 +LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY = 2 + # enum llama_vocab_type { # LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab @@ -294,6 +313,7 @@ LLAMA_ROPE_TYPE_NORM = 0 LLAMA_ROPE_TYPE_NEOX = GGML_ROPE_TYPE_NEOX = 2 LLAMA_ROPE_TYPE_MROPE = GGML_ROPE_TYPE_MROPE = 8 +LLAMA_ROPE_TYPE_IMROPE = GGML_ROPE_TYPE_IMROPE = 40 LLAMA_ROPE_TYPE_VISION = GGML_ROPE_TYPE_VISION = 24 @@ -462,6 +482,14 @@ LLAMA_ATTENTION_TYPE_CAUSAL = 0 LLAMA_ATTENTION_TYPE_NON_CAUSAL = 1 +# enum llama_flash_attn_type { +# LLAMA_FLASH_ATTN_TYPE_AUTO = -1, +# LLAMA_FLASH_ATTN_TYPE_DISABLED = 0, +# LLAMA_FLASH_ATTN_TYPE_ENABLED = 1, +# }; +LLAMA_FLASH_ATTN_TYPE_AUTO = -1 +LLAMA_FLASH_ATTN_TYPE_DISABLED = 0 +LLAMA_FLASH_ATTN_TYPE_ENABLED = 1 # enum llama_split_mode { # LLAMA_SPLIT_MODE_NONE = 0, // single GPU @@ -472,6 +500,15 @@ LLAMA_SPLIT_MODE_LAYER = 1 LLAMA_SPLIT_MODE_ROW = 2 +# enum llama_params_fit_status { +# LLAMA_PARAMS_FIT_STATUS_SUCCESS = 0, +# LLAMA_PARAMS_FIT_STATUS_FAILURE = 1, +# LLAMA_PARAMS_FIT_STATUS_ERROR = 2, +# }; +LLAMA_PARAMS_FIT_STATUS_SUCCESS = 0 +LLAMA_PARAMS_FIT_STATUS_FAILURE = 1 +LLAMA_PARAMS_FIT_STATUS_ERROR = 2 + # typedef struct llama_token_data { # llama_token id; // token id @@ -559,6 +596,7 @@ class llama_token_data_array(ctypes.Structure): # typedef struct llama_batch { # int32_t n_tokens; + # llama_token * token; # float * embd; # llama_pos * pos; @@ -613,6 +651,22 @@ class llama_batch(ctypes.Structure): LLAMA_KV_OVERRIDE_TYPE_BOOL = 2 LLAMA_KV_OVERRIDE_TYPE_STR = 3 +# enum llama_model_meta_key { +# LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE, +# ... +# }; +LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE = 0 +LLAMA_MODEL_META_KEY_SAMPLING_TOP_K = 1 +LLAMA_MODEL_META_KEY_SAMPLING_TOP_P = 2 +LLAMA_MODEL_META_KEY_SAMPLING_MIN_P = 3 +LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY = 4 +LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD = 5 +LLAMA_MODEL_META_KEY_SAMPLING_TEMP = 6 +LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N = 7 +LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT = 8 +LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT = 9 +LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU = 10 +LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA = 11 # struct llama_model_kv_override { # enum llama_model_kv_override_type tag; @@ -688,6 +742,7 @@ class llama_model_kv_override(ctypes.Structure): # // override key-value pairs of the model meta data # const struct llama_model_kv_override * kv_overrides; + # // Keep the booleans together to avoid misalignment during copy-by-value. # bool vocab_only; // only load the vocabulary, no weights # bool use_mmap; // use mmap if possible @@ -716,7 +771,9 @@ class llama_model_params(ctypes.Structure): if TYPE_CHECKING: devices: CtypesArray[ctypes.c_void_p] # NOTE: unused - tensor_buft_overrides: CtypesArray[llama_model_tensor_buft_override] # NOTE: unused + tensor_buft_overrides: CtypesArray[ + llama_model_tensor_buft_override + ] # NOTE: unused n_gpu_layers: int split_mode: int main_gpu: int @@ -731,8 +788,8 @@ class llama_model_params(ctypes.Structure): use_extra_bufts: bool _fields_ = [ - ("devices", ctypes.c_void_p), # NOTE: unnused - ("tensor_buft_overrides", ctypes.c_void_p), # NOTE: unused + ("devices", ctypes.c_void_p), # NOTE: unnused + ("tensor_buft_overrides", ctypes.c_void_p), # NOTE: unused ("n_gpu_layers", ctypes.c_int32), ("split_mode", ctypes.c_int), ("main_gpu", ctypes.c_int32), @@ -745,6 +802,8 @@ class llama_model_params(ctypes.Structure): ("use_mlock", ctypes.c_bool), ("check_tensors", ctypes.c_bool), ("use_extra_bufts", ctypes.c_bool), + ("no_host", ctypes.c_bool), + ("no_alloc", ctypes.c_bool), ] @@ -784,6 +843,7 @@ class llama_model_params(ctypes.Structure): # ggml_abort_callback abort_callback; # void * abort_callback_data; + # // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value. # bool embeddings; // if true, extract embeddings (together with logits) # bool offload_kqv; // offload the KQV ops (including the KV cache) to GPU @@ -826,7 +886,6 @@ class llama_context_params(ctypes.Structure): abort_callback_data (ctypes.ctypes.c_void_p): data for abort_callback embeddings (bool): if true, extract embeddings (together with logits) offload_kqv (bool): whether to offload the KQV ops (including the KV cache) to GPU - flash_attn (bool): whether to use flash attention no_perf (bool): whether to measure performance timings op_offload (bool): offload host tensor operations to device swa_full (bool): use full-size SWA cache @@ -843,6 +902,7 @@ class llama_context_params(ctypes.Structure): rope_scaling_type: int pooling_type: int attention_type: int + flash_attn_type: int rope_freq_base: float rope_freq_scale: float yarn_ext_factor: float @@ -859,7 +919,6 @@ class llama_context_params(ctypes.Structure): abort_callback_data: ctypes.c_void_p embeddings: bool offload_kqv: bool - flash_attn: bool no_perf: bool op_offload: bool swa_full: bool @@ -875,6 +934,7 @@ class llama_context_params(ctypes.Structure): ("rope_scaling_type", ctypes.c_int), ("pooling_type", ctypes.c_int), ("attention_type", ctypes.c_int), + ("flash_attn_type", ctypes.c_int), ("rope_freq_base", ctypes.c_float), ("rope_freq_scale", ctypes.c_float), ("yarn_ext_factor", ctypes.c_float), @@ -891,7 +951,6 @@ class llama_context_params(ctypes.Structure): ("abort_callback_data", ctypes.c_void_p), ("embeddings", ctypes.c_bool), ("offload_kqv", ctypes.c_bool), - ("flash_attn", ctypes.c_bool), ("no_perf", ctypes.c_bool), ("op_offload", ctypes.c_bool), ("swa_full", ctypes.c_bool), @@ -1137,8 +1196,7 @@ def llama_backend_free(): [ctypes.c_int], None, ) -def llama_numa_init(numa: int, /): - ... +def llama_numa_init(numa: int, /): ... # // Optional: an auto threadpool gets created in ggml if not passed explicitly @@ -1146,11 +1204,26 @@ def llama_numa_init(numa: int, /): # struct llama_context * ctx, # ggml_threadpool_t threadpool, # ggml_threadpool_t threadpool_batch); -# TODO: Add llama_attach_threadpool +@ctypes_function( + "llama_attach_threadpool", + [llama_context_p_ctypes, ggml_threadpool_t, ggml_threadpool_t], + None, +) +def llama_attach_threadpool( + ctx: llama_context_p, + threadpool: ctypes.c_void_p, + threadpool_batch: ctypes.c_void_p, + /, +): + """Attach threadpools to context""" + ... # LLAMA_API void llama_detach_threadpool(struct llama_context * ctx); -# TODO: Add llama_detach_threadpool +@ctypes_function("llama_detach_threadpool", [llama_context_p_ctypes], None) +def llama_detach_threadpool(ctx: llama_context_p, /): + """Detach threadpool from context""" + ... # DEPRECATED(LLAMA_API struct llama_model * llama_load_model_from_file( @@ -1164,8 +1237,7 @@ def llama_numa_init(numa: int, /): ) def llama_load_model_from_file( path_model: bytes, params: llama_model_params, / -) -> Optional[llama_model_p]: - ... +) -> Optional[llama_model_p]: ... # // Load the model from a file @@ -1230,8 +1302,7 @@ def llama_model_save_to_file(model: llama_model_p, path_model: bytes, /): [llama_model_p_ctypes], None, ) -def llama_free_model(model: llama_model_p, /): - ... +def llama_free_model(model: llama_model_p, /): ... # LLAMA_API void llama_model_free(struct llama_model * model); @@ -1240,8 +1311,7 @@ def llama_free_model(model: llama_model_p, /): [llama_model_p_ctypes], None, ) -def llama_model_free(model: llama_model_p, /): - ... +def llama_model_free(model: llama_model_p, /): ... # LLAMA_API struct llama_context * llama_init_from_model( @@ -1254,8 +1324,7 @@ def llama_model_free(model: llama_model_p, /): ) def llama_init_from_model( model: llama_model_p, params: llama_context_params, / -) -> Optional[llama_context_p]: - ... +) -> Optional[llama_context_p]: ... # DEPRECATED(LLAMA_API struct llama_context * llama_new_context_with_model( @@ -1269,8 +1338,7 @@ def llama_init_from_model( ) def llama_new_context_with_model( model: llama_model_p, params: llama_context_params, / -) -> Optional[llama_context_p]: - ... +) -> Optional[llama_context_p]: ... # // Frees all allocated memory @@ -1291,104 +1359,150 @@ def llama_free(ctx: llama_context_p, /): [], ctypes.c_int64, ) -def llama_time_us() -> int: - ... +def llama_time_us() -> int: ... # LLAMA_API size_t llama_max_devices(void); @ctypes_function("llama_max_devices", [], ctypes.c_size_t) -def llama_max_devices() -> int: - ... +def llama_max_devices() -> int: ... # LLAMA_API size_t llama_max_parallel_sequences(void); @ctypes_function("llama_max_parallel_sequences", [], ctypes.c_size_t) -def llama_max_parallel_sequences() -> int: - ... +def llama_max_parallel_sequences() -> int: ... # LLAMA_API bool llama_supports_mmap (void); @ctypes_function("llama_supports_mmap", [], ctypes.c_bool) -def llama_supports_mmap() -> bool: - ... +def llama_supports_mmap() -> bool: ... # LLAMA_API bool llama_supports_mlock (void); @ctypes_function("llama_supports_mlock", [], ctypes.c_bool) -def llama_supports_mlock() -> bool: - ... +def llama_supports_mlock() -> bool: ... # LLAMA_API bool llama_supports_gpu_offload(void); @ctypes_function("llama_supports_gpu_offload", [], ctypes.c_bool) -def llama_supports_gpu_offload() -> bool: - ... +def llama_supports_gpu_offload() -> bool: ... # LLAMA_API bool llama_supports_rpc (void); @ctypes_function("llama_supports_rpc", [], ctypes.c_bool) -def llama_supports_rpc() -> bool: +def llama_supports_rpc() -> bool: ... + + +# LLAMA_API size_t llama_max_tensor_buft_overrides(void); +@ctypes_function("llama_max_tensor_buft_overrides", [], ctypes.c_size_t) +def llama_max_tensor_buft_overrides() -> int: + """Get maximum number of tensor buffer type overrides""" + ... + + +# LLAMA_API enum llama_params_fit_status llama_params_fit( +# const char * path_model, +# struct llama_model_params * mparams, +# struct llama_context_params * cparams, +# float * tensor_split, +# struct llama_model_tensor_buft_override * tensor_buft_overrides, +# size_t margin, +# uint32_t n_ctx_min, +# enum ggml_log_level log_level); +@ctypes_function( + "llama_params_fit", + [ + ctypes.c_char_p, + ctypes.POINTER(llama_model_params), + ctypes.POINTER(llama_context_params), + ctypes.POINTER(ctypes.c_float), + ctypes.c_void_p, # tensor_buft_overrides - not fully bound + ctypes.c_size_t, # margin + ctypes.c_uint32, # n_ctx_min + ctypes.c_int, # ggml_log_level (enum) + ], + ctypes.c_int, +) +def llama_params_fit( + path_model: bytes, + mparams: CtypesPointerOrRef[llama_model_params], + cparams: CtypesPointerOrRef[llama_context_params], + tensor_split: CtypesArray[ctypes.c_float], + tensor_buft_overrides: Optional[ctypes.c_void_p], + margin: Union[ctypes.c_size_t, int], + n_ctx_min: Union[ctypes.c_uint32, int], + log_level: int, + /, +) -> int: + """Check if model parameters will fit in memory + + Args: + margin: Memory margin to leave per device in bytes + n_ctx_min: Minimum context size when trying to reduce memory + log_level: Minimum log level (ggml_log_level enum) + + Returns: + LLAMA_PARAMS_FIT_STATUS_SUCCESS (0) - found allocations that are projected to fit + LLAMA_PARAMS_FIT_STATUS_FAILURE (1) - could not find allocations that are projected to fit + LLAMA_PARAMS_FIT_STATUS_ERROR (2) - a hard error occurred + """ ... # LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx); @ctypes_function("llama_n_ctx", [llama_context_p_ctypes], ctypes.c_uint32) -def llama_n_ctx(ctx: llama_context_p, /) -> int: +def llama_n_ctx(ctx: llama_context_p, /) -> int: ... + + +# LLAMA_API uint32_t llama_n_ctx_seq(const struct llama_context * ctx); +@ctypes_function("llama_n_ctx_seq", [llama_context_p_ctypes], ctypes.c_uint32) +def llama_n_ctx_seq(ctx: llama_context_p, /) -> int: + """Get the context sequence size""" ... # LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); @ctypes_function("llama_n_batch", [llama_context_p_ctypes], ctypes.c_uint32) -def llama_n_batch(ctx: llama_context_p, /) -> int: - ... +def llama_n_batch(ctx: llama_context_p, /) -> int: ... # LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); @ctypes_function("llama_n_ubatch", [llama_context_p_ctypes], ctypes.c_uint32) -def llama_n_ubatch(ctx: llama_context_p, /) -> int: - ... +def llama_n_ubatch(ctx: llama_context_p, /) -> int: ... # LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); @ctypes_function("llama_n_seq_max", [llama_context_p_ctypes], ctypes.c_uint32) -def llama_n_seq_max(ctx: llama_context_p, /) -> int: - ... +def llama_n_seq_max(ctx: llama_context_p, /) -> int: ... # DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead"); @ctypes_function("llama_n_ctx_train", [llama_model_p_ctypes], ctypes.c_int32) -def llama_n_ctx_train(model: llama_model_p, /) -> int: - ... +def llama_n_ctx_train(model: llama_model_p, /) -> int: ... # DEPRECATED(LLAMA_API int32_t llama_n_embd (const struct llama_model * model), "use llama_model_n_embd instead"); @ctypes_function("llama_n_embd", [llama_model_p_ctypes], ctypes.c_int32) -def llama_n_embd(model: llama_model_p, /) -> int: - ... +def llama_n_embd(model: llama_model_p, /) -> int: ... # DEPRECATED(LLAMA_API int32_t llama_n_layer (const struct llama_model * model), "use llama_model_n_layer instead"); @ctypes_function("llama_n_layer", [llama_model_p_ctypes], ctypes.c_int32) -def llama_n_layer(model: llama_model_p, /) -> int: - ... +def llama_n_layer(model: llama_model_p, /) -> int: ... # DEPRECATED(LLAMA_API int32_t llama_n_head (const struct llama_model * model), "use llama_model_n_head instead"); @ctypes_function("llama_n_head", [llama_model_p_ctypes], ctypes.c_int32) -def llama_n_head(model: llama_model_p, /) -> int: - ... +def llama_n_head(model: llama_model_p, /) -> int: ... # DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead"); @ctypes_function("llama_n_vocab", [llama_vocab_p_ctypes], ctypes.c_int32) -def llama_n_vocab(model: llama_vocab_p, /) -> int: - ... +def llama_n_vocab(model: llama_vocab_p, /) -> int: ... # LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx); @ctypes_function("llama_get_model", [llama_context_p_ctypes], llama_model_p_ctypes) -def llama_get_model(ctx: llama_context_p, /) -> Optional[llama_model_p]: - ... +def llama_get_model(ctx: llama_context_p, /) -> Optional[llama_model_p]: ... # LLAMA_API llama_memory_t llama_get_memory (const struct llama_context * ctx); @@ -1400,74 +1514,63 @@ def llama_get_memory(ctx: llama_context_p, /) -> Optional[llama_memory_t]: # LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); @ctypes_function("llama_pooling_type", [llama_context_p_ctypes], ctypes.c_int) -def llama_pooling_type(ctx: llama_context_p, /) -> int: - ... +def llama_pooling_type(ctx: llama_context_p, /) -> int: ... # DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead"); -@ctypes_function( - "llama_get_kv_self", - [llama_context_p_ctypes], - llama_kv_cache_p_ctypes, -) -def llama_get_kv_self(ctx: llama_context_p, /) -> Optional[llama_kv_cache_p]: - """Get the KV cache for self-attention (DEPRECATED)""" - ... - - # LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model); @ctypes_function("llama_model_get_vocab", [llama_model_p_ctypes], llama_vocab_p_ctypes) -def llama_model_get_vocab(model: llama_model_p, /) -> Optional[llama_vocab_p]: - ... +def llama_model_get_vocab(model: llama_model_p, /) -> Optional[llama_vocab_p]: ... # LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model); @ctypes_function("llama_model_rope_type", [llama_model_p_ctypes], ctypes.c_int) -def llama_model_rope_type(model: llama_model_p, /) -> int: - ... +def llama_model_rope_type(model: llama_model_p, /) -> int: ... # LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model); @ctypes_function("llama_model_n_ctx_train", [llama_model_p_ctypes], ctypes.c_int32) -def llama_model_n_ctx_train(model: llama_model_p, /) -> int: - ... +def llama_model_n_ctx_train(model: llama_model_p, /) -> int: ... # LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model); @ctypes_function("llama_model_n_embd", [llama_model_p_ctypes], ctypes.c_int32) -def llama_model_n_embd(model: llama_model_p, /) -> int: +def llama_model_n_embd(model: llama_model_p, /) -> int: ... + + +# LLAMA_API int32_t llama_model_n_embd_inp(const struct llama_model * model); +@ctypes_function("llama_model_n_embd_inp", [llama_model_p_ctypes], ctypes.c_int32) +def llama_model_n_embd_inp(model: llama_model_p, /) -> int: + """Get the input embedding dimension""" ... # LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model); @ctypes_function("llama_model_n_layer", [llama_model_p_ctypes], ctypes.c_int32) -def llama_model_n_layer(model: llama_model_p, /) -> int: - ... +def llama_model_n_layer(model: llama_model_p, /) -> int: ... # LLAMA_API int32_t llama_model_n_head (const struct llama_model * model); @ctypes_function("llama_model_n_head", [llama_model_p_ctypes], ctypes.c_int32) -def llama_model_n_head(model: llama_model_p, /) -> int: - ... +def llama_model_n_head(model: llama_model_p, /) -> int: ... # LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model); @ctypes_function("llama_model_n_head_kv", [llama_model_p_ctypes], ctypes.c_int32) -def llama_model_n_head_kv(model: llama_model_p, /) -> int: - ... +def llama_model_n_head_kv(model: llama_model_p, /) -> int: ... # LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model); @ctypes_function("llama_model_n_swa", [llama_model_p_ctypes], ctypes.c_int32) -def llama_model_n_swa(model: llama_model_p, /) -> int: - ... +def llama_model_n_swa(model: llama_model_p, /) -> int: ... # // Get the model's RoPE frequency scaling factor # LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model); -@ctypes_function("llama_model_rope_freq_scale_train", [llama_model_p_ctypes], ctypes.c_float) -def llama_model_rope_freq_scale_train(model: llama_model_p, /) -> float: - ... +@ctypes_function( + "llama_model_rope_freq_scale_train", [llama_model_p_ctypes], ctypes.c_float +) +def llama_model_rope_freq_scale_train(model: llama_model_p, /) -> float: ... # // Returns the number of classifier outputs (only valid for classifier models) @@ -1481,7 +1584,9 @@ def llama_model_n_cls_out(model: llama_model_p, /) -> int: # // Returns label of classifier output by index ( Optional[bytes]: """Returns label of classifier output by index. Returns None if no label provided""" ... @@ -1489,14 +1594,12 @@ def llama_model_cls_label(model: llama_model_p, i: int, /) -> Optional[bytes]: # LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model); @ctypes_function("llama_vocab_type", [llama_vocab_p_ctypes], ctypes.c_int) -def llama_vocab_type(vocab: llama_vocab_p, /) -> int: - ... +def llama_vocab_type(vocab: llama_vocab_p, /) -> int: ... # LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab); @ctypes_function("llama_vocab_n_tokens", [llama_vocab_p_ctypes], ctypes.c_int32) -def llama_vocab_n_tokens(vocab: llama_vocab_p, /) -> int: - ... +def llama_vocab_n_tokens(vocab: llama_vocab_p, /) -> int: ... # // Functions to access the model's GGUF metadata scalar values @@ -1611,8 +1714,14 @@ def llama_model_size(model: llama_model_p, /) -> int: # // Get the default chat template. Returns nullptr if not available # // If name is NULL, returns the default chat template # LLAMA_API const char * llama_model_chat_template(const struct llama_model * model, const char * name); -@ctypes_function("llama_model_chat_template", [llama_model_p_ctypes, ctypes.c_char_p], ctypes.c_char_p) -def llama_model_chat_template(model: llama_model_p, name: Optional[bytes], /) -> Optional[bytes]: +@ctypes_function( + "llama_model_chat_template", + [llama_model_p_ctypes, ctypes.c_char_p], + ctypes.c_char_p, +) +def llama_model_chat_template( + model: llama_model_p, name: Optional[bytes], / +) -> Optional[bytes]: """Get the default chat template. Returns None if not available If name is None, returns the default chat template""" ... @@ -1663,6 +1772,13 @@ def llama_model_is_recurrent(model: llama_model_p, /) -> bool: ... +# LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model); +@ctypes_function("llama_model_is_hybrid", [llama_model_p_ctypes], ctypes.c_bool) +def llama_model_is_hybrid(model: llama_model_p, /) -> bool: + """Returns true if model is hybrid (Jamba, Granite, etc.)""" + ... + + # // Returns true if the model is diffusion-based (like LLaDA, Dream, etc.) # LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model); @ctypes_function("llama_model_is_diffusion", [llama_model_p_ctypes], ctypes.c_bool) @@ -1699,6 +1815,7 @@ def llama_model_quantize( # // Adapters # // + # // Load a LoRA adapter from file # LLAMA_API struct llama_adapter_lora * llama_adapter_lora_init( # struct llama_model * model, @@ -1710,8 +1827,7 @@ def llama_model_quantize( ) def llama_adapter_lora_init( model: llama_model_p, path_lora: bytes, / -) -> Optional[llama_adapter_lora_p]: - ... +) -> Optional[llama_adapter_lora_p]: ... # // Manually free a LoRA adapter @@ -1722,7 +1838,80 @@ def llama_adapter_lora_init( [llama_adapter_lora_p_ctypes], None, ) -def llama_adapter_lora_free(adapter: llama_adapter_lora_p, /): +def llama_adapter_lora_free(adapter: llama_adapter_lora_p, /): ... + + +# LLAMA_API int32_t llama_adapter_meta_val_str(const struct llama_adapter_lora * adapter, const char * key, char * buf, size_t buf_size); +@ctypes_function( + "llama_adapter_meta_val_str", + [llama_adapter_lora_p_ctypes, ctypes.c_char_p, ctypes.c_char_p, ctypes.c_size_t], + ctypes.c_int32, +) +def llama_adapter_meta_val_str( + adapter: llama_adapter_lora_p, key: bytes, buf: bytes, buf_size: int, / +) -> int: + """Get adapter metadata value as string""" + ... + + +# LLAMA_API int32_t llama_adapter_meta_count(const struct llama_adapter_lora * adapter); +@ctypes_function( + "llama_adapter_meta_count", [llama_adapter_lora_p_ctypes], ctypes.c_int32 +) +def llama_adapter_meta_count(adapter: llama_adapter_lora_p, /) -> int: + """Get number of adapter metadata pairs""" + ... + + +# LLAMA_API int32_t llama_adapter_meta_key_by_index(...); +@ctypes_function( + "llama_adapter_meta_key_by_index", + [llama_adapter_lora_p_ctypes, ctypes.c_int32, ctypes.c_char_p, ctypes.c_size_t], + ctypes.c_int32, +) +def llama_adapter_meta_key_by_index( + adapter: llama_adapter_lora_p, i: int, buf: bytes, buf_size: int, / +) -> int: + """Get adapter metadata key by index""" + ... + + +# LLAMA_API int32_t llama_adapter_meta_val_str_by_index(...); +@ctypes_function( + "llama_adapter_meta_val_str_by_index", + [llama_adapter_lora_p_ctypes, ctypes.c_int32, ctypes.c_char_p, ctypes.c_size_t], + ctypes.c_int32, +) +def llama_adapter_meta_val_str_by_index( + adapter: llama_adapter_lora_p, i: int, buf: bytes, buf_size: int, / +) -> int: + """Get adapter metadata value by index""" + ... + + +# LLAMA_API uint64_t llama_adapter_get_alora_n_invocation_tokens(...); +@ctypes_function( + "llama_adapter_get_alora_n_invocation_tokens", + [llama_adapter_lora_p_ctypes], + ctypes.c_uint64, +) +def llama_adapter_get_alora_n_invocation_tokens( + adapter: llama_adapter_lora_p, / +) -> int: + """Get alora invocation token count""" + ... + + +# LLAMA_API const llama_token * llama_adapter_get_alora_invocation_tokens(...); +@ctypes_function( + "llama_adapter_get_alora_invocation_tokens", + [llama_adapter_lora_p_ctypes], + ctypes.POINTER(llama_token), +) +def llama_adapter_get_alora_invocation_tokens( + adapter: llama_adapter_lora_p, / +) -> ctypes.Array: + """Get alora invocation tokens""" ... @@ -1825,6 +2014,7 @@ def llama_apply_adapter_cvec( # // Memory # // + # // Clear the memory contents # // If data == true, the data buffers will also be cleared together with the metadata # LLAMA_API void llama_memory_clear( @@ -1916,9 +2106,7 @@ def llama_memory_seq_cp( # LLAMA_API void llama_memory_seq_keep( # llama_memory_t mem, # llama_seq_id seq_id); -@ctypes_function( - "llama_memory_seq_keep", [llama_memory_t_ctypes, llama_seq_id], None -) +@ctypes_function("llama_memory_seq_keep", [llama_memory_t_ctypes, llama_seq_id], None) def llama_memory_seq_keep(mem: llama_memory_t, seq_id: Union[llama_seq_id, int], /): """Removes all tokens that do not belong to the specified sequence""" ... @@ -2038,260 +2226,11 @@ def llama_memory_can_shift(mem: llama_memory_t, /) -> bool: # // # // KV cache for self-attention (TODO: deprecate in favor of llama_memory) -# // - -# // Returns the number of tokens in the KV cache (slow, use only for debug) -# // If a KV cell has multiple sequences assigned to it, it will be counted multiple times -# DEPRECATED(LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx), -# "Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)"); -@ctypes_function( - "llama_kv_self_n_tokens", [llama_context_p_ctypes], ctypes.c_int32 -) -def llama_kv_self_n_tokens(ctx: llama_context_p, /) -> int: - """Returns the number of tokens in the KV cache (slow, use only for debug) (DEPRECATED)""" - ... - - -# // Returns the number of used KV cells (i.e. have at least one sequence assigned to them) -# DEPRECATED(LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx), -# "Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)"); -@ctypes_function( - "llama_kv_self_used_cells", [llama_context_p_ctypes], ctypes.c_int32 -) -def llama_kv_self_used_cells(ctx: llama_context_p, /) -> int: - """Returns the number of used KV cells (DEPRECATED)""" - ... - - -# // Clear the KV cache - both cell info is erased and KV data is zeroed -# DEPRECATED(LLAMA_API void llama_kv_self_clear( -# struct llama_context * ctx), -# "Use llama_memory_clear() instead"); -@ctypes_function( - "llama_kv_self_clear", [llama_context_p_ctypes], None -) -def llama_kv_self_clear(ctx: llama_context_p, /): - """Clear the KV cache (DEPRECATED)""" - ... - - -# // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) -# // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails -# // seq_id < 0 : match any sequence -# // p0 < 0 : [0, p1] -# // p1 < 0 : [p0, inf) -# DEPRECATED(LLAMA_API bool llama_kv_self_seq_rm( -# struct llama_context * ctx, -# llama_seq_id seq_id, -# llama_pos p0, -# llama_pos p1), -# "Use llama_memory_seq_rm() instead"); -@ctypes_function( - "llama_kv_self_seq_rm", - [ - llama_context_p_ctypes, - llama_seq_id, - llama_pos, - llama_pos, - ], - ctypes.c_bool, -) -def llama_kv_self_seq_rm( - ctx: llama_context_p, - seq_id: Union[llama_seq_id, int], - p0: Union[llama_pos, int], - p1: Union[llama_pos, int], - /, -) -> bool: - """Remove tokens from KV cache (DEPRECATED)""" - ... - - -# // Copy all tokens that belong to the specified sequence to another sequence -# // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence -# // p0 < 0 : [0, p1] -# // p1 < 0 : [p0, inf) -# DEPRECATED(LLAMA_API void llama_kv_self_seq_cp( -# struct llama_context * ctx, -# llama_seq_id seq_id_src, -# llama_seq_id seq_id_dst, -# llama_pos p0, -# llama_pos p1), -# "Use llama_memory_seq_cp() instead"); -@ctypes_function( - "llama_kv_self_seq_cp", - [ - llama_context_p_ctypes, - llama_seq_id, - llama_seq_id, - llama_pos, - llama_pos, - ], - None, -) -def llama_kv_self_seq_cp( - ctx: llama_context_p, - seq_id_src: Union[llama_seq_id, int], - seq_id_dst: Union[llama_seq_id, int], - p0: Union[llama_pos, int], - p1: Union[llama_pos, int], - /, -): - """Copy tokens in KV cache (DEPRECATED)""" - ... - - -# // Removes all tokens that do not belong to the specified sequence -# DEPRECATED(LLAMA_API void llama_kv_self_seq_keep( -# struct llama_context * ctx, -# llama_seq_id seq_id), -# "Use llama_memory_seq_keep() instead"); -@ctypes_function( - "llama_kv_self_seq_keep", [llama_context_p_ctypes, llama_seq_id], None -) -def llama_kv_self_seq_keep(ctx: llama_context_p, seq_id: Union[llama_seq_id, int], /): - """Keep only specified sequence in KV cache (DEPRECATED)""" - ... - - -# // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) -# // If the KV cache is RoPEd, the KV data is updated accordingly: -# // - lazily on next llama_decode() -# // p0 < 0 : [0, p1] -# // p1 < 0 : [p0, inf) -# DEPRECATED(LLAMA_API void llama_kv_self_seq_add( -# struct llama_context * ctx, -# llama_seq_id seq_id, -# llama_pos p0, -# llama_pos p1, -# llama_pos delta), -# "Use llama_memory_seq_add() instead"); -@ctypes_function( - "llama_kv_self_seq_add", - [ - llama_context_p_ctypes, - llama_seq_id, - llama_pos, - llama_pos, - llama_pos, - ], - None, -) -def llama_kv_self_seq_add( - ctx: llama_context_p, - seq_id: Union[llama_seq_id, int], - p0: Union[llama_pos, int], - p1: Union[llama_pos, int], - delta: Union[llama_pos, int], - /, -): - """Add delta to sequence positions in KV cache (DEPRECATED)""" - ... - - -# // Integer division of the positions by factor of `d > 1` -# // If the KV cache is RoPEd, the KV data is updated accordingly: -# // - lazily on next llama_decode() -# // p0 < 0 : [0, p1] -# // p1 < 0 : [p0, inf) -# DEPRECATED(LLAMA_API void llama_kv_self_seq_div( -# struct llama_context * ctx, -# llama_seq_id seq_id, -# llama_pos p0, -# llama_pos p1, -# int d), -# "Use llama_memory_seq_div() instead"); -@ctypes_function( - "llama_kv_self_seq_div", - [ - llama_context_p_ctypes, - llama_seq_id, - llama_pos, - llama_pos, - ctypes.c_int, - ], - None, -) -def llama_kv_self_seq_div( - ctx: llama_context_p, - seq_id: Union[llama_seq_id, int], - p0: Union[llama_pos, int], - p1: Union[llama_pos, int], - d: Union[ctypes.c_int, int], - /, -): - """Divide sequence positions in KV cache (DEPRECATED)""" - ... - - -# // Returns the smallest position present in the KV cache for the specified sequence -# // This is typically non-zero only for SWA caches -# // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache -# // Return -1 if the sequence is empty -# DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_min( -# struct llama_context * ctx, -# llama_seq_id seq_id), -# "Use llama_memory_seq_pos_min() instead"); -@ctypes_function( - "llama_kv_self_seq_pos_min", [llama_context_p_ctypes, llama_seq_id], llama_pos -) -def llama_kv_self_seq_pos_min( - ctx: llama_context_p, seq_id: Union[llama_seq_id, int], / -) -> int: - """Returns the smallest position in KV cache for sequence (DEPRECATED)""" - ... - - -# // Returns the largest position present in the KV cache for the specified sequence -# // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache -# // Return -1 if the sequence is empty -# DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_max( -# struct llama_context * ctx, -# llama_seq_id seq_id), -# "Use llama_memory_seq_pos_max() instead"); -@ctypes_function( - "llama_kv_self_seq_pos_max", [llama_context_p_ctypes, llama_seq_id], llama_pos -) -def llama_kv_self_seq_pos_max( - ctx: llama_context_p, seq_id: Union[llama_seq_id, int], / -) -> int: - """Returns the largest position in KV cache for sequence (DEPRECATED)""" - ... - - -# // Defragment the KV cache -# // This will be applied: -# // - lazily on next llama_decode() -# DEPRECATED(LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx), -# "simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'"); -@ctypes_function("llama_kv_self_defrag", [llama_context_p_ctypes], None) -def llama_kv_self_defrag(ctx: llama_context_p, /): - """Defragment the KV cache (DEPRECATED)""" - ... - - -# // Check if the context supports KV cache shifting -# DEPRECATED(LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx), -# "use llama_memory_can_shift() instead"); -@ctypes_function("llama_kv_self_can_shift", [llama_context_p_ctypes], ctypes.c_bool) -def llama_kv_self_can_shift(ctx: llama_context_p, /) -> bool: - """Check if the context supports KV cache shifting (DEPRECATED)""" - ... - - -# // Apply the KV cache updates (such as K-shifts, defragmentation, etc.) -# DEPRECATED(LLAMA_API void llama_kv_self_update(struct llama_context * ctx), -# "simply remove this call, updates are applied lazily on the next llama_decode()"); -@ctypes_function("llama_kv_self_update", [llama_context_p_ctypes], None) -def llama_kv_self_update(ctx: llama_context_p, /): - """Apply the KV cache updates (DEPRECATED)""" - ... - - # // # // State / sessions # // + # // Returns the *actual* size in bytes of the state # // (logits, embedding and memory) # // Only use when saving the state, not when restoring it, otherwise the size may be too small. @@ -2420,8 +2359,7 @@ def llama_state_load_file( n_token_capacity: Union[ctypes.c_size_t, int], n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t], /, -) -> bool: - ... +) -> bool: ... # LLAMA_API DEPRECATED(bool llama_load_session_file( @@ -2449,8 +2387,7 @@ def llama_load_session_file( n_token_capacity: Union[ctypes.c_size_t, int], n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t], /, -) -> bool: - ... +) -> bool: ... # LLAMA_API bool llama_state_save_file( @@ -2474,8 +2411,7 @@ def llama_state_save_file( tokens: CtypesArray[llama_token], n_token_count: Union[ctypes.c_size_t, int], /, -) -> bool: - ... +) -> bool: ... # LLAMA_API DEPRECATED(bool llama_save_session_file( @@ -2500,8 +2436,7 @@ def llama_save_session_file( tokens: CtypesArray[llama_token], n_token_count: Union[ctypes.c_size_t, int], /, -) -> bool: - ... +) -> bool: ... # // Get the exact size needed to copy the state of a single sequence @@ -2599,8 +2534,7 @@ def llama_state_seq_save_file( tokens: CtypesArray[llama_token], n_token_count: Union[ctypes.c_size_t, int], /, -) -> int: - ... +) -> int: ... # LLAMA_API size_t llama_state_seq_load_file( @@ -2630,7 +2564,83 @@ def llama_state_seq_load_file( n_token_capacity: Union[ctypes.c_size_t, int], n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t], /, +) -> int: ... + + +# LLAMA_API size_t llama_state_seq_get_size_ext( +# struct llama_context * ctx, +# llama_seq_id seq_id, +# llama_state_seq_flags flags); +@ctypes_function( + "llama_state_seq_get_size_ext", + [llama_context_p_ctypes, llama_seq_id, llama_state_seq_flags], + ctypes.c_size_t, +) +def llama_state_seq_get_size_ext( + ctx: llama_context_p, + seq_id: Union[llama_seq_id, int], + flags: Union[llama_state_seq_flags, int], + /, +) -> int: + """Get size needed to copy sequence state with flags""" + ... + + +# LLAMA_API size_t llama_state_seq_get_data_ext( +# struct llama_context * ctx, +# uint8_t * dst, +# size_t size, +# llama_seq_id seq_id, +# llama_state_seq_flags flags); +@ctypes_function( + "llama_state_seq_get_data_ext", + [ + llama_context_p_ctypes, + ctypes.POINTER(ctypes.c_uint8), + ctypes.c_size_t, + llama_seq_id, + llama_state_seq_flags, + ], + ctypes.c_size_t, +) +def llama_state_seq_get_data_ext( + ctx: llama_context_p, + dst: CtypesArray[ctypes.c_uint8], + size: Union[ctypes.c_size_t, int], + seq_id: Union[llama_seq_id, int], + flags: Union[llama_state_seq_flags, int], + /, +) -> int: + """Copy sequence state to buffer with flags""" + ... + + +# LLAMA_API size_t llama_state_seq_set_data_ext( +# struct llama_context * ctx, +# const uint8_t * src, +# size_t size, +# llama_seq_id dest_seq_id, +# llama_state_seq_flags flags); +@ctypes_function( + "llama_state_seq_set_data_ext", + [ + llama_context_p_ctypes, + ctypes.POINTER(ctypes.c_uint8), + ctypes.c_size_t, + llama_seq_id, + llama_state_seq_flags, + ], + ctypes.c_size_t, +) +def llama_state_seq_set_data_ext( + ctx: llama_context_p, + src: CtypesArray[ctypes.c_uint8], + size: Union[ctypes.c_size_t, int], + dest_seq_id: Union[llama_seq_id, int], + flags: Union[llama_state_seq_flags, int], + /, ) -> int: + """Restore sequence state from buffer with flags""" ... @@ -2638,6 +2648,7 @@ def llama_state_seq_load_file( # // Decoding # // + # // Return batch for single sequence of tokens # // The sequence ID will be fixed to 0 # // The position of the tokens will be tracked automatically by llama_decode @@ -2947,14 +2958,14 @@ def llama_get_embeddings_seq( # // Vocab # // + # LLAMA_API const char * llama_vocab_get_text(const struct llama_vocab * vocab, llama_token token); @ctypes_function( "llama_vocab_get_text", [llama_vocab_p_ctypes, llama_token], ctypes.c_char_p ) def llama_vocab_get_text( vocab: llama_vocab_p, token: Union[llama_token, int], / -) -> bytes: - ... +) -> bytes: ... # LLAMA_API float llama_vocab_get_score(const struct llama_vocab * vocab, llama_token token); @@ -2963,8 +2974,7 @@ def llama_vocab_get_text( ) def llama_vocab_get_score( vocab: llama_vocab_p, token: Union[llama_token, int], / -) -> float: - ... +) -> float: ... # LLAMA_API enum llama_token_attr llama_vocab_get_attr(const struct llama_vocab * vocab, llama_token token); @@ -2973,8 +2983,7 @@ def llama_vocab_get_score( ) def llama_vocab_get_attr( vocab: llama_vocab_p, token: Union[llama_token, int], / -) -> int: - ... +) -> int: ... # // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.) @@ -3055,8 +3064,7 @@ def llama_vocab_mask(vocab: llama_vocab_p, /) -> llama_token: [llama_vocab_p_ctypes], ctypes.c_bool, ) -def llama_vocab_get_add_bos(vocab: llama_vocab_p, /) -> bool: - ... +def llama_vocab_get_add_bos(vocab: llama_vocab_p, /) -> bool: ... # LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab); @@ -3065,8 +3073,7 @@ def llama_vocab_get_add_bos(vocab: llama_vocab_p, /) -> bool: [llama_vocab_p_ctypes], ctypes.c_bool, ) -def llama_vocab_get_add_eos(vocab: llama_vocab_p, /) -> bool: - ... +def llama_vocab_get_add_eos(vocab: llama_vocab_p, /) -> bool: ... # LLAMA_API bool llama_vocab_get_add_sep(const struct llama_vocab * vocab); @@ -3075,8 +3082,7 @@ def llama_vocab_get_add_eos(vocab: llama_vocab_p, /) -> bool: [llama_vocab_p_ctypes], ctypes.c_bool, ) -def llama_vocab_get_add_sep(vocab: llama_vocab_p, /) -> bool: - ... +def llama_vocab_get_add_sep(vocab: llama_vocab_p, /) -> bool: ... # LLAMA_API llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab); @@ -3085,8 +3091,7 @@ def llama_vocab_get_add_sep(vocab: llama_vocab_p, /) -> bool: [llama_vocab_p_ctypes], llama_token, ) -def llama_vocab_fim_pre(vocab: llama_vocab_p, /) -> llama_token: - ... +def llama_vocab_fim_pre(vocab: llama_vocab_p, /) -> llama_token: ... # LLAMA_API llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab); @@ -3095,8 +3100,7 @@ def llama_vocab_fim_pre(vocab: llama_vocab_p, /) -> llama_token: [llama_vocab_p_ctypes], llama_token, ) -def llama_vocab_fim_suf(vocab: llama_vocab_p, /) -> llama_token: - ... +def llama_vocab_fim_suf(vocab: llama_vocab_p, /) -> llama_token: ... # LLAMA_API llama_token llama_vocab_fim_mid(const struct llama_vocab * vocab); @@ -3105,8 +3109,7 @@ def llama_vocab_fim_suf(vocab: llama_vocab_p, /) -> llama_token: [llama_vocab_p_ctypes], llama_token, ) -def llama_vocab_fim_mid(vocab: llama_vocab_p, /) -> llama_token: - ... +def llama_vocab_fim_mid(vocab: llama_vocab_p, /) -> llama_token: ... # LLAMA_API llama_token llama_vocab_fim_pad(const struct llama_vocab * vocab); @@ -3115,8 +3118,7 @@ def llama_vocab_fim_mid(vocab: llama_vocab_p, /) -> llama_token: [llama_vocab_p_ctypes], llama_token, ) -def llama_vocab_fim_pad(vocab: llama_vocab_p, /) -> llama_token: - ... +def llama_vocab_fim_pad(vocab: llama_vocab_p, /) -> llama_token: ... # LLAMA_API llama_token llama_vocab_fim_rep(const struct llama_vocab * vocab); @@ -3125,8 +3127,7 @@ def llama_vocab_fim_pad(vocab: llama_vocab_p, /) -> llama_token: [llama_vocab_p_ctypes], llama_token, ) -def llama_vocab_fim_rep(vocab: llama_vocab_p, /) -> llama_token: - ... +def llama_vocab_fim_rep(vocab: llama_vocab_p, /) -> llama_token: ... # LLAMA_API llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab); @@ -3135,8 +3136,7 @@ def llama_vocab_fim_rep(vocab: llama_vocab_p, /) -> llama_token: [llama_vocab_p_ctypes], llama_token, ) -def llama_vocab_fim_sep(vocab: llama_vocab_p, /) -> llama_token: - ... +def llama_vocab_fim_sep(vocab: llama_vocab_p, /) -> llama_token: ... # DEPRECATED functions @@ -3148,8 +3148,7 @@ def llama_vocab_fim_sep(vocab: llama_vocab_p, /) -> llama_token: ) def llama_token_get_text( vocab: llama_vocab_p, token: Union[llama_token, int], / -) -> bytes: - ... +) -> bytes: ... # DEPRECATED(LLAMA_API float llama_token_get_score(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_score instead"); @@ -3160,8 +3159,8 @@ def llama_token_get_text( ) def llama_token_get_score( vocab: llama_vocab_p, token: Union[llama_token, int], / -) -> float: - ... +) -> float: ... + # DEPRECATED(LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_attr instead"); @ctypes_function( @@ -3171,8 +3170,8 @@ def llama_token_get_score( ) def llama_token_get_attr( vocab: llama_vocab_p, token: Union[llama_token, int], / -) -> int: - ... +) -> int: ... + # DEPRECATED(LLAMA_API bool llama_token_is_eog(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_is_eog instead"); @ctypes_function( @@ -3182,8 +3181,8 @@ def llama_token_get_attr( ) def llama_token_is_eog( vocab: llama_vocab_p, token: Union[llama_token, int], / -) -> bool: - ... +) -> bool: ... + # DEPRECATED(LLAMA_API bool llama_token_is_control(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_is_control instead"); @ctypes_function( @@ -3193,8 +3192,8 @@ def llama_token_is_eog( ) def llama_token_is_control( vocab: llama_vocab_p, token: Union[llama_token, int], / -) -> bool: - ... +) -> bool: ... + # DEPRECATED(LLAMA_API llama_token llama_token_bos(const struct llama_vocab * vocab), "use llama_vocab_bos instead"); @ctypes_function( @@ -3202,8 +3201,8 @@ def llama_token_is_control( [llama_vocab_p_ctypes], llama_token, ) -def llama_token_bos(vocab: llama_vocab_p, /) -> int: - ... +def llama_token_bos(vocab: llama_vocab_p, /) -> int: ... + # DEPRECATED(LLAMA_API llama_token llama_token_eos(const struct llama_vocab * vocab), "use llama_vocab_eos instead"); @ctypes_function( @@ -3211,8 +3210,8 @@ def llama_token_bos(vocab: llama_vocab_p, /) -> int: [llama_vocab_p_ctypes], llama_token, ) -def llama_token_eos(vocab: llama_vocab_p, /) -> int: - ... +def llama_token_eos(vocab: llama_vocab_p, /) -> int: ... + # DEPRECATED(LLAMA_API llama_token llama_token_eot(const struct llama_vocab * vocab), "use llama_vocab_eot instead"); @ctypes_function( @@ -3220,8 +3219,8 @@ def llama_token_eos(vocab: llama_vocab_p, /) -> int: [llama_vocab_p_ctypes], llama_token, ) -def llama_token_eot(vocab: llama_vocab_p, /) -> int: - ... +def llama_token_eot(vocab: llama_vocab_p, /) -> int: ... + # DEPRECATED(LLAMA_API llama_token llama_token_cls(const struct llama_vocab * vocab), "use llama_vocab_cls instead"); @ctypes_function( @@ -3229,8 +3228,8 @@ def llama_token_eot(vocab: llama_vocab_p, /) -> int: [llama_vocab_p_ctypes], llama_token, ) -def llama_token_cls(vocab: llama_vocab_p, /) -> int: - ... +def llama_token_cls(vocab: llama_vocab_p, /) -> int: ... + # DEPRECATED(LLAMA_API llama_token llama_token_sep(const struct llama_vocab * vocab), "use llama_vocab_sep instead"); @ctypes_function( @@ -3238,8 +3237,7 @@ def llama_token_cls(vocab: llama_vocab_p, /) -> int: [llama_vocab_p_ctypes], llama_token, ) -def llama_token_sep(vocab: llama_vocab_p, /) -> int: - ... +def llama_token_sep(vocab: llama_vocab_p, /) -> int: ... # DEPRECATED(LLAMA_API llama_token llama_token_nl (const struct llama_vocab * vocab), "use llama_vocab_nl instead"); @@ -3248,8 +3246,7 @@ def llama_token_sep(vocab: llama_vocab_p, /) -> int: [llama_vocab_p_ctypes], llama_token, ) -def llama_token_nl(vocab: llama_vocab_p, /) -> int: - ... +def llama_token_nl(vocab: llama_vocab_p, /) -> int: ... # DEPRECATED(LLAMA_API llama_token llama_token_pad(const struct llama_vocab * vocab), "use llama_vocab_pad instead"); @@ -3258,8 +3255,7 @@ def llama_token_nl(vocab: llama_vocab_p, /) -> int: [llama_vocab_p_ctypes], llama_token, ) -def llama_token_pad(vocab: llama_vocab_p, /) -> int: - ... +def llama_token_pad(vocab: llama_vocab_p, /) -> int: ... # DEPRECATED(LLAMA_API bool llama_add_bos_token(const struct llama_vocab * vocab), "use llama_vocab_get_add_bos instead"); @@ -3268,8 +3264,8 @@ def llama_token_pad(vocab: llama_vocab_p, /) -> int: [llama_vocab_p_ctypes], ctypes.c_bool, ) -def llama_add_bos_token(vocab: llama_vocab_p, /) -> bool: - ... +def llama_add_bos_token(vocab: llama_vocab_p, /) -> bool: ... + # DEPRECATED(LLAMA_API bool llama_add_eos_token(const struct llama_vocab * vocab), "use llama_vocab_get_add_eos instead"); @ctypes_function( @@ -3277,8 +3273,7 @@ def llama_add_bos_token(vocab: llama_vocab_p, /) -> bool: [llama_vocab_p_ctypes], ctypes.c_bool, ) -def llama_add_eos_token(vocab: llama_vocab_p, /) -> bool: - ... +def llama_add_eos_token(vocab: llama_vocab_p, /) -> bool: ... # DEPRECATED(LLAMA_API llama_token llama_token_fim_pre(const struct llama_vocab * vocab), "use llama_vocab_fim_pre instead"); @@ -3287,8 +3282,8 @@ def llama_add_eos_token(vocab: llama_vocab_p, /) -> bool: [llama_vocab_p_ctypes], llama_token, ) -def llama_token_fim_pre(vocab: llama_vocab_p, /) -> llama_token: - ... +def llama_token_fim_pre(vocab: llama_vocab_p, /) -> llama_token: ... + # DEPRECATED(LLAMA_API llama_token llama_token_fim_suf(const struct llama_vocab * vocab), "use llama_vocab_fim_suf instead"); @ctypes_function( @@ -3296,8 +3291,8 @@ def llama_token_fim_pre(vocab: llama_vocab_p, /) -> llama_token: [llama_vocab_p_ctypes], llama_token, ) -def llama_token_fim_suf(vocab: llama_vocab_p, /) -> llama_token: - ... +def llama_token_fim_suf(vocab: llama_vocab_p, /) -> llama_token: ... + # DEPRECATED(LLAMA_API llama_token llama_token_fim_mid(const struct llama_vocab * vocab), "use llama_vocab_fim_mid instead"); @ctypes_function( @@ -3305,8 +3300,8 @@ def llama_token_fim_suf(vocab: llama_vocab_p, /) -> llama_token: [llama_vocab_p_ctypes], llama_token, ) -def llama_token_fim_mid(vocab: llama_vocab_p, /) -> llama_token: - ... +def llama_token_fim_mid(vocab: llama_vocab_p, /) -> llama_token: ... + # DEPRECATED(LLAMA_API llama_token llama_token_fim_pad(const struct llama_vocab * vocab), "use llama_vocab_fim_pad instead"); @ctypes_function( @@ -3314,8 +3309,8 @@ def llama_token_fim_mid(vocab: llama_vocab_p, /) -> llama_token: [llama_vocab_p_ctypes], llama_token, ) -def llama_token_fim_pad(vocab: llama_vocab_p, /) -> llama_token: - ... +def llama_token_fim_pad(vocab: llama_vocab_p, /) -> llama_token: ... + # DEPRECATED(LLAMA_API llama_token llama_token_fim_rep(const struct llama_vocab * vocab), "use llama_vocab_fim_rep instead"); @ctypes_function( @@ -3323,8 +3318,8 @@ def llama_token_fim_pad(vocab: llama_vocab_p, /) -> llama_token: [llama_vocab_p_ctypes], llama_token, ) -def llama_token_fim_rep(vocab: llama_vocab_p, /) -> llama_token: - ... +def llama_token_fim_rep(vocab: llama_vocab_p, /) -> llama_token: ... + # DEPRECATED(LLAMA_API llama_token llama_token_fim_sep(const struct llama_vocab * vocab), "use llama_vocab_fim_sep instead"); @ctypes_function( @@ -3332,8 +3327,8 @@ def llama_token_fim_rep(vocab: llama_vocab_p, /) -> llama_token: [llama_vocab_p_ctypes], llama_token, ) -def llama_token_fim_sep(vocab: llama_vocab_p, /) -> llama_token: - ... +def llama_token_fim_sep(vocab: llama_vocab_p, /) -> llama_token: ... + # // CLS is equivalent to BOS # DEPRECATED(LLAMA_API llama_token llama_vocab_cls(const struct llama_vocab * vocab), // classification @@ -3343,8 +3338,7 @@ def llama_token_fim_sep(vocab: llama_vocab_p, /) -> llama_token: [llama_vocab_p_ctypes], llama_token, ) -def llama_vocab_cls(vocab: llama_vocab_p, /) -> llama_token: - ... +def llama_vocab_cls(vocab: llama_vocab_p, /) -> llama_token: ... # // @@ -3353,6 +3347,7 @@ def llama_vocab_cls(vocab: llama_vocab_p, /) -> llama_token: # // The API is thread-safe. # // + # /// @details Convert the provided text into tokens. # /// @param tokens The tokens pointer must be large enough to hold the resulting tokens. # /// @return Returns the number of tokens on success, no more than n_tokens_max @@ -3512,6 +3507,7 @@ def llama_detokenize( # // Chat templates # // + # /// Apply chat template. Inspired by hf apply_chat_template() on python. # /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model" # /// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggml-org/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template @@ -3535,9 +3531,9 @@ def llama_detokenize( ctypes.c_char_p, # tmpl ctypes.POINTER(llama_chat_message), # chat ctypes.c_size_t, # n_msg - ctypes.c_bool, # add_ass (added) + ctypes.c_bool, # add_ass (added) ctypes.c_char_p, # buf - ctypes.c_int32, # length + ctypes.c_int32, # length ], ctypes.c_int32, ) @@ -3611,11 +3607,11 @@ def llama_chat_builtin_templates( # struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL # void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL + # // TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph # //void (*apply_ggml) (struct llama_sampler * smpl, ...); # }; -class llama_sampler_i(ctypes.Structure): - ... +class llama_sampler_i(ctypes.Structure): ... # struct llama_sampler { @@ -3662,8 +3658,7 @@ class llama_sampler(ctypes.Structure): ) def llama_sampler_init( iface: ctypes.POINTER(llama_sampler_i), ctx: llama_sampler_context_t, / -) -> llama_sampler_p: - ... +) -> llama_sampler_p: ... # LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl); @@ -3672,8 +3667,7 @@ def llama_sampler_init( [llama_sampler_p_ctypes], ctypes.c_char_p, ) -def llama_sampler_name(smpl: llama_sampler_p, /) -> bytes: - ... +def llama_sampler_name(smpl: llama_sampler_p, /) -> bytes: ... # LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token); @@ -3682,8 +3676,7 @@ def llama_sampler_name(smpl: llama_sampler_p, /) -> bytes: [llama_sampler_p_ctypes, llama_token], None, ) -def llama_sampler_accept(smpl: llama_sampler_p, token: Union[llama_token, int], /): - ... +def llama_sampler_accept(smpl: llama_sampler_p, token: Union[llama_token, int], /): ... # LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p); @@ -3694,8 +3687,7 @@ def llama_sampler_accept(smpl: llama_sampler_p, token: Union[llama_token, int], ) def llama_sampler_apply( smpl: llama_sampler_p, cur_p: CtypesArray[llama_token_data_array], / -): - ... +): ... # LLAMA_API void llama_sampler_reset ( struct llama_sampler * smpl); @@ -3704,8 +3696,7 @@ def llama_sampler_apply( [llama_sampler_p_ctypes], None, ) -def llama_sampler_reset(smpl: llama_sampler_p, /): - ... +def llama_sampler_reset(smpl: llama_sampler_p, /): ... # LLAMA_API struct llama_sampler * llama_sampler_clone (const struct llama_sampler * smpl); @@ -3714,8 +3705,7 @@ def llama_sampler_reset(smpl: llama_sampler_p, /): [llama_sampler_p_ctypes], llama_sampler_p_ctypes, ) -def llama_sampler_clone(smpl: llama_sampler_p, /) -> llama_sampler_p: - ... +def llama_sampler_clone(smpl: llama_sampler_p, /) -> llama_sampler_p: ... # // important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add) @@ -3725,21 +3715,22 @@ def llama_sampler_clone(smpl: llama_sampler_p, /) -> llama_sampler_p: [llama_sampler_p_ctypes], None, ) -def llama_sampler_free(smpl: llama_sampler_p, /): - ... +def llama_sampler_free(smpl: llama_sampler_p, /): ... # // llama_sampler_chain # // a type of llama_sampler that can chain multiple samplers one after another + # LLAMA_API struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params); @ctypes_function( "llama_sampler_chain_init", [llama_sampler_chain_params], llama_sampler_p_ctypes, ) -def llama_sampler_chain_init(params: llama_sampler_chain_params, /) -> llama_sampler_p: - ... +def llama_sampler_chain_init( + params: llama_sampler_chain_params, / +) -> llama_sampler_p: ... # // important: takes ownership of the sampler object and will free it when llama_sampler_free is called @@ -3749,8 +3740,7 @@ def llama_sampler_chain_init(params: llama_sampler_chain_params, /) -> llama_sam [llama_sampler_p_ctypes, llama_sampler_p_ctypes], None, ) -def llama_sampler_chain_add(chain: llama_sampler_p, smpl: llama_sampler_p, /): - ... +def llama_sampler_chain_add(chain: llama_sampler_p, smpl: llama_sampler_p, /): ... # LLAMA_API struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i); @@ -3761,8 +3751,7 @@ def llama_sampler_chain_add(chain: llama_sampler_p, smpl: llama_sampler_p, /): ) def llama_sampler_chain_get( chain: llama_sampler_p, i: Union[ctypes.c_int32, int], / -) -> llama_sampler_p: - ... +) -> llama_sampler_p: ... # LLAMA_API int llama_sampler_chain_n (const struct llama_sampler * chain); @@ -3771,8 +3760,7 @@ def llama_sampler_chain_get( [llama_sampler_p_ctypes], ctypes.c_int, ) -def llama_sampler_chain_n(chain: llama_sampler_p, /) -> int: - ... +def llama_sampler_chain_n(chain: llama_sampler_p, /) -> int: ... # // after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed @@ -3784,39 +3772,33 @@ def llama_sampler_chain_n(chain: llama_sampler_p, /) -> int: ) def llama_sampler_chain_remove( chain: llama_sampler_p, i: Union[ctypes.c_int32, int], / -) -> llama_sampler_p: - ... +) -> llama_sampler_p: ... # // available samplers: + # LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void); @ctypes_function("llama_sampler_init_greedy", [], llama_sampler_p_ctypes) -def llama_sampler_init_greedy() -> llama_sampler_p: - ... +def llama_sampler_init_greedy() -> llama_sampler_p: ... # LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed); @ctypes_function("llama_sampler_init_dist", [ctypes.c_uint32], llama_sampler_p_ctypes) -def llama_sampler_init_dist(seed: int) -> llama_sampler_p: - ... +def llama_sampler_init_dist(seed: int) -> llama_sampler_p: ... # /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. # /// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first. # DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void), # "will be removed in the future (see https://github.com/ggml-org/llama.cpp/pull/9896#discussion_r1800920915)"); -@ctypes_function("llama_sampler_init_softmax", [], llama_sampler_p_ctypes) -def llama_sampler_init_softmax() -> llama_sampler_p: - ... # /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 # /// Setting k <= 0 makes this a noop # LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k); @ctypes_function("llama_sampler_init_top_k", [ctypes.c_int32], llama_sampler_p_ctypes) -def llama_sampler_init_top_k(k: int) -> llama_sampler_p: - ... +def llama_sampler_init_top_k(k: int) -> llama_sampler_p: ... # /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 @@ -3826,8 +3808,7 @@ def llama_sampler_init_top_k(k: int) -> llama_sampler_p: [ctypes.c_float, ctypes.c_size_t], llama_sampler_p_ctypes, ) -def llama_sampler_init_top_p(p: float, min_keep: int) -> llama_sampler_p: - ... +def llama_sampler_init_top_p(p: float, min_keep: int) -> llama_sampler_p: ... # /// @details Minimum P sampling as described in https://github.com/ggml-org/llama.cpp/pull/3841 @@ -3837,8 +3818,7 @@ def llama_sampler_init_top_p(p: float, min_keep: int) -> llama_sampler_p: [ctypes.c_float, ctypes.c_size_t], llama_sampler_p_ctypes, ) -def llama_sampler_init_min_p(p: float, min_keep: int) -> llama_sampler_p: - ... +def llama_sampler_init_min_p(p: float, min_keep: int) -> llama_sampler_p: ... # /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. @@ -3848,15 +3828,13 @@ def llama_sampler_init_min_p(p: float, min_keep: int) -> llama_sampler_p: [ctypes.c_float, ctypes.c_size_t], llama_sampler_p_ctypes, ) -def llama_sampler_init_typical(p: float, min_keep: int) -> llama_sampler_p: - ... +def llama_sampler_init_typical(p: float, min_keep: int) -> llama_sampler_p: ... # /// #details Updates the logits l_i` = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original value, the rest are set to -inf # LLAMA_API struct llama_sampler * llama_sampler_init_temp (float t); @ctypes_function("llama_sampler_init_temp", [ctypes.c_float], llama_sampler_p_ctypes) -def llama_sampler_init_temp(t: float) -> llama_sampler_p: - ... +def llama_sampler_init_temp(t: float) -> llama_sampler_p: ... # /// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772. @@ -3868,8 +3846,7 @@ def llama_sampler_init_temp(t: float) -> llama_sampler_p: ) def llama_sampler_init_temp_ext( t: float, delta: float, exponent: float -) -> llama_sampler_p: - ... +) -> llama_sampler_p: ... # /// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335 @@ -3881,8 +3858,7 @@ def llama_sampler_init_temp_ext( ) def llama_sampler_init_xtc( p: float, t: float, min_keep: int, seed: int, / -) -> llama_sampler_p: - ... +) -> llama_sampler_p: ... # /// @details Top n sigma sampling as described in academic paper "Top-nσ: Not All Logits Are You Need" https://arxiv.org/pdf/2411.07641 @@ -3892,8 +3868,7 @@ def llama_sampler_init_xtc( [ctypes.c_float], llama_sampler_p_ctypes, ) -def llama_sampler_init_top_n_sigma(n: float, /) -> llama_sampler_p: - ... +def llama_sampler_init_top_n_sigma(n: float, /) -> llama_sampler_p: ... # /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. @@ -3910,8 +3885,7 @@ def llama_sampler_init_top_n_sigma(n: float, /) -> llama_sampler_p: ) def llama_sampler_init_mirostat( n_vocab: int, seed: int, tau: float, eta: float, m: int, / -) -> llama_sampler_p: - ... +) -> llama_sampler_p: ... # /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. @@ -3926,8 +3900,7 @@ def llama_sampler_init_mirostat( ) def llama_sampler_init_mirostat_v2( seed: int, tau: float, eta: float, / -) -> llama_sampler_p: - ... +) -> llama_sampler_p: ... # /// @details Intializes a GBNF grammar, see grammars/README.md for details. @@ -3942,8 +3915,7 @@ def llama_sampler_init_mirostat_v2( ) def llama_sampler_init_grammar( vocab: llama_vocab_p, grammar_str: bytes, grammar_root: bytes, / -) -> llama_sampler_p: - ... +) -> llama_sampler_p: ... # DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy( @@ -3977,8 +3949,7 @@ def llama_sampler_init_grammar_lazy( trigger_tokens: CtypesArray[llama_token], num_trigger_tokens: int, /, -) -> llama_sampler_p: - ... +) -> llama_sampler_p: ... # /// @details Lazy grammar sampler, introduced in https://github.com/ggml-org/llama.cpp/pull/9639 @@ -4012,8 +3983,7 @@ def llama_sampler_init_grammar_lazy_patterns( trigger_tokens: CtypesArray[llama_token], num_trigger_tokens: int, /, -) -> llama_sampler_p: - ... +) -> llama_sampler_p: ... # /// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first. @@ -4033,8 +4003,7 @@ def llama_sampler_init_penalties( penalty_freq: float, penalty_present: float, /, -) -> llama_sampler_p: - ... +) -> llama_sampler_p: ... # /// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982 @@ -4071,8 +4040,7 @@ def llama_sampler_init_dry( seq_breakers, num_breakers: int, /, -) -> llama_sampler_p: - ... +) -> llama_sampler_p: ... # LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias( @@ -4086,8 +4054,7 @@ def llama_sampler_init_dry( ) def llama_sampler_init_logit_bias( n_vocab: int, n_logit_bias: int, logit_bias: CtypesArray[llama_logit_bias], / -) -> llama_sampler_p: - ... +) -> llama_sampler_p: ... # // this sampler is meant to be used for fill-in-the-middle infilling @@ -4097,8 +4064,7 @@ def llama_sampler_init_logit_bias( [llama_vocab_p_ctypes], llama_sampler_p_ctypes, ) -def llama_sampler_init_infill(vocab: llama_vocab_p, /) -> llama_sampler_p: - ... +def llama_sampler_init_infill(vocab: llama_vocab_p, /) -> llama_sampler_p: ... # // Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise @@ -4108,8 +4074,7 @@ def llama_sampler_init_infill(vocab: llama_vocab_p, /) -> llama_sampler_p: [llama_sampler_p_ctypes], ctypes.c_uint32, ) -def llama_sampler_get_seed(smpl: llama_sampler_p, /) -> int: - ... +def llama_sampler_get_seed(smpl: llama_sampler_p, /) -> int: ... # /// @details Sample and accept a token from the idx-th output of the last evaluation @@ -4121,14 +4086,14 @@ def llama_sampler_get_seed(smpl: llama_sampler_p, /) -> int: ) def llama_sampler_sample( smpl: llama_sampler_p, ctx: llama_context_p, idx: int, / -) -> int: - ... +) -> int: ... # // # // Model split # // + # /// @details Build a split GGUF final path for this chunk. # LLAMA_API int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count); @ctypes_function( @@ -4170,16 +4135,34 @@ def llama_split_prefix( # // Print system information # LLAMA_API const char * llama_print_system_info(void); @ctypes_function("llama_print_system_info", [], ctypes.c_char_p) -def llama_print_system_info() -> bytes: +def llama_print_system_info() -> bytes: ... + + +# LLAMA_API const char * llama_flash_attn_type_name(enum llama_flash_attn_type type); +@ctypes_function("llama_flash_attn_type_name", [ctypes.c_int], ctypes.c_char_p) +def llama_flash_attn_type_name(type: int, /) -> bytes: + """Get name of flash attention type""" + ... + + +# LLAMA_API const char * llama_model_meta_key_str(enum llama_model_meta_key key); +@ctypes_function("llama_model_meta_key_str", [ctypes.c_int], ctypes.c_char_p) +def llama_model_meta_key_str(key: int, /) -> bytes: + """Get string representation of model meta key""" + ... + + +# LLAMA_API ggml_log_callback llama_log_get(void); +@ctypes_function("llama_log_get", [], ggml_log_callback) +def llama_log_get() -> ggml_log_callback: + """Get current log callback""" ... -# // Set callback for all future logging events. -# // If this is not called, or NULL is supplied, everything is output on stderr. # LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data); @ctypes_function( "llama_log_set", - [ctypes.c_void_p, ctypes.c_void_p], + [ggml_log_callback, ctypes.c_void_p], None, ) def llama_log_set( @@ -4193,7 +4176,17 @@ def llama_log_set( ... -# // +# LLAMA_API void llama_memory_breakdown_print(const struct llama_context * ctx); +@ctypes_function( + "llama_memory_breakdown_print", + [llama_context_p_ctypes], + None, +) +def llama_memory_breakdown_print(ctx: llama_context_p, /): + """Print memory breakdown for context""" + ... + + # // Performance utils # // @@ -4203,6 +4196,7 @@ def llama_log_set( # double t_p_eval_ms; # double t_eval_ms; + # int32_t n_p_eval; # int32_t n_eval; # int32_t n_reused; // number of times a ggml compute graph had been reused @@ -4222,6 +4216,7 @@ class llama_perf_context_data(ctypes.Structure): # struct llama_perf_sampler_data { # double t_sample_ms; + # int32_t n_sample; # }; class llama_perf_sampler_data(ctypes.Structure): @@ -4237,8 +4232,7 @@ class llama_perf_sampler_data(ctypes.Structure): [llama_context_p_ctypes], llama_perf_context_data, ) -def llama_perf_context(ctx: llama_context_p, /) -> llama_perf_context_data: - ... +def llama_perf_context(ctx: llama_context_p, /) -> llama_perf_context_data: ... # LLAMA_API void llama_perf_context_print(const struct llama_context * ctx); @@ -4247,8 +4241,7 @@ def llama_perf_context(ctx: llama_context_p, /) -> llama_perf_context_data: [llama_context_p_ctypes], None, ) -def llama_perf_context_print(ctx: llama_context_p, /): - ... +def llama_perf_context_print(ctx: llama_context_p, /): ... # LLAMA_API void llama_perf_context_reset( struct llama_context * ctx); @@ -4257,8 +4250,7 @@ def llama_perf_context_print(ctx: llama_context_p, /): [llama_context_p_ctypes], None, ) -def llama_perf_context_reset(ctx: llama_context_p, /): - ... +def llama_perf_context_reset(ctx: llama_context_p, /): ... # // NOTE: the following work only with samplers constructed via llama_sampler_chain_init @@ -4268,8 +4260,7 @@ def llama_perf_context_reset(ctx: llama_context_p, /): [llama_sampler_p_ctypes], llama_perf_sampler_data, ) -def llama_perf_sampler(chain: llama_sampler_p, /) -> llama_perf_sampler_data: - ... +def llama_perf_sampler(chain: llama_sampler_p, /) -> llama_perf_sampler_data: ... # LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain); @@ -4278,8 +4269,7 @@ def llama_perf_sampler(chain: llama_sampler_p, /) -> llama_perf_sampler_data: [llama_sampler_p_ctypes], None, ) -def llama_perf_sampler_print(chain: llama_sampler_p, /): - ... +def llama_perf_sampler_print(chain: llama_sampler_p, /): ... # LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain); @@ -4288,8 +4278,7 @@ def llama_perf_sampler_print(chain: llama_sampler_p, /): [llama_sampler_p_ctypes], None, ) -def llama_perf_sampler_reset(chain: llama_sampler_p, /): - ... +def llama_perf_sampler_reset(chain: llama_sampler_p, /): ... # // @@ -4298,7 +4287,10 @@ def llama_perf_sampler_reset(chain: llama_sampler_p, /): # // function that returns whether or not a given tensor contains trainable parameters # typedef bool (*llama_opt_param_filter)(const struct ggml_tensor * tensor, void * userdata); -llama_opt_param_filter = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_void_p, ctypes.c_void_p) +llama_opt_param_filter = ctypes.CFUNCTYPE( + ctypes.c_bool, ctypes.c_void_p, ctypes.c_void_p +) + # // always returns true # LLAMA_API bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata); @@ -4307,8 +4299,9 @@ def llama_perf_sampler_reset(chain: llama_sampler_p, /): [ctypes.c_void_p, ctypes.c_void_p], ctypes.c_bool, ) -def llama_opt_param_filter_all(tensor: ctypes.c_void_p, userdata: ctypes.c_void_p, /) -> bool: - ... +def llama_opt_param_filter_all( + tensor: ctypes.c_void_p, userdata: ctypes.c_void_p, / +) -> bool: ... # struct llama_opt_params { @@ -4317,6 +4310,7 @@ def llama_opt_param_filter_all(tensor: ctypes.c_void_p, userdata: ctypes.c_void_ # llama_opt_param_filter param_filter; // callback for determining which tensors contain trainable parameters # void * param_filter_ud; // userdata for determining which tensors contain trainable parameters + # ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters # void * get_opt_pars_ud; // userdata for calculating optimizer parameters # }; @@ -4325,7 +4319,10 @@ class llama_opt_params(ctypes.Structure): ("n_ctx_train", ctypes.c_uint32), ("param_filter", llama_opt_param_filter), ("param_filter_ud", ctypes.c_void_p), - ("get_opt_pars", ctypes.c_void_p), # ggml_opt_get_optimizer_params - not implemented here + ( + "get_opt_pars", + ctypes.c_void_p, + ), # ggml_opt_get_optimizer_params - not implemented here ("get_opt_pars_ud", ctypes.c_void_p), ] @@ -4336,8 +4333,9 @@ class llama_opt_params(ctypes.Structure): [llama_context_p_ctypes, llama_model_p_ctypes, llama_opt_params], None, ) -def llama_opt_init(lctx: llama_context_p, model: llama_model_p, lopt_params: llama_opt_params, /): - ... +def llama_opt_init( + lctx: llama_context_p, model: llama_model_p, lopt_params: llama_opt_params, / +): ... # LLAMA_API void llama_opt_epoch( @@ -4353,7 +4351,7 @@ def llama_opt_init(lctx: llama_context_p, model: llama_model_p, lopt_params: lla [ llama_context_p_ctypes, ctypes.c_void_p, # ggml_opt_dataset_t - ctypes.c_void_p, # ggml_opt_result_t + ctypes.c_void_p, # ggml_opt_result_t ctypes.c_void_p, # ggml_opt_result_t ctypes.c_int64, ctypes.c_void_p, # ggml_opt_epoch_callback @@ -4370,5 +4368,4 @@ def llama_opt_epoch( callback_train: ctypes.c_void_p, callback_eval: ctypes.c_void_p, /, -): - ... +): ... diff --git a/llama_cpp/mtmd_cpp.py b/llama_cpp/mtmd_cpp.py index a45f8f406..e00eb3a0b 100644 --- a/llama_cpp/mtmd_cpp.py +++ b/llama_cpp/mtmd_cpp.py @@ -39,7 +39,11 @@ # Specify the base name of the shared library to load _libmtmd_base_name = "mtmd" _libmtmd_override_path = os.environ.get("MTMD_CPP_LIB") -_libmtmd_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib" if _libmtmd_override_path is None else pathlib.Path() +_libmtmd_base_path = ( + pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib" + if _libmtmd_override_path is None + else pathlib.Path(_libmtmd_override_path) +) # Load the library _libmtmd = load_shared_library(_libmtmd_base_name, _libmtmd_base_path) @@ -71,17 +75,22 @@ MTMD_INPUT_CHUNK_TYPE_IMAGE = 1 MTMD_INPUT_CHUNK_TYPE_AUDIO = 2 + # Structures class mtmd_context_params(Structure): _fields_ = [ ("use_gpu", c_bool), ("print_timings", c_bool), ("n_threads", c_int), - ("verbosity", c_int), # ggml_log_level ("image_marker", c_char_p), ("media_marker", c_char_p), + ("flash_attn_type", c_int), # enum llama_flash_attn_type + ("warmup", c_bool), + ("image_min_tokens", c_int), + ("image_max_tokens", c_int), ] + class mtmd_input_text(Structure): _fields_ = [ ("text", c_char_p), @@ -89,19 +98,21 @@ class mtmd_input_text(Structure): ("parse_special", c_bool), ] + ################################################ # mtmd.h functions ################################################ + # MTMD_API const char * mtmd_default_marker(void); @ctypes_function("mtmd_default_marker", [], c_char_p) -def mtmd_default_marker() -> bytes: - ... +def mtmd_default_marker() -> bytes: ... + # MTMD_API struct mtmd_context_params mtmd_context_params_default(void); @ctypes_function("mtmd_context_params_default", [], mtmd_context_params) -def mtmd_context_params_default() -> mtmd_context_params: - ... +def mtmd_context_params_default() -> mtmd_context_params: ... + # MTMD_API mtmd_context * mtmd_init_from_file(const char * mmproj_fname, # const struct llama_model * text_model, @@ -109,70 +120,68 @@ def mtmd_context_params_default() -> mtmd_context_params: @ctypes_function( "mtmd_init_from_file", [c_char_p, llama_cpp.llama_model_p_ctypes, mtmd_context_params], - mtmd_context_p_ctypes + mtmd_context_p_ctypes, ) def mtmd_init_from_file( mmproj_fname: bytes, text_model: llama_cpp.llama_model_p, ctx_params: mtmd_context_params, /, -) -> Optional[mtmd_context_p]: - ... +) -> Optional[mtmd_context_p]: ... + # MTMD_API void mtmd_free(mtmd_context * ctx); @ctypes_function("mtmd_free", [mtmd_context_p_ctypes], None) -def mtmd_free(ctx: mtmd_context_p, /): - ... +def mtmd_free(ctx: mtmd_context_p, /): ... + # MTMD_API bool mtmd_support_vision(mtmd_context * ctx); @ctypes_function("mtmd_support_vision", [mtmd_context_p_ctypes], c_bool) -def mtmd_support_vision(ctx: mtmd_context_p, /) -> bool: - ... +def mtmd_support_vision(ctx: mtmd_context_p, /) -> bool: ... + # MTMD_API mtmd_bitmap * mtmd_bitmap_init(uint32_t nx, uint32_t ny, const unsigned char * data); @ctypes_function( - "mtmd_bitmap_init", - [c_uint32, c_uint32, POINTER(c_uint8)], - mtmd_bitmap_p_ctypes + "mtmd_bitmap_init", [c_uint32, c_uint32, POINTER(c_uint8)], mtmd_bitmap_p_ctypes ) def mtmd_bitmap_init( nx: Union[c_uint32, int], ny: Union[c_uint32, int], data: CtypesArray[c_uint8], /, -) -> Optional[mtmd_bitmap_p]: - ... +) -> Optional[mtmd_bitmap_p]: ... + # MTMD_API void mtmd_bitmap_free(mtmd_bitmap * bitmap); @ctypes_function("mtmd_bitmap_free", [mtmd_bitmap_p_ctypes], None) -def mtmd_bitmap_free(bitmap: mtmd_bitmap_p, /): - ... +def mtmd_bitmap_free(bitmap: mtmd_bitmap_p, /): ... + # MTMD_API mtmd_input_chunks * mtmd_input_chunks_init(void); @ctypes_function("mtmd_input_chunks_init", [], mtmd_input_chunks_p_ctypes) -def mtmd_input_chunks_init() -> Optional[mtmd_input_chunks_p]: - ... +def mtmd_input_chunks_init() -> Optional[mtmd_input_chunks_p]: ... + # MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chunks); @ctypes_function("mtmd_input_chunks_free", [mtmd_input_chunks_p_ctypes], None) -def mtmd_input_chunks_free(chunks: mtmd_input_chunks_p, /): - ... +def mtmd_input_chunks_free(chunks: mtmd_input_chunks_p, /): ... + # MTMD_API size_t mtmd_input_chunks_size(const mtmd_input_chunks * chunks); @ctypes_function("mtmd_input_chunks_size", [mtmd_input_chunks_p_ctypes], c_size_t) -def mtmd_input_chunks_size(chunks: mtmd_input_chunks_p, /) -> int: - ... +def mtmd_input_chunks_size(chunks: mtmd_input_chunks_p, /) -> int: ... + # MTMD_API const mtmd_input_chunk * mtmd_input_chunks_get(const mtmd_input_chunks * chunks, size_t idx); @ctypes_function( "mtmd_input_chunks_get", [mtmd_input_chunks_p_ctypes, c_size_t], - mtmd_input_chunk_p_ctypes + mtmd_input_chunk_p_ctypes, ) def mtmd_input_chunks_get( chunks: mtmd_input_chunks_p, idx: Union[c_size_t, int], / -) -> Optional[mtmd_input_chunk_p]: - ... +) -> Optional[mtmd_input_chunk_p]: ... + # MTMD_API int32_t mtmd_tokenize(mtmd_context * ctx, # mtmd_input_chunks * output, @@ -197,52 +206,53 @@ def mtmd_tokenize( bitmaps: CtypesArray[mtmd_bitmap_p_ctypes], n_bitmaps: Union[c_size_t, int], /, -) -> int: - ... +) -> int: ... + # MTMD_API size_t mtmd_input_chunk_get_n_tokens(const mtmd_input_chunk * chunk); @ctypes_function("mtmd_input_chunk_get_n_tokens", [mtmd_input_chunk_p_ctypes], c_size_t) -def mtmd_input_chunk_get_n_tokens(chunk: mtmd_input_chunk_p, /) -> int: - ... +def mtmd_input_chunk_get_n_tokens(chunk: mtmd_input_chunk_p, /) -> int: ... + # MTMD_API enum mtmd_input_chunk_type mtmd_input_chunk_get_type(const mtmd_input_chunk * chunk); @ctypes_function("mtmd_input_chunk_get_type", [mtmd_input_chunk_p_ctypes], c_int) -def mtmd_input_chunk_get_type(chunk: mtmd_input_chunk_p, /) -> int: - ... +def mtmd_input_chunk_get_type(chunk: mtmd_input_chunk_p, /) -> int: ... + # MTMD_API const llama_token * mtmd_input_chunk_get_tokens_text(const mtmd_input_chunk * chunk, size_t * n_tokens_output); @ctypes_function( "mtmd_input_chunk_get_tokens_text", [mtmd_input_chunk_p_ctypes, POINTER(c_size_t)], - POINTER(llama_cpp.llama_token) + POINTER(llama_cpp.llama_token), ) def mtmd_input_chunk_get_tokens_text( chunk: mtmd_input_chunk_p, n_tokens_output: "_Pointer[c_size_t]", / -) -> Optional["_Pointer[llama_cpp.llama_token]"]: - ... +) -> Optional["_Pointer[llama_cpp.llama_token]"]: ... + ################################################ # mtmd-helper.h functions ################################################ + # MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(mtmd_context * ctx, const unsigned char * buf, size_t len); @ctypes_function( "mtmd_helper_bitmap_init_from_buf", [mtmd_context_p_ctypes, POINTER(c_uint8), c_size_t], - mtmd_bitmap_p_ctypes + mtmd_bitmap_p_ctypes, ) def mtmd_helper_bitmap_init_from_buf( ctx: mtmd_context_p, buf: CtypesArray[c_uint8], length: Union[c_size_t, int], /, -) -> Optional[mtmd_bitmap_p]: - ... +) -> Optional[mtmd_bitmap_p]: ... + # MTMD_API size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks); @ctypes_function("mtmd_helper_get_n_tokens", [mtmd_input_chunks_p_ctypes], c_size_t) -def mtmd_helper_get_n_tokens(chunks: mtmd_input_chunks_p, /) -> int: - ... +def mtmd_helper_get_n_tokens(chunks: mtmd_input_chunks_p, /) -> int: ... + # MTMD_API int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, # struct llama_context * lctx, @@ -276,5 +286,4 @@ def mtmd_helper_eval_chunk_single( logits_last: Union[c_bool, bool], new_n_past: "_Pointer[llama_cpp.llama_pos]", /, -) -> int: - ... +) -> int: ... diff --git a/llama_cpp/server/settings.py b/llama_cpp/server/settings.py index 13c951241..bad0d4ee7 100644 --- a/llama_cpp/server/settings.py +++ b/llama_cpp/server/settings.py @@ -103,8 +103,9 @@ class ModelSettings(BaseSettings): offload_kqv: bool = Field( default=True, description="Whether to offload kqv to the GPU." ) - flash_attn: bool = Field( - default=False, description="Whether to use flash attention." + flash_attn: Optional[bool] = Field( + default=None, + description="Use flash attention. None=auto, True=enabled, False=disabled.", ) # Sampling Params last_n_tokens_size: int = Field( diff --git a/tests/test_llama.py b/tests/test_llama.py index 0a1a9f5ad..f76190d34 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -82,7 +82,7 @@ def test_real_model(llama_cpp_model_path): cparams.n_threads = multiprocessing.cpu_count() cparams.n_threads_batch = multiprocessing.cpu_count() cparams.logits_all = False - cparams.flash_attn = True + cparams.flash_attn_type = llama_cpp.LLAMA_FLASH_ATTN_TYPE_ENABLED context = internals.LlamaContext(model=model, params=cparams) tokens = model.tokenize(b"Hello, world!", add_bos=True, special=True) diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 4227c9be4..be47fb928 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 4227c9be4268ac844921b90f31595f81236bd317 +Subproject commit be47fb9285779e900915bd8246eb9664110d4ba5