Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions singlestoredb/ibis_extras/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""SingleStoreDB extensions for Ibis.

This package adds SingleStoreDB-specific features to the Ibis backend.
Features are automatically registered on import.

Usage
-----
>>> import ibis
>>> import singlestoredb.ibis_extras # Auto-registers extensions
>>>
>>> con = ibis.singlestoredb.connect(host="...", database="...")
>>>
>>> # Variable accessors (from old ibis_singlestoredb)
>>> con.show.databases()
>>> con.globals["max_connections"]
>>> con.vars["autocommit"]
>>>
>>> # Backend methods
>>> con.get_storage_info()
>>> con.get_workload_metrics()
>>>
>>> # Table methods (work on any table from SingleStoreDB)
>>> t = con.table("users")
>>> t.optimize()
>>> t.get_stats()
"""
from __future__ import annotations

import warnings

from .mixins import BackendExtensionsMixin
from .mixins import TableExtensionsMixin

__all__ = [
'BackendExtensionsMixin',
'TableExtensionsMixin',
'is_registered',
'register',
]

_registered = False


def _check_collisions(cls: type, mixin: type) -> None:
"""Check for method collisions between mixin and target class."""
mixin_attrs = {
name
for name in dir(mixin)
if not name.startswith('_') and callable(getattr(mixin, name, None))
}
mixin_props = {
name
for name in dir(mixin)
if not name.startswith('_')
and isinstance(getattr(mixin, name, None), property)
}
mixin_members = mixin_attrs | mixin_props

existing_attrs = {name for name in dir(cls) if not name.startswith('_')}

collisions = mixin_members & existing_attrs
if collisions:
warnings.warn(
f'Mixin {mixin.__name__} has methods that collide with '
f'{cls.__name__}: {collisions}',
stacklevel=3,
)


def register() -> None:
"""Register mixins on Backend and ir.Table.

This is called automatically on import, but can be called
explicitly if needed.
"""
global _registered # noqa: PLW0603
if _registered:
return

try:
import ibis.expr.types as ir
from ibis.backends.singlestoredb import Backend
except ImportError as e:
raise ImportError(
'ibis_extras requires ibis with singlestoredb backend. '
'Install with: pip install "singlestoredb[ibis]"',
) from e

# Check for collisions before adding mixins
_check_collisions(Backend, BackendExtensionsMixin)
_check_collisions(ir.Table, TableExtensionsMixin)

# Add mixin to Backend
if BackendExtensionsMixin not in Backend.__bases__:
Backend.__bases__ = (BackendExtensionsMixin,) + Backend.__bases__

# Add mixin to ir.Table
if TableExtensionsMixin not in ir.Table.__bases__:
ir.Table.__bases__ = (TableExtensionsMixin,) + ir.Table.__bases__

_registered = True


def is_registered() -> bool:
"""Check if extensions have been registered."""
return _registered


# Auto-register on import
register()
191 changes: 191 additions & 0 deletions singlestoredb/ibis_extras/mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
"""Mixin classes for SingleStoreDB extensions."""
from __future__ import annotations

from typing import Any
from typing import Literal
from typing import Protocol
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from contextlib import AbstractContextManager

import ibis.expr.types as ir

class _BackendProtocol(Protocol):
"""Protocol defining backend interface used by BackendExtensionsMixin."""

_client: Any

@property
def current_database(self) -> str: ...
def sql(self, query: str) -> ir.Table: ...
def raw_sql(self, query: str) -> AbstractContextManager[Any]: ...

class _TableProtocol(Protocol):
"""Protocol defining table interface used by TableExtensionsMixin."""

def get_name(self) -> str: ...
def op(self) -> Any: ...

_BackendBase: type = _BackendProtocol
_TableBase: type = _TableProtocol
else:
_BackendBase = object
_TableBase = object


def _quote_identifier(name: str) -> str:
"""Quote an identifier (table, database, column name) for safe SQL usage."""
# Escape backticks by doubling them (MySQL/SingleStore convention)
escaped = name.replace('`', '``')
return f'`{escaped}`'


def _escape_string_literal(value: str) -> str:
"""Escape a string value for use in SQL string literals."""
# Escape single quotes by doubling them, and escape backslashes
return value.replace('\\', '\\\\').replace("'", "''")


def _get_table_backend_and_db(
table: ir.Table,
*,
escape: Literal['identifier', 'literal'] | None = None,
) -> tuple[BackendExtensionsMixin, str]:
"""Get SingleStoreDB backend and database from table.

Parameters
----------
table
The Ibis table object.
escape
How to escape the database name:
- None: return unescaped
- 'identifier': escape for SQL identifiers (backticks)
- 'literal': escape for string literals (quotes)
"""
op = table.op()
if hasattr(op, 'source') and op.source.name == 'singlestoredb':
db = getattr(getattr(op, 'namespace', None), 'database', None)
db = db or op.source.current_database
if escape == 'identifier':
db = _quote_identifier(db)
elif escape == 'literal':
db = _escape_string_literal(db)
return op.source, db # type: ignore[return-value]
raise TypeError(
f'This method only works with SingleStoreDB tables, '
f"got {getattr(op.source, 'name', 'unknown')} backend",
)


class BackendExtensionsMixin(_BackendBase):
"""Mixin for SingleStoreDB Backend extensions."""

__slots__ = ()

# --- Variable/Show accessors from old ibis_singlestoredb package ---

@property
def show(self) -> Any:
"""Access to SHOW commands on the server."""
return self._client.show

@property
def globals(self) -> Any:
"""Accessor for global variables in the server."""
return self._client.globals

@property
def locals(self) -> Any:
"""Accessor for local variables in the server."""
return self._client.locals

@property
def cluster_globals(self) -> Any:
"""Accessor for cluster global variables in the server."""
return self._client.cluster_globals

@property
def cluster_locals(self) -> Any:
"""Accessor for cluster local variables in the server."""
return self._client.cluster_locals

@property
def vars(self) -> Any:
"""Accessor for variables in the server."""
return self._client.vars

@property
def cluster_vars(self) -> Any:
"""Accessor for cluster variables in the server."""
return self._client.cluster_vars

# --- New extension methods ---

def get_storage_info(self, database: str | None = None) -> ir.Table:
"""Get storage statistics for tables in a database.

Parameters
----------
database
Database name. Defaults to current database.

Returns
-------
ir.Table
Table with storage statistics.
"""
db = _escape_string_literal(database or self.current_database)
# S608: db is escaped via _escape_string_literal
query = f"""
SELECT * FROM information_schema.table_statistics
WHERE database_name = '{db}'
""" # noqa: S608
return self.sql(query)

def get_workload_metrics(self) -> ir.Table:
"""Get workload management metrics."""
return self.sql(
'SELECT * FROM information_schema.mv_workload_management_events',
)


class TableExtensionsMixin(_TableBase):
"""Mixin for ir.Table extensions (SingleStoreDB only)."""

__slots__ = ()

def optimize(self) -> None:
"""Optimize this table (SingleStoreDB only)."""
backend, db = _get_table_backend_and_db(self, escape='identifier')
table = _quote_identifier(self.get_name())
with backend.raw_sql(f'OPTIMIZE TABLE {db}.{table} FULL'):
pass

def get_stats(self) -> ir.Table:
"""Get statistics for this table (SingleStoreDB only)."""
backend, db = _get_table_backend_and_db(self, escape='literal')
table = _escape_string_literal(self.get_name())
# S608: db and table are escaped via _escape_string_literal
return backend.sql(
f"""
SELECT * FROM information_schema.table_statistics
WHERE database_name = '{db}' AND table_name = '{table}'
""", # noqa: S608
)

def get_column_statistics(self, column: str | None = None) -> ir.Table:
"""Get column statistics (SingleStoreDB only).

Parameters
----------
column
Specific column name, or None for all columns.
"""
backend, db = _get_table_backend_and_db(self, escape='identifier')
table = _quote_identifier(self.get_name())
query = f'SHOW COLUMNAR_SEGMENT_INDEX ON {db}.{table}'
if column:
query += f' COLUMNS ({_quote_identifier(column)})'
return backend.sql(query)