diff --git a/CHANGELOG.md b/CHANGELOG.md index 517a60bfc..affa0e634 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,22 +7,53 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased] ### Added + - New feature: Support for macOS and Linux. - Documentation: Added API documentation in the Wiki. +- Added support for SQL Server spatial data types (geography, geometry, hierarchyid) via SQL_SS_UDT type handling. +- Added `SQLTypeCode` class for dual-compatible type codes in `cursor.description`. ### Changed + - Improved error handling in the connection module. +- Enhanced `cursor.description[i][1]` to return `SQLTypeCode` objects that compare equal to both SQL type integers and Python types, maintaining full backwards compatibility while aligning with DB-API 2.0. ### Fixed + - Bug fix: Resolved issue with connection timeout. +- Fixed `cursor.description` type handling for better DB-API 2.0 compliance (Issue #352). + +### SQLTypeCode Usage + +The `type_code` field in `cursor.description` now returns `SQLTypeCode` objects that support both comparison styles: + +```python +cursor.execute("SELECT id, name FROM users") +desc = cursor.description + +# Style 1: Compare with Python types (backwards compatible with pandas, etc.) +if desc[0][1] == int: + print("Integer column") + +# Style 2: Compare with SQL type codes (DB-API 2.0 compliant) +from mssql_python.constants import ConstantsDDBC as sql_types +if desc[0][1] == sql_types.SQL_INTEGER.value: # or just == 4 + print("Integer column") + +# Get the raw SQL type code +type_code = int(desc[0][1]) # Returns 4 for SQL_INTEGER +``` ## [1.0.0-alpha] - 2025-02-24 ### Added + - Initial release of the mssql-python driver for SQL Server. ### Changed + - N/A ### Fixed -- N/A \ No newline at end of file + +- N/A diff --git a/mssql_python/__init__.py b/mssql_python/__init__.py index 2bcac47bb..a8ac18d28 100644 --- a/mssql_python/__init__.py +++ b/mssql_python/__init__.py @@ -59,7 +59,7 @@ from .connection_string_builder import _ConnectionStringBuilder # Cursor Objects -from .cursor import Cursor +from .cursor import Cursor, SQLTypeCode # Logging Configuration (Simplified single-level DEBUG system) from .logging import logger, setup_logging, driver_logger diff --git a/mssql_python/connection.py b/mssql_python/connection.py index ba79e2a3f..64c02bed2 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -47,6 +47,7 @@ if TYPE_CHECKING: from mssql_python.row import Row + from mssql_python.cursor import SQLTypeCode # Add SQL_WMETADATA constant for metadata decoding configuration SQL_WMETADATA: int = -99 # Special flag for column name decoding @@ -923,7 +924,9 @@ def cursor(self) -> Cursor: logger.debug("cursor: Cursor created successfully - total_cursors=%d", len(self._cursors)) return cursor - def add_output_converter(self, sqltype: int, func: Callable[[Any], Any]) -> None: + def add_output_converter( + self, sqltype: "Union[int, SQLTypeCode, type]", func: Callable[[Any], Any] + ) -> None: """ Register an output converter function that will be called whenever a value with the given SQL type is read from the database. @@ -936,32 +939,47 @@ def add_output_converter(self, sqltype: int, func: Callable[[Any], Any]) -> None vulnerabilities. This API should never be exposed to untrusted or external input. Args: - sqltype (int): The integer SQL type value to convert, which can be one of the - defined standard constants (e.g. SQL_VARCHAR) or a database-specific - value (e.g. -151 for the SQL Server 2008 geometry data type). + sqltype (int, SQLTypeCode, or type): The integer SQL type value to convert, which can be + one of the defined standard constants (e.g. SQL_VARCHAR) or a + database-specific value (e.g. -151 for the SQL Server 2008 + geometry data type). Also accepts SQLTypeCode objects (from + cursor.description) or Python types (e.g., str, int) for + backward compatibility. func (callable): The converter function which will be called with a single parameter, the value, and should return the converted value. If the value is NULL - then the parameter passed to the function will be None, otherwise it - will be a bytes object. + then the parameter passed to the function will be None. For string/binary + columns, the value will be bytes (UTF-16LE encoded for strings). For other + types (int, decimal.Decimal, datetime, etc.), the value will be the native + Python object. Returns: None """ + # Handle SQLTypeCode objects (from cursor.description) by converting to int + if hasattr(sqltype, "type_code"): + sqltype = sqltype.type_code with self._converters_lock: self._output_converters[sqltype] = func # Pass to the underlying connection if native implementation supports it - if hasattr(self._conn, "add_output_converter"): + # Only forward int type codes to native layer; Python type keys are handled + # only in our Python-side dictionary + if isinstance(sqltype, int) and hasattr(self._conn, "add_output_converter"): self._conn.add_output_converter(sqltype, func) logger.info(f"Added output converter for SQL type {sqltype}") - def get_output_converter(self, sqltype: Union[int, type]) -> Optional[Callable[[Any], Any]]: + def get_output_converter( + self, sqltype: "Union[int, SQLTypeCode, type]" + ) -> Optional[Callable[[Any], Any]]: """ Get the output converter function for the specified SQL type. Thread-safe implementation that protects the converters dictionary with a lock. Args: - sqltype (int or type): The SQL type value or Python type to get the converter for + sqltype (int, SQLTypeCode, or type): The SQL type value to get the converter for. + Also accepts SQLTypeCode objects (from cursor.description), which are + automatically converted to their integer type code, or Python types + (e.g., str, int) for backward compatibility. Returns: callable or None: The converter function or None if no converter is registered @@ -970,26 +988,43 @@ def get_output_converter(self, sqltype: Union[int, type]) -> Optional[Callable[[ ⚠️ The returned converter function will be executed on database values. Only use converters from trusted sources. """ + # Handle SQLTypeCode objects (from cursor.description) by converting to int + # SQLTypeCode has a type_code attribute and supports int() conversion + original_sqltype = sqltype + if hasattr(sqltype, "type_code"): + sqltype = sqltype.type_code with self._converters_lock: - return self._output_converters.get(sqltype) - - def remove_output_converter(self, sqltype: Union[int, type]) -> None: + result = self._output_converters.get(sqltype) + # If int lookup misses for an SQLTypeCode, also try its python_type + # to preserve backward compatibility with converters registered by Python type + if result is None and hasattr(original_sqltype, "python_type"): + result = self._output_converters.get(original_sqltype.python_type) + return result + + def remove_output_converter(self, sqltype: "Union[int, SQLTypeCode, type]") -> None: """ Remove the output converter function for the specified SQL type. Thread-safe implementation that protects the converters dictionary with a lock. Args: - sqltype (int or type): The SQL type value to remove the converter for + sqltype (int, SQLTypeCode, or type): The SQL type value to remove the converter for. + Also accepts SQLTypeCode objects (from cursor.description) or Python types + (e.g., str, int) for backward compatibility. Returns: None """ + # Handle SQLTypeCode objects (from cursor.description) by converting to int + if hasattr(sqltype, "type_code"): + sqltype = sqltype.type_code with self._converters_lock: if sqltype in self._output_converters: del self._output_converters[sqltype] # Pass to the underlying connection if native implementation supports it - if hasattr(self._conn, "remove_output_converter"): + # Only forward int type codes to native layer; Python type keys are handled + # only in our Python-side dictionary + if isinstance(sqltype, int) and hasattr(self._conn, "remove_output_converter"): self._conn.remove_output_converter(sqltype) logger.info(f"Removed output converter for SQL type {sqltype}") diff --git a/mssql_python/constants.py b/mssql_python/constants.py index 03d40c833..c24822760 100644 --- a/mssql_python/constants.py +++ b/mssql_python/constants.py @@ -114,7 +114,12 @@ class ConstantsDDBC(Enum): SQL_FETCH_ABSOLUTE = 5 SQL_FETCH_RELATIVE = 6 SQL_FETCH_BOOKMARK = 8 + # NOTE: The following SQL Server-specific type constants MUST stay in sync with + # the corresponding values in mssql_python/pybind/ddbc_bindings.cpp SQL_DATETIMEOFFSET = -155 + SQL_SS_TIME2 = -154 # SQL Server TIME(n) type + SQL_SS_UDT = -151 # SQL Server User-Defined Types (geometry, geography, hierarchyid) + SQL_SS_XML = -152 # SQL Server XML type SQL_C_SS_TIMESTAMPOFFSET = 0x4001 SQL_SCOPE_CURROW = 0 SQL_BEST_ROWID = 1 diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 84bb650d5..02a5b7338 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -48,6 +48,116 @@ MONEY_MAX: decimal.Decimal = decimal.Decimal("922337203685477.5807") +class SQLTypeCode: + """ + A dual-compatible type code that compares equal to both SQL type integers and Python types. + + This class maintains backwards compatibility with code that checks + `cursor.description[i][1] == str` while also supporting DB-API 2.0 + compliant code that checks `cursor.description[i][1] == -9`. + + Examples: + >>> type_code = SQLTypeCode(-9, str) + >>> type_code == str # Backwards compatible with pandas, etc. + True + >>> type_code == -9 # DB-API 2.0 compliant + True + >>> int(type_code) # Get the raw SQL type code + -9 + """ + + # SQL type code to Python type mapping (class-level cache) + _type_map = None + + def __init__(self, type_code: int, python_type: Optional[type] = None): + self.type_code = type_code + # If python_type not provided, look it up from the mapping + if python_type is None: + python_type = self._get_python_type(type_code) + self.python_type = python_type + + @classmethod + def _get_type_map(cls): + """Lazily build the SQL to Python type mapping.""" + if cls._type_map is None: + cls._type_map = { + ddbc_sql_const.SQL_CHAR.value: str, + ddbc_sql_const.SQL_VARCHAR.value: str, + ddbc_sql_const.SQL_LONGVARCHAR.value: str, + ddbc_sql_const.SQL_WCHAR.value: str, + ddbc_sql_const.SQL_WVARCHAR.value: str, + ddbc_sql_const.SQL_WLONGVARCHAR.value: str, + ddbc_sql_const.SQL_INTEGER.value: int, + ddbc_sql_const.SQL_REAL.value: float, + ddbc_sql_const.SQL_FLOAT.value: float, + ddbc_sql_const.SQL_DOUBLE.value: float, + ddbc_sql_const.SQL_DECIMAL.value: decimal.Decimal, + ddbc_sql_const.SQL_NUMERIC.value: decimal.Decimal, + ddbc_sql_const.SQL_DATE.value: datetime.date, + ddbc_sql_const.SQL_TIMESTAMP.value: datetime.datetime, + ddbc_sql_const.SQL_TIME.value: datetime.time, + ddbc_sql_const.SQL_SS_TIME2.value: datetime.time, # SQL Server TIME(n) + # ODBC 3.x date/time type codes + ddbc_sql_const.SQL_TYPE_DATE.value: datetime.date, + ddbc_sql_const.SQL_TYPE_TIME.value: datetime.time, + ddbc_sql_const.SQL_TYPE_TIMESTAMP.value: datetime.datetime, + ddbc_sql_const.SQL_TYPE_TIMESTAMP_WITH_TIMEZONE.value: datetime.datetime, + ddbc_sql_const.SQL_BIT.value: bool, + ddbc_sql_const.SQL_TINYINT.value: int, + ddbc_sql_const.SQL_SMALLINT.value: int, + ddbc_sql_const.SQL_BIGINT.value: int, + ddbc_sql_const.SQL_BINARY.value: bytes, + ddbc_sql_const.SQL_VARBINARY.value: bytes, + ddbc_sql_const.SQL_LONGVARBINARY.value: bytes, + ddbc_sql_const.SQL_GUID.value: uuid.UUID, + ddbc_sql_const.SQL_SS_UDT.value: bytes, + ddbc_sql_const.SQL_SS_XML.value: str, # SQL Server XML type (-152) + ddbc_sql_const.SQL_DATETIME2.value: datetime.datetime, + ddbc_sql_const.SQL_SMALLDATETIME.value: datetime.datetime, + ddbc_sql_const.SQL_DATETIMEOFFSET.value: datetime.datetime, + } + return cls._type_map + + @classmethod + def _get_python_type(cls, sql_code: int) -> type: + """Get the Python type for a SQL type code.""" + return cls._get_type_map().get(sql_code, str) + + def __eq__(self, other): + """Compare equal to both Python types and SQL integer codes.""" + if isinstance(other, type): + return self.python_type == other + if isinstance(other, int): + return self.type_code == other + if isinstance(other, SQLTypeCode): + return self.type_code == other.type_code + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + """ + SQLTypeCode is intentionally unhashable because __eq__ allows + comparisons to both Python types and integer SQL codes, and + there is no single hash value that can be consistent with both. + """ + raise TypeError( + "SQLTypeCode is unhashable. Use int(type_code) or type_code.type_code " + "as a dict key instead. Example: {int(desc[1]): handler}" + ) + + def __int__(self): + return self.type_code + + def __repr__(self): + type_name = self.python_type.__name__ if self.python_type else "Unknown" + return f"SQLTypeCode({self.type_code}, {type_name})" + + def __str__(self): + return str(self.type_code) + + class Cursor: # pylint: disable=too-many-instance-attributes,too-many-public-methods """ Represents a database cursor, which is used to manage the context of a fetch operation. @@ -142,6 +252,9 @@ def __init__(self, connection: "Connection", timeout: int = 0) -> None: ) self.messages = [] # Store diagnostic messages + # Store raw column metadata for converter lookups + self._column_metadata = None + def _is_unicode_string(self, param: str) -> bool: """ Check if a string contains non-ASCII characters. @@ -724,6 +837,14 @@ def _reset_cursor(self) -> None: logger.debug("SQLFreeHandle succeeded") self._clear_rownumber() + self._column_metadata = None # Clear metadata to prevent stale data + self.description = None # Clear description for consistency + + # Clear any result-set-specific caches to avoid stale mappings + if hasattr(self, "_cached_column_map"): + self._cached_column_map = None + if hasattr(self, "_cached_converter_map"): + self._cached_converter_map = None # Reinitialize the statement handle self._initialize_cursor() @@ -756,6 +877,7 @@ def close(self) -> None: self.hstmt = None logger.debug("SQLFreeHandle succeeded") self._clear_rownumber() + self._column_metadata = None # Clear metadata to prevent memory leaks self.closed = True def _check_closed(self) -> None: @@ -942,8 +1064,12 @@ def _initialize_description(self, column_metadata: Optional[Any] = None) -> None """Initialize the description attribute from column metadata.""" if not column_metadata: self.description = None + self._column_metadata = None # Clear metadata too return + # Store raw metadata for converter map building + self._column_metadata = column_metadata + description = [] for _, col in enumerate(column_metadata): # Get column name - lowercase it if the lowercase flag is set @@ -954,10 +1080,13 @@ def _initialize_description(self, column_metadata: Optional[Any] = None) -> None column_name = column_name.lower() # Add to description tuple (7 elements as per PEP-249) + # Use SQLTypeCode for backwards-compatible type_code that works with both + # `desc[1] == str` (pandas) and `desc[1] == -9` (DB-API 2.0) + sql_type = col["DataType"] description.append( ( column_name, # name - self._map_data_type(col["DataType"]), # type_code + SQLTypeCode(sql_type), # type_code - dual compatible None, # display_size col["ColumnSize"], # internal_size col["ColumnSize"], # precision - should match ColumnSize @@ -975,6 +1104,7 @@ def _build_converter_map(self): """ if ( not self.description + or not self._column_metadata or not hasattr(self.connection, "_output_converters") or not self.connection._output_converters ): @@ -982,20 +1112,32 @@ def _build_converter_map(self): converter_map = [] - for desc in self.description: - if desc is None: - converter_map.append(None) - continue - sql_type = desc[1] + for col_meta in self._column_metadata: + # Use the raw SQL type code from metadata, not the mapped Python type + sql_type = col_meta["DataType"] + python_type = SQLTypeCode._get_python_type(sql_type) converter = self.connection.get_output_converter(sql_type) - # If no converter found for the SQL type, try the WVARCHAR converter as a fallback + + # Fallback: If no converter found for SQL type code, try the mapped Python type. + # This provides backward compatibility for code that registered converters by Python type. if converter is None: - from mssql_python.constants import ConstantsDDBC + converter = self.connection.get_output_converter(python_type) - converter = self.connection.get_output_converter(ConstantsDDBC.SQL_WVARCHAR.value) + # Additional fallback for string/bytes values: + # If a user has only registered a SQL_WVARCHAR converter and this column + # is mapped to str/bytes, try that converter so behavior matches the + # non-optimized Row._apply_output_converters path. + if converter is None and python_type in (str, bytes): + # Use the named constant for SQL_WVARCHAR instead of the magic number -9. + converter = self.connection.get_output_converter(ddbc_sql_const.SQL_WVARCHAR.value) converter_map.append(converter) + # If all entries are None, return None so that Row can fall back to the + # non-optimized path, preserving legacy behavior and fallbacks. + if not any(converter_map): + return None + return converter_map def _get_column_and_converter_maps(self): @@ -1022,41 +1164,6 @@ def _get_column_and_converter_maps(self): return column_map, converter_map - def _map_data_type(self, sql_type): - """ - Map SQL data type to Python data type. - - Args: - sql_type: SQL data type. - - Returns: - Corresponding Python data type. - """ - sql_to_python_type = { - ddbc_sql_const.SQL_INTEGER.value: int, - ddbc_sql_const.SQL_VARCHAR.value: str, - ddbc_sql_const.SQL_WVARCHAR.value: str, - ddbc_sql_const.SQL_CHAR.value: str, - ddbc_sql_const.SQL_WCHAR.value: str, - ddbc_sql_const.SQL_FLOAT.value: float, - ddbc_sql_const.SQL_DOUBLE.value: float, - ddbc_sql_const.SQL_DECIMAL.value: decimal.Decimal, - ddbc_sql_const.SQL_NUMERIC.value: decimal.Decimal, - ddbc_sql_const.SQL_DATE.value: datetime.date, - ddbc_sql_const.SQL_TIMESTAMP.value: datetime.datetime, - ddbc_sql_const.SQL_TIME.value: datetime.time, - ddbc_sql_const.SQL_BIT.value: bool, - ddbc_sql_const.SQL_TINYINT.value: int, - ddbc_sql_const.SQL_SMALLINT.value: int, - ddbc_sql_const.SQL_BIGINT.value: int, - ddbc_sql_const.SQL_BINARY.value: bytes, - ddbc_sql_const.SQL_VARBINARY.value: bytes, - ddbc_sql_const.SQL_LONGVARBINARY.value: bytes, - ddbc_sql_const.SQL_GUID.value: uuid.UUID, - # Add more mappings as needed - } - return sql_to_python_type.get(sql_type, str) - @property def rownumber(self) -> int: """ @@ -1369,6 +1476,7 @@ def execute( # pylint: disable=too-many-locals,too-many-branches,too-many-state except Exception as e: # pylint: disable=broad-exception-caught # If describe fails, it's likely there are no results (e.g., for INSERT) self.description = None + self._column_metadata = None # Clear metadata to prevent stale data # Reset rownumber for new result set (only for SELECT statements) if self.description: # If we have column descriptions, it's likely a SELECT @@ -1385,15 +1493,6 @@ def execute( # pylint: disable=too-many-locals,too-many-branches,too-many-state self._cached_column_map = None self._cached_converter_map = None - # After successful execution, initialize description if there are results - column_metadata = [] - try: - ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata) - self._initialize_description(column_metadata) - except Exception as e: - # If describe fails, it's likely there are no results (e.g., for INSERT) - self.description = None - self._reset_inputsizes() # Reset input sizes after execution # Return self for method chaining return self @@ -2425,6 +2524,7 @@ def nextset(self) -> Union[bool, None]: logger.debug("nextset: No more result sets available") self._clear_rownumber() self.description = None + self._column_metadata = None # Clear metadata to prevent stale data return False self._reset_rownumber() @@ -2444,6 +2544,7 @@ def nextset(self) -> Union[bool, None]: except Exception as e: # pylint: disable=broad-exception-caught # If describe fails, there might be no results in this result set self.description = None + self._column_metadata = None # Clear metadata to prevent stale data logger.debug( "nextset: Moved to next result set - column_count=%d", @@ -2756,7 +2857,13 @@ def __del__(self): Destructor to ensure the cursor is closed when it is no longer needed. This is a safety net to ensure resources are cleaned up even if close() was not called explicitly. - If the cursor is already closed, it will not raise an exception during cleanup. + + Error handling: + This destructor performs best-effort cleanup only. Any exceptions raised + while closing the cursor are caught and, when possible, logged instead of + being propagated, because raising from __del__ can cause hard-to-debug + failures during garbage collection. During interpreter shutdown, logging + may be suppressed if the logging subsystem is no longer available. """ if "closed" not in self.__dict__ or not self.closed: try: diff --git a/mssql_python/mssql_python.pyi b/mssql_python/mssql_python.pyi index dd3fd96a0..a446f69fa 100644 --- a/mssql_python/mssql_python.pyi +++ b/mssql_python/mssql_python.pyi @@ -81,6 +81,33 @@ def TimeFromTicks(ticks: int) -> datetime.time: ... def TimestampFromTicks(ticks: int) -> datetime.datetime: ... def Binary(value: Union[str, bytes, bytearray]) -> bytes: ... +# SQLTypeCode - Dual-compatible type code for cursor.description +class SQLTypeCode: + """ + A type code that supports dual comparison with both SQL type integers and Python types. + + This class is used in cursor.description[i][1] to provide backwards compatibility + with libraries like pandas (which compare with Python types like str, int, float) + while also supporting DB-API 2.0 style integer type code comparisons. + + Examples: + >>> desc = cursor.description + >>> desc[0][1] == str # True if column is string type + >>> desc[0][1] == 12 # True if SQL_VARCHAR + >>> int(desc[0][1]) # Returns the SQL type code as integer + """ + + type_code: int + python_type: type + + def __init__(self, type_code: int, python_type: Optional[type] = None) -> None: ... + def __eq__(self, other: Any) -> bool: ... + def __ne__(self, other: Any) -> bool: ... + def __int__(self) -> int: ... + def __hash__(self) -> int: ... # Raises TypeError with helpful message + def __repr__(self) -> str: ... + def __str__(self) -> str: ... + # DB-API 2.0 Exception Hierarchy # https://www.python.org/dev/peps/pep-0249/#exceptions class Warning(Exception): @@ -133,7 +160,7 @@ class Row: description: List[ Tuple[ str, - Any, + Union[SQLTypeCode, type], Optional[int], Optional[int], Optional[int], @@ -163,11 +190,14 @@ class Cursor: """ # DB-API 2.0 Required Attributes + # description is a sequence of 7-item tuples: + # (name, type_code, display_size, internal_size, precision, scale, null_ok) + # type_code is SQLTypeCode which compares equal to both SQL integers and Python types description: Optional[ List[ Tuple[ str, - Any, + Union[SQLTypeCode, type], Optional[int], Optional[int], Optional[int], @@ -265,9 +295,13 @@ class Connection: ) -> None: ... def getdecoding(self, sqltype: int) -> Dict[str, Union[str, int]]: ... def set_attr(self, attribute: int, value: Union[int, str, bytes, bytearray]) -> None: ... - def add_output_converter(self, sqltype: int, func: Callable[[Any], Any]) -> None: ... - def get_output_converter(self, sqltype: Union[int, type]) -> Optional[Callable[[Any], Any]]: ... - def remove_output_converter(self, sqltype: Union[int, type]) -> None: ... + def add_output_converter( + self, sqltype: Union[int, SQLTypeCode, type], func: Callable[[Any], Any] + ) -> None: ... + def get_output_converter( + self, sqltype: Union[int, SQLTypeCode, type] + ) -> Optional[Callable[[Any], Any]]: ... + def remove_output_converter(self, sqltype: Union[int, SQLTypeCode, type]) -> None: ... def clear_output_converters(self) -> None: ... def execute(self, sql: str, *args: Any) -> Cursor: ... def batch_execute( diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 2cf04fe0d..069581623 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -27,6 +27,18 @@ #define MAX_DIGITS_IN_NUMERIC 64 #define SQL_MAX_NUMERIC_LEN 16 #define SQL_SS_XML (-152) +#define SQL_SS_UDT (-151) // SQL Server User-Defined Types (geometry, geography, hierarchyid) +#ifndef SQL_DATETIME2 +#define SQL_DATETIME2 (42) +#endif +#ifndef SQL_SMALLDATETIME +#define SQL_SMALLDATETIME (58) +#endif + +// NOTE: The following SQL Server-specific type constants MUST stay in sync with +// the corresponding values in mssql_python/constants.py (ConstantsDDBC enum): +// SQL_SS_TIME2, SQL_SS_XML, SQL_SS_UDT, SQL_DATETIME2, SQL_SMALLDATETIME, SQL_SS_TIMESTAMPOFFSET +// (In Python, SQL_SS_TIMESTAMPOFFSET corresponds to ConstantsDDBC.SQL_DATETIMEOFFSET.) #define STRINGIFY_FOR_CASE(x) \ case x: \ @@ -3004,6 +3016,11 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p } break; } + case SQL_SS_UDT: { + LOG("SQLGetData: Streaming SQL Server UDT (e.g., geometry/geography/hierarchyid) for column %d", i); + row.append(FetchLobColumnData(hStmt, i, SQL_C_BINARY, false, true, "")); + break; + } case SQL_SS_XML: { LOG("SQLGetData: Streaming XML for column %d", i); row.append(FetchLobColumnData(hStmt, i, SQL_C_WCHAR, true, false, "utf-16le")); @@ -3228,6 +3245,8 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p } case SQL_TIMESTAMP: case SQL_TYPE_TIMESTAMP: + case SQL_DATETIME2: + case SQL_SMALLDATETIME: case SQL_DATETIME: { SQL_TIMESTAMP_STRUCT timestampValue; ret = SQLGetData_ptr(hStmt, i, SQL_C_TYPE_TIMESTAMP, ×tampValue, @@ -3558,6 +3577,7 @@ SQLRETURN SQLBindColums(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& column case SQL_BINARY: case SQL_VARBINARY: case SQL_LONGVARBINARY: + case SQL_SS_UDT: // geography, geometry, hierarchyid // TODO: handle variable length data correctly. This logic wont // suffice HandleZeroColumnSizeAtFetch(columnSize); @@ -3686,6 +3706,7 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum case SQL_BINARY: case SQL_VARBINARY: case SQL_LONGVARBINARY: + case SQL_SS_UDT: // geography, geometry, hierarchyid columnProcessors[col] = ColumnProcessors::ProcessBinary; break; default: @@ -3810,6 +3831,8 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum } case SQL_TIMESTAMP: case SQL_TYPE_TIMESTAMP: + case SQL_DATETIME2: + case SQL_SMALLDATETIME: case SQL_DATETIME: { const SQL_TIMESTAMP_STRUCT& ts = buffers.timestampBuffers[col - 1][i]; PyObject* datetimeObj = PythonObjectCache::get_datetime_class()( @@ -3961,6 +3984,8 @@ size_t calculateRowSize(py::list& columnNames, SQLUSMALLINT numCols) { case SQL_TIMESTAMP: case SQL_TYPE_TIMESTAMP: case SQL_DATETIME: + case SQL_DATETIME2: + case SQL_SMALLDATETIME: rowSize += sizeof(SQL_TIMESTAMP_STRUCT); break; case SQL_BIGINT: @@ -3989,6 +4014,15 @@ size_t calculateRowSize(py::list& columnNames, SQLUSMALLINT numCols) { case SQL_SS_TIMESTAMPOFFSET: rowSize += sizeof(DateTimeOffset); break; + case SQL_SS_UDT: { + SQLULEN effectiveSize = columnSize; + // Guard against SQL_NO_TOTAL or unrealistic sizes to avoid inflating/overflowing rowSize. + if (effectiveSize == SQL_NO_TOTAL || effectiveSize == 0 || effectiveSize > SQL_MAX_LOB_SIZE) { + effectiveSize = SQL_MAX_LOB_SIZE; + } + rowSize += static_cast(effectiveSize); + break; + } default: std::wstring columnName = columnMeta["ColumnName"].cast(); std::ostringstream errorString; @@ -4043,7 +4077,7 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch if ((dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR || dataType == SQL_VARCHAR || dataType == SQL_LONGVARCHAR || dataType == SQL_VARBINARY || - dataType == SQL_LONGVARBINARY || dataType == SQL_SS_XML) && + dataType == SQL_LONGVARBINARY || dataType == SQL_SS_XML || dataType == SQL_SS_UDT) && (columnSize == 0 || columnSize == SQL_NO_TOTAL || columnSize > SQL_MAX_LOB_SIZE)) { lobColumns.push_back(i + 1); // 1-based } @@ -4177,7 +4211,7 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows, if ((dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR || dataType == SQL_VARCHAR || dataType == SQL_LONGVARCHAR || dataType == SQL_VARBINARY || - dataType == SQL_LONGVARBINARY || dataType == SQL_SS_XML) && + dataType == SQL_LONGVARBINARY || dataType == SQL_SS_XML || dataType == SQL_SS_UDT) && (columnSize == 0 || columnSize == SQL_NO_TOTAL || columnSize > SQL_MAX_LOB_SIZE)) { lobColumns.push_back(i + 1); // 1-based } diff --git a/pyproject.toml b/pyproject.toml index 538a4a992..b2c267c48 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.black] line-length = 100 -target-version = ['py38', 'py39', 'py310', 'py311'] +target-version = ['py38', 'py39', 'py310', 'py311', 'py312', 'py313'] include = '\.pyi?$' extend-exclude = ''' /( diff --git a/tests/test_002_types.py b/tests/test_002_types.py index 4828d72ea..a9e99f9f5 100644 --- a/tests/test_002_types.py +++ b/tests/test_002_types.py @@ -1267,3 +1267,215 @@ def test_utf8_4byte_sequence_complete_coverage(): assert len(result) > 0, f"Invalid pattern should produce some output" assert True, "Complete 4-byte sequence coverage validated" + + +# ============================================================================= +# SQLTypeCode Unit Tests (DB-API 2.0 + pandas compatibility) +# ============================================================================= + + +class TestSQLTypeCode: + """ + Unit tests for SQLTypeCode class. + + SQLTypeCode provides dual compatibility: + - Compares equal to Python type objects (str, int, float, etc.) for pandas compatibility + - Compares equal to SQL integer codes for DB-API 2.0 compliance + """ + + def test_sqltypecode_import(self): + """Test that SQLTypeCode is importable from public API.""" + from mssql_python import SQLTypeCode + + assert SQLTypeCode is not None + + def test_sqltypecode_equals_python_type_str(self): + """Test SQLTypeCode for SQL_WVARCHAR (-9) equals str.""" + from mssql_python import SQLTypeCode + + tc = SQLTypeCode(-9) # SQL_WVARCHAR + assert tc == str, "SQLTypeCode(-9) should equal str" + assert not (tc != str), "SQLTypeCode(-9) should not be != str" + + def test_sqltypecode_equals_python_type_int(self): + """Test SQLTypeCode for SQL_INTEGER (4) equals int.""" + from mssql_python import SQLTypeCode + + tc = SQLTypeCode(4) # SQL_INTEGER + assert tc == int, "SQLTypeCode(4) should equal int" + + def test_sqltypecode_equals_python_type_float(self): + """Test SQLTypeCode for SQL_REAL (7) equals float.""" + from mssql_python import SQLTypeCode + + tc = SQLTypeCode(7) # SQL_REAL + assert tc == float, "SQLTypeCode(7) should equal float" + + def test_sqltypecode_equals_python_type_bytes(self): + """Test SQLTypeCode for SQL_BINARY (-2) equals bytes.""" + from mssql_python import SQLTypeCode + + tc = SQLTypeCode(-2) # SQL_BINARY + assert tc == bytes, "SQLTypeCode(-2) should equal bytes" + + def test_sqltypecode_equals_sql_integer_code(self): + """Test SQLTypeCode equals its raw SQL integer code.""" + from mssql_python import SQLTypeCode + + tc = SQLTypeCode(4) # SQL_INTEGER + assert tc == 4, "SQLTypeCode(4) should equal 4" + assert tc == SQLTypeCode(4).type_code, "SQLTypeCode(4) should equal its type_code" + + def test_sqltypecode_equals_negative_sql_code(self): + """Test SQLTypeCode with negative SQL codes (e.g., SQL_WVARCHAR = -9).""" + from mssql_python import SQLTypeCode + + tc = SQLTypeCode(-9) # SQL_WVARCHAR + assert tc == -9, "SQLTypeCode(-9) should equal -9" + + def test_sqltypecode_dual_compatibility(self): + """Test that SQLTypeCode equals both Python type AND SQL code simultaneously.""" + from mssql_python import SQLTypeCode + + tc = SQLTypeCode(4) # SQL_INTEGER + # Must satisfy BOTH comparisons - this is the key feature + assert tc == int and tc == 4, "SQLTypeCode should equal both int and 4" + + def test_sqltypecode_int_conversion(self): + """Test int(SQLTypeCode) returns raw SQL code.""" + from mssql_python import SQLTypeCode + + tc = SQLTypeCode(-9) + assert int(tc) == -9, "int(SQLTypeCode(-9)) should return -9" + tc2 = SQLTypeCode(4) + assert int(tc2) == 4, "int(SQLTypeCode(4)) should return 4" + + def test_sqltypecode_unhashable(self): + """Test SQLTypeCode is intentionally unhashable due to eq/hash contract.""" + from mssql_python import SQLTypeCode + import pytest + + tc = SQLTypeCode(4) + # SQLTypeCode should not be hashable because __eq__ compares to both + # Python types and integers, which have different hash values. + # The __hash__ method raises TypeError with a helpful message. + with pytest.raises(TypeError) as exc_info: + hash(tc) + # Verify the error message provides guidance + assert "unhashable" in str(exc_info.value).lower() + assert "int(type_code)" in str(exc_info.value) or "type_code.type_code" in str( + exc_info.value + ) + + def test_sqltypecode_repr(self): + """Test SQLTypeCode has informative repr.""" + from mssql_python import SQLTypeCode + + tc = SQLTypeCode(4) + r = repr(tc) + assert "4" in r, "repr should contain the SQL code" + assert "SQLTypeCode" in r, "repr should contain class name" + + def test_sqltypecode_type_code_property(self): + """Test SQLTypeCode.type_code returns raw SQL code.""" + from mssql_python import SQLTypeCode + + tc = SQLTypeCode(-9) + assert tc.type_code == -9 + tc2 = SQLTypeCode(93) # SQL_TYPE_TIMESTAMP + assert tc2.type_code == 93 + + def test_sqltypecode_python_type_property(self): + """Test SQLTypeCode.python_type returns mapped type.""" + from mssql_python import SQLTypeCode + + tc = SQLTypeCode(4) # SQL_INTEGER + assert tc.python_type == int + tc2 = SQLTypeCode(-9) # SQL_WVARCHAR + assert tc2.python_type == str + + def test_sqltypecode_unknown_type_maps_to_str(self): + """Test unknown SQL codes map to str by default.""" + from mssql_python import SQLTypeCode + + tc = SQLTypeCode(99999) # Unknown code + assert tc.python_type == str + assert tc == str # Should still work for comparison + + def test_sqltypecode_pandas_simulation(self): + """ + Simulate pandas read_sql type checking behavior. + + Pandas checks `cursor.description[i][1] == str` to determine + if a column should be treated as string data. + """ + from mssql_python import SQLTypeCode + + # Simulate a description tuple like pandas receives + description = [ + ("name", SQLTypeCode(-9), None, None, None, None, None), # nvarchar + ("age", SQLTypeCode(4), None, None, None, None, None), # int + ("salary", SQLTypeCode(6), None, None, None, None, None), # float + ] + + # Pandas-style type checking + string_columns = [] + for name, type_code, *rest in description: + if type_code == str: + string_columns.append(name) + + assert string_columns == ["name"], "Only 'name' column should be detected as string" + + # Verify other types work too + for name, type_code, *rest in description: + if type_code == int: + assert name == "age" + if type_code == float: + assert name == "salary" + + def test_sqltypecode_dbapi_simulation(self): + """ + Simulate DB-API 2.0 style type checking with integer codes. + """ + from mssql_python import SQLTypeCode + + # Simulate description + description = [ + ("id", SQLTypeCode(4), None, None, None, None, None), # SQL_INTEGER + ("data", SQLTypeCode(-9), None, None, None, None, None), # SQL_WVARCHAR + ] + + # DB-API style: check raw SQL code + for name, type_code, *rest in description: + if type_code == 4: # SQL_INTEGER + assert name == "id" + if type_code == -9: # SQL_WVARCHAR + assert name == "data" + + def test_sqltypecode_equality_with_other_sqltypecode(self): + """Test SQLTypeCode equality with another SQLTypeCode.""" + from mssql_python import SQLTypeCode + + tc1 = SQLTypeCode(4) + tc2 = SQLTypeCode(4) + tc3 = SQLTypeCode(-9) + + # Explicitly test __eq__ when comparing one SQLTypeCode instance to another + result1 = tc1.__eq__(tc2) # Same type codes + result2 = tc1.__eq__(tc3) # Different type codes + assert result1 is True, "Same code SQLTypeCodes should be equal" + assert result2 is False, "Different code SQLTypeCodes should not be equal" + + # Also test via == operator + assert tc1 == tc2, "Same code SQLTypeCodes should be equal via ==" + assert tc1 != tc3, "Different code SQLTypeCodes should not be equal via !=" + + def test_sqltypecode_inequality(self): + """Test SQLTypeCode inequality comparisons.""" + from mssql_python import SQLTypeCode + + tc = SQLTypeCode(4) + assert tc != str, "SQL_INTEGER should not equal str" + assert tc != float, "SQL_INTEGER should not equal float" + assert tc != 5, "SQLTypeCode(4) should not equal 5" + assert tc != -9, "SQLTypeCode(4) should not equal -9" diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index 575496299..554b50976 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -16,7 +16,6 @@ from contextlib import closing import mssql_python import uuid -import re from conftest import is_azure_sql_connection # Setup test table @@ -182,13 +181,15 @@ def test_mixed_empty_and_null_values(cursor, db_connection): try: # Create test table drop_table_if_exists(cursor, "#pytest_empty_vs_null") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_empty_vs_null ( id INT, text_col NVARCHAR(100), binary_col VARBINARY(100) ) - """) + """ + ) db_connection.commit() # Insert mix of empty and NULL values @@ -886,13 +887,15 @@ def test_rowcount(cursor, db_connection): cursor.execute("INSERT INTO #pytest_test_rowcount (name) VALUES ('JohnDoe3');") assert cursor.rowcount == 1, "Rowcount should be 1 after third insert" - cursor.execute(""" + cursor.execute( + """ INSERT INTO #pytest_test_rowcount (name) VALUES ('JohnDoe4'), ('JohnDoe5'), ('JohnDoe6'); - """) + """ + ) assert cursor.rowcount == 3, "Rowcount should be 3 after inserting multiple rows" cursor.execute("SELECT * FROM #pytest_test_rowcount;") @@ -988,12 +991,14 @@ def test_fetchmany_size_zero_lob(cursor, db_connection): """Test fetchmany with size=0 for LOB columns""" try: cursor.execute("DROP TABLE IF EXISTS #test_fetchmany_lob") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #test_fetchmany_lob ( id INT PRIMARY KEY, lob_data NVARCHAR(MAX) ) - """) + """ + ) # Insert test data test_data = [(1, "First LOB data"), (2, "Second LOB data"), (3, "Third LOB data")] @@ -1018,12 +1023,14 @@ def test_fetchmany_more_than_exist_lob(cursor, db_connection): """Test fetchmany requesting more rows than exist with LOB columns""" try: cursor.execute("DROP TABLE IF EXISTS #test_fetchmany_lob_more") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #test_fetchmany_lob_more ( id INT PRIMARY KEY, lob_data NVARCHAR(MAX) ) - """) + """ + ) # Insert only 3 rows test_data = [(1, "First LOB data"), (2, "Second LOB data"), (3, "Third LOB data")] @@ -1057,12 +1064,14 @@ def test_fetchmany_empty_result_lob(cursor, db_connection): """Test fetchmany on empty result set with LOB columns""" try: cursor.execute("DROP TABLE IF EXISTS #test_fetchmany_lob_empty") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #test_fetchmany_lob_empty ( id INT PRIMARY KEY, lob_data NVARCHAR(MAX) ) - """) + """ + ) db_connection.commit() # Query empty table @@ -1085,12 +1094,14 @@ def test_fetchmany_very_large_lob(cursor, db_connection): """Test fetchmany with very large LOB column data""" try: cursor.execute("DROP TABLE IF EXISTS #test_fetchmany_large_lob") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #test_fetchmany_large_lob ( id INT PRIMARY KEY, large_lob NVARCHAR(MAX) ) - """) + """ + ) # Create very large data (10000 characters) large_data = "x" * 10000 @@ -1140,12 +1151,14 @@ def test_fetchmany_mixed_lob_sizes(cursor, db_connection): """Test fetchmany with mixed LOB sizes including empty and NULL""" try: cursor.execute("DROP TABLE IF EXISTS #test_fetchmany_mixed_lob") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #test_fetchmany_mixed_lob ( id INT PRIMARY KEY, mixed_lob NVARCHAR(MAX) ) - """) + """ + ) # Mix of sizes: empty, NULL, small, medium, large test_data = [ @@ -1197,7 +1210,7 @@ def test_fetchall(cursor): def test_fetchall_lob(cursor): - """Test fetching all rows""" + """Test fetching all rows with LOB columns""" cursor.execute("SELECT * FROM #pytest_all_data_types") rows = cursor.fetchall() assert isinstance(rows, list), "fetchall should return a list" @@ -1273,12 +1286,14 @@ def test_executemany_empty_strings(cursor, db_connection): """Test executemany with empty strings - regression test for Unix UTF-16 conversion issue""" try: # Create test table for empty string testing - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_empty_batch ( id INT, data NVARCHAR(50) ) - """) + """ + ) # Clear any existing data cursor.execute("DELETE FROM #pytest_empty_batch") @@ -1319,7 +1334,8 @@ def test_executemany_empty_strings_various_types(cursor, db_connection): """Test executemany with empty strings in different column types""" try: # Create test table with different string types - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_string_types ( id INT, varchar_col VARCHAR(50), @@ -1327,7 +1343,8 @@ def test_executemany_empty_strings_various_types(cursor, db_connection): text_col TEXT, ntext_col NTEXT ) - """) + """ + ) # Clear any existing data cursor.execute("DELETE FROM #pytest_string_types") @@ -1368,12 +1385,14 @@ def test_executemany_unicode_and_empty_strings(cursor, db_connection): """Test executemany with mix of Unicode characters and empty strings""" try: # Create test table - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_unicode_test ( id INT, data NVARCHAR(100) ) - """) + """ + ) # Clear any existing data cursor.execute("DELETE FROM #pytest_unicode_test") @@ -1418,12 +1437,14 @@ def test_executemany_large_batch_with_empty_strings(cursor, db_connection): """Test executemany with large batch containing empty strings""" try: # Create test table - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_large_batch ( id INT, data NVARCHAR(50) ) - """) + """ + ) # Clear any existing data cursor.execute("DELETE FROM #pytest_large_batch") @@ -1476,12 +1497,14 @@ def test_executemany_compare_with_execute(cursor, db_connection): """Test that executemany produces same results as individual execute calls""" try: # Create test table - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_compare_test ( id INT, data NVARCHAR(50) ) - """) + """ + ) # Test data with empty strings test_data = [ @@ -1534,13 +1557,15 @@ def test_executemany_edge_cases_empty_strings(cursor, db_connection): """Test executemany edge cases with empty strings and special characters""" try: # Create test table - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_edge_cases ( id INT, varchar_data VARCHAR(100), nvarchar_data NVARCHAR(100) ) - """) + """ + ) # Clear any existing data cursor.execute("DELETE FROM #pytest_edge_cases") @@ -1594,12 +1619,14 @@ def test_executemany_null_vs_empty_string(cursor, db_connection): """Test that executemany correctly distinguishes between NULL and empty string""" try: # Create test table - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_null_vs_empty ( id INT, data NVARCHAR(50) ) - """) + """ + ) # Clear any existing data cursor.execute("DELETE FROM #pytest_null_vs_empty") @@ -1664,12 +1691,14 @@ def test_executemany_binary_data_edge_cases(cursor, db_connection): """Test executemany with binary data and empty byte arrays""" try: # Create test table - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_binary_test ( id INT, binary_data VARBINARY(100) ) - """) + """ + ) # Clear any existing data cursor.execute("DELETE FROM #pytest_binary_test") @@ -1831,7 +1860,8 @@ def test_executemany_mixed_null_and_typed_values(cursor, db_connection): """Test executemany with randomly mixed NULL and non-NULL values across multiple columns and rows (50 rows, 10 columns).""" try: # Create table with 10 columns of various types - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_empty_params ( col1 INT, col2 VARCHAR(50), @@ -1844,7 +1874,8 @@ def test_executemany_mixed_null_and_typed_values(cursor, db_connection): col9 DATE, col10 REAL ) - """) + """ + ) # Generate 50 rows with randomly mixed NULL and non-NULL values across 10 columns data = [] @@ -1908,7 +1939,8 @@ def test_executemany_multi_column_null_arrays(cursor, db_connection): """Test executemany with multi-column NULL arrays (50 records, 8 columns).""" try: # Create table with 8 columns of various types - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_null_arrays ( col1 INT, col2 VARCHAR(100), @@ -1919,7 +1951,8 @@ def test_executemany_multi_column_null_arrays(cursor, db_connection): col7 BIGINT, col8 DATE ) - """) + """ + ) # Generate 50 rows with all NULL values across 8 columns data = [(None, None, None, None, None, None, None, None) for _ in range(50)] @@ -1939,12 +1972,14 @@ def test_executemany_multi_column_null_arrays(cursor, db_connection): assert null_count == 50, f"Expected 50 NULLs in col{col_num}, got {null_count}" # Verify no non-NULL values exist - cursor.execute(""" + cursor.execute( + """ SELECT COUNT(*) FROM #pytest_null_arrays WHERE col1 IS NOT NULL OR col2 IS NOT NULL OR col3 IS NOT NULL OR col4 IS NOT NULL OR col5 IS NOT NULL OR col6 IS NOT NULL OR col7 IS NOT NULL OR col8 IS NOT NULL - """) + """ + ) non_null_count = cursor.fetchone()[0] assert non_null_count == 0, f"Expected 0 non-NULL values, got {non_null_count}" @@ -1983,7 +2018,8 @@ def test_executemany_concurrent_null_parameters(db_connection): # Create table with db_connection.cursor() as cursor: - cursor.execute(f""" + cursor.execute( + f""" IF OBJECT_ID('{table_name}', 'U') IS NOT NULL DROP TABLE {table_name} @@ -1995,7 +2031,8 @@ def test_executemany_concurrent_null_parameters(db_connection): col3 FLOAT, col4 DATETIME ) - """) + """ + ) db_connection.commit() # Execute multiple sequential insert operations @@ -2250,12 +2287,14 @@ def test_insert_data_for_join(cursor, db_connection): def test_join_operations(cursor): """Test join operations""" try: - cursor.execute(""" + cursor.execute( + """ SELECT e.name, d.department_name, p.project_name FROM #pytest_employees e JOIN #pytest_departments d ON e.department_id = d.department_id JOIN #pytest_projects p ON e.employee_id = p.employee_id - """) + """ + ) rows = cursor.fetchall() assert len(rows) == 3, "Join operation returned incorrect number of rows" assert rows[0] == [ @@ -2345,10 +2384,12 @@ def test_execute_stored_procedure_with_parameters(cursor): def test_execute_stored_procedure_without_parameters(cursor): """Test executing stored procedure without parameters""" try: - cursor.execute(""" + cursor.execute( + """ DECLARE @EmployeeID INT = 2 EXEC dbo.GetEmployeeProjects @EmployeeID - """) + """ + ) rows = cursor.fetchall() assert ( len(rows) == 1 @@ -2382,16 +2423,125 @@ def test_drop_tables_for_join(cursor, db_connection): def test_cursor_description(cursor): - """Test cursor description""" + """Test cursor description with SQLTypeCode for backwards compatibility.""" cursor.execute("SELECT database_id, name FROM sys.databases;") desc = cursor.description - expected_description = [ - ("database_id", int, None, 10, 10, 0, False), - ("name", str, None, 128, 128, 0, False), - ] - assert len(desc) == len(expected_description), "Description length mismatch" - for desc, expected in zip(desc, expected_description): - assert desc == expected, f"Description mismatch: {desc} != {expected}" + + from mssql_python.constants import ConstantsDDBC as ddbc_sql_const + + # Verify length + assert len(desc) == 2, "Description should have 2 columns" + + # Test 1: DB-API 2.0 compliant - compare with SQL type codes (integers) + assert desc[0][1] == ddbc_sql_const.SQL_INTEGER.value, "database_id should be SQL_INTEGER (4)" + assert desc[1][1] == ddbc_sql_const.SQL_WVARCHAR.value, "name should be SQL_WVARCHAR (-9)" + + # Test 2: Backwards compatible - compare with Python types (for pandas, etc.) + assert desc[0][1] == int, "database_id should also compare equal to Python int" + assert desc[1][1] == str, "name should also compare equal to Python str" + + # Test 3: Can convert to int to get raw SQL code + assert int(desc[0][1]) == 4, "int(type_code) should return SQL_INTEGER (4)" + assert int(desc[1][1]) == -9, "int(type_code) should return SQL_WVARCHAR (-9)" + + # Test 4: Verify other tuple elements + assert desc[0][0] == "database_id", "First column name should be database_id" + assert desc[1][0] == "name", "Second column name should be name" + + +def test_cursor_description_pandas_compatibility(cursor): + """ + Test that cursor.description type_code works with pandas-style type checking. + + Pandas and other libraries check `cursor.description[i][1] == str` to determine + column types. This test ensures SQLTypeCode maintains backwards compatibility. + """ + cursor.execute("SELECT database_id, name FROM sys.databases;") + desc = cursor.description + + # Simulate what pandas does internally when reading SQL results + # pandas checks: if description[i][1] == str: treat as string column + type_map = {} + for col_desc in desc: + col_name = col_desc[0] + type_code = col_desc[1] + + # This is how pandas-like code typically checks types + if type_code == str: + type_map[col_name] = "string" + elif type_code == int: + type_map[col_name] = "integer" + elif type_code == float: + type_map[col_name] = "float" + elif type_code == bytes: + type_map[col_name] = "bytes" + else: + type_map[col_name] = "other" + + assert type_map["database_id"] == "integer", "database_id should be detected as integer" + assert type_map["name"] == "string", "name should be detected as string" + + +def test_cursor_description_datetime_types(cursor, db_connection): + """ + Regression test for Issue #352: Ensure DATE/datetime columns return correct ODBC type codes. + + This test verifies that cursor.description properly handles date/time columns, + returning the correct ODBC 3.x type codes while maintaining backwards compatibility + with Python datetime types for pandas-style comparisons. + """ + from mssql_python.constants import ConstantsDDBC + + try: + # Create a table with various date/time types + cursor.execute( + """ + CREATE TABLE #pytest_datetime_desc ( + id INT PRIMARY KEY, + date_col DATE, + time_col TIME, + datetime_col DATETIME, + datetime2_col DATETIME2 + ); + """ + ) + db_connection.commit() + + cursor.execute( + "SELECT id, date_col, time_col, datetime_col, datetime2_col FROM #pytest_datetime_desc;" + ) + desc = cursor.description + + assert len(desc) == 5, "Should have 5 columns in description" + + # Verify column names + assert desc[0][0] == "id", "First column should be 'id'" + assert desc[1][0] == "date_col", "Second column should be 'date_col'" + assert desc[2][0] == "time_col", "Third column should be 'time_col'" + assert desc[3][0] == "datetime_col", "Fourth column should be 'datetime_col'" + assert desc[4][0] == "datetime2_col", "Fifth column should be 'datetime2_col'" + + # Test 1: DB-API 2.0 compliant - verify SQL type codes as integers + # DATE should be SQL_TYPE_DATE (91) + assert ( + int(desc[1][1]) == ConstantsDDBC.SQL_TYPE_DATE.value + ), f"DATE column should have SQL_TYPE_DATE type code ({ConstantsDDBC.SQL_TYPE_DATE.value})" + # TIME should be SQL_SS_TIME2 (-154) or SQL_TYPE_TIME (92) + time_type_code = int(desc[2][1]) + assert time_type_code in ( + ConstantsDDBC.SQL_SS_TIME2.value, + ConstantsDDBC.SQL_TYPE_TIME.value, + ), f"TIME column should have SQL_SS_TIME2 or SQL_TYPE_TIME type code, got {time_type_code}" + + # Test 2: Backwards compatible - compare with Python types (for pandas, etc.) + assert desc[1][1] == date, "DATE should compare equal to datetime.date" + assert desc[2][1] == time, "TIME should compare equal to datetime.time" + assert desc[3][1] == datetime, "DATETIME should compare equal to datetime.datetime" + assert desc[4][1] == datetime, "DATETIME2 should compare equal to datetime.datetime" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_datetime_desc;") + db_connection.commit() def test_parse_datetime(cursor, db_connection): @@ -2568,21 +2718,25 @@ def test_row_attribute_access(cursor, db_connection): """Test accessing row values by column name as attributes""" try: # Create test table with multiple columns - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_row_attr_test ( id INT PRIMARY KEY, name VARCHAR(50), email VARCHAR(100), age INT ) - """) + """ + ) db_connection.commit() # Insert test data - cursor.execute(""" + cursor.execute( + """ INSERT INTO #pytest_row_attr_test (id, name, email, age) VALUES (1, 'John Doe', 'john@example.com', 30) - """) + """ + ) db_connection.commit() # Test attribute access @@ -2678,13 +2832,15 @@ def test_row_comparison_with_list(cursor, db_connection): def test_row_string_representation(cursor, db_connection): """Test Row string and repr representations""" try: - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_row_test ( id INT PRIMARY KEY, text_col NVARCHAR(50), null_col INT ) - """) + """ + ) db_connection.commit() cursor.execute( @@ -2717,13 +2873,15 @@ def test_row_string_representation(cursor, db_connection): def test_row_column_mapping(cursor, db_connection): """Test Row column name mapping""" try: - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_row_test ( FirstColumn INT PRIMARY KEY, Second_Column NVARCHAR(50), [Complex Name!] INT ) - """) + """ + ) db_connection.commit() cursor.execute( @@ -3206,10 +3364,12 @@ def test_execute_rowcount_chaining(cursor, db_connection): assert count == 1, "INSERT should affect 1 row" # Test multiple INSERT rowcount chaining - count = cursor.execute(""" + count = cursor.execute( + """ INSERT INTO #test_chaining (id, value) VALUES (2, 'test2'), (3, 'test3'), (4, 'test4') - """).rowcount + """ + ).rowcount assert count == 3, "Multiple INSERT should affect 3 rows" # Test UPDATE rowcount chaining @@ -3444,7 +3604,8 @@ def test_cursor_next_with_different_data_types(cursor, db_connection): """Test next() functionality with various data types""" try: # Create test table with various data types - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #test_next_types ( id INT, name NVARCHAR(50), @@ -3453,7 +3614,8 @@ def test_cursor_next_with_different_data_types(cursor, db_connection): created_date DATE, created_time DATETIME ) - """) + """ + ) db_connection.commit() # Insert test data with different types @@ -3645,14 +3807,16 @@ def test_execute_chaining_compatibility_examples(cursor, db_connection): """Test real-world chaining examples""" try: # Create users table - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #users ( user_id INT IDENTITY(1,1) PRIMARY KEY, user_name NVARCHAR(50), last_logon DATETIME, status NVARCHAR(20) ) - """) + """ + ) db_connection.commit() # Insert test users @@ -4351,7 +4515,8 @@ def test_fetchval_different_data_types(cursor, db_connection): try: # Create test table with different data types drop_table_if_exists(cursor, "#pytest_fetchval_types") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_fetchval_types ( int_col INTEGER, float_col FLOAT, @@ -4363,14 +4528,17 @@ def test_fetchval_different_data_types(cursor, db_connection): date_col DATE, time_col TIME ) - """) + """ + ) # Insert test data - cursor.execute(""" + cursor.execute( + """ INSERT INTO #pytest_fetchval_types VALUES (123, 45.67, 89.12, 'ASCII text', N'Unicode text', 1, '2024-05-20 12:34:56', '2024-05-20', '12:34:56') - """) + """ + ) db_connection.commit() # Test different data types @@ -5668,21 +5836,25 @@ def test_cursor_rollback_data_consistency(cursor, db_connection): drop_table_if_exists(cursor, "#pytest_rollback_orders") drop_table_if_exists(cursor, "#pytest_rollback_customers") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_rollback_customers ( id INTEGER PRIMARY KEY, name VARCHAR(50) ) - """) + """ + ) - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_rollback_orders ( id INTEGER PRIMARY KEY, customer_id INTEGER, amount DECIMAL(10,2), FOREIGN KEY (customer_id) REFERENCES #pytest_rollback_customers(id) ) - """) + """ + ) cursor.commit() # Insert initial data @@ -6164,26 +6336,32 @@ def test_tables_setup(cursor, db_connection): cursor.execute("DROP VIEW IF EXISTS pytest_tables_schema.test_view") # Create regular table - cursor.execute(""" + cursor.execute( + """ CREATE TABLE pytest_tables_schema.regular_table ( id INT PRIMARY KEY, name VARCHAR(100) ) - """) + """ + ) # Create another table - cursor.execute(""" + cursor.execute( + """ CREATE TABLE pytest_tables_schema.another_table ( id INT PRIMARY KEY, description VARCHAR(200) ) - """) + """ + ) # Create a view - cursor.execute(""" + cursor.execute( + """ CREATE VIEW pytest_tables_schema.test_view AS SELECT id, name FROM pytest_tables_schema.regular_table - """) + """ + ) db_connection.commit() except Exception as e: @@ -6535,12 +6713,14 @@ def test_emoji_round_trip(cursor, db_connection): "1🚀' OR '1'='1", ] - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_emoji_test ( id INT IDENTITY PRIMARY KEY, content NVARCHAR(MAX) ); - """) + """ + ) db_connection.commit() for text in test_inputs: @@ -6692,14 +6872,16 @@ def test_empty_values_fetchmany(cursor, db_connection): try: # Create comprehensive test table drop_table_if_exists(cursor, "#pytest_fetchmany_empty") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_fetchmany_empty ( id INT, varchar_col VARCHAR(50), nvarchar_col NVARCHAR(50), binary_col VARBINARY(50) ) - """) + """ + ) db_connection.commit() # Insert multiple rows with empty values @@ -6824,7 +7006,8 @@ def test_batch_fetch_empty_values_no_assertion_failure(cursor, db_connection): try: # Create comprehensive test table drop_table_if_exists(cursor, "#pytest_batch_empty_assertions") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_batch_empty_assertions ( id INT, empty_varchar VARCHAR(100), @@ -6834,24 +7017,29 @@ def test_batch_fetch_empty_values_no_assertion_failure(cursor, db_connection): null_nvarchar NVARCHAR(100), null_binary VARBINARY(100) ) - """) + """ + ) db_connection.commit() # Insert rows with mix of empty and NULL values - cursor.execute(""" + cursor.execute( + """ INSERT INTO #pytest_batch_empty_assertions VALUES (1, '', '', 0x, NULL, NULL, NULL), (2, '', '', 0x, NULL, NULL, NULL), (3, '', '', 0x, NULL, NULL, NULL) - """) + """ + ) db_connection.commit() # Test fetchall - should not trigger any assertions about dataLen - cursor.execute(""" + cursor.execute( + """ SELECT empty_varchar, empty_nvarchar, empty_binary, null_varchar, null_nvarchar, null_binary FROM #pytest_batch_empty_assertions ORDER BY id - """) + """ + ) rows = cursor.fetchall() assert len(rows) == 3, "Should return 3 rows" @@ -6868,10 +7056,12 @@ def test_batch_fetch_empty_values_no_assertion_failure(cursor, db_connection): assert row[5] is None, f"Row {i+1} null_binary should be None" # Test fetchmany - should also not trigger assertions - cursor.execute(""" + cursor.execute( + """ SELECT empty_nvarchar, empty_binary FROM #pytest_batch_empty_assertions ORDER BY id - """) + """ + ) # Fetch in batches first_batch = cursor.fetchmany(2) @@ -6911,13 +7101,15 @@ def test_executemany_utf16_length_validation(cursor, db_connection): try: # Create test table with small column size to trigger validation drop_table_if_exists(cursor, "#pytest_utf16_validation") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_utf16_validation ( id INT, short_text NVARCHAR(5), -- Small column to test length validation medium_text NVARCHAR(10) -- Medium column for edge cases ) - """) + """ + ) db_connection.commit() # Test 1: Valid strings that should work on all platforms @@ -7063,12 +7255,14 @@ def test_binary_data_over_8000_bytes(cursor, db_connection): try: # Create test table with VARBINARY(MAX) to handle large data drop_table_if_exists(cursor, "#pytest_small_binary") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_small_binary ( id INT, large_binary VARBINARY(MAX) ) - """) + """ + ) # Test data that fits within both parameter and fetch limits (< 4096 bytes) medium_data = b"B" * 3000 # 3,000 bytes - under both limits @@ -7102,12 +7296,14 @@ def test_varbinarymax_insert_fetch(cursor, db_connection): try: # Create test table drop_table_if_exists(cursor, "#pytest_varbinarymax") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_varbinarymax ( id INT, binary_data VARBINARY(MAX) ) - """) + """ + ) # Prepare test data - use moderate sizes to guarantee LOB fetch path (line 867-868) efficiently test_data = [ @@ -7174,12 +7370,14 @@ def test_all_empty_binaries(cursor, db_connection): try: # Create test table drop_table_if_exists(cursor, "#pytest_all_empty_binary") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_all_empty_binary ( id INT, empty_binary VARBINARY(100) ) - """) + """ + ) # Insert multiple rows with only empty binary data test_data = [ @@ -7218,12 +7416,14 @@ def test_mixed_bytes_and_bytearray_types(cursor, db_connection): try: # Create test table drop_table_if_exists(cursor, "#pytest_mixed_binary_types") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_mixed_binary_types ( id INT, binary_data VARBINARY(100) ) - """) + """ + ) # Test data mixing bytes and bytearray for the same column test_data = [ @@ -7278,12 +7478,14 @@ def test_binary_mostly_small_one_large(cursor, db_connection): try: # Create test table drop_table_if_exists(cursor, "#pytest_mixed_size_binary") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_mixed_size_binary ( id INT, binary_data VARBINARY(MAX) ) - """) + """ + ) # Create large binary value within both parameter and fetch limits (< 4096 bytes) large_binary = b"X" * 3500 # 3,500 bytes - under both limits @@ -7343,12 +7545,14 @@ def test_varbinarymax_insert_fetch_null(cursor, db_connection): """Test insertion and retrieval of NULL value in VARBINARY(MAX) column.""" try: drop_table_if_exists(cursor, "#pytest_varbinarymax_null") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_varbinarymax_null ( id INT, binary_data VARBINARY(MAX) ) - """) + """ + ) # Insert a row with NULL for binary_data cursor.execute( @@ -7378,13 +7582,15 @@ def test_sql_double_type(cursor, db_connection): """Test SQL_DOUBLE type (FLOAT(53)) to cover line 3213 in dispatcher.""" try: drop_table_if_exists(cursor, "#pytest_double_type") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_double_type ( id INT PRIMARY KEY, double_col FLOAT(53), float_col FLOAT ) - """) + """ + ) # Insert test data with various double precision values test_data = [ @@ -7432,13 +7638,15 @@ def test_null_guid_type(cursor, db_connection): """Test NULL UNIQUEIDENTIFIER (GUID) to cover lines 3376-3377.""" try: drop_table_if_exists(cursor, "#pytest_null_guid") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_null_guid ( id INT PRIMARY KEY, guid_col UNIQUEIDENTIFIER, guid_nullable UNIQUEIDENTIFIER NULL ) - """) + """ + ) # Insert test data with NULL and non-NULL GUIDs test_guid = uuid.uuid4() @@ -7490,12 +7698,14 @@ def test_only_null_and_empty_binary(cursor, db_connection): try: # Create test table drop_table_if_exists(cursor, "#pytest_null_empty_binary") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_null_empty_binary ( id INT, binary_data VARBINARY(100) ) - """) + """ + ) # Test data with only NULL and empty values test_data = [ @@ -7818,7 +8028,8 @@ def test_money_smallmoney_insert_fetch(cursor, db_connection): """Test inserting and retrieving valid MONEY and SMALLMONEY values including boundaries and typical data""" try: drop_table_if_exists(cursor, "#pytest_money_test") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_money_test ( id INT IDENTITY PRIMARY KEY, m MONEY, @@ -7826,7 +8037,8 @@ def test_money_smallmoney_insert_fetch(cursor, db_connection): d DECIMAL(19,4), n NUMERIC(10,4) ) - """) + """ + ) db_connection.commit() # Max values @@ -7916,13 +8128,15 @@ def test_money_smallmoney_insert_fetch(cursor, db_connection): def test_money_smallmoney_null_handling(cursor, db_connection): """Test that NULL values for MONEY and SMALLMONEY are stored and retrieved correctly""" try: - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_money_test ( id INT IDENTITY PRIMARY KEY, m MONEY, sm SMALLMONEY ) - """) + """ + ) db_connection.commit() # Row with both NULLs @@ -7972,13 +8186,15 @@ def test_money_smallmoney_null_handling(cursor, db_connection): def test_money_smallmoney_roundtrip(cursor, db_connection): """Test inserting and retrieving MONEY and SMALLMONEY using decimal.Decimal roundtrip""" try: - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_money_test ( id INT IDENTITY PRIMARY KEY, m MONEY, sm SMALLMONEY ) - """) + """ + ) db_connection.commit() values = (decimal.Decimal("12345.6789"), decimal.Decimal("987.6543")) @@ -8002,13 +8218,15 @@ def test_money_smallmoney_boundaries(cursor, db_connection): """Test boundary values for MONEY and SMALLMONEY types are handled correctly""" try: drop_table_if_exists(cursor, "#pytest_money_test") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_money_test ( id INT IDENTITY PRIMARY KEY, m MONEY, sm SMALLMONEY ) - """) + """ + ) db_connection.commit() # Insert max boundary @@ -8048,13 +8266,15 @@ def test_money_smallmoney_boundaries(cursor, db_connection): def test_money_smallmoney_invalid_values(cursor, db_connection): """Test that invalid or out-of-range MONEY and SMALLMONEY values raise errors""" try: - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_money_test ( id INT IDENTITY PRIMARY KEY, m MONEY, sm SMALLMONEY ) - """) + """ + ) db_connection.commit() # Out of range MONEY @@ -8085,13 +8305,15 @@ def test_money_smallmoney_invalid_values(cursor, db_connection): def test_money_smallmoney_roundtrip_executemany(cursor, db_connection): """Test inserting and retrieving MONEY and SMALLMONEY using executemany with decimal.Decimal""" try: - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_money_test ( id INT IDENTITY PRIMARY KEY, m MONEY, sm SMALLMONEY ) - """) + """ + ) db_connection.commit() test_data = [ @@ -8125,13 +8347,15 @@ def test_money_smallmoney_roundtrip_executemany(cursor, db_connection): def test_money_smallmoney_executemany_null_handling(cursor, db_connection): """Test inserting NULLs into MONEY and SMALLMONEY using executemany""" try: - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_money_test ( id INT IDENTITY PRIMARY KEY, m MONEY, sm SMALLMONEY ) - """) + """ + ) db_connection.commit() rows = [ @@ -8189,12 +8413,14 @@ def test_uuid_insert_and_select_none(cursor, db_connection): table_name = "#pytest_uuid_nullable" try: cursor.execute(f"DROP TABLE IF EXISTS {table_name}") - cursor.execute(f""" + cursor.execute( + f""" CREATE TABLE {table_name} ( id UNIQUEIDENTIFIER, name NVARCHAR(50) ) - """) + """ + ) db_connection.commit() # Insert a row with None for the UUID @@ -8218,12 +8444,14 @@ def test_insert_multiple_uuids(cursor, db_connection): table_name = "#pytest_uuid_multiple" try: cursor.execute(f"DROP TABLE IF EXISTS {table_name}") - cursor.execute(f""" + cursor.execute( + f""" CREATE TABLE {table_name} ( id UNIQUEIDENTIFIER PRIMARY KEY, description NVARCHAR(50) ) - """) + """ + ) db_connection.commit() # Prepare test data @@ -8259,12 +8487,14 @@ def test_fetchmany_uuids(cursor, db_connection): table_name = "#pytest_uuid_fetchmany" try: cursor.execute(f"DROP TABLE IF EXISTS {table_name}") - cursor.execute(f""" + cursor.execute( + f""" CREATE TABLE {table_name} ( id UNIQUEIDENTIFIER PRIMARY KEY, description NVARCHAR(50) ) - """) + """ + ) db_connection.commit() uuids_to_insert = {f"Item {i}": uuid.uuid4() for i in range(10)} @@ -8300,12 +8530,14 @@ def test_uuid_insert_with_none(cursor, db_connection): table_name = "#pytest_uuid_none" try: cursor.execute(f"DROP TABLE IF EXISTS {table_name}") - cursor.execute(f""" + cursor.execute( + f""" CREATE TABLE {table_name} ( id UNIQUEIDENTIFIER, name NVARCHAR(50) ) - """) + """ + ) db_connection.commit() cursor.execute(f"INSERT INTO {table_name} (id, name) VALUES (?, ?)", [None, "Alice"]) @@ -8401,12 +8633,14 @@ def test_executemany_uuid_insert_and_select(cursor, db_connection): try: # Drop and create a temporary table for the test cursor.execute(f"DROP TABLE IF EXISTS {table_name}") - cursor.execute(f""" + cursor.execute( + f""" CREATE TABLE {table_name} ( id UNIQUEIDENTIFIER PRIMARY KEY, description NVARCHAR(50) ) - """) + """ + ) db_connection.commit() # Generate data for insertion @@ -8456,12 +8690,14 @@ def test_executemany_uuid_roundtrip_fixed_value(cursor, db_connection): table_name = "#pytest_uuid_fixed" try: cursor.execute(f"DROP TABLE IF EXISTS {table_name}") - cursor.execute(f""" + cursor.execute( + f""" CREATE TABLE {table_name} ( id UNIQUEIDENTIFIER, description NVARCHAR(50) ) - """) + """ + ) db_connection.commit() fixed_uuid = uuid.UUID("12345678-1234-5678-1234-567812345678") @@ -8502,7 +8738,8 @@ def test_decimal_separator_with_multiple_values(cursor, db_connection): try: # Create test table - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_decimal_multi_test ( id INT PRIMARY KEY, positive_value DECIMAL(10, 2), @@ -8510,13 +8747,16 @@ def test_decimal_separator_with_multiple_values(cursor, db_connection): zero_value DECIMAL(10, 2), small_value DECIMAL(10, 4) ) - """) + """ + ) db_connection.commit() # Insert test data - cursor.execute(""" + cursor.execute( + """ INSERT INTO #pytest_decimal_multi_test VALUES (1, 123.45, -67.89, 0.00, 0.0001) - """) + """ + ) db_connection.commit() # Test with default separator first @@ -8553,19 +8793,23 @@ def test_decimal_separator_calculations(cursor, db_connection): try: # Create test table - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_decimal_calc_test ( id INT PRIMARY KEY, value1 DECIMAL(10, 2), value2 DECIMAL(10, 2) ) - """) + """ + ) db_connection.commit() # Insert test data - cursor.execute(""" + cursor.execute( + """ INSERT INTO #pytest_decimal_calc_test VALUES (1, 10.25, 5.75) - """) + """ + ) db_connection.commit() # Test with default separator @@ -8604,12 +8848,14 @@ def test_decimal_separator_function(cursor, db_connection): try: # Create test table - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_decimal_separator_test ( id INT PRIMARY KEY, decimal_value DECIMAL(10, 2) ) - """) + """ + ) db_connection.commit() # Insert test values with default separator (.) @@ -8694,21 +8940,25 @@ def test_lowercase_attribute(cursor, db_connection): try: # Create a test table with mixed-case column names - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_lowercase_test ( ID INT PRIMARY KEY, UserName VARCHAR(50), EMAIL_ADDRESS VARCHAR(100), PhoneNumber VARCHAR(20) ) - """) + """ + ) db_connection.commit() # Insert test data - cursor.execute(""" + cursor.execute( + """ INSERT INTO #pytest_lowercase_test (ID, UserName, EMAIL_ADDRESS, PhoneNumber) VALUES (1, 'JohnDoe', 'john@example.com', '555-1234') - """) + """ + ) db_connection.commit() # First test with lowercase=False (default) @@ -8763,12 +9013,14 @@ def test_decimal_separator_function(cursor, db_connection): try: # Create test table - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_decimal_separator_test ( id INT PRIMARY KEY, decimal_value DECIMAL(10, 2) ) - """) + """ + ) db_connection.commit() # Insert test values with default separator (.) @@ -8850,7 +9102,8 @@ def test_decimal_separator_with_multiple_values(cursor, db_connection): try: # Create test table - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_decimal_multi_test ( id INT PRIMARY KEY, positive_value DECIMAL(10, 2), @@ -8858,13 +9111,16 @@ def test_decimal_separator_with_multiple_values(cursor, db_connection): zero_value DECIMAL(10, 2), small_value DECIMAL(10, 4) ) - """) + """ + ) db_connection.commit() # Insert test data - cursor.execute(""" + cursor.execute( + """ INSERT INTO #pytest_decimal_multi_test VALUES (1, 123.45, -67.89, 0.00, 0.0001) - """) + """ + ) db_connection.commit() # Test with default separator first @@ -8901,19 +9157,23 @@ def test_decimal_separator_calculations(cursor, db_connection): try: # Create test table - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_decimal_calc_test ( id INT PRIMARY KEY, value1 DECIMAL(10, 2), value2 DECIMAL(10, 2) ) - """) + """ + ) db_connection.commit() # Insert test data - cursor.execute(""" + cursor.execute( + """ INSERT INTO #pytest_decimal_calc_test VALUES (1, 10.25, 5.75) - """) + """ + ) db_connection.commit() # Test with default separator @@ -8984,13 +9244,10 @@ def test_decimal_separator_fetch_regression(cursor, db_connection): assert val == decimal.Decimal("99.99") finally: - # Reset separator to default just in case + # Reset separator to default mssql_python.setDecimalSeparator(".") - try: - cursor.execute("DROP TABLE IF EXISTS #TestDecimal") - db_connection.commit() - except Exception: - pass + cursor.execute("DROP TABLE IF EXISTS #TestDecimal") + db_connection.commit() def test_datetimeoffset_read_write(cursor, db_connection): @@ -9426,21 +9683,25 @@ def test_lowercase_attribute(cursor, db_connection): try: # Create a test table with mixed-case column names - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_lowercase_test ( ID INT PRIMARY KEY, UserName VARCHAR(50), EMAIL_ADDRESS VARCHAR(100), PhoneNumber VARCHAR(20) ) - """) + """ + ) db_connection.commit() # Insert test data - cursor.execute(""" + cursor.execute( + """ INSERT INTO #pytest_lowercase_test (ID, UserName, EMAIL_ADDRESS, PhoneNumber) VALUES (1, 'JohnDoe', 'john@example.com', '555-1234') - """) + """ + ) db_connection.commit() # First test with lowercase=False (default) @@ -9495,12 +9756,14 @@ def test_decimal_separator_function(cursor, db_connection): try: # Create test table - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_decimal_separator_test ( id INT PRIMARY KEY, decimal_value DECIMAL(10, 2) ) - """) + """ + ) db_connection.commit() # Insert test values with default separator (.) @@ -9582,7 +9845,8 @@ def test_decimal_separator_with_multiple_values(cursor, db_connection): try: # Create test table - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_decimal_multi_test ( id INT PRIMARY KEY, positive_value DECIMAL(10, 2), @@ -9590,13 +9854,16 @@ def test_decimal_separator_with_multiple_values(cursor, db_connection): zero_value DECIMAL(10, 2), small_value DECIMAL(10, 4) ) - """) + """ + ) db_connection.commit() # Insert test data - cursor.execute(""" + cursor.execute( + """ INSERT INTO #pytest_decimal_multi_test VALUES (1, 123.45, -67.89, 0.00, 0.0001) - """) + """ + ) db_connection.commit() # Test with default separator first @@ -9633,19 +9900,23 @@ def test_decimal_separator_calculations(cursor, db_connection): try: # Create test table - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_decimal_calc_test ( id INT PRIMARY KEY, value1 DECIMAL(10, 2), value2 DECIMAL(10, 2) ) - """) + """ + ) db_connection.commit() # Insert test data - cursor.execute(""" + cursor.execute( + """ INSERT INTO #pytest_decimal_calc_test VALUES (1, 10.25, 5.75) - """) + """ + ) db_connection.commit() # Test with default separator @@ -9684,12 +9955,14 @@ def test_cursor_setinputsizes_basic(db_connection): # Create a test table cursor.execute("DROP TABLE IF EXISTS #test_inputsizes") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #test_inputsizes ( string_col NVARCHAR(100), int_col INT ) - """) + """ + ) # Set input sizes for parameters cursor.setinputsizes([(mssql_python.SQL_WVARCHAR, 100, 0), (mssql_python.SQL_INTEGER, 0, 0)]) @@ -9715,13 +9988,15 @@ def test_cursor_setinputsizes_with_executemany_float(db_connection): # Create a test table cursor.execute("DROP TABLE IF EXISTS #test_inputsizes_float") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #test_inputsizes_float ( id INT, name NVARCHAR(50), price REAL /* Use REAL instead of DECIMAL */ ) - """) + """ + ) # Prepare data with float values data = [(1, "Item 1", 10.99), (2, "Item 2", 20.50), (3, "Item 3", 30.75)] @@ -9758,12 +10033,14 @@ def test_cursor_setinputsizes_reset(db_connection): # Create a test table cursor.execute("DROP TABLE IF EXISTS #test_inputsizes_reset") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #test_inputsizes_reset ( col1 NVARCHAR(100), col2 INT ) - """) + """ + ) # Set input sizes for parameters cursor.setinputsizes([(mssql_python.SQL_WVARCHAR, 100, 0), (mssql_python.SQL_INTEGER, 0, 0)]) @@ -9798,12 +10075,14 @@ def test_cursor_setinputsizes_override_inference(db_connection): # Create a test table with specific types cursor.execute("DROP TABLE IF EXISTS #test_inputsizes_override") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #test_inputsizes_override ( small_int SMALLINT, big_text NVARCHAR(MAX) ) - """) + """ + ) # Set input sizes that override the default inference # For SMALLINT, use a valid precision value (5 is typical for SMALLINT) @@ -9859,13 +10138,15 @@ def test_setinputsizes_parameter_count_mismatch_fewer(db_connection): # Create a test table cursor.execute("DROP TABLE IF EXISTS #test_inputsizes_mismatch") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #test_inputsizes_mismatch ( col1 INT, col2 NVARCHAR(100), col3 FLOAT ) - """) + """ + ) # Set fewer input sizes than parameters cursor.setinputsizes( @@ -9908,12 +10189,14 @@ def test_setinputsizes_parameter_count_mismatch_more(db_connection): # Create a test table cursor.execute("DROP TABLE IF EXISTS #test_inputsizes_mismatch") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #test_inputsizes_mismatch ( col1 INT, col2 NVARCHAR(100) ) - """) + """ + ) # Set more input sizes than parameters cursor.setinputsizes( @@ -9948,7 +10231,8 @@ def test_setinputsizes_with_null_values(db_connection): # Create a test table with multiple data types cursor.execute("DROP TABLE IF EXISTS #test_inputsizes_null") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #test_inputsizes_null ( int_col INT, string_col NVARCHAR(100), @@ -9956,7 +10240,8 @@ def test_setinputsizes_with_null_values(db_connection): date_col DATE, binary_col VARBINARY(100) ) - """) + """ + ) # Set input sizes for all columns cursor.setinputsizes( @@ -10259,15 +10544,18 @@ def test_procedures_setup(cursor, db_connection): ) # Create test stored procedures - cursor.execute(""" + cursor.execute( + """ CREATE OR ALTER PROCEDURE pytest_proc_schema.test_proc1 AS BEGIN SELECT 1 AS result END - """) + """ + ) - cursor.execute(""" + cursor.execute( + """ CREATE OR ALTER PROCEDURE pytest_proc_schema.test_proc2 @param1 INT, @param2 VARCHAR(50) OUTPUT @@ -10276,7 +10564,8 @@ def test_procedures_setup(cursor, db_connection): SELECT @param2 = 'Output ' + CAST(@param1 AS VARCHAR(10)) RETURN @param1 END - """) + """ + ) db_connection.commit() except Exception as e: @@ -10394,7 +10683,8 @@ def test_procedures_with_parameters(cursor, db_connection): """Test that procedures() correctly reports parameter information""" try: # Create a simpler procedure with basic parameters - cursor.execute(""" + cursor.execute( + """ CREATE OR ALTER PROCEDURE pytest_proc_schema.test_params_proc @in1 INT, @in2 VARCHAR(50) @@ -10402,7 +10692,8 @@ def test_procedures_with_parameters(cursor, db_connection): BEGIN SELECT @in1 AS value1, @in2 AS value2 END - """) + """ + ) db_connection.commit() # Get procedure info @@ -10436,23 +10727,28 @@ def test_procedures_result_set_info(cursor, db_connection): """Test that procedures() reports information about result sets""" try: # Create procedures with different result set patterns - cursor.execute(""" + cursor.execute( + """ CREATE OR ALTER PROCEDURE pytest_proc_schema.test_no_results AS BEGIN DECLARE @x INT = 1 END - """) + """ + ) - cursor.execute(""" + cursor.execute( + """ CREATE OR ALTER PROCEDURE pytest_proc_schema.test_one_result AS BEGIN SELECT 1 AS col1, 'test' AS col2 END - """) + """ + ) - cursor.execute(""" + cursor.execute( + """ CREATE OR ALTER PROCEDURE pytest_proc_schema.test_multiple_results AS BEGIN @@ -10460,7 +10756,8 @@ def test_procedures_result_set_info(cursor, db_connection): SELECT 'test' AS result2 SELECT GETDATE() AS result3 END - """) + """ + ) db_connection.commit() # Get procedure info for all test procedures @@ -10542,15 +10839,18 @@ def test_foreignkeys_setup(cursor, db_connection): cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.customers") # Create parent table - cursor.execute(""" + cursor.execute( + """ CREATE TABLE pytest_fk_schema.customers ( customer_id INT PRIMARY KEY, customer_name VARCHAR(100) NOT NULL ) - """) + """ + ) # Create child table with foreign key - cursor.execute(""" + cursor.execute( + """ CREATE TABLE pytest_fk_schema.orders ( order_id INT PRIMARY KEY, order_date DATETIME NOT NULL, @@ -10559,18 +10859,23 @@ def test_foreignkeys_setup(cursor, db_connection): CONSTRAINT FK_Orders_Customers FOREIGN KEY (customer_id) REFERENCES pytest_fk_schema.customers (customer_id) ) - """) + """ + ) # Insert test data - cursor.execute(""" + cursor.execute( + """ INSERT INTO pytest_fk_schema.customers (customer_id, customer_name) VALUES (1, 'Test Customer 1'), (2, 'Test Customer 2') - """) + """ + ) - cursor.execute(""" + cursor.execute( + """ INSERT INTO pytest_fk_schema.orders (order_id, order_date, customer_id, total_amount) VALUES (101, GETDATE(), 1, 150.00), (102, GETDATE(), 2, 250.50) - """) + """ + ) db_connection.commit() except Exception as e: @@ -10798,17 +11103,20 @@ def test_foreignkeys_multiple_column_fk(cursor, db_connection): cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.product_variants") # Create parent table with composite primary key - cursor.execute(""" + cursor.execute( + """ CREATE TABLE pytest_fk_schema.product_variants ( product_id INT NOT NULL, variant_id INT NOT NULL, variant_name VARCHAR(100) NOT NULL, PRIMARY KEY (product_id, variant_id) ) - """) + """ + ) # Create child table with composite foreign key - cursor.execute(""" + cursor.execute( + """ CREATE TABLE pytest_fk_schema.order_details ( order_id INT NOT NULL, product_id INT NOT NULL, @@ -10818,7 +11126,8 @@ def test_foreignkeys_multiple_column_fk(cursor, db_connection): CONSTRAINT FK_OrderDetails_ProductVariants FOREIGN KEY (product_id, variant_id) REFERENCES pytest_fk_schema.product_variants (product_id, variant_id) ) - """) + """ + ) db_connection.commit() @@ -10883,23 +11192,27 @@ def test_primarykeys_setup(cursor, db_connection): cursor.execute("DROP TABLE IF EXISTS pytest_pk_schema.composite_pk_test") # Create table with simple primary key - cursor.execute(""" + cursor.execute( + """ CREATE TABLE pytest_pk_schema.single_pk_test ( id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, description VARCHAR(200) NULL ) - """) + """ + ) # Create table with composite primary key - cursor.execute(""" + cursor.execute( + """ CREATE TABLE pytest_pk_schema.composite_pk_test ( dept_id INT NOT NULL, emp_id INT NOT NULL, hire_date DATE NOT NULL, CONSTRAINT PK_composite_test PRIMARY KEY (dept_id, emp_id) ) - """) + """ + ) db_connection.commit() except Exception as e: @@ -11210,13 +11523,15 @@ def test_rowcount(cursor, db_connection): cursor.execute("INSERT INTO #pytest_test_rowcount (name) VALUES ('JohnDoe3');") assert cursor.rowcount == 1, "Rowcount should be 1 after third insert" - cursor.execute(""" + cursor.execute( + """ INSERT INTO #pytest_test_rowcount (name) VALUES ('JohnDoe4'), ('JohnDoe5'), ('JohnDoe6'); - """) + """ + ) assert cursor.rowcount == 3, "Rowcount should be 3 after inserting multiple rows" cursor.execute("SELECT * FROM #pytest_test_rowcount;") @@ -11251,26 +11566,31 @@ def test_specialcolumns_setup(cursor, db_connection): cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.identity_test") # Create table with primary key (for rowIdColumns) - cursor.execute(""" + cursor.execute( + """ CREATE TABLE pytest_special_schema.rowid_test ( id INT PRIMARY KEY, name NVARCHAR(100) NOT NULL, unique_col NVARCHAR(100) UNIQUE, non_unique_col NVARCHAR(100) ) - """) + """ + ) # Create table with rowversion column (for rowVerColumns) - cursor.execute(""" + cursor.execute( + """ CREATE TABLE pytest_special_schema.timestamp_test ( id INT PRIMARY KEY, name NVARCHAR(100) NOT NULL, last_updated ROWVERSION ) - """) + """ + ) # Create table with multiple unique identifiers - cursor.execute(""" + cursor.execute( + """ CREATE TABLE pytest_special_schema.multiple_unique_test ( id INT NOT NULL, code VARCHAR(10) NOT NULL, @@ -11278,16 +11598,19 @@ def test_specialcolumns_setup(cursor, db_connection): order_number VARCHAR(20) UNIQUE, CONSTRAINT PK_multiple_unique_test PRIMARY KEY (id, code) ) - """) + """ + ) # Create table with identity column - cursor.execute(""" + cursor.execute( + """ CREATE TABLE pytest_special_schema.identity_test ( id INT IDENTITY(1,1) PRIMARY KEY, name NVARCHAR(100) NOT NULL, last_modified DATETIME DEFAULT GETDATE() ) - """) + """ + ) db_connection.commit() except Exception as e: @@ -11406,12 +11729,14 @@ def test_rowid_columns_nullable(cursor, db_connection): """Test rowIdColumns with nullable parameter""" try: # First create a table with nullable unique column and non-nullable PK - cursor.execute(""" + cursor.execute( + """ CREATE TABLE pytest_special_schema.nullable_test ( id INT PRIMARY KEY, -- PK can't be nullable in SQL Server data NVARCHAR(100) NULL ) - """) + """ + ) db_connection.commit() # Test with nullable=True (default) @@ -11504,12 +11829,14 @@ def test_rowver_columns_nullable(cursor, db_connection): """Test rowVerColumns with nullable parameter (not expected to have effect)""" try: # First create a table with rowversion column - cursor.execute(""" + cursor.execute( + """ CREATE TABLE pytest_special_schema.nullable_rowver_test ( id INT PRIMARY KEY, ts ROWVERSION ) - """) + """ + ) db_connection.commit() # Test with nullable=True (default) @@ -11618,7 +11945,8 @@ def test_statistics_setup(cursor, db_connection): cursor.execute("DROP TABLE IF EXISTS pytest_stats_schema.empty_stats_test") # Create test table with various indexes - cursor.execute(""" + cursor.execute( + """ CREATE TABLE pytest_stats_schema.stats_test ( id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, @@ -11627,25 +11955,32 @@ def test_statistics_setup(cursor, db_connection): salary DECIMAL(10, 2) NULL, hire_date DATE NOT NULL ) - """) + """ + ) # Create a non-unique index - cursor.execute(""" + cursor.execute( + """ CREATE INDEX IX_stats_test_dept_date ON pytest_stats_schema.stats_test (department, hire_date) - """) + """ + ) # Create a unique index on multiple columns - cursor.execute(""" + cursor.execute( + """ CREATE UNIQUE INDEX UX_stats_test_name_dept ON pytest_stats_schema.stats_test (name, department) - """) + """ + ) # Create an empty table for testing - cursor.execute(""" + cursor.execute( + """ CREATE TABLE pytest_stats_schema.empty_stats_test ( id INT PRIMARY KEY, data VARCHAR(100) NULL ) - """) + """ + ) db_connection.commit() except Exception as e: @@ -11910,7 +12245,8 @@ def test_columns_setup(cursor, db_connection): cursor.execute("DROP TABLE IF EXISTS pytest_cols_schema.columns_special_test") # Create test table with various column types - cursor.execute(""" + cursor.execute( + """ CREATE TABLE pytest_cols_schema.columns_test ( id INT PRIMARY KEY, name NVARCHAR(100) NOT NULL, @@ -11922,10 +12258,12 @@ def test_columns_setup(cursor, db_connection): notes TEXT NULL, [computed_col] AS (name + ' - ' + CAST(id AS VARCHAR(10))) ) - """) + """ + ) # Create table with special column names and edge cases - fix the problematic column name - cursor.execute(""" + cursor.execute( + """ CREATE TABLE pytest_cols_schema.columns_special_test ( [ID] INT PRIMARY KEY, [User Name] NVARCHAR(100) NULL, @@ -11937,7 +12275,8 @@ def test_columns_setup(cursor, db_connection): [Column/With/Slashes] VARCHAR(20) NULL, [Column_With_Underscores] VARCHAR(20) NULL -- Changed from problematic nested brackets ) - """) + """ + ) db_connection.commit() except Exception as e: @@ -12401,21 +12740,25 @@ def test_lowercase_attribute(cursor, db_connection): try: # Create a test table with mixed-case column names - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_lowercase_test ( ID INT PRIMARY KEY, UserName VARCHAR(50), EMAIL_ADDRESS VARCHAR(100), PhoneNumber VARCHAR(20) ) - """) + """ + ) db_connection.commit() # Insert test data - cursor.execute(""" + cursor.execute( + """ INSERT INTO #pytest_lowercase_test (ID, UserName, EMAIL_ADDRESS, PhoneNumber) VALUES (1, 'JohnDoe', 'john@example.com', '555-1234') - """) + """ + ) db_connection.commit() # First test with lowercase=False (default) @@ -12470,12 +12813,14 @@ def test_decimal_separator_function(cursor, db_connection): try: # Create test table - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_decimal_separator_test ( id INT PRIMARY KEY, decimal_value DECIMAL(10, 2) ) - """) + """ + ) db_connection.commit() # Insert test values with default separator (.) @@ -12557,7 +12902,8 @@ def test_decimal_separator_with_multiple_values(cursor, db_connection): try: # Create test table - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_decimal_multi_test ( id INT PRIMARY KEY, positive_value DECIMAL(10, 2), @@ -12565,13 +12911,16 @@ def test_decimal_separator_with_multiple_values(cursor, db_connection): zero_value DECIMAL(10, 2), small_value DECIMAL(10, 4) ) - """) + """ + ) db_connection.commit() # Insert test data - cursor.execute(""" + cursor.execute( + """ INSERT INTO #pytest_decimal_multi_test VALUES (1, 123.45, -67.89, 0.00, 0.0001) - """) + """ + ) db_connection.commit() # Test with default separator first @@ -12608,19 +12957,23 @@ def test_decimal_separator_calculations(cursor, db_connection): try: # Create test table - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_decimal_calc_test ( id INT PRIMARY KEY, value1 DECIMAL(10, 2), value2 DECIMAL(10, 2) ) - """) + """ + ) db_connection.commit() # Insert test data - cursor.execute(""" + cursor.execute( + """ INSERT INTO #pytest_decimal_calc_test VALUES (1, 10.25, 5.75) - """) + """ + ) db_connection.commit() # Test with default separator @@ -12657,12 +13010,14 @@ def test_executemany_with_uuids(cursor, db_connection): table_name = "#pytest_uuid_batch" try: cursor.execute(f"DROP TABLE IF EXISTS {table_name}") - cursor.execute(f""" + cursor.execute( + f""" CREATE TABLE {table_name} ( id UNIQUEIDENTIFIER, description NVARCHAR(50) ) - """) + """ + ) db_connection.commit() # Prepare test data: mix of UUIDs and None @@ -12810,11 +13165,13 @@ def test_date_string_parameter_binding(cursor, db_connection): table_name = "#pytest_date_string" try: drop_table_if_exists(cursor, table_name) - cursor.execute(f""" + cursor.execute( + f""" CREATE TABLE {table_name} ( a_column VARCHAR(20) ) - """) + """ + ) cursor.execute(f"INSERT INTO {table_name} (a_column) VALUES ('string1'), ('string2')") db_connection.commit() @@ -12841,11 +13198,13 @@ def test_time_string_parameter_binding(cursor, db_connection): table_name = "#pytest_time_string" try: drop_table_if_exists(cursor, table_name) - cursor.execute(f""" + cursor.execute( + f""" CREATE TABLE {table_name} ( time_col VARCHAR(22) ) - """) + """ + ) cursor.execute(f"INSERT INTO {table_name} (time_col) VALUES ('prefix_14:30:45_suffix')") db_connection.commit() @@ -12870,11 +13229,13 @@ def test_datetime_string_parameter_binding(cursor, db_connection): table_name = "#pytest_datetime_string" try: drop_table_if_exists(cursor, table_name) - cursor.execute(f""" + cursor.execute( + f""" CREATE TABLE {table_name} ( datetime_col VARCHAR(33) ) - """) + """ + ) cursor.execute( f"INSERT INTO {table_name} (datetime_col) VALUES ('prefix_2025-08-12T14:30:45_suffix')" ) @@ -13405,11 +13766,8 @@ def test_decimal_scientific_notation_to_varchar(cursor, db_connection, values, d ), f"{description}: Row {i} mismatch - expected {expected_val}, got {stored_val}" finally: - try: - cursor.execute(f"DROP TABLE {table_name}") - db_connection.commit() - except: - pass + cursor.execute(f"DROP TABLE IF EXISTS {table_name}") + db_connection.commit() SMALL_XML = "1" @@ -13511,13 +13869,212 @@ def test_xml_malformed_input(cursor, db_connection): ) db_connection.commit() - with pytest.raises(Exception): + with pytest.raises(mssql_python.Error): cursor.execute("INSERT INTO #pytest_xml_invalid (xml_col) VALUES (?);", INVALID_XML) finally: cursor.execute("DROP TABLE IF EXISTS #pytest_xml_invalid;") db_connection.commit() +# NOTE: Spatial type tests (geography, geometry, hierarchyid) have been moved to +# tests/test_017_spatial_types.py for better organization and maintainability. + + +# ==================== THREAD SAFETY TESTS ==================== + + +def test_column_metadata_thread_safety_concurrent_cursors(db_connection, conn_str): + """ + Test thread safety of _column_metadata with multiple cursors in concurrent threads. + + Validates: + - Multiple threads can safely create connections and cursors + - Each cursor's _column_metadata remains isolated and valid + - No race conditions between execute() setting metadata and fetchall() reading it + + This tests the _column_metadata instance attribute that is set during + _initialize_description() and read during _build_converter_map(). + + Note: Each thread uses its own connection because SQL Server doesn't support + Multiple Active Result Sets (MARS) by default. The test still validates that + _column_metadata works correctly under concurrent load. + """ + import threading + from mssql_python import connect + + # Track results and errors from each thread + results = {} + errors = [] + lock = threading.Lock() + + def worker(thread_id, table_suffix): + """Worker that creates connection, cursor, executes queries, and verifies metadata.""" + # Each thread gets its own connection (required - SQL Server doesn't support MARS) + thread_conn = None + cursor = None + try: + thread_conn = connect(conn_str) + cursor = thread_conn.cursor() + + try: + # Create a unique temp table for this thread + table_name = f"#pytest_thread_meta_{table_suffix}" + cursor.execute(f"DROP TABLE IF EXISTS {table_name};") + + # Create table with distinct column structure for this thread + cursor.execute( + f""" + CREATE TABLE {table_name} ( + thread_id INT, + col_{table_suffix}_a NVARCHAR(100), + col_{table_suffix}_b INT, + col_{table_suffix}_c FLOAT + ); + """ + ) + thread_conn.commit() + + # Insert test data + cursor.execute( + f""" + INSERT INTO {table_name} VALUES + ({thread_id}, 'data_{thread_id}_1', {thread_id * 100}, {thread_id * 1.5}), + ({thread_id}, 'data_{thread_id}_2', {thread_id * 200}, {thread_id * 2.5}); + """ + ) + thread_conn.commit() + + # Execute SELECT and verify description metadata is correct + cursor.execute(f"SELECT * FROM {table_name} ORDER BY col_{table_suffix}_b;") + + # Verify cursor has correct description for THIS query + desc = cursor.description + assert desc is not None, f"Thread {thread_id}: description should not be None" + assert len(desc) == 4, f"Thread {thread_id}: should have 4 columns" + + # Verify column names are correct for this thread's table + col_names = [d[0].lower() for d in desc] + expected_names = [ + "thread_id", + f"col_{table_suffix}_a", + f"col_{table_suffix}_b", + f"col_{table_suffix}_c", + ] + assert col_names == expected_names, f"Thread {thread_id}: column names should match" + + # Fetch all rows and verify data + rows = cursor.fetchall() + assert len(rows) == 2, f"Thread {thread_id}: should have 2 rows" + assert rows[0][0] == thread_id, f"Thread {thread_id}: thread_id column should match" + + # Verify _column_metadata is set (internal attribute) + assert ( + cursor._column_metadata is not None + ), f"Thread {thread_id}: _column_metadata should be set" + + # Clean up + cursor.execute(f"DROP TABLE IF EXISTS {table_name};") + thread_conn.commit() + + with lock: + results[thread_id] = { + "success": True, + "col_count": len(desc), + "row_count": len(rows), + } + + finally: + if cursor: + cursor.close() + if thread_conn: + thread_conn.close() + + except Exception as e: + with lock: + errors.append((thread_id, str(e))) + + # Create and start multiple threads + num_threads = 5 + threads = [] + + for i in range(num_threads): + t = threading.Thread(target=worker, args=(i, f"t{i}"), daemon=True) + threads.append(t) + + # Start all threads at roughly the same time + for t in threads: + t.start() + + # Wait for all threads to complete + for t in threads: + t.join(timeout=30) # 30 second timeout per thread + + # Verify threads actually finished (not just timed out) + hung_threads = [t for t in threads if t.is_alive()] + assert len(hung_threads) == 0, f"{len(hung_threads)} thread(s) still running after timeout" + + # Verify no errors occurred + assert len(errors) == 0, f"Thread errors occurred: {errors}" + + # Verify all threads completed successfully + assert len(results) == num_threads, f"Expected {num_threads} results, got {len(results)}" + + for thread_id, result in results.items(): + assert result["success"], f"Thread {thread_id} did not succeed" + assert result["col_count"] == 4, f"Thread {thread_id} had wrong column count" + assert result["row_count"] == 2, f"Thread {thread_id} had wrong row count" + + +def test_column_metadata_isolation_sequential_queries(cursor, db_connection): + """ + Test that _column_metadata is correctly updated between sequential queries. + + Verifies that each execute() call properly replaces the previous metadata, + ensuring no stale data leaks between queries. + """ + try: + # Query 1: Simple 2-column query + cursor.execute("SELECT 1 as col_a, 'hello' as col_b;") + desc1 = cursor.description + meta1 = cursor._column_metadata + cursor.fetchall() + + assert len(desc1) == 2, "First query should have 2 columns" + assert meta1 is not None, "_column_metadata should be set" + + # Query 2: Different structure - 4 columns + cursor.execute("SELECT 1 as x, 2 as y, 3 as z, 4 as w;") + desc2 = cursor.description + meta2 = cursor._column_metadata + cursor.fetchall() + + assert len(desc2) == 4, "Second query should have 4 columns" + assert meta2 is not None, "_column_metadata should be set" + + # Verify the metadata was replaced, not appended + assert len(meta2) == 4, "_column_metadata should have 4 entries" + assert meta1 is not meta2, "_column_metadata should be a new object" + + # Query 3: Back to 2 columns with different names + cursor.execute("SELECT 'test' as different_name, 42.5 as another_col;") + desc3 = cursor.description + meta3 = cursor._column_metadata + cursor.fetchall() + + assert len(desc3) == 2, "Third query should have 2 columns" + assert len(meta3) == 2, "_column_metadata should have 2 entries" + + # Verify column names are from the new query + col_names = [d[0].lower() for d in desc3] + assert col_names == [ + "different_name", + "another_col", + ], "Column names should be from third query" + + except Exception as e: + pytest.fail(f"Column metadata isolation test failed: {e}") + + # ==================== CODE COVERAGE TEST CASES ==================== @@ -13738,12 +14295,14 @@ def test_column_metadata_error_handling(cursor): """Test column metadata retrieval error handling (Lines 1156-1167).""" # Execute a complex query that might stress metadata retrieval - cursor.execute(""" + cursor.execute( + """ SELECT CAST(1 as INT) as int_col, CAST('test' as NVARCHAR(100)) as nvarchar_col, CAST(NEWID() as UNIQUEIDENTIFIER) as guid_col - """) + """ + ) # This should exercise the metadata retrieval code paths # If there are any errors, they should be logged but not crash @@ -13859,12 +14418,14 @@ def test_row_uuid_processing_with_braces(cursor, db_connection): drop_table_if_exists(cursor, "#pytest_uuid_braces") # Create table with UNIQUEIDENTIFIER column - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_uuid_braces ( id INT IDENTITY(1,1), guid_col UNIQUEIDENTIFIER ) - """) + """ + ) # Insert a GUID with braces (this is how SQL Server often returns them) test_guid = "12345678-1234-5678-9ABC-123456789ABC" @@ -13908,12 +14469,14 @@ def test_row_uuid_processing_sql_guid_type(cursor, db_connection): drop_table_if_exists(cursor, "#pytest_sql_guid_type") # Create table with UNIQUEIDENTIFIER column - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_sql_guid_type ( id INT, guid_col UNIQUEIDENTIFIER ) - """) + """ + ) # Insert test data test_guid = "ABCDEF12-3456-7890-ABCD-1234567890AB" @@ -13959,12 +14522,14 @@ def test_row_output_converter_overflow_error(cursor, db_connection): try: # Create a table with integer column drop_table_if_exists(cursor, "#pytest_overflow_test") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_overflow_test ( id INT, small_int TINYINT -- TINYINT can only hold 0-255 ) - """) + """ + ) # Insert a valid value first cursor.execute("INSERT INTO #pytest_overflow_test (id, small_int) VALUES (?, ?)", [1, 100]) @@ -14014,12 +14579,14 @@ def test_row_output_converter_general_exception(cursor, db_connection): try: # Create a table with string column drop_table_if_exists(cursor, "#pytest_exception_test") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_exception_test ( id INT, text_col VARCHAR(50) ) - """) + """ + ) # Insert test data cursor.execute( @@ -14030,7 +14597,9 @@ def test_row_output_converter_general_exception(cursor, db_connection): # Create a custom output converter that will raise a general exception def failing_converter(value): - if value == "test_value": + # This driver passes string values as UTF-16LE encoded bytes to output + # converters. This test uses the same encoding for the comparison. + if value == "test_value".encode("utf-16-le"): raise RuntimeError("Custom converter error for testing") return value @@ -14070,12 +14639,14 @@ def test_row_cursor_log_method_availability(cursor, db_connection): try: # Create test data drop_table_if_exists(cursor, "#pytest_log_check") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_log_check ( id INT, value_col INT ) - """) + """ + ) cursor.execute("INSERT INTO #pytest_log_check (id, value_col) VALUES (?, ?)", [1, 42]) db_connection.commit() @@ -14103,7 +14674,8 @@ def test_all_numeric_types_with_nulls(cursor, db_connection): """Test NULL handling for all numeric types to ensure processor functions handle NULLs correctly""" try: drop_table_if_exists(cursor, "#pytest_all_numeric_nulls") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_all_numeric_nulls ( int_col INT, bigint_col BIGINT, @@ -14113,7 +14685,8 @@ def test_all_numeric_types_with_nulls(cursor, db_connection): real_col REAL, float_col FLOAT ) - """) + """ + ) db_connection.commit() # Insert row with all NULLs @@ -14155,14 +14728,16 @@ def test_lob_data_types(cursor, db_connection): """Test LOB (Large Object) data types to ensure LOB fallback paths are exercised""" try: drop_table_if_exists(cursor, "#pytest_lob_test") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_lob_test ( id INT, text_lob VARCHAR(MAX), ntext_lob NVARCHAR(MAX), binary_lob VARBINARY(MAX) ) - """) + """ + ) db_connection.commit() # Create large data that will trigger LOB handling @@ -14195,12 +14770,14 @@ def test_lob_char_column_types(cursor, db_connection): """Test LOB fetching specifically for CHAR/VARCHAR columns (covers lines 3313-3314)""" try: drop_table_if_exists(cursor, "#pytest_lob_char") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_lob_char ( id INT, char_lob VARCHAR(MAX) ) - """) + """ + ) db_connection.commit() # Create data large enough to trigger LOB path (>8000 bytes) @@ -14227,12 +14804,14 @@ def test_lob_wchar_column_types(cursor, db_connection): """Test LOB fetching specifically for WCHAR/NVARCHAR columns (covers lines 3358-3359)""" try: drop_table_if_exists(cursor, "#pytest_lob_wchar") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_lob_wchar ( id INT, wchar_lob NVARCHAR(MAX) ) - """) + """ + ) db_connection.commit() # Create unicode data large enough to trigger LOB path (>4000 characters for NVARCHAR) @@ -14259,12 +14838,14 @@ def test_lob_binary_column_types(cursor, db_connection): """Test LOB fetching specifically for BINARY/VARBINARY columns (covers lines 3384-3385)""" try: drop_table_if_exists(cursor, "#pytest_lob_binary") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_lob_binary ( id INT, binary_lob VARBINARY(MAX) ) - """) + """ + ) db_connection.commit() # Create binary data large enough to trigger LOB path (>8000 bytes) @@ -14291,14 +14872,16 @@ def test_zero_length_complex_types(cursor, db_connection): """Test zero-length data for complex types (covers lines 3531-3533)""" try: drop_table_if_exists(cursor, "#pytest_zero_length") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_zero_length ( id INT, empty_varchar VARCHAR(100), empty_nvarchar NVARCHAR(100), empty_binary VARBINARY(100) ) - """) + """ + ) db_connection.commit() # Insert empty (non-NULL) values @@ -14326,12 +14909,14 @@ def test_guid_with_nulls(cursor, db_connection): """Test GUID type with NULL values""" try: drop_table_if_exists(cursor, "#pytest_guid_nulls") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_guid_nulls ( id INT, guid_col UNIQUEIDENTIFIER ) - """) + """ + ) db_connection.commit() # Insert NULL GUID @@ -14358,12 +14943,14 @@ def test_datetimeoffset_with_nulls(cursor, db_connection): """Test DATETIMEOFFSET type with NULL values""" try: drop_table_if_exists(cursor, "#pytest_dto_nulls") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_dto_nulls ( id INT, dto_col DATETIMEOFFSET ) - """) + """ + ) db_connection.commit() # Insert NULL DATETIMEOFFSET @@ -14390,12 +14977,14 @@ def test_decimal_conversion_edge_cases(cursor, db_connection): """Test DECIMAL/NUMERIC type conversion including edge cases""" try: drop_table_if_exists(cursor, "#pytest_decimal_edge") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_decimal_edge ( id INT, dec_col DECIMAL(18, 4) ) - """) + """ + ) db_connection.commit() # Insert various decimal values including edge cases @@ -14516,7 +15105,8 @@ def test_all_numeric_types_with_nulls(cursor, db_connection): """Test NULL handling for all numeric types to ensure processor functions handle NULLs correctly""" try: drop_table_if_exists(cursor, "#pytest_all_numeric_nulls") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_all_numeric_nulls ( int_col INT, bigint_col BIGINT, @@ -14526,7 +15116,8 @@ def test_all_numeric_types_with_nulls(cursor, db_connection): real_col REAL, float_col FLOAT ) - """) + """ + ) db_connection.commit() # Insert row with all NULLs @@ -14568,14 +15159,16 @@ def test_lob_data_types(cursor, db_connection): """Test LOB (Large Object) data types to ensure LOB fallback paths are exercised""" try: drop_table_if_exists(cursor, "#pytest_lob_test") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_lob_test ( id INT, text_lob VARCHAR(MAX), ntext_lob NVARCHAR(MAX), binary_lob VARBINARY(MAX) ) - """) + """ + ) db_connection.commit() # Create large data that will trigger LOB handling @@ -14608,12 +15201,14 @@ def test_lob_char_column_types(cursor, db_connection): """Test LOB fetching specifically for CHAR/VARCHAR columns (covers lines 3313-3314)""" try: drop_table_if_exists(cursor, "#pytest_lob_char") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_lob_char ( id INT, char_lob VARCHAR(MAX) ) - """) + """ + ) db_connection.commit() # Create data large enough to trigger LOB path (>8000 bytes) @@ -14640,12 +15235,14 @@ def test_lob_wchar_column_types(cursor, db_connection): """Test LOB fetching specifically for WCHAR/NVARCHAR columns (covers lines 3358-3359)""" try: drop_table_if_exists(cursor, "#pytest_lob_wchar") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_lob_wchar ( id INT, wchar_lob NVARCHAR(MAX) ) - """) + """ + ) db_connection.commit() # Create unicode data large enough to trigger LOB path (>4000 characters for NVARCHAR) @@ -14672,12 +15269,14 @@ def test_lob_binary_column_types(cursor, db_connection): """Test LOB fetching specifically for BINARY/VARBINARY columns (covers lines 3384-3385)""" try: drop_table_if_exists(cursor, "#pytest_lob_binary") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_lob_binary ( id INT, binary_lob VARBINARY(MAX) ) - """) + """ + ) db_connection.commit() # Create binary data large enough to trigger LOB path (>8000 bytes) @@ -14704,14 +15303,16 @@ def test_zero_length_complex_types(cursor, db_connection): """Test zero-length data for complex types (covers lines 3531-3533)""" try: drop_table_if_exists(cursor, "#pytest_zero_length") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_zero_length ( id INT, empty_varchar VARCHAR(100), empty_nvarchar NVARCHAR(100), empty_binary VARBINARY(100) ) - """) + """ + ) db_connection.commit() # Insert empty (non-NULL) values @@ -14739,12 +15340,14 @@ def test_guid_with_nulls(cursor, db_connection): """Test GUID type with NULL values""" try: drop_table_if_exists(cursor, "#pytest_guid_nulls") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_guid_nulls ( id INT, guid_col UNIQUEIDENTIFIER ) - """) + """ + ) db_connection.commit() # Insert NULL GUID @@ -14771,12 +15374,14 @@ def test_datetimeoffset_with_nulls(cursor, db_connection): """Test DATETIMEOFFSET type with NULL values""" try: drop_table_if_exists(cursor, "#pytest_dto_nulls") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_dto_nulls ( id INT, dto_col DATETIMEOFFSET ) - """) + """ + ) db_connection.commit() # Insert NULL DATETIMEOFFSET @@ -14803,12 +15408,14 @@ def test_decimal_conversion_edge_cases(cursor, db_connection): """Test DECIMAL/NUMERIC type conversion including edge cases""" try: drop_table_if_exists(cursor, "#pytest_decimal_edge") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #pytest_decimal_edge ( id INT, dec_col DECIMAL(18, 4) ) - """) + """ + ) db_connection.commit() # Insert various decimal values including edge cases @@ -14929,14 +15536,16 @@ def test_fetchall_with_integrity_constraint(cursor, db_connection): try: # Setup table with unique constraint cursor.execute("DROP TABLE IF EXISTS #uniq_cons_test") - cursor.execute(""" + cursor.execute( + """ CREATE TABLE #uniq_cons_test ( id INTEGER NOT NULL IDENTITY, data VARCHAR(50) NULL, PRIMARY KEY (id), UNIQUE (data) ) - """) + """ + ) # Insert initial row - should work cursor.execute( diff --git a/tests/test_017_spatial_types.py b/tests/test_017_spatial_types.py new file mode 100644 index 000000000..51dd75377 --- /dev/null +++ b/tests/test_017_spatial_types.py @@ -0,0 +1,1241 @@ +""" +SQL Server Spatial Types Tests (geography, geometry, hierarchyid) + +This module contains tests for SQL Server's spatial and hierarchical data types: +- geography: Geodetic (round-earth) spatial data for GPS coordinates, regions +- geometry: Planar (flat-earth) spatial data for 2D shapes, coordinates +- hierarchyid: Tree structure data for org charts, file systems, etc. + +Tests include: +- Basic insert/fetch operations +- Various geometry types (Point, LineString, Polygon, etc.) +- NULL value handling +- LOB/streaming for large spatial values +- Output converters +- cursor.description metadata +- Error handling for invalid data +- Binary parameter round-trip tests +""" + +import pytest +import mssql_python +from mssql_python.constants import ConstantsDDBC + + +# ==================== GEOGRAPHY TYPE TESTS ==================== + +# Test geography data - Well-Known Text (WKT) format +POINT_WKT = "POINT(-122.34900 47.65100)" # Seattle coordinates +LINESTRING_WKT = "LINESTRING(-122.360 47.656, -122.343 47.656)" +POLYGON_WKT = "POLYGON((-122.358 47.653, -122.348 47.649, -122.348 47.658, -122.358 47.653))" +MULTIPOINT_WKT = "MULTIPOINT((-122.34900 47.65100), (-122.11100 47.67700))" +COLLECTION_WKT = "GEOMETRYCOLLECTION(POINT(-122.34900 47.65100))" + + +def test_geography_basic_insert_fetch(cursor, db_connection): + """Test insert and fetch of a basic geography Point value.""" + try: + cursor.execute( + "CREATE TABLE #pytest_geography_basic (id INT PRIMARY KEY IDENTITY(1,1), geo_col GEOGRAPHY NULL);" + ) + db_connection.commit() + + # Insert using STGeomFromText + cursor.execute( + "INSERT INTO #pytest_geography_basic (geo_col) VALUES (geography::STGeomFromText(?, 4326));", + POINT_WKT, + ) + db_connection.commit() + + # Fetch as binary (default behavior) + row = cursor.execute("SELECT geo_col FROM #pytest_geography_basic;").fetchone() + assert row[0] is not None, "Geography value should not be None" + assert isinstance(row[0], bytes), "Geography should be returned as bytes" + assert len(row[0]) > 0, "Geography binary should have content" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_geography_basic;") + db_connection.commit() + + +def test_geography_as_text(cursor, db_connection): + """Test fetching geography as WKT text using STAsText().""" + try: + cursor.execute( + "CREATE TABLE #pytest_geography_text (id INT PRIMARY KEY IDENTITY(1,1), geo_col GEOGRAPHY NULL);" + ) + db_connection.commit() + + cursor.execute( + "INSERT INTO #pytest_geography_text (geo_col) VALUES (geography::STGeomFromText(?, 4326));", + POINT_WKT, + ) + db_connection.commit() + + # Fetch as text using STAsText() + row = cursor.execute( + "SELECT geo_col.STAsText() as wkt FROM #pytest_geography_text;" + ).fetchone() + # SQL Server normalizes WKT format (adds space, removes trailing zeros) + assert row[0] is not None, "Geography WKT should not be None" + assert row[0].startswith("POINT"), "Should be a POINT geometry" + assert "-122.349" in row[0] and "47.651" in row[0], "Should contain expected coordinates" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_geography_text;") + db_connection.commit() + + +def test_geography_various_types(cursor, db_connection): + """Test insert and fetch of various geography types.""" + try: + cursor.execute( + "CREATE TABLE #pytest_geography_types (id INT PRIMARY KEY IDENTITY(1,1), geo_col GEOGRAPHY NULL, description NVARCHAR(100));" + ) + db_connection.commit() + + test_cases = [ + (POINT_WKT, "Point", "POINT"), + (LINESTRING_WKT, "LineString", "LINESTRING"), + (POLYGON_WKT, "Polygon", "POLYGON"), + (MULTIPOINT_WKT, "MultiPoint", "MULTIPOINT"), + (COLLECTION_WKT, "GeometryCollection", "GEOMETRYCOLLECTION"), + ] + + for wkt, desc, _ in test_cases: + cursor.execute( + "INSERT INTO #pytest_geography_types (geo_col, description) VALUES (geography::STGeomFromText(?, 4326), ?);", + (wkt, desc), + ) + db_connection.commit() + + # Fetch all and verify + rows = cursor.execute( + "SELECT geo_col.STAsText() as wkt, description FROM #pytest_geography_types ORDER BY id;" + ).fetchall() + + for i, (_, expected_desc, expected_type) in enumerate(test_cases): + assert rows[i][0] is not None, f"{expected_desc} WKT should not be None" + assert rows[i][0].startswith( + expected_type + ), f"{expected_desc} should start with {expected_type}" + assert rows[i][1] == expected_desc, "Description should match" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_geography_types;") + db_connection.commit() + + +def test_geography_null_value(cursor, db_connection): + """Test insert and fetch of NULL geography values.""" + try: + cursor.execute( + "CREATE TABLE #pytest_geography_null (id INT PRIMARY KEY IDENTITY(1,1), geo_col GEOGRAPHY NULL);" + ) + db_connection.commit() + + cursor.execute("INSERT INTO #pytest_geography_null (geo_col) VALUES (?);", None) + db_connection.commit() + + row = cursor.execute("SELECT geo_col FROM #pytest_geography_null;").fetchone() + assert row[0] is None, "NULL geography should be returned as None" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_geography_null;") + db_connection.commit() + + +def test_geography_fetchone(cursor, db_connection): + """Test fetchone with geography columns.""" + try: + cursor.execute( + "CREATE TABLE #pytest_geography_fetchone (id INT PRIMARY KEY IDENTITY(1,1), geo_col GEOGRAPHY NULL);" + ) + db_connection.commit() + + cursor.execute( + "INSERT INTO #pytest_geography_fetchone (geo_col) VALUES (geography::STGeomFromText(?, 4326));", + POINT_WKT, + ) + db_connection.commit() + + cursor.execute("SELECT geo_col FROM #pytest_geography_fetchone;") + row = cursor.fetchone() + assert row is not None, "fetchone should return a row" + assert isinstance(row[0], bytes), "Geography should be bytes" + + # Verify no more rows + assert cursor.fetchone() is None, "Should be no more rows" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_geography_fetchone;") + db_connection.commit() + + +def test_geography_fetchmany(cursor, db_connection): + """Test fetchmany with geography columns.""" + try: + cursor.execute( + "CREATE TABLE #pytest_geography_fetchmany (id INT PRIMARY KEY IDENTITY(1,1), geo_col GEOGRAPHY NULL);" + ) + db_connection.commit() + + # Insert multiple rows + for i in range(5): + cursor.execute( + "INSERT INTO #pytest_geography_fetchmany (geo_col) VALUES (geography::STGeomFromText(?, 4326));", + POINT_WKT, + ) + db_connection.commit() + + cursor.execute("SELECT geo_col FROM #pytest_geography_fetchmany;") + rows = cursor.fetchmany(3) + assert isinstance(rows, list), "fetchmany should return a list" + assert len(rows) == 3, "fetchmany should return 3 rows" + for row in rows: + assert isinstance(row[0], bytes), "Each geography should be bytes" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_geography_fetchmany;") + db_connection.commit() + + +def test_geography_fetchall(cursor, db_connection): + """Test fetchall with geography columns.""" + try: + cursor.execute( + "CREATE TABLE #pytest_geography_fetchall (id INT PRIMARY KEY IDENTITY(1,1), geo_col GEOGRAPHY NULL);" + ) + db_connection.commit() + + # Insert multiple rows + num_rows = 10 + for i in range(num_rows): + cursor.execute( + "INSERT INTO #pytest_geography_fetchall (geo_col) VALUES (geography::STGeomFromText(?, 4326));", + POINT_WKT, + ) + db_connection.commit() + + cursor.execute("SELECT geo_col FROM #pytest_geography_fetchall;") + rows = cursor.fetchall() + assert isinstance(rows, list), "fetchall should return a list" + assert len(rows) == num_rows, f"fetchall should return {num_rows} rows" + for row in rows: + assert isinstance(row[0], bytes), "Each geography should be bytes" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_geography_fetchall;") + db_connection.commit() + + +def test_geography_executemany(cursor, db_connection): + """Test batch insert (executemany) of multiple geography values.""" + try: + cursor.execute( + "CREATE TABLE #pytest_geography_batch (id INT PRIMARY KEY IDENTITY(1,1), geo_col GEOGRAPHY NULL, name NVARCHAR(50));" + ) + db_connection.commit() + + test_data = [ + (POINT_WKT, "Point1"), + (LINESTRING_WKT, "Line1"), + (POLYGON_WKT, "Poly1"), + ] + + # Insert both geography (from WKT) and name using executemany + cursor.executemany( + "INSERT INTO #pytest_geography_batch (geo_col, name) " + "VALUES (geography::STGeomFromText(?, 4326), ?);", + [(wkt, name) for wkt, name in test_data], + ) + db_connection.commit() + + rows = cursor.execute( + "SELECT geo_col, name FROM #pytest_geography_batch ORDER BY id;" + ).fetchall() + assert len(rows) == len(test_data), "Should have inserted all rows" + for (expected_wkt, expected_name), (geo_value, name_value) in zip(test_data, rows): + # Geography values should be returned as bytes, consistent with other geography tests + assert isinstance(geo_value, bytes), "Each geography value should be bytes" + assert name_value == expected_name, "Names should round-trip correctly" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_geography_batch;") + db_connection.commit() + + +def test_geography_large_value_lob_streaming(cursor, db_connection): + """Test large geography values to verify LOB/streaming behavior.""" + try: + cursor.execute( + "CREATE TABLE #pytest_geography_large (id INT PRIMARY KEY IDENTITY(1,1), geo_col GEOGRAPHY NULL);" + ) + db_connection.commit() + + # Create a large but valid polygon with many vertices (not as extreme as 5000) + # This creates a polygon large enough to test LOB behavior but small enough to pass as parameter + large_polygon = ( + "POLYGON((" + + ", ".join([f"{-122.5 + i*0.0001} {47.5 + i*0.0001}" for i in range(100)]) + + ", -122.5 47.5))" + ) + + # Insert large polygon + cursor.execute( + "INSERT INTO #pytest_geography_large (geo_col) VALUES (geography::STGeomFromText(?, 4326));", + large_polygon, + ) + db_connection.commit() + + # Fetch the large geography + row = cursor.execute("SELECT geo_col FROM #pytest_geography_large;").fetchone() + assert row[0] is not None, "Large geography should not be None" + assert isinstance(row[0], bytes), "Large geography should be bytes" + # Just verify it's non-empty bytes (don't check for 8000 byte threshold as that varies) + assert len(row[0]) > 0, "Large geography should have content" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_geography_large;") + db_connection.commit() + + +def test_geography_mixed_with_other_types(cursor, db_connection): + """Test geography columns mixed with other data types.""" + try: + cursor.execute( + """CREATE TABLE #pytest_geography_mixed ( + id INT PRIMARY KEY IDENTITY(1,1), + name NVARCHAR(100), + geo_col GEOGRAPHY NULL, + created_date DATETIME, + score FLOAT + );""" + ) + db_connection.commit() + + cursor.execute( + """INSERT INTO #pytest_geography_mixed (name, geo_col, created_date, score) + VALUES (?, geography::STGeomFromText(?, 4326), ?, ?);""", + ("Seattle", POINT_WKT, "2025-11-26", 95.5), + ) + db_connection.commit() + + row = cursor.execute( + "SELECT name, geo_col, created_date, score FROM #pytest_geography_mixed;" + ).fetchone() + assert row[0] == "Seattle", "Name should match" + assert isinstance(row[1], bytes), "Geography should be bytes" + assert row[3] == 95.5, "Score should match" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_geography_mixed;") + db_connection.commit() + + +def test_geography_null_and_empty_mixed(cursor, db_connection): + """Test mix of NULL and valid geography values.""" + try: + cursor.execute( + "CREATE TABLE #pytest_geography_null_mixed (id INT PRIMARY KEY IDENTITY(1,1), geo_col GEOGRAPHY NULL);" + ) + db_connection.commit() + + cursor.execute("INSERT INTO #pytest_geography_null_mixed (geo_col) VALUES (?);", None) + cursor.execute( + "INSERT INTO #pytest_geography_null_mixed (geo_col) VALUES (geography::STGeomFromText(?, 4326));", + POINT_WKT, + ) + cursor.execute("INSERT INTO #pytest_geography_null_mixed (geo_col) VALUES (?);", None) + db_connection.commit() + + rows = cursor.execute( + "SELECT geo_col FROM #pytest_geography_null_mixed ORDER BY id;" + ).fetchall() + assert len(rows) == 3, "Should have 3 rows" + assert rows[0][0] is None, "First row should be NULL" + assert isinstance(rows[1][0], bytes), "Second row should be bytes" + assert rows[2][0] is None, "Third row should be NULL" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_geography_null_mixed;") + db_connection.commit() + + +def test_geography_with_srid(cursor, db_connection): + """Test geography with different SRID values.""" + try: + cursor.execute( + "CREATE TABLE #pytest_geography_srid (id INT PRIMARY KEY IDENTITY(1,1), geo_col GEOGRAPHY NULL, srid INT);" + ) + db_connection.commit() + + # WGS84 (most common) + cursor.execute( + "INSERT INTO #pytest_geography_srid (geo_col, srid) VALUES (geography::STGeomFromText(?, 4326), 4326);", + POINT_WKT, + ) + db_connection.commit() + + row = cursor.execute( + "SELECT geo_col.STSrid as srid FROM #pytest_geography_srid;" + ).fetchone() + assert row[0] == 4326, "SRID should be 4326 (WGS84)" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_geography_srid;") + db_connection.commit() + + +def test_geography_methods(cursor, db_connection): + """Test various geography methods (STDistance, STArea, etc.).""" + try: + cursor.execute( + "CREATE TABLE #pytest_geography_methods (id INT PRIMARY KEY IDENTITY(1,1), geo_col GEOGRAPHY NULL);" + ) + db_connection.commit() + + # Insert a polygon to test area + cursor.execute( + "INSERT INTO #pytest_geography_methods (geo_col) VALUES (geography::STGeomFromText(?, 4326));", + POLYGON_WKT, + ) + db_connection.commit() + + # Test STArea + row = cursor.execute( + "SELECT geo_col.STArea() as area FROM #pytest_geography_methods;" + ).fetchone() + assert row[0] is not None, "STArea should return a value" + assert row[0] > 0, "Polygon should have positive area" + + # Test STLength for linestring + cursor.execute( + "UPDATE #pytest_geography_methods SET geo_col = geography::STGeomFromText(?, 4326);", + LINESTRING_WKT, + ) + db_connection.commit() + + row = cursor.execute( + "SELECT geo_col.STLength() as length FROM #pytest_geography_methods;" + ).fetchone() + assert row[0] is not None, "STLength should return a value" + assert row[0] > 0, "LineString should have positive length" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_geography_methods;") + db_connection.commit() + + +def test_geography_output_converter(cursor, db_connection): + """Test using output converter to process geography data.""" + try: + cursor.execute( + "CREATE TABLE #pytest_geography_converter (id INT PRIMARY KEY IDENTITY(1,1), geo_col GEOGRAPHY NULL);" + ) + db_connection.commit() + + cursor.execute( + "INSERT INTO #pytest_geography_converter (geo_col) VALUES (geography::STGeomFromText(?, 4326));", + POINT_WKT, + ) + db_connection.commit() + + # Define a converter that tracks if it was called + converted = [] + + def geography_converter(value): + if value is None: + return None + converted.append(True) + return value # Just return as-is for this test + + # Register the converter for SQL_SS_UDT type + db_connection.add_output_converter(ConstantsDDBC.SQL_SS_UDT.value, geography_converter) + + try: + # Fetch data - converter should be called + row = cursor.execute("SELECT geo_col FROM #pytest_geography_converter;").fetchone() + assert len(converted) > 0, "Converter should have been called" + assert isinstance(row[0], bytes), "Geography should still be bytes" + finally: + # Clean up converter - always remove even if assertions fail + db_connection.remove_output_converter(ConstantsDDBC.SQL_SS_UDT.value) + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_geography_converter;") + db_connection.commit() + + +def test_geography_description_metadata(cursor, db_connection): + """Test cursor.description for geography columns.""" + try: + cursor.execute( + "CREATE TABLE #pytest_geography_desc (id INT PRIMARY KEY, geo_col GEOGRAPHY NULL);" + ) + db_connection.commit() + + cursor.execute("SELECT id, geo_col FROM #pytest_geography_desc;") + desc = cursor.description + + assert len(desc) == 2, "Should have 2 columns in description" + assert desc[0][0] == "id", "First column should be 'id'" + assert desc[1][0] == "geo_col", "Second column should be 'geo_col'" + + # Geography should be SQL_SS_UDT + assert ( + int(desc[1][1]) == ConstantsDDBC.SQL_SS_UDT.value + ), f"Geography column should have SQL_SS_UDT type code ({ConstantsDDBC.SQL_SS_UDT.value})" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_geography_desc;") + db_connection.commit() + + +def test_geography_complex_operations(cursor, db_connection): + """Test complex geography operations with multiple geometries.""" + try: + cursor.execute( + """CREATE TABLE #pytest_geography_complex ( + id INT PRIMARY KEY IDENTITY(1,1), + geo1 GEOGRAPHY NULL, + geo2 GEOGRAPHY NULL + );""" + ) + db_connection.commit() + + # Insert two points + point1 = "POINT(-122.34900 47.65100)" # Seattle + point2 = "POINT(-73.98500 40.75800)" # New York + + cursor.execute( + """INSERT INTO #pytest_geography_complex (geo1, geo2) + VALUES (geography::STGeomFromText(?, 4326), geography::STGeomFromText(?, 4326));""", + (point1, point2), + ) + db_connection.commit() + + # Calculate distance between points + row = cursor.execute( + """SELECT geo1.STDistance(geo2) as distance_meters + FROM #pytest_geography_complex;""" + ).fetchone() + + assert row[0] is not None, "Distance should be calculated" + assert row[0] > 0, "Distance should be positive" + # Seattle to New York is approximately 3,900 km = 3,900,000 meters + assert row[0] > 3000000, "Distance should be over 3,000 km" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_geography_complex;") + db_connection.commit() + + +def test_geography_binary_parameter_round_trip(cursor, db_connection): + """ + Test inserting and fetching geography data using binary parameters. + + This tests the round-trip of geography data when inserting the raw binary + representation directly (as opposed to using WKT text with STGeomFromText). + """ + try: + cursor.execute( + "CREATE TABLE #pytest_geography_binary (id INT PRIMARY KEY IDENTITY(1,1), geo_col GEOGRAPHY NULL);" + ) + db_connection.commit() + + # First, insert using WKT and fetch the binary representation + cursor.execute( + "INSERT INTO #pytest_geography_binary (geo_col) VALUES (geography::STGeomFromText(?, 4326));", + POINT_WKT, + ) + db_connection.commit() + + row = cursor.execute("SELECT geo_col FROM #pytest_geography_binary;").fetchone() + original_binary = row[0] + assert isinstance(original_binary, bytes), "Should get binary geography" + + # Now insert the binary representation back using STGeomFromWKB + # (SQL Server can accept Well-Known Binary format) + cursor.execute( + "INSERT INTO #pytest_geography_binary (geo_col) VALUES (geography::STGeomFromWKB(?, 4326));", + original_binary, + ) + db_connection.commit() + + # Fetch both and compare + rows = cursor.execute( + "SELECT geo_col, geo_col.STAsText() FROM #pytest_geography_binary ORDER BY id;" + ).fetchall() + assert len(rows) == 2, "Should have 2 rows" + + # Both should produce the same WKT text representation + wkt1 = rows[0][1] + wkt2 = rows[1][1] + # Normalize WKT for comparison (SQL Server may format slightly differently) + assert "POINT" in wkt1 and "POINT" in wkt2, "Both should be POINT geometries" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_geography_binary;") + db_connection.commit() + + +# ==================== GEOMETRY TYPE TESTS ==================== + +# Test geometry data - Well-Known Text (WKT) format (planar/2D coordinate system) +GEOMETRY_POINT_WKT = "POINT(100 200)" +GEOMETRY_LINESTRING_WKT = "LINESTRING(0 0, 100 100, 200 0)" +GEOMETRY_POLYGON_WKT = "POLYGON((0 0, 100 0, 100 100, 0 100, 0 0))" +GEOMETRY_MULTIPOINT_WKT = "MULTIPOINT((0 0), (100 100))" + + +def test_geometry_basic_insert_fetch(cursor, db_connection): + """Test insert and fetch of a basic geometry Point value.""" + try: + cursor.execute( + "CREATE TABLE #pytest_geometry_basic (id INT PRIMARY KEY IDENTITY(1,1), geom_col GEOMETRY NULL);" + ) + db_connection.commit() + + # Insert using STGeomFromText (no SRID needed for geometry) + cursor.execute( + "INSERT INTO #pytest_geometry_basic (geom_col) VALUES (geometry::STGeomFromText(?, 0));", + GEOMETRY_POINT_WKT, + ) + db_connection.commit() + + # Fetch as binary (default behavior) + row = cursor.execute("SELECT geom_col FROM #pytest_geometry_basic;").fetchone() + assert row[0] is not None, "Geometry value should not be None" + assert isinstance(row[0], bytes), "Geometry should be returned as bytes" + assert len(row[0]) > 0, "Geometry binary should have content" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_geometry_basic;") + db_connection.commit() + + +def test_geometry_as_text(cursor, db_connection): + """Test fetching geometry as WKT text using STAsText().""" + try: + cursor.execute( + "CREATE TABLE #pytest_geometry_text (id INT PRIMARY KEY IDENTITY(1,1), geom_col GEOMETRY NULL);" + ) + db_connection.commit() + + cursor.execute( + "INSERT INTO #pytest_geometry_text (geom_col) VALUES (geometry::STGeomFromText(?, 0));", + GEOMETRY_POINT_WKT, + ) + db_connection.commit() + + # Fetch as text using STAsText() + row = cursor.execute( + "SELECT geom_col.STAsText() as wkt FROM #pytest_geometry_text;" + ).fetchone() + assert row[0] is not None, "Geometry WKT should not be None" + assert row[0].startswith("POINT"), "Should be a POINT geometry" + assert "100" in row[0] and "200" in row[0], "Should contain expected coordinates" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_geometry_text;") + db_connection.commit() + + +def test_geometry_various_types(cursor, db_connection): + """Test insert and fetch of various geometry types.""" + try: + cursor.execute( + "CREATE TABLE #pytest_geometry_types (id INT PRIMARY KEY IDENTITY(1,1), geom_col GEOMETRY NULL, description NVARCHAR(100));" + ) + db_connection.commit() + + test_cases = [ + (GEOMETRY_POINT_WKT, "Point", "POINT"), + (GEOMETRY_LINESTRING_WKT, "LineString", "LINESTRING"), + (GEOMETRY_POLYGON_WKT, "Polygon", "POLYGON"), + (GEOMETRY_MULTIPOINT_WKT, "MultiPoint", "MULTIPOINT"), + ] + + for wkt, desc, _ in test_cases: + cursor.execute( + "INSERT INTO #pytest_geometry_types (geom_col, description) VALUES (geometry::STGeomFromText(?, 0), ?);", + (wkt, desc), + ) + db_connection.commit() + + # Fetch all and verify + rows = cursor.execute( + "SELECT geom_col.STAsText() as wkt, description FROM #pytest_geometry_types ORDER BY id;" + ).fetchall() + + for i, (_, expected_desc, expected_type) in enumerate(test_cases): + assert rows[i][0] is not None, f"{expected_desc} WKT should not be None" + assert rows[i][0].startswith( + expected_type + ), f"{expected_desc} should start with {expected_type}" + assert rows[i][1] == expected_desc, "Description should match" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_geometry_types;") + db_connection.commit() + + +def test_geometry_null_value(cursor, db_connection): + """Test insert and fetch of NULL geometry values.""" + try: + cursor.execute( + "CREATE TABLE #pytest_geometry_null (id INT PRIMARY KEY IDENTITY(1,1), geom_col GEOMETRY NULL);" + ) + db_connection.commit() + + cursor.execute("INSERT INTO #pytest_geometry_null (geom_col) VALUES (?);", None) + db_connection.commit() + + row = cursor.execute("SELECT geom_col FROM #pytest_geometry_null;").fetchone() + assert row[0] is None, "NULL geometry should be returned as None" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_geometry_null;") + db_connection.commit() + + +def test_geometry_fetchall(cursor, db_connection): + """Test fetchall with geometry columns.""" + try: + cursor.execute( + "CREATE TABLE #pytest_geometry_fetchall (id INT PRIMARY KEY IDENTITY(1,1), geom_col GEOMETRY NULL);" + ) + db_connection.commit() + + # Insert multiple rows + num_rows = 5 + for i in range(num_rows): + cursor.execute( + "INSERT INTO #pytest_geometry_fetchall (geom_col) VALUES (geometry::STGeomFromText(?, 0));", + GEOMETRY_POINT_WKT, + ) + db_connection.commit() + + cursor.execute("SELECT geom_col FROM #pytest_geometry_fetchall;") + rows = cursor.fetchall() + assert isinstance(rows, list), "fetchall should return a list" + assert len(rows) == num_rows, f"fetchall should return {num_rows} rows" + for row in rows: + assert isinstance(row[0], bytes), "Each geometry should be bytes" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_geometry_fetchall;") + db_connection.commit() + + +def test_geometry_methods(cursor, db_connection): + """Test various geometry methods (STArea, STLength, STDistance).""" + try: + cursor.execute( + "CREATE TABLE #pytest_geometry_methods (id INT PRIMARY KEY IDENTITY(1,1), geom_col GEOMETRY NULL);" + ) + db_connection.commit() + + # Insert a polygon to test area + cursor.execute( + "INSERT INTO #pytest_geometry_methods (geom_col) VALUES (geometry::STGeomFromText(?, 0));", + GEOMETRY_POLYGON_WKT, + ) + db_connection.commit() + + # Test STArea - 100x100 square = 10000 sq units + row = cursor.execute( + "SELECT geom_col.STArea() as area FROM #pytest_geometry_methods;" + ).fetchone() + assert row[0] is not None, "STArea should return a value" + assert row[0] == 10000, "Square should have area of 10000" + + # Test STLength for linestring + cursor.execute( + "UPDATE #pytest_geometry_methods SET geom_col = geometry::STGeomFromText(?, 0);", + GEOMETRY_LINESTRING_WKT, + ) + db_connection.commit() + + row = cursor.execute( + "SELECT geom_col.STLength() as length FROM #pytest_geometry_methods;" + ).fetchone() + assert row[0] is not None, "STLength should return a value" + assert row[0] > 0, "LineString should have positive length" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_geometry_methods;") + db_connection.commit() + + +def test_geometry_description_metadata(cursor, db_connection): + """Test cursor.description for geometry columns.""" + try: + cursor.execute( + "CREATE TABLE #pytest_geometry_desc (id INT PRIMARY KEY, geom_col GEOMETRY NULL);" + ) + db_connection.commit() + + cursor.execute("SELECT id, geom_col FROM #pytest_geometry_desc;") + desc = cursor.description + + assert len(desc) == 2, "Should have 2 columns in description" + assert desc[0][0] == "id", "First column should be 'id'" + assert desc[1][0] == "geom_col", "Second column should be 'geom_col'" + + # Geometry uses SQL_SS_UDT + assert ( + int(desc[1][1]) == ConstantsDDBC.SQL_SS_UDT.value + ), f"Geometry type should be SQL_SS_UDT ({ConstantsDDBC.SQL_SS_UDT.value})" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_geometry_desc;") + db_connection.commit() + + +def test_geometry_mixed_with_other_types(cursor, db_connection): + """Test geometry columns mixed with other data types.""" + try: + cursor.execute( + """CREATE TABLE #pytest_geometry_mixed ( + id INT PRIMARY KEY IDENTITY(1,1), + name NVARCHAR(100), + geom_col GEOMETRY NULL, + area FLOAT + );""" + ) + db_connection.commit() + + cursor.execute( + """INSERT INTO #pytest_geometry_mixed (name, geom_col, area) + VALUES (?, geometry::STGeomFromText(?, 0), ?);""", + ("Square", GEOMETRY_POLYGON_WKT, 10000.0), + ) + db_connection.commit() + + row = cursor.execute("SELECT name, geom_col, area FROM #pytest_geometry_mixed;").fetchone() + assert row[0] == "Square", "Name should match" + assert isinstance(row[1], bytes), "Geometry should be bytes" + assert row[2] == 10000.0, "Area should match" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_geometry_mixed;") + db_connection.commit() + + +def test_geometry_binary_parameter_round_trip(cursor, db_connection): + """ + Test inserting and fetching geometry data using binary parameters. + + This tests the round-trip of geometry data when inserting the raw binary + representation directly (as opposed to using WKT text with STGeomFromText). + """ + try: + cursor.execute( + "CREATE TABLE #pytest_geometry_binary (id INT PRIMARY KEY IDENTITY(1,1), geom_col GEOMETRY NULL);" + ) + db_connection.commit() + + # First, insert using WKT and fetch the binary representation + cursor.execute( + "INSERT INTO #pytest_geometry_binary (geom_col) VALUES (geometry::STGeomFromText(?, 0));", + GEOMETRY_POINT_WKT, + ) + db_connection.commit() + + row = cursor.execute("SELECT geom_col FROM #pytest_geometry_binary;").fetchone() + original_binary = row[0] + assert isinstance(original_binary, bytes), "Should get binary geometry" + + # Now insert the binary representation back using STGeomFromWKB + cursor.execute( + "INSERT INTO #pytest_geometry_binary (geom_col) VALUES (geometry::STGeomFromWKB(?, 0));", + original_binary, + ) + db_connection.commit() + + # Fetch both and compare + rows = cursor.execute( + "SELECT geom_col, geom_col.STAsText() FROM #pytest_geometry_binary ORDER BY id;" + ).fetchall() + assert len(rows) == 2, "Should have 2 rows" + + # Both should produce the same WKT text representation + wkt1 = rows[0][1] + wkt2 = rows[1][1] + assert "POINT" in wkt1 and "POINT" in wkt2, "Both should be POINT geometries" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_geometry_binary;") + db_connection.commit() + + +# ==================== HIERARCHYID TYPE TESTS ==================== + + +def test_hierarchyid_basic_insert_fetch(cursor, db_connection): + """Test insert and fetch of a basic hierarchyid value.""" + try: + cursor.execute( + "CREATE TABLE #pytest_hierarchyid_basic (id INT PRIMARY KEY IDENTITY(1,1), node HIERARCHYID NULL);" + ) + db_connection.commit() + + # Insert using hierarchyid::Parse + cursor.execute( + "INSERT INTO #pytest_hierarchyid_basic (node) VALUES (hierarchyid::Parse(?));", + "/1/2/3/", + ) + db_connection.commit() + + # Fetch as binary (default behavior) + row = cursor.execute("SELECT node FROM #pytest_hierarchyid_basic;").fetchone() + assert row[0] is not None, "Hierarchyid value should not be None" + assert isinstance(row[0], bytes), "Hierarchyid should be returned as bytes" + assert len(row[0]) > 0, "Hierarchyid binary should have content" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_hierarchyid_basic;") + db_connection.commit() + + +def test_hierarchyid_as_string(cursor, db_connection): + """Test fetching hierarchyid as string using ToString().""" + try: + cursor.execute( + "CREATE TABLE #pytest_hierarchyid_string (id INT PRIMARY KEY IDENTITY(1,1), node HIERARCHYID NULL);" + ) + db_connection.commit() + + cursor.execute( + "INSERT INTO #pytest_hierarchyid_string (node) VALUES (hierarchyid::Parse(?));", + "/1/2/3/", + ) + db_connection.commit() + + # Fetch as string using ToString() + row = cursor.execute( + "SELECT node.ToString() as path FROM #pytest_hierarchyid_string;" + ).fetchone() + assert row[0] is not None, "Hierarchyid string should not be None" + assert row[0] == "/1/2/3/", "Hierarchyid path should match" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_hierarchyid_string;") + db_connection.commit() + + +def test_hierarchyid_null_value(cursor, db_connection): + """Test insert and fetch of NULL hierarchyid values.""" + try: + cursor.execute( + "CREATE TABLE #pytest_hierarchyid_null (id INT PRIMARY KEY IDENTITY(1,1), node HIERARCHYID NULL);" + ) + db_connection.commit() + + cursor.execute("INSERT INTO #pytest_hierarchyid_null (node) VALUES (?);", None) + db_connection.commit() + + row = cursor.execute("SELECT node FROM #pytest_hierarchyid_null;").fetchone() + assert row[0] is None, "NULL hierarchyid should be returned as None" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_hierarchyid_null;") + db_connection.commit() + + +def test_hierarchyid_fetchall(cursor, db_connection): + """Test fetchall with hierarchyid columns.""" + try: + cursor.execute( + "CREATE TABLE #pytest_hierarchyid_fetchall (id INT PRIMARY KEY IDENTITY(1,1), node HIERARCHYID NULL);" + ) + db_connection.commit() + + # Insert multiple rows with different hierarchy levels + paths = ["/1/", "/1/1/", "/1/2/", "/2/", "/2/1/"] + for path in paths: + cursor.execute( + "INSERT INTO #pytest_hierarchyid_fetchall (node) VALUES (hierarchyid::Parse(?));", + path, + ) + db_connection.commit() + + cursor.execute("SELECT node FROM #pytest_hierarchyid_fetchall;") + rows = cursor.fetchall() + assert isinstance(rows, list), "fetchall should return a list" + assert len(rows) == len(paths), f"fetchall should return {len(paths)} rows" + for row in rows: + assert isinstance(row[0], bytes), "Each hierarchyid should be bytes" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_hierarchyid_fetchall;") + db_connection.commit() + + +def test_hierarchyid_methods(cursor, db_connection): + """Test various hierarchyid methods (GetLevel, GetAncestor, IsDescendantOf).""" + try: + cursor.execute( + "CREATE TABLE #pytest_hierarchyid_methods (id INT PRIMARY KEY IDENTITY(1,1), node HIERARCHYID NULL);" + ) + db_connection.commit() + + cursor.execute( + "INSERT INTO #pytest_hierarchyid_methods (node) VALUES (hierarchyid::Parse(?));", + "/1/2/3/", + ) + db_connection.commit() + + # Test GetLevel - /1/2/3/ is at level 3 + row = cursor.execute( + "SELECT node.GetLevel() as level FROM #pytest_hierarchyid_methods;" + ).fetchone() + assert row[0] == 3, "Level should be 3" + + # Test GetAncestor - parent of /1/2/3/ is /1/2/ + row = cursor.execute( + "SELECT node.GetAncestor(1).ToString() as parent FROM #pytest_hierarchyid_methods;" + ).fetchone() + assert row[0] == "/1/2/", "Parent should be /1/2/" + + # Test IsDescendantOf + row = cursor.execute( + "SELECT node.IsDescendantOf(hierarchyid::Parse('/1/')) as is_descendant FROM #pytest_hierarchyid_methods;" + ).fetchone() + assert row[0] == 1, "Node should be descendant of /1/" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_hierarchyid_methods;") + db_connection.commit() + + +def test_hierarchyid_description_metadata(cursor, db_connection): + """Test cursor.description for hierarchyid columns.""" + try: + cursor.execute( + "CREATE TABLE #pytest_hierarchyid_desc (id INT PRIMARY KEY, node HIERARCHYID NULL);" + ) + db_connection.commit() + + cursor.execute("SELECT id, node FROM #pytest_hierarchyid_desc;") + desc = cursor.description + + assert len(desc) == 2, "Should have 2 columns in description" + assert desc[0][0] == "id", "First column should be 'id'" + assert desc[1][0] == "node", "Second column should be 'node'" + + # Hierarchyid uses SQL_SS_UDT + assert ( + int(desc[1][1]) == ConstantsDDBC.SQL_SS_UDT.value + ), f"Hierarchyid type should be SQL_SS_UDT ({ConstantsDDBC.SQL_SS_UDT.value})" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_hierarchyid_desc;") + db_connection.commit() + + +def test_hierarchyid_tree_structure(cursor, db_connection): + """Test hierarchyid with a typical org chart tree structure.""" + try: + cursor.execute( + """CREATE TABLE #pytest_hierarchyid_tree ( + id INT PRIMARY KEY IDENTITY(1,1), + name NVARCHAR(100), + node HIERARCHYID NULL + );""" + ) + db_connection.commit() + + # Build an org chart + org_data = [ + ("CEO", "/"), + ("VP Engineering", "/1/"), + ("VP Sales", "/2/"), + ("Dev Manager", "/1/1/"), + ("QA Manager", "/1/2/"), + ("Senior Dev", "/1/1/1/"), + ("Junior Dev", "/1/1/2/"), + ] + + for name, path in org_data: + cursor.execute( + "INSERT INTO #pytest_hierarchyid_tree (name, node) VALUES (?, hierarchyid::Parse(?));", + (name, path), + ) + db_connection.commit() + + # Query all descendants of VP Engineering + rows = cursor.execute( + """SELECT name, node.ToString() as path + FROM #pytest_hierarchyid_tree + WHERE node.IsDescendantOf(hierarchyid::Parse('/1/')) = 1 + ORDER BY node;""" + ).fetchall() + + assert len(rows) == 5, "Should have 5 employees under VP Engineering (including self)" + assert rows[0][0] == "VP Engineering", "First should be VP Engineering" + + # Query direct reports of Dev Manager + rows = cursor.execute( + """SELECT name, node.ToString() as path + FROM #pytest_hierarchyid_tree + WHERE node.GetAncestor(1) = hierarchyid::Parse('/1/1/') + ORDER BY node;""" + ).fetchall() + + assert len(rows) == 2, "Dev Manager should have 2 direct reports" + names = [r[0] for r in rows] + assert "Senior Dev" in names and "Junior Dev" in names + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_hierarchyid_tree;") + db_connection.commit() + + +def test_hierarchyid_mixed_with_other_types(cursor, db_connection): + """Test hierarchyid columns mixed with other data types.""" + try: + cursor.execute( + """CREATE TABLE #pytest_hierarchyid_mixed ( + id INT PRIMARY KEY IDENTITY(1,1), + name NVARCHAR(100), + node HIERARCHYID NULL, + salary DECIMAL(10,2) + );""" + ) + db_connection.commit() + + cursor.execute( + "INSERT INTO #pytest_hierarchyid_mixed (name, node, salary) VALUES (?, hierarchyid::Parse(?), ?);", + ("Manager", "/1/", 75000.00), + ) + db_connection.commit() + + row = cursor.execute("SELECT name, node, salary FROM #pytest_hierarchyid_mixed;").fetchone() + assert row[0] == "Manager", "Name should match" + assert isinstance(row[1], bytes), "Hierarchyid should be bytes" + assert float(row[2]) == 75000.00, "Salary should match" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_hierarchyid_mixed;") + db_connection.commit() + + +# ==================== SPATIAL TYPE ERROR HANDLING TESTS ==================== + + +def test_geography_invalid_wkt_parsing(cursor, db_connection): + """ + Test behavior when geography conversion/parsing fails with invalid WKT. + + SQL Server raises an error when attempting to create a geography from + invalid Well-Known Text (WKT) format. + """ + cursor.execute( + "CREATE TABLE #pytest_geography_invalid (id INT PRIMARY KEY IDENTITY(1,1), geo_col GEOGRAPHY NULL);" + ) + db_connection.commit() + + try: + # Test 1: Invalid WKT format - missing closing parenthesis + invalid_wkt1 = "POINT(-122.34900 47.65100" # Missing closing paren + with pytest.raises(mssql_python.DatabaseError): + cursor.execute( + "INSERT INTO #pytest_geography_invalid (geo_col) VALUES (geography::STGeomFromText(?, 4326));", + invalid_wkt1, + ) + db_connection.rollback() + + # Test 2: Invalid WKT format - not a valid geometry type + invalid_wkt2 = "INVALIDTYPE(0 0)" + with pytest.raises(mssql_python.DatabaseError): + cursor.execute( + "INSERT INTO #pytest_geography_invalid (geo_col) VALUES (geography::STGeomFromText(?, 4326));", + invalid_wkt2, + ) + db_connection.rollback() + + # Test 3: Invalid coordinates for geography (latitude > 90) + # Geography uses geodetic coordinates where latitude must be between -90 and 90 + invalid_coords_wkt = "POINT(0 100)" # Latitude 100 is invalid + with pytest.raises(mssql_python.DatabaseError): + cursor.execute( + "INSERT INTO #pytest_geography_invalid (geo_col) VALUES (geography::STGeomFromText(?, 4326));", + invalid_coords_wkt, + ) + db_connection.rollback() + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_geography_invalid;") + db_connection.commit() + + +def test_geometry_invalid_wkt_parsing(cursor, db_connection): + """ + Test behavior when geometry conversion/parsing fails with invalid WKT. + + Geometry (planar coordinates) is more lenient than geography but still + requires valid WKT format. + """ + cursor.execute( + "CREATE TABLE #pytest_geometry_invalid (id INT PRIMARY KEY IDENTITY(1,1), geom_col GEOMETRY NULL);" + ) + db_connection.commit() + + try: + # Test 1: Invalid WKT format - missing coordinates + invalid_wkt1 = "POINT()" + with pytest.raises(mssql_python.DatabaseError): + cursor.execute( + "INSERT INTO #pytest_geometry_invalid (geom_col) VALUES (geometry::STGeomFromText(?, 0));", + invalid_wkt1, + ) + db_connection.rollback() + + # Test 2: Invalid WKT format - incomplete polygon (not closed) + invalid_wkt2 = "POLYGON((0 0, 100 0, 100 100))" # Not closed (first/last points differ) + with pytest.raises(mssql_python.DatabaseError): + cursor.execute( + "INSERT INTO #pytest_geometry_invalid (geom_col) VALUES (geometry::STGeomFromText(?, 0));", + invalid_wkt2, + ) + db_connection.rollback() + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_geometry_invalid;") + db_connection.commit() + + +def test_hierarchyid_invalid_parsing(cursor, db_connection): + """ + Test behavior when hierarchyid parsing fails with invalid path. + """ + cursor.execute( + "CREATE TABLE #pytest_hierarchyid_invalid (id INT PRIMARY KEY IDENTITY(1,1), node HIERARCHYID NULL);" + ) + db_connection.commit() + + try: + # Test 1: Invalid hierarchyid format - letters where numbers expected + invalid_path1 = "/abc/" + with pytest.raises(mssql_python.DatabaseError): + cursor.execute( + "INSERT INTO #pytest_hierarchyid_invalid (node) VALUES (hierarchyid::Parse(?));", + invalid_path1, + ) + db_connection.rollback() + + # Test 2: Invalid hierarchyid format - missing leading slash + invalid_path2 = "1/2/" # Missing leading slash + with pytest.raises(mssql_python.DatabaseError): + cursor.execute( + "INSERT INTO #pytest_hierarchyid_invalid (node) VALUES (hierarchyid::Parse(?));", + invalid_path2, + ) + db_connection.rollback() + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_hierarchyid_invalid;") + db_connection.commit()