diff --git a/README-zh.md b/README-zh.md index e2d9eacdf9d..024fbbbf206 100644 --- a/README-zh.md +++ b/README-zh.md @@ -240,6 +240,54 @@ sess.close() 更多示例,请参见 [examples](examples) 和 [tests](tests)。 +
+

🧠 AI 辅助 SQL 生成

+ +chDB 可以将自然语言提示转换为 SQL。通过连接/会话字符串配置 AI 客户端参数: + +- `ai_provider`:`openai` 或 `anthropic`。当设置了 `ai_base_url` 时默认使用 OpenAI 兼容接口,否则自动检测。 +- `ai_api_key`:API 密钥;也可从环境变量 `AI_API_KEY`、`OPENAI_API_KEY` 或 `ANTHROPIC_API_KEY` 读取。 +- `ai_base_url`:OpenAI 兼容服务的自定义 Base URL。 +- `ai_model`:模型名称(如 `gpt-4o-mini`、`claude-3-opus-20240229`)。 +- `ai_temperature`:生成温度,默认 `0.0`。 +- `ai_max_tokens`:最大全量生成 token 数,默认 `1000`。 +- `ai_timeout_seconds`:请求超时时间(秒),默认 `30`。 +- `ai_system_prompt`:自定义系统提示词。 +- `ai_max_steps`:工具调用的最大步数,默认 `5`。 +- `ai_enable_schema_access`:允许 AI 查看数据库/表元数据,默认 `true`。 + +未开启 AI 或配置缺失时,调用 `generate_sql`/`ask` 会抛出 `RuntimeError`。 + +```python +import chdb + +# 使用环境变量 OPENAI_API_KEY/AI_API_KEY/ANTHROPIC_API_KEY 提供凭据 +conn = chdb.connect("file::memory:?ai_provider=openai&ai_model=gpt-4o-mini") +conn.query("CREATE TABLE nums (n UInt32) ENGINE = Memory") +conn.query("INSERT INTO nums VALUES (1), (2), (3)") + +sql = conn.generate_sql("Select all rows from nums ordered by n desc") +print(sql) # 例如:SELECT * FROM nums ORDER BY n DESC + +# ask():一键生成并执行 SQL +# `ask()` 会先调用 `generate_sql` 再执行 `query`,关键字参数会透传给 `query`。 +print(conn.ask("List the numbers table", format="Pretty")) +``` + +`Session` 同样支持以上能力;`Session.ask()` 会将关键字参数透传给 `Session.query`: + +```python +from chdb import session as chs + +with chs.Session("file::memory:?ai_provider=openai") as sess: + sess.query("CREATE TABLE users (id UInt32, name String) ENGINE = Memory") + sess.query("INSERT INTO users VALUES (1, 'alice'), (2, 'bob')") + df = sess.ask("Show all users ordered by id", format="DataFrame") + print(df) +``` + +
+ ## 演示和示例 - [Colab Notebook](https://colab.research.google.com/drive/1-zKB6oKfXeptggXi0kUX87iR8ZTSr4P3?usp=sharing) 和更多 [示例](examples) diff --git a/README.md b/README.md index 9b738449a56..d0e81a8766c 100644 --- a/README.md +++ b/README.md @@ -473,6 +473,54 @@ For more examples, see [examples](examples) and [tests](tests).
+
+

🧠 AI-assisted SQL generation

+ +chDB can translate natural language prompts into SQL. Configure the AI client through the connection (or session) string parameters: + +- `ai_provider`: `openai` or `anthropic`. Defaults to OpenAI-compatible when `ai_base_url` is set, otherwise auto-detected. +- `ai_api_key`: API key; falls back to `AI_API_KEY`, `OPENAI_API_KEY`, or `ANTHROPIC_API_KEY` env vars. +- `ai_base_url`: Custom base URL for OpenAI-compatible endpoints. +- `ai_model`: Model name (e.g., `gpt-4o-mini`, `claude-3-opus-20240229`). +- `ai_temperature`: Generation temperature (default `0.0`). +- `ai_max_tokens`: Maximum tokens to generate (default `1000`). +- `ai_timeout_seconds`: Request timeout in seconds (default `30`). +- `ai_system_prompt`: Custom system prompt to steer SQL generation. +- `ai_max_steps`: Maximum tool-calling steps (default `5`). +- `ai_enable_schema_access`: Allow the AI to inspect database/table metadata (default `true`). + +If AI is not enabled in the build or the provider is misconfigured, `generate_sql`/`ask` raise a `RuntimeError`. + +```python +import chdb + +# Use env OPENAI_API_KEY/AI_API_KEY/ANTHROPIC_API_KEY for credentials +conn = chdb.connect("file::memory:?ai_provider=openai&ai_model=gpt-4o-mini") +conn.query("CREATE TABLE nums (n UInt32) ENGINE = Memory") +conn.query("INSERT INTO nums VALUES (1), (2), (3)") + +sql = conn.generate_sql("Select all rows from nums ordered by n desc") +print(sql) # e.g., SELECT * FROM nums ORDER BY n DESC + +# ask(): one-call generate + execute +# `ask()` first calls `generate_sql` then runs `query`; keyword arguments are forwarded to `query`. +print(conn.ask("List the numbers table", format="Pretty")) +``` + +`Session` objects support the same helpers and defaults; `Session.ask()` forwards keyword arguments to `Session.query`: + +```python +from chdb import session as chs + +with chs.Session("file::memory:?ai_provider=openai") as sess: + sess.query("CREATE TABLE users (id UInt32, name String) ENGINE = Memory") + sess.query("INSERT INTO users VALUES (1), (2), (3)") + df = sess.ask("Show all users ordered by id", format="DataFrame") + print(df) +``` + +
+ ## Demos and Examples - [Project Documentation](https://clickhouse.com/docs/en/chdb) and [Usage Examples](https://clickhouse.com/docs/en/chdb/install/python) diff --git a/chdb/__init__.py b/chdb/__init__.py index 60557353037..0eeaedc75a4 100644 --- a/chdb/__init__.py +++ b/chdb/__init__.py @@ -214,6 +214,7 @@ def query(sql, output_format="CSV", path="", udf_path="", params=None): # alias for query sql = query + PyReader = _chdb.PyReader from . import dbapi, session, udf, utils # noqa: E402 diff --git a/chdb/build-musl.sh b/chdb/build-musl.sh index 208a491e692..27e97730a62 100755 --- a/chdb/build-musl.sh +++ b/chdb/build-musl.sh @@ -60,6 +60,7 @@ CMAKE_ARGS="-DCMAKE_BUILD_TYPE=${build_type} -DENABLE_THINLTO=0 -DENABLE_TESTS=0 -DENABLE_KAFKA=1 -DENABLE_LIBPQXX=1 -DENABLE_NATS=0 -DENABLE_AMQPCPP=0 -DENABLE_NURAFT=0 \ -DENABLE_CASSANDRA=0 -DENABLE_ODBC=0 -DENABLE_NLP=0 \ -DENABLE_LDAP=0 \ + -DENABLE_CLIENT_AI=1 \ -DUSE_MUSL=1 \ -DRust_RUSTUP_INSTALL_MISSING_TARGET=ON \ ${MYSQL} \ diff --git a/chdb/build.sh b/chdb/build.sh index f7a41f0b577..96797cca3dd 100755 --- a/chdb/build.sh +++ b/chdb/build.sh @@ -95,6 +95,7 @@ CMAKE_ARGS="-DCMAKE_BUILD_TYPE=${build_type} -DENABLE_THINLTO=0 -DENABLE_TESTS=0 -DENABLE_KAFKA=1 -DENABLE_LIBPQXX=1 -DENABLE_NATS=0 -DENABLE_AMQPCPP=0 -DENABLE_NURAFT=0 \ -DENABLE_CASSANDRA=0 -DENABLE_ODBC=0 -DENABLE_NLP=0 \ -DENABLE_LDAP=0 \ + -DENABLE_CLIENT_AI=1 \ ${MYSQL} \ ${HDFS} \ -DENABLE_LIBRARIES=0 ${RUST_FEATURES} \ diff --git a/chdb/build/build_static_lib.sh b/chdb/build/build_static_lib.sh index 5e3953216b9..a6792f47a5b 100755 --- a/chdb/build/build_static_lib.sh +++ b/chdb/build/build_static_lib.sh @@ -52,6 +52,7 @@ CMAKE_ARGS="-DCMAKE_BUILD_TYPE=${build_type} -DENABLE_THINLTO=0 -DENABLE_TESTS=0 -DENABLE_KAFKA=1 -DENABLE_LIBPQXX=1 -DENABLE_NATS=0 -DENABLE_AMQPCPP=0 -DENABLE_NURAFT=0 \ -DENABLE_CASSANDRA=0 -DENABLE_ODBC=0 -DENABLE_NLP=0 \ -DENABLE_LDAP=0 \ + -DENABLE_CLIENT_AI=1 \ ${MYSQL} \ ${HDFS} \ -DENABLE_LIBRARIES=0 ${RUST_FEATURES} \ diff --git a/chdb/build/build_static_lib_mac_on_linux.sh b/chdb/build/build_static_lib_mac_on_linux.sh index 0370a153530..47dc6854d4c 100755 --- a/chdb/build/build_static_lib_mac_on_linux.sh +++ b/chdb/build/build_static_lib_mac_on_linux.sh @@ -116,6 +116,7 @@ CMAKE_ARGS="-DCMAKE_BUILD_TYPE=${build_type} \ -DENABLE_KAFKA=1 -DENABLE_LIBPQXX=1 -DENABLE_NATS=0 -DENABLE_AMQPCPP=0 -DENABLE_NURAFT=0 \ -DENABLE_CASSANDRA=0 -DENABLE_ODBC=0 -DENABLE_NLP=0 \ -DENABLE_LDAP=0 \ + -DENABLE_CLIENT_AI=1 \ ${MYSQL} \ ${HDFS} \ -DENABLE_LIBRARIES=0 ${RUST_FEATURES} \ diff --git a/chdb/build_mac_on_linux.sh b/chdb/build_mac_on_linux.sh index b9b4625294b..db5aba7a903 100755 --- a/chdb/build_mac_on_linux.sh +++ b/chdb/build_mac_on_linux.sh @@ -121,6 +121,7 @@ CMAKE_ARGS="-DCMAKE_BUILD_TYPE=${build_type} \ -DENABLE_KAFKA=1 -DENABLE_LIBPQXX=1 -DENABLE_NATS=0 -DENABLE_AMQPCPP=0 -DENABLE_NURAFT=0 \ -DENABLE_CASSANDRA=0 -DENABLE_ODBC=0 -DENABLE_NLP=0 \ -DENABLE_LDAP=0 \ + -DENABLE_CLIENT_AI=1 \ ${MYSQL} \ ${HDFS} \ -DENABLE_LIBRARIES=0 ${RUST_FEATURES} \ diff --git a/chdb/session/state.py b/chdb/session/state.py index 7a193bdee0b..1cfb7ea5d7e 100644 --- a/chdb/session/state.py +++ b/chdb/session/state.py @@ -207,6 +207,21 @@ def query(self, sql, fmt="CSV", udf_path="", params=None): # alias sql = query sql = query + def generate_sql(self, prompt: str) -> str: + """Generate SQL text from a natural language prompt using the configured AI provider.""" + if self._conn is None: + raise RuntimeError("Session is closed.") + return self._conn.generate_sql(prompt) + + def ask(self, prompt: str, **kwargs): + """Generate SQL from a prompt, execute it, and return the results. + + All keyword arguments are forwarded to the underlying :meth:`query`. + """ + if self._conn is None: + raise RuntimeError("Session is closed.") + return self._conn.ask(prompt, **kwargs) + def send_query(self, sql, fmt="CSV", params=None) -> StreamingResult: """Execute a SQL query and return a streaming result iterator. diff --git a/chdb/state/sqlitelike.py b/chdb/state/sqlitelike.py index 5e5215e05af..be3db120770 100644 --- a/chdb/state/sqlitelike.py +++ b/chdb/state/sqlitelike.py @@ -466,6 +466,34 @@ def query(self, query: str, format: str = "CSV", params=None) -> Any: result = self._conn.query(query, format, params=params or {}) return result_func(result) + def generate_sql(self, prompt: str) -> str: + """Generate SQL text from a natural language prompt using the configured AI provider.""" + if not hasattr(self._conn, "generate_sql"): + raise RuntimeError("AI SQL generation is not available in this build.") + return self._conn.generate_sql(prompt) + + def ask(self, prompt: str, **kwargs) -> Any: + """Generate SQL from a prompt, execute it, and return the results. + + This convenience method first calls :meth:`generate_sql` to translate + a natural language prompt into SQL, then executes the generated SQL via + :meth:`query`, forwarding any keyword arguments to :meth:`query`. + + Args: + prompt (str): Natural language description of the desired query. + **kwargs: Additional keyword arguments forwarded to :meth:`query` + (for example ``format`` or ``params``). If omitted, defaults + from :meth:`query` are used. + + Returns: + Query results in the requested format (CSV by default). + + Raises: + RuntimeError: If SQL generation is unavailable or query execution fails. + """ + generated_sql = self.generate_sql(prompt) + return self.query(generated_sql, **kwargs) + def send_query(self, query: str, format: str = "CSV", params=None) -> StreamingResult: """Execute a SQL query and return a streaming result iterator. diff --git a/programs/local/AIQueryProcessor.cpp b/programs/local/AIQueryProcessor.cpp new file mode 100644 index 00000000000..33e39ad77b5 --- /dev/null +++ b/programs/local/AIQueryProcessor.cpp @@ -0,0 +1,120 @@ +#include "AIQueryProcessor.h" + +#include "chdb-internal.h" +#include "PybindWrapper.h" + +#include +#include + +#if USE_CLIENT_AI +#include +#include +#endif + +#include +#include +#include + +namespace py = pybind11; + +#if USE_CLIENT_AI + +AIQueryProcessor::AIQueryProcessor(chdb_connection * connection_, const DB::AIConfiguration & config_) + : connection(connection_), ai_config(config_) +{ +} + +AIQueryProcessor::~AIQueryProcessor() = default; + +namespace +{ +void applyEnvFallback(DB::AIConfiguration & config) +{ + if (config.api_key.empty()) + { + if (const char * api_key = std::getenv("AI_API_KEY")) + config.api_key = api_key; + else if (const char * openai_key = std::getenv("OPENAI_API_KEY")) + config.api_key = openai_key; + else if (const char * anthropic_key = std::getenv("ANTHROPIC_API_KEY")) + config.api_key = anthropic_key; + } +} +} + +std::string AIQueryProcessor::executeQueryForAI(const std::string & query) +{ + chdb_result * result = chdb_query_n(*connection, query.data(), query.size(), "TSV", 3); + const auto & error_msg = CHDB::chdb_result_error_string(result); + if (!error_msg.empty()) + { + std::string msg_copy(error_msg); + chdb_destroy_query_result(result); + throw std::runtime_error(msg_copy); + } + + std::string data(chdb_result_buffer(result), chdb_result_length(result)); + chdb_destroy_query_result(result); + return data; +} + +void AIQueryProcessor::initializeGenerator() +{ + if (generator) + return; + + // If a custom base URL is provided but provider is empty, default to OpenAI-compatible. + if (ai_config.provider.empty() && !ai_config.base_url.empty()) + ai_config.provider = "openai"; + + applyEnvFallback(ai_config); + + if (ai_config.api_key.empty()) + throw std::runtime_error("AI SQL generator is not configured. Provide ai_api_key (or set OPENAI_API_KEY/ANTHROPIC_API_KEY) when creating the connection or session."); + + auto ai_result = DB::AIClientFactory::createClient(ai_config); + + if (ai_result.no_configuration_found || !ai_result.client.has_value()) + throw std::runtime_error("AI SQL generator is not configured. Provide ai_api_key (or set OPENAI_API_KEY/ANTHROPIC_API_KEY) when creating the connection or session."); + + auto query_executor = [this](const std::string & query_text) { return executeQueryForAI(query_text); }; + std::cerr << "[chdb] AI SQL generator using provider=" << (ai_config.provider.empty() ? "" : ai_config.provider) + << ", model=" << (ai_config.model.empty() ? "" : ai_config.model) + << ", base_url=" << (ai_config.base_url.empty() ? "" : ai_config.base_url) << std::endl; + generator = std::make_unique(ai_config, std::move(ai_result.client.value()), query_executor, std::cerr); +} + +std::string AIQueryProcessor::generateSQLFromPrompt(const std::string & prompt) +{ + initializeGenerator(); + + if (!generator) + throw std::runtime_error("AI SQL generator is not configured. Provide ai_api_key (or set OPENAI_API_KEY/ANTHROPIC_API_KEY) when creating the connection or session."); + + std::string sql; + { + py::gil_scoped_release release; + sql = generator->generateSQL(prompt); + } + + if (sql.empty()) + throw std::runtime_error("AI did not return a SQL query."); + + return sql; +} + +std::string AIQueryProcessor::generateSQL(const std::string & prompt) +{ + return generateSQLFromPrompt(prompt); +} + +#else + +AIQueryProcessor::AIQueryProcessor(chdb_connection *, const DB::AIConfiguration &) : connection(nullptr) { } +AIQueryProcessor::~AIQueryProcessor() = default; +std::string AIQueryProcessor::executeQueryForAI(const std::string &) { return {}; } +void AIQueryProcessor::initializeGenerator() { } +std::string AIQueryProcessor::generateSQLFromPrompt(const std::string &) { return {}; } +std::string AIQueryProcessor::generateSQL(const std::string &) { return {}; } + +#endif diff --git a/programs/local/AIQueryProcessor.h b/programs/local/AIQueryProcessor.h new file mode 100644 index 00000000000..2f38c5948b4 --- /dev/null +++ b/programs/local/AIQueryProcessor.h @@ -0,0 +1,28 @@ +#pragma once + +#include "chdb.h" +#include +#include + +#include +#include + +/// AI query processor that delegates to AISQLGenerator. +class AIQueryProcessor +{ +public: + AIQueryProcessor(chdb_connection * connection_, const DB::AIConfiguration & config_); + ~AIQueryProcessor(); + + /// Generate SQL using the configured AI provider. + std::string generateSQL(const std::string & prompt); + +private: + chdb_connection * connection; + std::unique_ptr generator; + DB::AIConfiguration ai_config; + + std::string executeQueryForAI(const std::string & query); + std::string generateSQLFromPrompt(const std::string & prompt); + void initializeGenerator(); +}; diff --git a/programs/local/CMakeLists.txt b/programs/local/CMakeLists.txt index 968ec379344..7d109a1e4f2 100644 --- a/programs/local/CMakeLists.txt +++ b/programs/local/CMakeLists.txt @@ -44,6 +44,7 @@ if (USE_PYTHON) PandasDataFrame.cpp PandasDataFrameBuilder.cpp PandasScan.cpp + AIQueryProcessor.cpp PyArrowStreamFactory.cpp PyArrowTable.cpp PybindWrapper.cpp @@ -156,5 +157,9 @@ if (TARGET ch_contrib::pybind11_stubs) target_compile_definitions(clickhouse-local-lib PRIVATE Py_LIMITED_API=0x03080000) endif() +if (ENABLE_CLIENT_AI AND TARGET ch_contrib::ai-sdk-cpp) + target_link_libraries(clickhouse-local-lib PRIVATE ch_contrib::ai-sdk-cpp) +endif() + # Always use internal readpassphrase target_link_libraries(clickhouse-local-lib PRIVATE readpassphrase) diff --git a/programs/local/LocalChdb.cpp b/programs/local/LocalChdb.cpp index 30730234ba1..8d5b977b532 100644 --- a/programs/local/LocalChdb.cpp +++ b/programs/local/LocalChdb.cpp @@ -11,11 +11,18 @@ #include #include #include -#include #if USE_JEMALLOC # include #endif +#if USE_CLIENT_AI +# include "AIQueryProcessor.h" +#endif + +#include +#include +#include + namespace py = pybind11; extern bool inside_main = true; @@ -310,8 +317,13 @@ connection_wrapper::build_clickhouse_args(const std::string & path, const std::m connection_wrapper::connection_wrapper(const std::string & conn_str) { + is_readonly = false; auto [path, params] = parse_connection_string(conn_str); +#if USE_CLIENT_AI + applyAIParams(params); +#endif + auto argv = build_clickhouse_args(path, params); std::vector argv_char; argv_char.reserve(argv.size()); @@ -413,6 +425,28 @@ py::object connection_wrapper::query_df(const std::string & query_str, const py: return df; } +std::string connection_wrapper::generate_sql(const std::string & prompt) +{ +#if USE_CLIENT_AI + if (!ai_config.has_value()) + ai_config = DB::AIConfiguration{}; + + try + { + if (!ai_processor) + ai_processor = std::make_unique(conn, *ai_config); + return ai_processor->generateSQL(prompt); + } + catch (const std::exception & e) + { + throw std::runtime_error(std::string("AI SQL generation failed: ") + e.what()); + } +#else + (void)prompt; + throw std::runtime_error("AI SQL generation is not available in this build. Rebuild with USE_CLIENT_AI enabled."); +#endif +} + streaming_query_result * connection_wrapper::send_query(const std::string & query_str, const std::string & format, const py::dict & params) { const auto parsed_params = parseParametersDict(params); @@ -494,6 +528,119 @@ void connection_wrapper::streaming_cancel_query(streaming_query_result * streami chdb_stream_cancel_query(*conn, streaming_result->get_result()); } +#if USE_CLIENT_AI +void connection_wrapper::applyAIParams(std::map & params) +{ + DB::AIConfiguration config; + + auto consume_string = [&](const std::string & target, std::string DB::AIConfiguration::* field) + { + for (auto it = params.begin(); it != params.end();) + { + if (Poco::toLower(it->first) == target) + { + if (!it->second.empty()) + { + if ((config.*field).empty()) + config.*field = it->second; + } + it = params.erase(it); + } + else + { + ++it; + } + } + }; + + auto consume_double = [&](const std::string & target, double DB::AIConfiguration::* field) + { + for (auto it = params.begin(); it != params.end();) + { + if (Poco::toLower(it->first) == target) + { + if (!it->second.empty()) + { + try + { + config.*field = std::stod(it->second); + } + catch (...) + { + } + } + it = params.erase(it); + } + else + { + ++it; + } + } + }; + + auto consume_size_t = [&](const std::string & target, size_t DB::AIConfiguration::* field) + { + for (auto it = params.begin(); it != params.end();) + { + if (Poco::toLower(it->first) == target) + { + if (!it->second.empty()) + { + try + { + config.*field = static_cast(std::stoul(it->second)); + } + catch (...) + { + } + } + it = params.erase(it); + } + else + { + ++it; + } + } + }; + + auto consume_bool = [&](const std::string & target, bool DB::AIConfiguration::* field) + { + for (auto it = params.begin(); it != params.end();) + { + if (Poco::toLower(it->first) == target) + { + if (!it->second.empty()) + { + std::string val = Poco::toLower(it->second); + if (val == "1" || val == "true" || val == "yes" || val == "on") + config.*field = true; + else if (val == "0" || val == "false" || val == "no" || val == "off") + config.*field = false; + } + it = params.erase(it); + } + else + { + ++it; + } + } + }; + + consume_string("ai_api_key", &DB::AIConfiguration::api_key); + consume_string("ai_base_url", &DB::AIConfiguration::base_url); + consume_string("ai_model", &DB::AIConfiguration::model); + consume_string("ai_provider", &DB::AIConfiguration::provider); + consume_double("ai_temperature", &DB::AIConfiguration::temperature); + consume_size_t("ai_max_tokens", &DB::AIConfiguration::max_tokens); + consume_size_t("ai_timeout_seconds", &DB::AIConfiguration::timeout_seconds); + consume_string("ai_system_prompt", &DB::AIConfiguration::system_prompt); + consume_size_t("ai_max_steps", &DB::AIConfiguration::max_steps); + consume_bool("ai_enable_schema_access", &DB::AIConfiguration::enable_schema_access); + + ai_config = config; +} +#endif + void cursor_wrapper::execute(const std::string & query_str) { release_result(); @@ -652,6 +799,13 @@ PYBIND11_MODULE(_chdb, m) py::kw_only(), py::arg("params") = py::dict(), "Execute a query and return a DataFrame") +#if USE_CLIENT_AI + .def( + "generate_sql", + &connection_wrapper::generate_sql, + py::arg("prompt"), + "Generate SQL text from a natural language prompt using the configured AI provider") +#endif .def( "send_query", &connection_wrapper::send_query, diff --git a/programs/local/LocalChdb.h b/programs/local/LocalChdb.h index f78bdfcda82..f46cffb460b 100644 --- a/programs/local/LocalChdb.h +++ b/programs/local/LocalChdb.h @@ -5,6 +5,17 @@ #include "config.h" #include +#include + +#if USE_CLIENT_AI +# include +namespace DB +{ +class AISQLGenerator; +struct AIConfiguration; +} +class AIQueryProcessor; +#endif namespace py = pybind11; @@ -22,6 +33,10 @@ class connection_wrapper std::string db_path; bool is_memory_db; bool is_readonly; +#if USE_CLIENT_AI + std::unique_ptr ai_processor; + std::optional ai_config; +#endif public: explicit connection_wrapper(const std::string & conn_str); @@ -36,10 +51,16 @@ class connection_wrapper query_result * streaming_fetch_result(streaming_query_result * streaming_result); py::object streaming_fetch_df(streaming_query_result * streaming_result); void streaming_cancel_query(streaming_query_result * streaming_result); + std::string generate_sql(const std::string & prompt); // Move the private methods declarations here std::pair> parse_connection_string(const std::string & conn_str); std::vector build_clickhouse_args(const std::string & path, const std::map & params); + +#if USE_CLIENT_AI +private: + void applyAIParams(std::map & params); +#endif }; class local_result_wrapper diff --git a/tests/test_ai_query.py b/tests/test_ai_query.py new file mode 100644 index 00000000000..336a1dd446e --- /dev/null +++ b/tests/test_ai_query.py @@ -0,0 +1,71 @@ +import unittest +import shutil + +import chdb + +class TestAIQuery(unittest.TestCase): + def setUp(self): + base_url = "https://openrouter.ai/api" + ai_model = "z-ai/glm-4.5-air:free" + # API key should be set in environment variable `OPENAI_API_KEY` + # Explicitly set provider=openai so the engine honors the custom base URL + connection_str = f"file::memory:?ai_provider=openai&ai_base_url={base_url}&ai_model={ai_model}" + self.conn = chdb.connect(connection_str) + self.conn.query("CREATE TABLE users (id UInt32, name String) ENGINE = Memory") + self.conn.query("INSERT INTO users VALUES (1, 'alice'), (2, 'bob'), (3, 'carol')") + self.prompt = "List all rows from the users table ordered by id ascending." + + def tearDown(self): + self.conn.close() + + def _run_ai_or_skip(self, fn): + """ + Execute an AI-assisted call, skipping the test on common provider/configuration failures. + """ + try: + return fn() + except Exception as exc: + message = str(exc).lower() + if ( + "ai sql generation is not available" in message + or "ai sql generator is not configured" in message + or "unknown ai provider" in message + or "api key not provided" in message + # "unsupported country" "unsupported_country_region_territory" + or "unsupported" in message + or "syntax error" in message + # musl linux generic error message + or "caught an unknown exception" in message + ): + self.skipTest("AI provider not configured/enabled or not suppported") + raise + + def test_ai_query_lists_users_in_order(self): + """ + Run an AI-generated query via generate_sql. If AI is unavailable/configured, + the test is skipped to avoid false failures. + """ + def _gen_sql(): + generated_sql = self.conn.generate_sql(self.prompt) + return self.conn.query(generated_sql, format="DataFrame") + + df = self._run_ai_or_skip(_gen_sql) + self.assertFalse(df.empty) + self.assertListEqual(list(df["id"]), [1, 2, 3]) + self.assertListEqual(list(df["name"]), ["alice", "bob", "carol"]) + + def test_ai_ask_runs_prompt_and_returns_dataframe(self): + """ + Run a prompt end-to-end via ask() using default DataFrame output. Skip if + AI is unavailable/configured. + """ + df = self._run_ai_or_skip(lambda: self.conn.ask(self.prompt, format="DataFrame")) + + self.assertFalse(df.empty) + self.assertListEqual(list(df["id"]), [1, 2, 3]) + self.assertListEqual(list(df["name"]), ["alice", "bob", "carol"]) + + + +if __name__ == "__main__": + unittest.main()