diff --git a/deepnote_toolkit/runtime_initialization.py b/deepnote_toolkit/runtime_initialization.py index bfafaf5..04a650c 100644 --- a/deepnote_toolkit/runtime_initialization.py +++ b/deepnote_toolkit/runtime_initialization.py @@ -15,6 +15,7 @@ from .set_integrations_env import set_integration_env from .set_notebook_path import set_notebook_path from .sql.spark_sql_magic import SparkSql +from .sql.sql_utils import configure_sqlparse_limits def init_deepnote_runtime(): @@ -51,6 +52,13 @@ def init_deepnote_runtime(): except Exception as e: # pylint: disable=broad-exception-caught logger.error("Failed to add output middleware with a error: %s", e) + # Disable sqlparse grouping limits for large analytical queries + try: + logger.debug("Configuring sqlparse limits.") + configure_sqlparse_limits() + except Exception as e: # pylint: disable=broad-exception-caught + logger.error("Failed to configure sqlparse limits with error: %s", e) + # Set up psycopg2 to make long-running queries interruptible by SIGINT (interrupt kernel) try: logger.debug("Setting psycopg2.") diff --git a/deepnote_toolkit/sql/sql_utils.py b/deepnote_toolkit/sql/sql_utils.py index d5e24f8..8b66169 100644 --- a/deepnote_toolkit/sql/sql_utils.py +++ b/deepnote_toolkit/sql/sql_utils.py @@ -1,3 +1,5 @@ +from typing import Optional + import sqlparse @@ -10,3 +12,29 @@ def is_single_select_query(sql_string): # Check if the query is a SELECT statement return parsed_queries[0].get_type() == "SELECT" + + +def configure_sqlparse_limits( + max_grouping_tokens: Optional[int] = None, + max_grouping_depth: Optional[int] = None, +) -> None: + """Disable or adjust sqlparse's grouping limits for large analytical queries. + + sqlparse v0.5.4 started capping token count at 10,000 by default. + Since the toolkit runtime is isolated and users write their own queries, + we disable limits by default. + + See: https://github.com/andialbrecht/sqlparse/blob/0.5.4/docs/source/api.rst#security-and-performance-considerations + """ + try: + import sqlparse.engine.grouping + + sqlparse.engine.grouping.MAX_GROUPING_TOKENS = max_grouping_tokens + sqlparse.engine.grouping.MAX_GROUPING_DEPTH = max_grouping_depth + except (ImportError, AttributeError): + pass + + +def reset_sqlparse_limits() -> None: + """Restore sqlparse grouping limits to their built-in defaults.""" + configure_sqlparse_limits(max_grouping_tokens=10_000, max_grouping_depth=100) diff --git a/tests/unit/test_sql_utils.py b/tests/unit/test_sql_utils.py new file mode 100644 index 0000000..bb9ce91 --- /dev/null +++ b/tests/unit/test_sql_utils.py @@ -0,0 +1,52 @@ +import pytest +import sqlparse.engine.grouping +from sqlparse.exceptions import SQLParseError + +from deepnote_toolkit.sql.sql_utils import ( + configure_sqlparse_limits, + is_single_select_query, + reset_sqlparse_limits, +) + + +class TestSqlparseLimits: + @pytest.fixture(autouse=True) + def disable_sqlparse_limits(self): + """Ensure every test starts and ends with limits disabled.""" + configure_sqlparse_limits() + yield + configure_sqlparse_limits() + + @staticmethod + def _build_large_select(num_columns: int = 5000) -> str: + """Build a SELECT with enough columns to exceed the default 10,000 token limit.""" + columns = ", ".join(f"column_{i}" for i in range(num_columns)) + return f"SELECT {columns} FROM some_table" + + def test_disables_limits_by_default(self): + assert sqlparse.engine.grouping.MAX_GROUPING_TOKENS is None + assert sqlparse.engine.grouping.MAX_GROUPING_DEPTH is None + + def test_sets_custom_values(self): + expected_tokens = 50_000 + expected_depth = 200 + configure_sqlparse_limits( + max_grouping_tokens=expected_tokens, max_grouping_depth=expected_depth + ) + assert sqlparse.engine.grouping.MAX_GROUPING_TOKENS == expected_tokens + assert sqlparse.engine.grouping.MAX_GROUPING_DEPTH == expected_depth + + def test_restores_builtin_defaults(self): + reset_sqlparse_limits() + assert sqlparse.engine.grouping.MAX_GROUPING_TOKENS == 10_000 + assert sqlparse.engine.grouping.MAX_GROUPING_DEPTH == 100 + + def test_large_query_parses_with_limits_disabled(self): + large_query = self._build_large_select() + assert is_single_select_query(large_query) is True + + def test_large_query_fails_with_default_limits(self): + reset_sqlparse_limits() + large_query = self._build_large_select() + with pytest.raises(SQLParseError): + is_single_select_query(large_query)