"""SQL statement and configuration management."""
import hashlib
import uuid
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Final, TypeAlias
import sqlglot
from mypy_extensions import mypyc_attr
from sqlglot import exp
from sqlglot.errors import ParseError
import sqlspec.exceptions
from sqlspec.core import pipeline
from sqlspec.core._pool import get_processed_state_pool, get_sql_pool
from sqlspec.core.cache import FiltersView
from sqlspec.core.compiler import OperationProfile, OperationType
from sqlspec.core.explain import ExplainFormat, ExplainOptions
from sqlspec.core.hashing import hash_filters
from sqlspec.core.parameters import (
ParameterConverter,
ParameterProcessor,
ParameterProfile,
ParameterStyle,
ParameterStyleConfig,
ParameterValidator,
structural_fingerprint,
)
from sqlspec.core.query_modifiers import (
apply_column_pruning,
apply_limit,
apply_offset,
apply_select_only,
apply_where,
create_between_condition,
create_condition,
create_in_condition,
create_not_in_condition,
expr_eq,
expr_gt,
expr_gte,
expr_ilike,
expr_is_not_null,
expr_is_null,
expr_like,
expr_lt,
expr_lte,
expr_neq,
extract_column_name,
safe_modify_with_cte,
)
from sqlspec.typing import Empty, EmptyEnum
from sqlspec.utils.logging import get_logger
from sqlspec.utils.type_guards import is_statement_filter, supports_where
if TYPE_CHECKING:
from collections.abc import Callable
from sqlglot.dialects.dialect import DialectType
from sqlspec.builder import QueryBuilder
from sqlspec.core.filters import StatementFilter
__all__ = (
"SQL",
"ProcessedState",
"Statement",
"StatementConfig",
"get_default_config",
"get_default_parameter_config",
)
logger = get_logger("sqlspec.core.statement")
RETURNS_ROWS_OPERATIONS: Final = {"SELECT", "WITH", "VALUES", "TABLE", "SHOW", "DESCRIBE", "PRAGMA"}
MODIFYING_OPERATIONS: Final = {"INSERT", "UPDATE", "DELETE", "MERGE", "UPSERT"}
_ORDER_PARTS_COUNT: Final = 2
_MAX_PARAM_COLLISION_ATTEMPTS: Final = 1000
SQL_CONFIG_SLOTS: Final = (
"dialect",
"enable_analysis",
"enable_caching",
"enable_column_pruning",
"enable_expression_simplification",
"enable_parameter_type_wrapping",
"enable_parsing",
"enable_transformations",
"enable_validation",
"execution_mode",
"execution_args",
"output_transformer",
"statement_transformers",
"parameter_config",
"parameter_converter",
"parameter_validator",
"_fingerprint_cache",
"_hash_cache",
"_is_frozen",
)
PROCESSED_STATE_SLOTS: Final = (
"compiled_sql",
"execution_parameters",
"parsed_expression",
"operation_type",
"input_named_parameters",
"applied_wrap_types",
"filter_hash",
"parameter_fingerprint",
"parameter_casts",
"parameter_profile",
"operation_profile",
"validation_errors",
"is_many",
)
@mypyc_attr(allow_interpreted_subclasses=False)
class ProcessedState:
"""Processing results for SQL statements.
Contains the compiled SQL, execution parameters, parsed expression,
operation type, and validation errors for a processed SQL statement.
"""
__slots__ = PROCESSED_STATE_SLOTS
operation_type: "OperationType"
def __init__(
self,
compiled_sql: str,
execution_parameters: Any,
parsed_expression: "exp.Expression | None" = None,
operation_type: "OperationType" = "COMMAND",
input_named_parameters: "tuple[str, ...] | None" = None,
applied_wrap_types: bool = False,
filter_hash: int = 0,
parameter_fingerprint: str | None = None,
parameter_casts: "dict[int, str] | None" = None,
validation_errors: "list[str] | None" = None,
parameter_profile: "ParameterProfile | None" = None,
operation_profile: "OperationProfile | None" = None,
is_many: bool = False,
) -> None:
self.compiled_sql = compiled_sql
self.execution_parameters = execution_parameters
self.parsed_expression = parsed_expression
self.operation_type = operation_type
self.input_named_parameters = input_named_parameters or ()
self.applied_wrap_types = applied_wrap_types
self.filter_hash = filter_hash
self.parameter_fingerprint = parameter_fingerprint
self.parameter_casts = parameter_casts or {}
self.validation_errors = validation_errors or []
self.parameter_profile = parameter_profile or ParameterProfile.empty()
self.operation_profile = operation_profile or OperationProfile.empty()
self.is_many = is_many
def __hash__(self) -> int:
return hash((self.compiled_sql, str(self.execution_parameters), self.operation_type))
def reset(self) -> None:
"""Reset processing state for reuse."""
self.compiled_sql = ""
self.execution_parameters = []
self.parsed_expression = None
self.operation_type = "COMMAND"
self.input_named_parameters = ()
self.applied_wrap_types = False
self.filter_hash = 0
self.parameter_fingerprint = None
self.parameter_casts.clear()
self.validation_errors.clear()
self.parameter_profile = ParameterProfile.empty()
self.operation_profile = OperationProfile.empty()
self.is_many = False
[docs]
@mypyc_attr(allow_interpreted_subclasses=False)
class SQL:
"""SQL statement with parameter and filter support.
Represents a SQL statement that can be compiled with parameters and filters.
Supports both positional and named parameters, statement filtering,
and various execution modes including batch operations.
"""
__slots__ = (
"_compiled_from_cache",
"_dialect",
"_filters",
"_hash",
"_is_many",
"_is_script",
"_named_parameters",
"_original_parameters",
"_pooled",
"_positional_parameters",
"_processed_state",
"_raw_expression",
"_raw_sql",
"_sql_param_counters",
"_statement_config",
)
# Type annotation for mypyc compatibility
_sql_param_counters: "dict[str, int]"
[docs]
def __init__(
self,
statement: "str | exp.Expression | 'SQL'",
*parameters: "Any | StatementFilter | list[Any | StatementFilter]",
statement_config: "StatementConfig | None" = None,
is_many: bool | None = None,
**kwargs: Any,
) -> None:
"""Initialize SQL statement.
Args:
statement: SQL string, expression, or existing SQL object
*parameters: Parameters and filters
statement_config: Configuration
is_many: Mark as execute_many operation
**kwargs: Additional parameters
"""
config = statement_config or self._create_auto_config(statement, parameters, kwargs)
self._statement_config = config
self._dialect = self._normalize_dialect(config.dialect)
self._compiled_from_cache = False
self._pooled = False
self._processed_state: EmptyEnum | ProcessedState = Empty
self._hash: int | None = None
self._filters: list[StatementFilter] = []
self._named_parameters: dict[str, Any] = {}
self._positional_parameters: list[Any] = []
self._sql_param_counters = {}
self._is_script = False
self._raw_expression: exp.Expression | None = None
if isinstance(statement, SQL):
self._init_from_sql_object(statement)
if is_many is not None:
self._is_many = is_many
else:
if isinstance(statement, str):
self._raw_sql = statement
else:
dialect = self._dialect
self._raw_sql = statement.sql(dialect=str(dialect) if dialect else None)
self._raw_expression = statement
self._is_many = is_many if is_many is not None else self._should_auto_detect_many(parameters)
self._original_parameters = parameters
self._process_parameters(*parameters, **kwargs)
def _create_auto_config(
self, _statement: "str | exp.Expression | 'SQL'", _parameters: tuple, _kwargs: "dict[str, Any]"
) -> "StatementConfig":
"""Create default StatementConfig when none provided.
Args:
_statement: The SQL statement (unused)
_parameters: Statement parameters (unused)
_kwargs: Additional keyword arguments (unused)
Returns:
Default StatementConfig instance
"""
return get_default_config()
[docs]
def reset(self) -> None:
"""Reset SQL object for reuse in pooling scenarios."""
if self._pooled and not self._compiled_from_cache and self._processed_state is not Empty:
get_processed_state_pool().release(self._processed_state)
self._compiled_from_cache = False
self._processed_state = Empty
self._hash = None
self._filters.clear()
self._named_parameters.clear()
self._positional_parameters.clear()
self._sql_param_counters.clear()
self._original_parameters = ()
self._is_many = False
self._is_script = False
self._raw_expression = None
self._raw_sql = ""
self._statement_config = get_default_config()
self._dialect = self._normalize_dialect(self._statement_config.dialect)
def _normalize_dialect(self, dialect: "DialectType") -> "str | None":
"""Convert dialect to string representation.
Args:
dialect: Dialect type, string, or None
Returns:
String representation of the dialect or None
"""
if dialect is None:
return None
if isinstance(dialect, str):
return dialect
return dialect.__class__.__name__.lower()
def _init_from_sql_object(self, sql_obj: "SQL") -> None:
"""Initialize instance attributes from existing SQL object.
Args:
sql_obj: Existing SQL object to copy from
"""
self._raw_sql = sql_obj.raw_sql
self._raw_expression = sql_obj.raw_expression
self._filters = sql_obj.filters.copy()
self._named_parameters = sql_obj.named_parameters.copy()
self._positional_parameters = sql_obj.positional_parameters.copy()
self._sql_param_counters = sql_obj._sql_param_counters.copy()
self._is_many = sql_obj.is_many
self._is_script = sql_obj.is_script
if sql_obj.is_processed:
self._processed_state = sql_obj.get_processed_state()
def _should_auto_detect_many(self, parameters: tuple) -> bool:
"""Detect execute_many mode from parameter structure.
Args:
parameters: Parameter tuple to analyze
Returns:
True if parameters indicate batch execution
"""
if len(parameters) == 1 and isinstance(parameters[0], list):
param_list = parameters[0]
if not param_list:
return False
# Optimization: Check only the first element for batch structure
# O(1) check instead of O(N) scan
first_item = param_list[0]
if isinstance(first_item, (tuple, list, dict)):
return len(param_list) > 1
return False
def _process_parameters(self, *parameters: Any, dialect: str | None = None, **kwargs: Any) -> None:
"""Process and organize parameters and filters.
Args:
*parameters: Variable parameters and filters
dialect: SQL dialect override
**kwargs: Additional named parameters
"""
if dialect is not None:
self._dialect = self._normalize_dialect(dialect)
if "is_script" in kwargs:
self._is_script = bool(kwargs.pop("is_script"))
self._filters.extend(self._extract_filters(parameters))
self._normalize_parameters(parameters)
self._named_parameters.update(kwargs)
def _extract_filters(self, parameters: "tuple[Any, ...]") -> "list[StatementFilter]":
return [p for p in parameters if is_statement_filter(p)]
def _normalize_parameters(self, parameters: "tuple[Any, ...]") -> None:
if not parameters:
return
# Optimization: Fast path for single parameter (most common case)
if len(parameters) == 1:
param = parameters[0]
# Fast check for simple types before filter check
if isinstance(param, (str, int, float, bool)) or param is None:
self._positional_parameters.append(param)
return
if is_statement_filter(param):
return
if isinstance(param, dict):
self._named_parameters.update(param)
elif isinstance(param, (list, tuple)):
if self._is_many:
self._positional_parameters = list(param)
else:
self._positional_parameters.extend(param)
else:
self._positional_parameters.append(param)
return
# Multiple parameters: check for filters
# O(N) check only if we have more than 1 param
has_filter = any(is_statement_filter(p) for p in parameters)
if not has_filter:
self._positional_parameters.extend(parameters)
return
actual_params = [p for p in parameters if not is_statement_filter(p)]
if not actual_params:
return
if len(actual_params) == 1:
param = actual_params[0]
if isinstance(param, dict):
self._named_parameters.update(param)
elif isinstance(param, (list, tuple)):
if self._is_many:
self._positional_parameters = list(param)
else:
self._positional_parameters.extend(param)
else:
self._positional_parameters.append(param)
else:
self._positional_parameters.extend(actual_params)
@property
def sql(self) -> str:
"""Get the raw SQL string."""
return self._raw_sql
@property
def raw_sql(self) -> str:
"""Get raw SQL string (public API).
Returns:
The raw SQL string
"""
return self._raw_sql
@property
def parameters(self) -> Any:
"""Get the original parameters."""
if self._named_parameters:
return self._named_parameters
return self._positional_parameters or []
@property
def positional_parameters(self) -> "list[Any]":
"""Get positional parameters (public API)."""
return self._positional_parameters or []
@property
def named_parameters(self) -> "dict[str, Any]":
"""Get named parameters (public API)."""
return self._named_parameters
@property
def original_parameters(self) -> Any:
"""Get original parameters (public API)."""
return self._original_parameters
@property
def operation_type(self) -> "OperationType":
"""SQL operation type."""
if self._processed_state is Empty:
return "COMMAND"
return self._processed_state.operation_type
@property
def statement_config(self) -> "StatementConfig":
"""Statement configuration."""
return self._statement_config
@property
def expression(self) -> "exp.Expression | None":
"""SQLGlot expression."""
if self._processed_state is not Empty:
return self._processed_state.parsed_expression
return self._raw_expression
@property
def raw_expression(self) -> "exp.Expression | None":
"""Original expression supplied at construction, if available."""
return self._raw_expression
@property
def filters(self) -> "list[StatementFilter]":
"""Applied filters."""
return self._filters.copy()
[docs]
def get_filters_view(self) -> "FiltersView":
"""Get zero-copy filters view (public API).
Returns:
Read-only view of filters without copying
"""
return FiltersView(self._filters)
@property
def is_processed(self) -> bool:
"""Check if SQL has been processed (public API)."""
return self._processed_state is not Empty
[docs]
def get_processed_state(self) -> Any:
"""Get processed state (public API)."""
return self._processed_state
@property
def dialect(self) -> "str | None":
"""SQL dialect."""
return self._dialect
@property
def statement_expression(self) -> "exp.Expression | None":
"""Get parsed statement expression (public API).
Returns:
Parsed SQLGlot expression or None if not parsed
"""
if self._processed_state is not Empty:
return self._processed_state.parsed_expression
return self._raw_expression
@property
def is_many(self) -> bool:
"""Check if this is execute_many."""
return self._is_many
@property
def is_script(self) -> bool:
"""Check if this is script execution."""
return self._is_script
@property
def validation_errors(self) -> "list[str]":
"""Validation errors."""
if self._processed_state is Empty:
return []
return self._processed_state.validation_errors.copy()
@property
def has_errors(self) -> bool:
"""Check if there are validation errors."""
return len(self.validation_errors) > 0
[docs]
def returns_rows(self) -> bool:
"""Check if statement returns rows.
Returns:
True if the SQL statement returns result rows
"""
if self._processed_state is Empty:
self.compile()
if self._processed_state is Empty:
return False
profile = self._processed_state.operation_profile
if profile.returns_rows:
return True
op_type = self._processed_state.operation_type
if op_type in RETURNS_ROWS_OPERATIONS:
return True
if self._processed_state.parsed_expression:
expr = self._processed_state.parsed_expression
if isinstance(expr, (exp.Insert, exp.Update, exp.Delete)) and expr.args.get("returning"):
return True
return False
[docs]
def is_modifying_operation(self) -> bool:
"""Check if the SQL statement is a modifying operation.
Returns:
True if the operation modifies data (INSERT/UPDATE/DELETE)
"""
if self._processed_state is Empty:
return False
profile = self._processed_state.operation_profile
if profile.modifies_rows:
return True
op_type = self._processed_state.operation_type
if op_type in MODIFYING_OPERATIONS:
return True
if self._processed_state.parsed_expression:
return isinstance(self._processed_state.parsed_expression, (exp.Insert, exp.Update, exp.Delete, exp.Merge))
return False
[docs]
def compile(self) -> "tuple[str, Any]":
"""Compile SQL statement with parameters.
Returns:
Tuple of compiled SQL string and execution parameters
"""
if self._processed_state is not Empty:
if self._compiled_from_cache:
state = self._processed_state
if state.execution_parameters is None:
self._processed_state = Empty
self._compiled_from_cache = False
else:
can_reuse = (
not self._statement_config.parameter_config.needs_static_script_compilation
and self._can_reuse_cached_state(state)
)
if can_reuse:
return self._rebind_cached_parameters(state)
self._processed_state = Empty
self._compiled_from_cache = False
else:
return self._processed_state.compiled_sql, self._processed_state.execution_parameters
if self._processed_state is Empty:
try:
config = self._statement_config
raw_sql = self._raw_sql
params = self._named_parameters or self._positional_parameters
is_many = self._is_many
param_fingerprint = structural_fingerprint(params, is_many=is_many)
compiled_result = pipeline.compile_with_pipeline(
config, raw_sql, params, is_many=is_many, expression=self._raw_expression
)
self._processed_state = self._build_processed_state(
compiled_sql=compiled_result.compiled_sql,
execution_parameters=compiled_result.execution_parameters,
parsed_expression=compiled_result.expression,
operation_type=compiled_result.operation_type,
input_named_parameters=compiled_result.input_named_parameters,
applied_wrap_types=compiled_result.applied_wrap_types,
filter_hash=hash_filters(self._filters),
parameter_fingerprint=param_fingerprint,
parameter_casts=compiled_result.parameter_casts,
parameter_profile=compiled_result.parameter_profile,
operation_profile=compiled_result.operation_profile,
validation_errors=[],
is_many=self._is_many,
)
except sqlspec.exceptions.SQLSpecError:
raise
except Exception as e:
self._processed_state = self._handle_compile_failure(e)
return self._processed_state.compiled_sql, self._processed_state.execution_parameters
def _rebind_cached_parameters(self, state: "ProcessedState") -> "tuple[str, Any]":
params = self._named_parameters or self._positional_parameters
processor = ParameterProcessor(
converter=self._statement_config.parameter_converter,
validator=self._statement_config.parameter_validator,
cache_max_size=0,
validator_cache_max_size=0,
)
rebound_params = processor._transform_cached_parameters( # pyright: ignore[reportPrivateUsage]
params,
state.parameter_profile,
self._statement_config.parameter_config,
input_named_parameters=state.input_named_parameters,
is_many=self._is_many,
apply_wrap_types=state.applied_wrap_types,
)
compiled_sql = state.compiled_sql
output_transformer = self._statement_config.output_transformer
if output_transformer:
compiled_sql, rebound_params = output_transformer(compiled_sql, rebound_params)
self._processed_state = self._build_processed_state(
compiled_sql=compiled_sql,
execution_parameters=rebound_params,
parsed_expression=state.parsed_expression,
operation_type=state.operation_type,
input_named_parameters=state.input_named_parameters,
applied_wrap_types=state.applied_wrap_types,
filter_hash=state.filter_hash,
parameter_fingerprint=state.parameter_fingerprint,
parameter_casts=state.parameter_casts,
parameter_profile=state.parameter_profile,
operation_profile=state.operation_profile,
validation_errors=state.validation_errors.copy(),
is_many=state.is_many,
)
self._compiled_from_cache = False
return compiled_sql, rebound_params
def _can_reuse_cached_state(self, state: "ProcessedState") -> bool:
cached_fingerprint = state.parameter_fingerprint
if cached_fingerprint is None:
return False
if state.is_many != self._is_many:
return False
if state.filter_hash != hash_filters(self._filters):
return False
params = self._named_parameters or self._positional_parameters
return structural_fingerprint(params, is_many=self._is_many) == cached_fingerprint
[docs]
def as_script(self) -> "SQL":
"""Create copy marked for script execution.
Returns:
New SQL instance configured for script execution
"""
original_params = self._original_parameters
config = self._statement_config
is_many = self._is_many
statement_seed = self._raw_expression or self._raw_sql
new_sql = SQL(statement_seed, *original_params, statement_config=config, is_many=is_many)
new_sql._named_parameters.update(self._named_parameters)
new_sql._positional_parameters = self._positional_parameters.copy()
new_sql._filters = self._filters.copy()
new_sql._is_script = True
return new_sql
[docs]
def copy(
self, statement: "str | exp.Expression | None" = None, parameters: Any | None = None, **kwargs: Any
) -> "SQL":
"""Create copy with modifications.
Args:
statement: New SQL statement to use
parameters: New parameters to use
**kwargs: Additional modifications
Returns:
New SQL instance with modifications applied
"""
# FAST PATH: Only parameters are changing
if statement is None and not kwargs and parameters is not None and not isinstance(parameters, (str, bytes)):
new_sql = self._create_empty_copy()
new_sql._process_parameters(*(parameters if isinstance(parameters, tuple) else (parameters,)))
return new_sql
statement_expression = self._raw_expression if statement is None else statement
new_sql = SQL(
statement_expression or self._raw_sql,
*(parameters if parameters is not None else self._original_parameters),
statement_config=self._statement_config,
is_many=self._is_many,
**kwargs,
)
if parameters is None:
new_sql._named_parameters.update(self._named_parameters)
new_sql._positional_parameters = self._positional_parameters.copy()
new_sql._filters = self._filters.copy()
return new_sql
def _create_empty_copy(self) -> "SQL":
"""Create a shell copy with shared immutable state but empty mutable state."""
new_sql = get_sql_pool().acquire()
new_sql._raw_sql = self._raw_sql
new_sql._raw_expression = self._raw_expression
new_sql._statement_config = self._statement_config
new_sql._dialect = self._dialect
new_sql._is_many = self._is_many
new_sql._is_script = self._is_script
new_sql._original_parameters = ()
new_sql._pooled = True
# Reset mutable state
new_sql._compiled_from_cache = self._processed_state is not Empty
new_sql._processed_state = self._processed_state if self._processed_state is not Empty else Empty
new_sql._hash = None
new_sql._filters = self._filters.copy()
new_sql._named_parameters = {}
new_sql._positional_parameters = []
new_sql._sql_param_counters = self._sql_param_counters.copy()
return new_sql
def _build_processed_state(
self,
*,
compiled_sql: str,
execution_parameters: Any,
parsed_expression: "exp.Expression | None",
operation_type: "OperationType",
input_named_parameters: "tuple[str, ...] | None",
applied_wrap_types: bool,
filter_hash: int,
parameter_fingerprint: str | None,
parameter_casts: "dict[int, str] | None",
parameter_profile: "ParameterProfile | None",
operation_profile: "OperationProfile | None",
validation_errors: "list[str] | None",
is_many: bool,
) -> "ProcessedState":
state = get_processed_state_pool().acquire()
ProcessedState.__init__(
state,
compiled_sql=compiled_sql,
execution_parameters=execution_parameters,
parsed_expression=parsed_expression,
operation_type=operation_type,
input_named_parameters=input_named_parameters,
applied_wrap_types=applied_wrap_types,
filter_hash=filter_hash,
parameter_fingerprint=parameter_fingerprint,
parameter_casts=parameter_casts,
validation_errors=validation_errors,
parameter_profile=parameter_profile,
operation_profile=operation_profile,
is_many=is_many,
)
return state
def _handle_compile_failure(self, error: Exception) -> ProcessedState:
import traceback
traceback.print_exc()
logger.debug("Processing failed, using fallback: %s", error)
params = self._named_parameters or self._positional_parameters
return self._build_processed_state(
compiled_sql=self._raw_sql,
execution_parameters=self._named_parameters or self._positional_parameters,
parsed_expression=None,
operation_type="COMMAND",
input_named_parameters=(),
applied_wrap_types=False,
filter_hash=hash_filters(self._filters),
parameter_fingerprint=structural_fingerprint(params, is_many=self._is_many),
parameter_casts={},
parameter_profile=ParameterProfile.empty(),
operation_profile=OperationProfile.empty(),
validation_errors=[str(error)],
is_many=self._is_many,
)
# ==========================================================================
# Parameter Generation Helpers
# ==========================================================================
def _generate_sql_param_name(self, base_name: str) -> str:
"""Generate unique parameter name with _sqlspec_ prefix.
Uses _sqlspec_ prefix to avoid collision with user-provided parameters.
Auto-generated parameters are namespaced to prevent conflicts.
Args:
base_name: The base name for the parameter (e.g., column name)
Returns:
A unique parameter name that doesn't exist in current parameters
"""
prefixed_base = f"_sqlspec_{base_name}"
current_index = self._sql_param_counters.get(prefixed_base, 0)
if prefixed_base not in self._named_parameters:
self._sql_param_counters[prefixed_base] = current_index
return prefixed_base
next_index = current_index + 1
candidate = f"{prefixed_base}_{next_index}"
while candidate in self._named_parameters:
next_index += 1
if next_index > _MAX_PARAM_COLLISION_ATTEMPTS:
return f"{prefixed_base}_{uuid.uuid4().hex[:8]}"
candidate = f"{prefixed_base}_{next_index}"
self._sql_param_counters[prefixed_base] = next_index
return candidate
def _get_or_parse_expression(self) -> exp.Expression:
"""Get the current expression or parse the raw SQL.
Prefers cached parsed expression over re-parsing raw SQL.
Returns:
The SQLGlot expression for this statement
"""
# First check processed state for parsed expression
if self._processed_state is not Empty and self._processed_state.parsed_expression is not None:
return self._processed_state.parsed_expression.copy()
# Then check statement_expression (from compilation)
if self.statement_expression is not None:
return self.statement_expression.copy()
# Then check raw_expression (from construction)
if self._raw_expression is not None:
return self._raw_expression.copy()
# Fall back to parsing if enabled
if not self._statement_config.enable_parsing:
return exp.Select().from_(f"({self._raw_sql})")
try:
return sqlglot.parse_one(self._raw_sql, dialect=self._dialect)
except ParseError:
return exp.Select().from_(f"({self._raw_sql})")
def _create_modified_copy_with_expression(self, new_expr: exp.Expression) -> "SQL":
"""Create a new SQL instance with a modified expression.
Args:
new_expr: The new SQLGlot expression
Returns:
New SQL instance with the expression and copied state
"""
new_sql = SQL(
new_expr, *self._original_parameters, statement_config=self._statement_config, is_many=self._is_many
)
new_sql._named_parameters.update(self._named_parameters)
new_sql._positional_parameters = self._positional_parameters.copy()
new_sql._filters = self._filters.copy()
new_sql._sql_param_counters = self._sql_param_counters.copy()
return new_sql
[docs]
def add_named_parameter(self, name: str, value: Any) -> "SQL":
"""Add a named parameter and return a new SQL instance.
Args:
name: Parameter name
value: Parameter value
Returns:
New SQL instance with the added parameter
"""
original_params = self._original_parameters
config = self._statement_config
is_many = self._is_many
statement_seed = self._raw_expression or self._raw_sql
new_sql = SQL(statement_seed, *original_params, statement_config=config, is_many=is_many)
new_sql._named_parameters.update(self._named_parameters)
new_sql._named_parameters[name] = value
new_sql._positional_parameters = self._positional_parameters.copy()
new_sql._filters = self._filters.copy()
return new_sql
[docs]
def where(self, condition: "str | exp.Expression") -> "SQL":
"""Add WHERE condition to the SQL statement.
Args:
condition: WHERE condition as string or SQLGlot expression
Returns:
New SQL instance with the WHERE condition applied
"""
if self.statement_expression is not None:
current_expr = self.statement_expression.copy()
elif not self._statement_config.enable_parsing:
current_expr = exp.Select().from_(f"({self._raw_sql})")
else:
try:
current_expr = sqlglot.parse_one(self._raw_sql, dialect=self._dialect)
except ParseError:
subquery_sql = f"SELECT * FROM ({self._raw_sql}) AS subquery"
current_expr = sqlglot.parse_one(subquery_sql, dialect=self._dialect)
condition_expr: exp.Expression
if isinstance(condition, str):
if not self._statement_config.enable_parsing:
condition_expr = exp.Condition(this=condition)
else:
try:
condition_expr = sqlglot.parse_one(condition, dialect=self._dialect, into=exp.Condition)
except ParseError:
condition_expr = exp.Condition(this=condition)
else:
condition_expr = condition
if isinstance(current_expr, exp.Select) or supports_where(current_expr):
new_expr = current_expr.where(condition_expr, copy=False)
else:
new_expr = exp.Select().from_(current_expr).where(condition_expr, copy=False)
original_params = self._original_parameters
config = self._statement_config
is_many = self._is_many
new_sql = SQL(new_expr, *original_params, statement_config=config, is_many=is_many)
new_sql._named_parameters.update(self._named_parameters)
new_sql._positional_parameters = self._positional_parameters.copy()
new_sql._filters = self._filters.copy()
return new_sql
# ==========================================================================
# Parameterized WHERE Methods (using shared utilities)
# ==========================================================================
[docs]
def where_eq(self, column: "str | exp.Column", value: Any) -> "SQL":
"""Add WHERE column = value condition.
Args:
column: Column name or expression
value: Value to compare against
Returns:
New SQL instance with WHERE condition applied
"""
expression = self._get_or_parse_expression()
col_name = extract_column_name(column)
param_name = self._generate_sql_param_name(col_name)
condition = create_condition(column, param_name, expr_eq)
new_expr = safe_modify_with_cte(expression, lambda e: apply_where(e, condition))
new_sql = self._create_modified_copy_with_expression(new_expr)
new_sql._named_parameters[param_name] = value
return new_sql
[docs]
def where_neq(self, column: "str | exp.Column", value: Any) -> "SQL":
"""Add WHERE column != value condition.
Args:
column: Column name or expression
value: Value to compare against
Returns:
New SQL instance with WHERE condition applied
"""
expression = self._get_or_parse_expression()
col_name = extract_column_name(column)
param_name = self._generate_sql_param_name(col_name)
condition = create_condition(column, param_name, expr_neq)
new_expr = safe_modify_with_cte(expression, lambda e: apply_where(e, condition))
new_sql = self._create_modified_copy_with_expression(new_expr)
new_sql._named_parameters[param_name] = value
return new_sql
[docs]
def where_lt(self, column: "str | exp.Column", value: Any) -> "SQL":
"""Add WHERE column < value condition.
Args:
column: Column name or expression
value: Value to compare against
Returns:
New SQL instance with WHERE condition applied
"""
expression = self._get_or_parse_expression()
col_name = extract_column_name(column)
param_name = self._generate_sql_param_name(col_name)
condition = create_condition(column, param_name, expr_lt)
new_expr = safe_modify_with_cte(expression, lambda e: apply_where(e, condition))
new_sql = self._create_modified_copy_with_expression(new_expr)
new_sql._named_parameters[param_name] = value
return new_sql
[docs]
def where_lte(self, column: "str | exp.Column", value: Any) -> "SQL":
"""Add WHERE column <= value condition.
Args:
column: Column name or expression
value: Value to compare against
Returns:
New SQL instance with WHERE condition applied
"""
expression = self._get_or_parse_expression()
col_name = extract_column_name(column)
param_name = self._generate_sql_param_name(col_name)
condition = create_condition(column, param_name, expr_lte)
new_expr = safe_modify_with_cte(expression, lambda e: apply_where(e, condition))
new_sql = self._create_modified_copy_with_expression(new_expr)
new_sql._named_parameters[param_name] = value
return new_sql
[docs]
def where_gt(self, column: "str | exp.Column", value: Any) -> "SQL":
"""Add WHERE column > value condition.
Args:
column: Column name or expression
value: Value to compare against
Returns:
New SQL instance with WHERE condition applied
"""
expression = self._get_or_parse_expression()
col_name = extract_column_name(column)
param_name = self._generate_sql_param_name(col_name)
condition = create_condition(column, param_name, expr_gt)
new_expr = safe_modify_with_cte(expression, lambda e: apply_where(e, condition))
new_sql = self._create_modified_copy_with_expression(new_expr)
new_sql._named_parameters[param_name] = value
return new_sql
[docs]
def where_gte(self, column: "str | exp.Column", value: Any) -> "SQL":
"""Add WHERE column >= value condition.
Args:
column: Column name or expression
value: Value to compare against
Returns:
New SQL instance with WHERE condition applied
"""
expression = self._get_or_parse_expression()
col_name = extract_column_name(column)
param_name = self._generate_sql_param_name(col_name)
condition = create_condition(column, param_name, expr_gte)
new_expr = safe_modify_with_cte(expression, lambda e: apply_where(e, condition))
new_sql = self._create_modified_copy_with_expression(new_expr)
new_sql._named_parameters[param_name] = value
return new_sql
[docs]
def where_like(self, column: "str | exp.Column", pattern: str) -> "SQL":
"""Add WHERE column LIKE pattern condition.
Args:
column: Column name or expression
pattern: LIKE pattern (e.g., '%search%')
Returns:
New SQL instance with WHERE condition applied
"""
expression = self._get_or_parse_expression()
col_name = extract_column_name(column)
param_name = self._generate_sql_param_name(col_name)
condition = create_condition(column, param_name, expr_like)
new_expr = safe_modify_with_cte(expression, lambda e: apply_where(e, condition))
new_sql = self._create_modified_copy_with_expression(new_expr)
new_sql._named_parameters[param_name] = pattern
return new_sql
[docs]
def where_ilike(self, column: "str | exp.Column", pattern: str) -> "SQL":
"""Add WHERE column ILIKE pattern condition (case-insensitive).
Args:
column: Column name or expression
pattern: ILIKE pattern (e.g., '%search%')
Returns:
New SQL instance with WHERE condition applied
"""
expression = self._get_or_parse_expression()
col_name = extract_column_name(column)
param_name = self._generate_sql_param_name(col_name)
condition = create_condition(column, param_name, expr_ilike)
new_expr = safe_modify_with_cte(expression, lambda e: apply_where(e, condition))
new_sql = self._create_modified_copy_with_expression(new_expr)
new_sql._named_parameters[param_name] = pattern
return new_sql
[docs]
def where_is_null(self, column: "str | exp.Column") -> "SQL":
"""Add WHERE column IS NULL condition.
Args:
column: Column name or expression
Returns:
New SQL instance with WHERE condition applied
"""
expression = self._get_or_parse_expression()
condition = create_condition(column, "_unused", expr_is_null)
new_expr = safe_modify_with_cte(expression, lambda e: apply_where(e, condition))
return self._create_modified_copy_with_expression(new_expr)
[docs]
def where_is_not_null(self, column: "str | exp.Column") -> "SQL":
"""Add WHERE column IS NOT NULL condition.
Args:
column: Column name or expression
Returns:
New SQL instance with WHERE condition applied
"""
expression = self._get_or_parse_expression()
condition = create_condition(column, "_unused", expr_is_not_null)
new_expr = safe_modify_with_cte(expression, lambda e: apply_where(e, condition))
return self._create_modified_copy_with_expression(new_expr)
[docs]
def where_in(self, column: "str | exp.Column", values: "Sequence[Any]") -> "SQL":
"""Add WHERE column IN (values) condition.
Args:
column: Column name or expression
values: Sequence of values for IN clause
Returns:
New SQL instance with WHERE condition applied
"""
if not values:
expression = self._get_or_parse_expression()
false_condition = exp.EQ(this=exp.Literal.number(1), expression=exp.Literal.number(0))
new_expr = safe_modify_with_cte(expression, lambda e: apply_where(e, false_condition))
return self._create_modified_copy_with_expression(new_expr)
expression = self._get_or_parse_expression()
col_name = extract_column_name(column)
param_names: list[str] = []
param_values: dict[str, Any] = {}
for i, val in enumerate(values):
param_name = self._generate_sql_param_name(f"{col_name}_in_{i}")
param_names.append(param_name)
param_values[param_name] = val
condition = create_in_condition(column, param_names)
new_expr = safe_modify_with_cte(expression, lambda e: apply_where(e, condition))
new_sql = self._create_modified_copy_with_expression(new_expr)
new_sql._named_parameters.update(param_values)
return new_sql
[docs]
def where_not_in(self, column: "str | exp.Column", values: "Sequence[Any]") -> "SQL":
"""Add WHERE column NOT IN (values) condition.
Args:
column: Column name or expression
values: Sequence of values for NOT IN clause
Returns:
New SQL instance with WHERE condition applied
"""
if not values:
return self
expression = self._get_or_parse_expression()
col_name = extract_column_name(column)
param_names: list[str] = []
param_values: dict[str, Any] = {}
for i, val in enumerate(values):
param_name = self._generate_sql_param_name(f"{col_name}_not_in_{i}")
param_names.append(param_name)
param_values[param_name] = val
condition = create_not_in_condition(column, param_names)
new_expr = safe_modify_with_cte(expression, lambda e: apply_where(e, condition))
new_sql = self._create_modified_copy_with_expression(new_expr)
new_sql._named_parameters.update(param_values)
return new_sql
[docs]
def where_between(self, column: "str | exp.Column", low: Any, high: Any) -> "SQL":
"""Add WHERE column BETWEEN low AND high condition.
Args:
column: Column name or expression
low: Lower bound value
high: Upper bound value
Returns:
New SQL instance with WHERE condition applied
"""
expression = self._get_or_parse_expression()
col_name = extract_column_name(column)
low_param = self._generate_sql_param_name(f"{col_name}_low")
high_param = self._generate_sql_param_name(f"{col_name}_high")
condition = create_between_condition(column, low_param, high_param)
new_expr = safe_modify_with_cte(expression, lambda e: apply_where(e, condition))
new_sql = self._create_modified_copy_with_expression(new_expr)
new_sql._named_parameters[low_param] = low
new_sql._named_parameters[high_param] = high
return new_sql
[docs]
def order_by(self, *items: "str | exp.Expression", desc: bool = False) -> "SQL":
"""Add ORDER BY clause to the SQL statement.
Args:
*items: ORDER BY expressions as strings or SQLGlot expressions
desc: Apply descending order to each item
Returns:
New SQL instance with ORDER BY applied
"""
if not items:
return self
if self.statement_expression is not None:
current_expr = self.statement_expression.copy()
elif not self._statement_config.enable_parsing:
current_expr = exp.Select().from_(f"({self._raw_sql})")
else:
try:
current_expr = sqlglot.parse_one(self._raw_sql, dialect=self._dialect)
except ParseError:
current_expr = exp.Select().from_(f"({self._raw_sql})")
def parse_order_item(order_item: str) -> exp.Expression:
normalized = order_item.strip()
if not normalized:
return exp.column(order_item)
if self._statement_config.enable_parsing:
try:
parsed = sqlglot.parse_one(normalized, dialect=self._dialect, into=exp.Ordered)
except ParseError:
parsed = None
if parsed is not None:
return parsed
parts = normalized.rsplit(None, 1)
if len(parts) == _ORDER_PARTS_COUNT and parts[1].lower() in {"asc", "desc"}:
base_expr = exp.column(parts[0]) if parts[0] else exp.column(normalized)
return base_expr.desc() if parts[1].lower() == "desc" else base_expr.asc()
return exp.column(normalized)
new_expr = current_expr
for item in items:
if isinstance(item, str):
order_expr = parse_order_item(item)
if desc and not isinstance(order_expr, exp.Ordered):
order_expr = order_expr.desc()
else:
order_expr = item.desc() if desc and not isinstance(item, exp.Ordered) else item
if isinstance(new_expr, exp.Select):
new_expr = new_expr.order_by(order_expr, copy=False)
else:
new_expr = exp.Select().from_(new_expr).order_by(order_expr)
original_params = self._original_parameters
config = self._statement_config
is_many = self._is_many
new_sql = SQL(new_expr, *original_params, statement_config=config, is_many=is_many)
new_sql._named_parameters.update(self._named_parameters)
new_sql._positional_parameters = self._positional_parameters.copy()
new_sql._filters = self._filters.copy()
return new_sql
# ==========================================================================
# Pagination Methods
# ==========================================================================
[docs]
def limit(self, value: int) -> "SQL":
"""Add LIMIT clause to the SQL statement.
Args:
value: Maximum number of rows to return
Returns:
New SQL instance with LIMIT applied
Raises:
SQLSpecError: If statement is not a SELECT
"""
expression = self._get_or_parse_expression()
new_expr = safe_modify_with_cte(expression, lambda e: apply_limit(e, value))
return self._create_modified_copy_with_expression(new_expr)
[docs]
def offset(self, value: int) -> "SQL":
"""Add OFFSET clause to the SQL statement.
Args:
value: Number of rows to skip
Returns:
New SQL instance with OFFSET applied
Raises:
SQLSpecError: If statement is not a SELECT
"""
expression = self._get_or_parse_expression()
new_expr = safe_modify_with_cte(expression, lambda e: apply_offset(e, value))
return self._create_modified_copy_with_expression(new_expr)
[docs]
def paginate(self, page: int, page_size: int) -> "SQL":
"""Add LIMIT and OFFSET for pagination.
Args:
page: Page number (1-indexed)
page_size: Number of items per page
Returns:
New SQL instance with LIMIT and OFFSET applied
Example:
# Get page 3 with 20 items per page
stmt = SQL("SELECT * FROM users").paginate(3, 20)
# Results in: SELECT * FROM users LIMIT 20 OFFSET 40
"""
if page < 1:
msg = "paginate page must be >= 1"
raise sqlspec.exceptions.SQLSpecError(msg)
if page_size < 1:
msg = "paginate page_size must be >= 1"
raise sqlspec.exceptions.SQLSpecError(msg)
offset_value = (page - 1) * page_size
return self.limit(page_size).offset(offset_value)
# ==========================================================================
# Column Projection Methods
# ==========================================================================
[docs]
def select_only(self, *columns: "str | exp.Expression", prune_columns: bool | None = None) -> "SQL":
"""Replace SELECT columns with only the specified columns.
This is useful for narrowing down the columns returned by a query
without modifying the FROM clause or WHERE conditions.
Args:
*columns: Column names or expressions to select
prune_columns: Remove unused columns from subqueries. When True,
applies SQLGlot's qualify and pushdown_projections optimizations.
Defaults to the config's enable_column_pruning setting.
Returns:
New SQL instance with only the specified columns
Example:
stmt = SQL("SELECT * FROM users WHERE active = 1")
narrow = stmt.select_only("id", "name", "email")
# Results in: SELECT id, name, email FROM users WHERE active = 1
# With column pruning on a subquery:
stmt = SQL("SELECT * FROM (SELECT id, name, email, created_at FROM users) AS u")
narrow = stmt.select_only("id", "name", prune_columns=True)
# Results in: SELECT id, name FROM (SELECT id, name FROM users) AS u
"""
if not columns:
return self
expression = self._get_or_parse_expression()
new_expr = safe_modify_with_cte(expression, lambda e: apply_select_only(e, columns))
# Determine whether to apply column pruning
should_prune = prune_columns if prune_columns is not None else self._statement_config.enable_column_pruning
if should_prune:
if self._statement_config.enable_caching:
# Use stable content-based key instead of object id (which can be reused after GC)
expr_sql = new_expr.sql(dialect=self._dialect)
cache_key = f"prune:{hashlib.blake2b(expr_sql.encode(), digest_size=8).hexdigest()}"
else:
cache_key = None
new_expr = apply_column_pruning(new_expr, dialect=self._dialect, cache_key=cache_key)
return self._create_modified_copy_with_expression(new_expr)
[docs]
def explain(self, analyze: bool = False, verbose: bool = False, format: "str | None" = None) -> "SQL":
"""Create an EXPLAIN statement for this SQL.
Wraps the current SQL statement in an EXPLAIN clause with
dialect-aware syntax generation.
Args:
analyze: Execute the statement and show actual runtime statistics
verbose: Show additional information
format: Output format (TEXT, JSON, XML, YAML, TREE, TRADITIONAL)
Returns:
New SQL instance containing the EXPLAIN statement
Examples:
Basic EXPLAIN:
stmt = SQL("SELECT * FROM users")
explain_stmt = stmt.explain()
With options:
explain_stmt = stmt.explain(analyze=True, format="json")
"""
from sqlspec.builder import Explain
fmt = None
if format is not None:
fmt = ExplainFormat(format.lower())
options = ExplainOptions(analyze=analyze, verbose=verbose, format=fmt)
explain_builder = Explain(self, dialect=self._dialect, options=options)
return explain_builder.build()
[docs]
def builder(self, dialect: "DialectType | None" = None) -> "QueryBuilder":
"""Create a query builder seeded from this SQL statement.
Args:
dialect: Optional SQL dialect override for parsing and rendering.
Returns:
QueryBuilder instance initialized with the parsed statement.
Raises:
SQLBuilderError: If the statement cannot be parsed.
Notes:
Statements outside the DML set return an ExpressionBuilder without
DML-specific helper methods.
"""
if self._is_many:
msg = "QueryBuilder does not support execute_many SQL statements."
raise sqlspec.exceptions.SQLBuilderError(msg)
from sqlspec.builder import Delete, ExpressionBuilder, Insert, Merge, Select, Update
builder_dialect = dialect or self._dialect
converter = self._statement_config.parameter_converter or ParameterConverter(
self._statement_config.parameter_validator
)
raw_params = self.parameters
converted_sql, converted_params = converter.convert_placeholder_style(
self._raw_sql, raw_params, ParameterStyle.NAMED_COLON, is_many=False
)
if self._raw_expression is not None and converted_sql == self._raw_sql and (builder_dialect == self._dialect):
expression = self._raw_expression.copy()
else:
try:
expression = sqlglot.parse_one(converted_sql, dialect=builder_dialect)
except ParseError as exc:
msg = f"Failed to parse SQL for builder: {exc}"
raise sqlspec.exceptions.SQLBuilderError(msg) from exc
base_expression = expression
ctes: list[exp.CTE] | None = None
if isinstance(expression, exp.With):
if expression.this is None:
msg = "WITH expression does not include a base statement."
raise sqlspec.exceptions.SQLBuilderError(msg)
base_expression = expression.this
ctes = list(expression.expressions)
builder: QueryBuilder
if isinstance(base_expression, (exp.Select, exp.Union, exp.Except, exp.Intersect, exp.Values)):
builder = Select(dialect=builder_dialect)
builder.set_expression(base_expression.copy())
elif isinstance(base_expression, exp.Insert):
builder = Insert(dialect=builder_dialect)
builder.set_expression(base_expression.copy())
elif isinstance(base_expression, exp.Update):
builder = Update(dialect=builder_dialect)
builder.set_expression(base_expression.copy())
elif isinstance(base_expression, exp.Delete):
builder = Delete(dialect=builder_dialect)
builder.set_expression(base_expression.copy())
elif isinstance(base_expression, exp.Merge):
builder = Merge(dialect=builder_dialect)
builder.set_expression(base_expression.copy())
else:
builder = ExpressionBuilder(base_expression.copy(), dialect=builder_dialect)
if ctes:
builder.load_ctes(ctes)
if isinstance(converted_params, Mapping):
builder.load_parameters(converted_params)
return builder
if (
converted_params
and isinstance(converted_params, Sequence)
and not isinstance(converted_params, (str, bytes, bytearray))
):
param_info = converter.validator.extract_parameters(converted_sql)
param_map: dict[str, Any] = {}
for index, param in enumerate(param_info):
if index >= len(converted_params):
break
param_name = param.name or f"param_{param.ordinal}"
param_map[param_name] = converted_params[index]
builder.load_parameters(param_map)
return builder
[docs]
def __hash__(self) -> int:
"""Hash value computation."""
if self._hash is None:
positional_tuple = tuple(self._positional_parameters)
named_tuple = tuple(sorted(self._named_parameters.items())) if self._named_parameters else ()
raw_sql = self._raw_sql
is_many = self._is_many
is_script = self._is_script
self._hash = hash((raw_sql, positional_tuple, named_tuple, is_many, is_script))
return self._hash
[docs]
def __eq__(self, other: object) -> bool:
"""Equality comparison."""
if not isinstance(other, SQL):
return False
return (
self._raw_sql == other._raw_sql
and self._positional_parameters == other._positional_parameters
and self._named_parameters == other._named_parameters
and self._is_many == other._is_many
and self._is_script == other._is_script
)
[docs]
def __repr__(self) -> str:
"""String representation."""
params_parts = []
if self._positional_parameters:
params_parts.append(f"params={self._positional_parameters}")
if self._named_parameters:
params_parts.append(f"named_params={self._named_parameters}")
params_str = f", {', '.join(params_parts)}" if params_parts else ""
flags = []
if self._is_many:
flags.append("is_many")
if self._is_script:
flags.append("is_script")
flags_str = f", {', '.join(flags)}" if flags else ""
return f"SQL({self._raw_sql!r}{params_str}{flags_str})"
[docs]
@mypyc_attr(allow_interpreted_subclasses=False)
class StatementConfig:
"""Configuration for SQL statement processing.
Controls SQL parsing, validation, transformations, parameter handling,
and other processing options for SQL statements.
"""
__slots__ = SQL_CONFIG_SLOTS
[docs]
def __init__(
self,
parameter_config: "ParameterStyleConfig | None" = None,
enable_parsing: bool = True,
enable_validation: bool = True,
enable_transformations: bool = True,
enable_analysis: bool = False,
enable_expression_simplification: bool = False,
enable_column_pruning: bool = False,
enable_parameter_type_wrapping: bool = True,
enable_caching: bool = True,
parameter_converter: "ParameterConverter | None" = None,
parameter_validator: "ParameterValidator | None" = None,
dialect: "DialectType | None" = None,
execution_mode: "str | None" = None,
execution_args: "dict[str, Any] | None" = None,
output_transformer: "Callable[[str, Any], tuple[str, Any]] | None" = None,
statement_transformers: "Sequence[Callable[[exp.Expression, Any], tuple[exp.Expression, Any]]] | None" = None,
) -> None:
"""Initialize StatementConfig.
Args:
parameter_config: Parameter style configuration
enable_parsing: Enable SQL parsing
enable_validation: Run SQL validators
enable_transformations: Apply SQL transformers
enable_analysis: Run SQL analyzers
enable_expression_simplification: Apply expression simplification
enable_column_pruning: Remove unused columns from subqueries during select_only
enable_parameter_type_wrapping: Wrap parameters with type information
enable_caching: Cache processed SQL statements
parameter_converter: Handles parameter style conversions
parameter_validator: Validates parameter usage and styles
dialect: SQL dialect
execution_mode: Special execution mode
execution_args: Arguments for special execution modes
output_transformer: Optional output transformation function
statement_transformers: Optional AST transformers executed during compilation
"""
self.enable_parsing = enable_parsing
self.enable_validation = enable_validation
self.enable_transformations = enable_transformations
self.enable_analysis = enable_analysis
self.enable_expression_simplification = enable_expression_simplification
self.enable_column_pruning = enable_column_pruning
self.enable_parameter_type_wrapping = enable_parameter_type_wrapping
self.enable_caching = enable_caching
if parameter_converter is None:
if parameter_validator is None:
parameter_validator = ParameterValidator()
self.parameter_converter = ParameterConverter(parameter_validator)
else:
self.parameter_converter = parameter_converter
if parameter_validator is None:
self.parameter_validator = self.parameter_converter.validator
else:
self.parameter_validator = parameter_validator
self.parameter_converter.validator = parameter_validator
self.parameter_config = parameter_config or ParameterStyleConfig(
default_parameter_style=ParameterStyle.QMARK, supported_parameter_styles={ParameterStyle.QMARK}
)
self.dialect = dialect
self.execution_mode = execution_mode
self.execution_args = execution_args
self.output_transformer = output_transformer
if statement_transformers:
self.statement_transformers = tuple(statement_transformers)
else:
self.statement_transformers = ()
self._fingerprint_cache: str | None = None
self._hash_cache: int | None = None
self._is_frozen = False
[docs]
def freeze(self) -> None:
"""Mark the configuration as immutable to enable caching."""
self._is_frozen = True
[docs]
def replace(self, **kwargs: Any) -> "StatementConfig":
"""Immutable update pattern.
Args:
**kwargs: Attributes to update
Returns:
New StatementConfig instance with updated attributes
"""
for key in kwargs:
if key not in SQL_CONFIG_SLOTS:
msg = f"{key!r} is not a field in {type(self).__name__}"
raise TypeError(msg)
current_kwargs: dict[str, Any] = {
"parameter_config": self.parameter_config,
"enable_parsing": self.enable_parsing,
"enable_validation": self.enable_validation,
"enable_transformations": self.enable_transformations,
"enable_analysis": self.enable_analysis,
"enable_expression_simplification": self.enable_expression_simplification,
"enable_column_pruning": self.enable_column_pruning,
"enable_parameter_type_wrapping": self.enable_parameter_type_wrapping,
"enable_caching": self.enable_caching,
"parameter_converter": self.parameter_converter,
"parameter_validator": self.parameter_validator,
"dialect": self.dialect,
"execution_mode": self.execution_mode,
"execution_args": self.execution_args,
"output_transformer": self.output_transformer,
"statement_transformers": self.statement_transformers,
}
current_kwargs.update(kwargs)
return type(self)(**current_kwargs)
[docs]
def __hash__(self) -> int:
"""Hash based on configuration settings."""
if self._hash_cache is None:
self._hash_cache = hash((
self.enable_parsing,
self.enable_validation,
self.enable_transformations,
self.enable_analysis,
self.enable_expression_simplification,
self.enable_column_pruning,
self.enable_parameter_type_wrapping,
self.enable_caching,
str(self.dialect),
self.parameter_config.hash(),
self.execution_mode,
self.output_transformer,
self.statement_transformers,
))
return self._hash_cache
[docs]
def __repr__(self) -> str:
"""String representation of the StatementConfig instance."""
field_strs = [
f"parameter_config={self.parameter_config!r}",
f"enable_parsing={self.enable_parsing!r}",
f"enable_validation={self.enable_validation!r}",
f"enable_transformations={self.enable_transformations!r}",
f"enable_analysis={self.enable_analysis!r}",
f"enable_expression_simplification={self.enable_expression_simplification!r}",
f"enable_column_pruning={self.enable_column_pruning!r}",
f"enable_parameter_type_wrapping={self.enable_parameter_type_wrapping!r}",
f"enable_caching={self.enable_caching!r}",
f"parameter_converter={self.parameter_converter!r}",
f"parameter_validator={self.parameter_validator!r}",
f"dialect={self.dialect!r}",
f"execution_mode={self.execution_mode!r}",
f"execution_args={self.execution_args!r}",
f"output_transformer={self.output_transformer!r}",
f"statement_transformers={self.statement_transformers!r}",
]
return f"{self.__class__.__name__}({', '.join(field_strs)})"
[docs]
def __eq__(self, other: object) -> bool:
"""Equality comparison."""
if not isinstance(other, type(self)):
return False
if not self._compare_parameter_configs(self.parameter_config, other.parameter_config):
return False
return (
self.enable_parsing == other.enable_parsing
and self.enable_validation == other.enable_validation
and self.enable_transformations == other.enable_transformations
and self.enable_analysis == other.enable_analysis
and self.enable_expression_simplification == other.enable_expression_simplification
and self.enable_column_pruning == other.enable_column_pruning
and self.enable_parameter_type_wrapping == other.enable_parameter_type_wrapping
and self.enable_caching == other.enable_caching
and self.dialect == other.dialect
and self.execution_mode == other.execution_mode
and self.execution_args == other.execution_args
and self.output_transformer == other.output_transformer
and self.statement_transformers == other.statement_transformers
)
def _compare_parameter_configs(self, config1: Any, config2: Any) -> bool:
"""Compare parameter configs."""
return bool(
config1.default_parameter_style == config2.default_parameter_style
and config1.supported_parameter_styles == config2.supported_parameter_styles
and config1.supported_execution_parameter_styles == config2.supported_execution_parameter_styles
)
_DEFAULT_CONFIG: "StatementConfig | None" = None
def get_default_config() -> StatementConfig:
"""Get default statement configuration.
Returns:
Cached StatementConfig singleton with default settings.
"""
global _DEFAULT_CONFIG
if _DEFAULT_CONFIG is None:
_DEFAULT_CONFIG = StatementConfig()
_DEFAULT_CONFIG.freeze()
return _DEFAULT_CONFIG
def get_default_parameter_config() -> ParameterStyleConfig:
"""Get default parameter configuration.
Returns:
ParameterStyleConfig with QMARK style as default
"""
return ParameterStyleConfig(
default_parameter_style=ParameterStyle.QMARK, supported_parameter_styles={ParameterStyle.QMARK}
)
Statement: TypeAlias = str | exp.Expression | SQL