import asyncio
import atexit
from collections.abc import Awaitable, Coroutine
from contextlib import AbstractAsyncContextManager, AbstractContextManager
from typing import TYPE_CHECKING, Any, TypeGuard, cast, overload
from typing_extensions import Self
from sqlspec.config import (
AsyncConfigT,
AsyncDatabaseConfig,
DatabaseConfigProtocol,
DriverT,
NoPoolAsyncConfig,
NoPoolSyncConfig,
SyncConfigT,
SyncDatabaseConfig,
)
from sqlspec.core import (
CacheConfig,
get_cache_config,
get_cache_statistics,
log_cache_stats,
reset_cache_stats,
update_cache_config,
)
from sqlspec.exceptions import ImproperConfigurationError
from sqlspec.extensions.events import AsyncEventChannel, SyncEventChannel
from sqlspec.loader import SQLFileLoader
from sqlspec.observability import ObservabilityConfig, ObservabilityRuntime, TelemetryDiagnostics
from sqlspec.typing import ConnectionT
from sqlspec.utils.logging import get_logger
from sqlspec.utils.type_guards import has_name
if TYPE_CHECKING:
from pathlib import Path
from sqlspec.core import SQL
from sqlspec.typing import PoolT
__all__ = ("SQLSpec",)
logger = get_logger()
def _is_async_context_manager(obj: Any) -> TypeGuard[AbstractAsyncContextManager[Any]]:
return isinstance(obj, AbstractAsyncContextManager)
class _RuntimeConnectionContext(AbstractContextManager[ConnectionT]):
def __init__(self, context: "AbstractContextManager[ConnectionT]", runtime: "ObservabilityRuntime") -> None:
self._context = context
self._runtime = runtime
self._connection: ConnectionT | None = None
def __enter__(self) -> ConnectionT:
self._connection = self._context.__enter__()
self._runtime.emit_connection_create(self._connection)
return self._connection
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> "bool | None":
try:
return self._context.__exit__(exc_type, exc_val, exc_tb)
finally:
if self._connection is not None:
self._runtime.emit_connection_destroy(self._connection)
self._connection = None
class _RuntimeAsyncConnectionContext(AbstractAsyncContextManager[ConnectionT]):
def __init__(self, context: "AbstractAsyncContextManager[ConnectionT]", runtime: "ObservabilityRuntime") -> None:
self._context = context
self._runtime = runtime
self._connection: ConnectionT | None = None
async def __aenter__(self) -> ConnectionT:
self._connection = await self._context.__aenter__()
self._runtime.emit_connection_create(self._connection)
return self._connection
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> "bool | None":
try:
return await self._context.__aexit__(exc_type, exc_val, exc_tb)
finally:
if self._connection is not None:
self._runtime.emit_connection_destroy(self._connection)
self._connection = None
class _RuntimeSessionContext(AbstractContextManager[DriverT]):
def __init__(
self,
context: "AbstractContextManager[DriverT]",
runtime: "ObservabilityRuntime",
config: "DatabaseConfigProtocol[Any, Any, DriverT]",
) -> None:
self._context = context
self._runtime = runtime
self._config = config
self._driver: DriverT | None = None
def __enter__(self) -> DriverT:
session = self._context.__enter__()
driver = self._config._prepare_driver(session) # pyright: ignore[reportPrivateUsage]
self._driver = driver
connection = driver.connection
if connection is not None:
self._runtime.emit_connection_create(connection)
self._runtime.emit_session_start(driver)
return driver
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> "bool | None":
try:
return self._context.__exit__(exc_type, exc_val, exc_tb)
finally:
if self._driver is not None:
self._runtime.emit_session_end(self._driver)
connection = self._driver.connection
if connection is not None:
self._runtime.emit_connection_destroy(connection)
self._driver = None
class _RuntimeAsyncSessionContext(AbstractAsyncContextManager[DriverT]):
def __init__(
self,
context: "AbstractAsyncContextManager[DriverT]",
runtime: "ObservabilityRuntime",
config: "DatabaseConfigProtocol[Any, Any, DriverT]",
) -> None:
self._context = context
self._runtime = runtime
self._config = config
self._driver: DriverT | None = None
async def __aenter__(self) -> DriverT:
session = await self._context.__aenter__()
driver = self._config._prepare_driver(session) # pyright: ignore[reportPrivateUsage]
self._driver = driver
connection = driver.connection
if connection is not None:
self._runtime.emit_connection_create(connection)
self._runtime.emit_session_start(driver)
return driver
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> "bool | None":
try:
return await self._context.__aexit__(exc_type, exc_val, exc_tb)
finally:
if self._driver is not None:
self._runtime.emit_session_end(self._driver)
connection = self._driver.connection
if connection is not None:
self._runtime.emit_connection_destroy(connection)
self._driver = None
[docs]
class SQLSpec:
"""Configuration manager and registry for database connections and pools."""
__slots__ = ("_configs", "_instance_cache_config", "_loader_runtime", "_observability_config", "_sql_loader")
[docs]
def __init__(
self, *, loader: "SQLFileLoader | None" = None, observability_config: "ObservabilityConfig | None" = None
) -> None:
self._configs: dict[int, DatabaseConfigProtocol[Any, Any, Any]] = {}
atexit.register(self._cleanup_sync_pools)
self._instance_cache_config: CacheConfig | None = None
self._sql_loader: SQLFileLoader | None = loader
self._observability_config = observability_config
self._loader_runtime = ObservabilityRuntime(observability_config, config_name="SQLFileLoader")
if self._sql_loader is not None:
self._sql_loader.set_observability_runtime(self._loader_runtime)
@staticmethod
def _get_config_name(obj: Any) -> str:
"""Get display name for configuration object."""
if isinstance(obj, str):
return obj
if has_name(obj):
return obj.__name__
return type(obj).__name__
def _cleanup_sync_pools(self) -> None:
"""Clean up only synchronous connection pools at exit."""
cleaned_count = 0
failed_configs: list[str] = []
for config in self._configs.values():
if config.supports_connection_pooling and not config.is_async:
failure = self._safe_close_pool(config)
if failure is None:
cleaned_count += 1
else:
failed_configs.append(failure)
if cleaned_count or failed_configs:
summary: dict[str, object] = {"cleaned_pools": cleaned_count, "failed_pools": len(failed_configs)}
if failed_configs:
summary["failures"] = failed_configs
logger.debug("Sync pool cleanup completed.", extra=summary)
[docs]
async def close_all_pools(self) -> None:
"""Explicitly close all connection pools (async and sync).
This method should be called before application shutdown for proper cleanup.
"""
cleanup_tasks = []
sync_configs: list[DatabaseConfigProtocol[Any, Any, Any]] = []
for config in self._configs.values():
if config.supports_connection_pooling:
try:
if config.is_async:
close_pool_awaitable = config.close_pool()
if close_pool_awaitable is not None:
cleanup_tasks.append(cast("Coroutine[Any, Any, None]", close_pool_awaitable)) # pyright: ignore
else:
sync_configs.append(config) # pyright: ignore
except Exception as e:
logger.debug("Failed to prepare cleanup for config %s: %s", config.__class__.__name__, e)
async_failures: list[str] = []
if cleanup_tasks:
try:
await asyncio.gather(*cleanup_tasks, return_exceptions=True) # pyright: ignore
except Exception as e:
async_failures.append(str(e))
for config in sync_configs: # pyright: ignore
failure = self._safe_close_pool(config)
if failure is not None:
async_failures.append(failure)
if cleanup_tasks or sync_configs or async_failures:
summary: dict[str, object] = {
"async_pools": len(cleanup_tasks),
"sync_pools": len(sync_configs),
"failures": async_failures,
}
logger.debug("Pool cleanup completed.", extra=summary)
@staticmethod
def _safe_close_pool(config: "DatabaseConfigProtocol[Any, Any, Any]") -> "str | None":
"""Close a pool, returning an error string when it fails."""
try:
config.close_pool()
except Exception as exc: # pragma: no cover - best effort cleanup
return f"{config.__class__.__name__}: {exc}"
return None
[docs]
async def __aenter__(self) -> Self:
"""Async context manager entry."""
return self
[docs]
async def __aexit__(self, _exc_type: Any, _exc_val: Any, _exc_tb: Any) -> None:
"""Async context manager exit with automatic cleanup."""
await self.close_all_pools()
@overload
def add_config(self, config: "SyncConfigT") -> "SyncConfigT": ...
@overload
def add_config(self, config: "AsyncConfigT") -> "AsyncConfigT": ...
[docs]
def add_config(self, config: "SyncConfigT | AsyncConfigT") -> "SyncConfigT | AsyncConfigT":
"""Add a configuration instance to the registry.
Args:
config: The configuration instance to add.
Returns:
The same configuration instance (it IS the handle).
"""
config_id = id(config)
if config_id in self._configs:
logger.debug("Configuration for %s already exists. Overwriting.", config.__class__.__name__)
config.attach_observability(self._observability_config)
self._configs[config_id] = config
return config
@property
def configs(self) -> "dict[int, DatabaseConfigProtocol[Any, Any, Any]]":
"""Access the registry of database configurations.
Returns:
Dictionary mapping config instance IDs to config instances.
"""
return self._configs
@overload
def event_channel(self, config: "type[SyncConfigT]") -> "SyncEventChannel": ...
@overload
def event_channel(self, config: "type[AsyncConfigT]") -> "AsyncEventChannel": ...
@overload
def event_channel(
self, config: "SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any]"
) -> "SyncEventChannel": ...
@overload
def event_channel(
self, config: "AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]"
) -> "AsyncEventChannel": ...
[docs]
def event_channel(
self,
config: "type[SyncConfigT | AsyncConfigT] | SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]",
) -> "SyncEventChannel | AsyncEventChannel":
"""Create an event channel for the provided configuration.
Returns SyncEventChannel for sync configs, AsyncEventChannel for async configs.
Args:
config: A registered database configuration instance or type.
Returns:
The appropriate event channel type for the configuration.
"""
if isinstance(config, type):
config_obj: DatabaseConfigProtocol[Any, Any, Any] | None = None
for registered_config in self._configs.values():
if isinstance(registered_config, config):
config_obj = registered_config
break
if config_obj is None:
msg = f"Configuration {self._get_config_name(config)} is not registered"
raise ImproperConfigurationError(msg)
if config_obj.is_async:
return AsyncEventChannel(config_obj) # type: ignore[arg-type]
return SyncEventChannel(config_obj) # type: ignore[arg-type]
if config.is_async:
return AsyncEventChannel(config) # type: ignore[arg-type]
return SyncEventChannel(config) # type: ignore[arg-type]
[docs]
def telemetry_snapshot(self) -> "dict[str, Any]":
"""Return aggregated diagnostics across all registered configurations."""
diagnostics = TelemetryDiagnostics()
loader_metrics = self._loader_runtime.metrics_snapshot()
if loader_metrics:
diagnostics.add_metric_snapshot(loader_metrics)
for config in self._configs.values():
runtime = config.get_observability_runtime()
diagnostics.add_lifecycle_snapshot(runtime.diagnostics_key, runtime.lifecycle_snapshot())
metrics_snapshot = runtime.metrics_snapshot()
if metrics_snapshot:
diagnostics.add_metric_snapshot(metrics_snapshot)
return diagnostics.snapshot()
def _ensure_sql_loader(self) -> SQLFileLoader:
"""Return a SQLFileLoader instance configured with observability runtime."""
if self._sql_loader is None:
self._sql_loader = SQLFileLoader(runtime=self._loader_runtime)
else:
self._sql_loader.set_observability_runtime(self._loader_runtime)
return self._sql_loader
@overload
def get_connection(
self, config: "NoPoolSyncConfig[ConnectionT, DriverT] | SyncDatabaseConfig[ConnectionT, PoolT, DriverT]"
) -> "ConnectionT": ...
@overload
def get_connection(
self, config: "NoPoolAsyncConfig[ConnectionT, DriverT] | AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]"
) -> "Awaitable[ConnectionT]": ...
[docs]
def get_connection(
self,
config: "NoPoolSyncConfig[ConnectionT, DriverT] | SyncDatabaseConfig[ConnectionT, PoolT, DriverT] | NoPoolAsyncConfig[ConnectionT, DriverT] | AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]",
) -> "ConnectionT | Awaitable[ConnectionT]":
"""Get a database connection for the specified configuration.
Args:
config: The configuration instance.
Returns:
A database connection or an awaitable yielding a connection.
"""
if id(config) not in self._configs:
self.add_config(config)
return config.create_connection()
@overload
def get_session(
self, config: "NoPoolSyncConfig[ConnectionT, DriverT] | SyncDatabaseConfig[ConnectionT, PoolT, DriverT]"
) -> "DriverT": ...
@overload
def get_session(
self, config: "NoPoolAsyncConfig[ConnectionT, DriverT] | AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]"
) -> "Awaitable[DriverT]": ...
[docs]
def get_session(
self,
config: "NoPoolSyncConfig[ConnectionT, DriverT] | SyncDatabaseConfig[ConnectionT, PoolT, DriverT] | NoPoolAsyncConfig[ConnectionT, DriverT] | AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]",
) -> "DriverT | Awaitable[DriverT]":
"""Get a database session (driver adapter) for the specified configuration.
Args:
config: The configuration instance.
Returns:
A driver adapter instance or an awaitable yielding one.
"""
if id(config) not in self._configs:
self.add_config(config)
connection_obj = self.get_connection(config)
if isinstance(connection_obj, Awaitable):
async_config = cast(
"NoPoolAsyncConfig[ConnectionT, DriverT] | AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]", config
)
return self._create_driver_async(async_config, connection_obj) # pyright: ignore
driver = config.driver_type( # pyright: ignore
connection=connection_obj, statement_config=config.statement_config, driver_features=config.driver_features
)
return config._prepare_driver(driver) # pyright: ignore
async def _create_driver_async(
self,
config: "NoPoolAsyncConfig[ConnectionT, DriverT] | AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]",
connection_obj: "Awaitable[ConnectionT]",
) -> "DriverT":
resolved_connection = await connection_obj
driver = config.driver_type( # pyright: ignore
connection=resolved_connection,
statement_config=config.statement_config,
driver_features=config.driver_features,
)
return config._prepare_driver(driver) # pyright: ignore
@overload
def provide_connection(
self,
config: "NoPoolSyncConfig[ConnectionT, DriverT] | SyncDatabaseConfig[ConnectionT, PoolT, DriverT]",
*args: Any,
**kwargs: Any,
) -> "AbstractContextManager[ConnectionT]": ...
@overload
def provide_connection(
self,
config: "NoPoolAsyncConfig[ConnectionT, DriverT] | AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]",
*args: Any,
**kwargs: Any,
) -> "AbstractAsyncContextManager[ConnectionT]": ...
[docs]
def provide_connection(
self,
config: "NoPoolSyncConfig[ConnectionT, DriverT] | SyncDatabaseConfig[ConnectionT, PoolT, DriverT] | NoPoolAsyncConfig[ConnectionT, DriverT] | AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]",
*args: Any,
**kwargs: Any,
) -> "AbstractContextManager[ConnectionT] | AbstractAsyncContextManager[ConnectionT]":
"""Create and provide a database connection from the specified configuration.
Args:
config: The configuration instance.
*args: Positional arguments to pass to the config's provide_connection.
**kwargs: Keyword arguments to pass to the config's provide_connection.
Returns:
A sync or async context manager yielding a connection.
"""
if id(config) not in self._configs:
self.add_config(config)
connection_context = config.provide_connection(*args, **kwargs)
runtime = config.get_observability_runtime()
if _is_async_context_manager(connection_context):
async_context = cast("AbstractAsyncContextManager[ConnectionT]", connection_context)
return _RuntimeAsyncConnectionContext(async_context, runtime)
sync_context = cast("AbstractContextManager[ConnectionT]", connection_context)
return _RuntimeConnectionContext(sync_context, runtime)
@overload
def provide_session(
self,
config: "NoPoolSyncConfig[ConnectionT, DriverT] | SyncDatabaseConfig[ConnectionT, PoolT, DriverT]",
*args: Any,
**kwargs: Any,
) -> "AbstractContextManager[DriverT]": ...
@overload
def provide_session(
self,
config: "NoPoolAsyncConfig[ConnectionT, DriverT] | AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]",
*args: Any,
**kwargs: Any,
) -> "AbstractAsyncContextManager[DriverT]": ...
[docs]
def provide_session(
self,
config: "NoPoolSyncConfig[ConnectionT, DriverT] | SyncDatabaseConfig[ConnectionT, PoolT, DriverT] | NoPoolAsyncConfig[ConnectionT, DriverT] | AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]",
*args: Any,
**kwargs: Any,
) -> "AbstractContextManager[DriverT] | AbstractAsyncContextManager[DriverT]":
"""Create and provide a database session from the specified configuration.
Args:
config: The configuration instance.
*args: Positional arguments to pass to the config's provide_session.
**kwargs: Keyword arguments to pass to the config's provide_session.
Returns:
A sync or async context manager yielding a driver adapter instance.
"""
if id(config) not in self._configs:
self.add_config(config)
session_context = config.provide_session(*args, **kwargs)
runtime = config.get_observability_runtime()
if _is_async_context_manager(session_context):
async_session = cast("AbstractAsyncContextManager[DriverT]", session_context)
return _RuntimeAsyncSessionContext(async_session, runtime, config)
sync_session = cast("AbstractContextManager[DriverT]", session_context)
return _RuntimeSessionContext(sync_session, runtime, config)
@overload
def get_pool(
self, config: "NoPoolSyncConfig[ConnectionT, DriverT] | NoPoolAsyncConfig[ConnectionT, DriverT]"
) -> "None": ...
@overload
def get_pool(self, config: "SyncDatabaseConfig[ConnectionT, PoolT, DriverT]") -> "type[PoolT]": ...
@overload
def get_pool(self, config: "AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]") -> "Awaitable[type[PoolT]]": ...
[docs]
def get_pool(
self,
config: "NoPoolSyncConfig[ConnectionT, DriverT] | NoPoolAsyncConfig[ConnectionT, DriverT] | SyncDatabaseConfig[ConnectionT, PoolT, DriverT] | AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]",
) -> "type[PoolT] | Awaitable[type[PoolT]] | None":
"""Get the connection pool for the specified configuration.
Args:
config: The configuration instance.
Returns:
The connection pool, an awaitable yielding the pool, or None if not supported.
"""
if id(config) not in self._configs:
self.add_config(config)
if config.supports_connection_pooling:
return cast("type[PoolT] | Awaitable[type[PoolT]]", config.create_pool())
return None
@overload
def close_pool(
self, config: "NoPoolSyncConfig[ConnectionT, DriverT] | SyncDatabaseConfig[ConnectionT, PoolT, DriverT]"
) -> "None": ...
@overload
def close_pool(
self, config: "NoPoolAsyncConfig[ConnectionT, DriverT] | AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]"
) -> "Awaitable[None]": ...
[docs]
def close_pool(
self,
config: "NoPoolSyncConfig[ConnectionT, DriverT] | SyncDatabaseConfig[ConnectionT, PoolT, DriverT] | NoPoolAsyncConfig[ConnectionT, DriverT] | AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]",
) -> "Awaitable[None] | None":
"""Close the connection pool for the specified configuration.
Args:
config: The configuration instance.
Returns:
None, or an awaitable if closing an async pool.
"""
if id(config) not in self._configs:
self.add_config(config)
if config.supports_connection_pooling:
return config.close_pool()
return None
[docs]
@staticmethod
def get_cache_config() -> CacheConfig:
"""Get the current global cache configuration.
Returns:
The current cache configuration.
"""
return get_cache_config()
[docs]
@staticmethod
def update_cache_config(config: CacheConfig) -> None:
"""Update the global cache configuration.
Args:
config: The new cache configuration to apply.
"""
update_cache_config(config)
[docs]
@staticmethod
def get_cache_stats() -> "dict[str, Any]":
"""Get current cache statistics.
Returns:
Cache statistics object with detailed metrics.
"""
return get_cache_statistics()
[docs]
@staticmethod
def reset_cache_stats() -> None:
"""Reset all cache statistics to zero."""
reset_cache_stats()
[docs]
@staticmethod
def log_cache_stats() -> None:
"""Log current cache statistics using the configured logger."""
log_cache_stats()
[docs]
def load_sql_files(self, *paths: "str | Path") -> None:
"""Load SQL files from paths or directories.
Args:
*paths: One or more file paths or directory paths to load.
"""
loader = self._ensure_sql_loader()
loader.load_sql(*paths)
logger.debug("Loaded SQL files: %s", paths)
[docs]
def add_named_sql(self, name: str, sql: str, dialect: "str | None" = None) -> None:
"""Add a named SQL query directly.
Args:
name: Name for the SQL query.
sql: Raw SQL content.
dialect: Optional dialect for the SQL statement.
"""
loader = self._ensure_sql_loader()
loader.add_named_sql(name, sql, dialect)
logger.debug("Added named SQL: %s", name)
[docs]
def get_sql(self, name: str) -> "SQL":
"""Get a SQL object by name.
Args:
name: Name of the statement from SQL file comments.
Hyphens in names are converted to underscores.
Returns:
SQL object ready for execution.
"""
loader = self._ensure_sql_loader()
return loader.get_sql(name)
[docs]
def list_sql_queries(self) -> "list[str]":
"""List all available query names.
Returns:
Sorted list of query names.
"""
if self._sql_loader is None:
return []
return self._sql_loader.list_queries()
[docs]
def has_sql_query(self, name: str) -> bool:
"""Check if a SQL query exists.
Args:
name: Query name to check.
Returns:
True if the query exists in the loader.
"""
if self._sql_loader is None:
return False
return self._sql_loader.has_query(name)
[docs]
def clear_sql_cache(self) -> None:
"""Clear the SQL file cache."""
if self._sql_loader is not None:
self._sql_loader.clear_cache()
logger.debug("Cleared SQL cache")
[docs]
def reload_sql_files(self) -> None:
"""Reload all SQL files.
Note:
This clears the cache and requires calling load_sql_files again.
"""
if self._sql_loader is not None:
self._sql_loader.clear_cache()
logger.debug("Cleared SQL cache for reload")
[docs]
def get_sql_files(self) -> "list[str]":
"""Get list of loaded SQL files.
Returns:
Sorted list of file paths.
"""
if self._sql_loader is None:
return []
return self._sql_loader.list_files()