Skip to content

Commit 21d475b

Browse files
Address review: inline test methods, drop mixin pattern
Per review feedback from jprakash-db: - Remove mixin classes (LargeWideResultSetMixin, etc) — inline the test methods directly into the test classes in test_driver.py - Remove backward-compat LargeQueriesMixin alias (nothing uses it) - Rename _LargeQueryRowHelper — replaced entirely by inlining - Convert large_queries_mixin.py to just a fetch_rows() helper function Co-authored-by: Isaac Signed-off-by: Vikrant Puppala <vikrant.puppala@databricks.com>
1 parent ba8be42 commit 21d475b

File tree

2 files changed

+108
-170
lines changed

2 files changed

+108
-170
lines changed

tests/e2e/common/large_queries_mixin.py

Lines changed: 33 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -2,154 +2,43 @@
22
import math
33
import time
44

5-
import pytest
6-
75
log = logging.getLogger(__name__)
86

97

10-
class LargeQueriesFetchMixin:
11-
"""Shared fetch helper for large query test classes."""
12-
13-
def fetch_rows(self, cursor, row_count, fetchmany_size):
14-
"""
15-
A generator for rows. Fetches until the end or up to 5 minutes.
16-
"""
17-
# TODO: Remove fetchmany_size when we have fixed the performance issues with fetchone
18-
# in the Python client
19-
max_fetch_time = 5 * 60 # Fetch for at most 5 minutes
20-
21-
rows = self.get_some_rows(cursor, fetchmany_size)
22-
start_time = time.time()
23-
n = 0
24-
while rows:
25-
for row in rows:
26-
n += 1
27-
yield row
28-
if time.time() - start_time >= max_fetch_time:
29-
log.warning("Fetching rows timed out")
30-
break
31-
rows = self.get_some_rows(cursor, fetchmany_size)
32-
if not rows:
33-
# Read all the rows, row_count should match
34-
self.assertEqual(n, row_count)
35-
36-
num_fetches = max(math.ceil(n / 10000), 1)
37-
latency_ms = int((time.time() - start_time) * 1000 / num_fetches), 1
38-
print(
39-
"Fetched {} rows with an avg latency of {} per fetch, ".format(
40-
n, latency_ms
41-
)
42-
+ "assuming 10K fetch size."
8+
def fetch_rows(test_case, cursor, row_count, fetchmany_size):
9+
"""
10+
A generator for rows. Fetches until the end or up to 5 minutes.
11+
"""
12+
max_fetch_time = 5 * 60 # Fetch for at most 5 minutes
13+
14+
rows = _get_some_rows(cursor, fetchmany_size)
15+
start_time = time.time()
16+
n = 0
17+
while rows:
18+
for row in rows:
19+
n += 1
20+
yield row
21+
if time.time() - start_time >= max_fetch_time:
22+
log.warning("Fetching rows timed out")
23+
break
24+
rows = _get_some_rows(cursor, fetchmany_size)
25+
if not rows:
26+
# Read all the rows, row_count should match
27+
test_case.assertEqual(n, row_count)
28+
29+
num_fetches = max(math.ceil(n / 10000), 1)
30+
latency_ms = int((time.time() - start_time) * 1000 / num_fetches), 1
31+
print(
32+
"Fetched {} rows with an avg latency of {} per fetch, ".format(
33+
n, latency_ms
4334
)
44-
45-
46-
class LargeWideResultSetMixin(LargeQueriesFetchMixin):
47-
"""Test mixin for large wide result set queries."""
48-
49-
@pytest.mark.parametrize(
50-
"extra_params",
51-
[
52-
{},
53-
{"use_sea": True},
54-
],
35+
+ "assuming 10K fetch size."
5536
)
56-
@pytest.mark.parametrize("lz4_compression", [False, True])
57-
def test_query_with_large_wide_result_set(self, extra_params, lz4_compression):
58-
resultSize = 100 * 1000 * 1000 # 100 MB
59-
width = 8192 # B
60-
rows = resultSize // width
61-
cols = width // 36
62-
63-
# Set the fetchmany_size to get 10MB of data a go
64-
fetchmany_size = 10 * 1024 * 1024 // width
65-
# This is used by PyHive tests to determine the buffer size
66-
self.arraysize = 1000
67-
with self.cursor(extra_params) as cursor:
68-
cursor.connection.lz4_compression = lz4_compression
69-
uuids = ", ".join(["uuid() uuid{}".format(i) for i in range(cols)])
70-
cursor.execute(
71-
"SELECT id, {uuids} FROM RANGE({rows})".format(
72-
uuids=uuids, rows=rows
73-
)
74-
)
75-
assert lz4_compression == cursor.active_result_set.lz4_compressed
76-
for row_id, row in enumerate(
77-
self.fetch_rows(cursor, rows, fetchmany_size)
78-
):
79-
assert row[0] == row_id # Verify no rows are dropped in the middle.
80-
assert len(row[1]) == 36
81-
82-
83-
class LargeNarrowResultSetMixin(LargeQueriesFetchMixin):
84-
"""Test mixin for large narrow result set queries."""
85-
86-
@pytest.mark.parametrize(
87-
"extra_params",
88-
[
89-
{},
90-
{"use_sea": True},
91-
],
92-
)
93-
def test_query_with_large_narrow_result_set(self, extra_params):
94-
resultSize = 100 * 1000 * 1000 # 100 MB
95-
width = 8 # sizeof(long)
96-
rows = resultSize / width
97-
98-
# Set the fetchmany_size to get 10MB of data a go
99-
fetchmany_size = 10 * 1024 * 1024 // width
100-
# This is used by PyHive tests to determine the buffer size
101-
self.arraysize = 10000000
102-
with self.cursor(extra_params) as cursor:
103-
cursor.execute("SELECT * FROM RANGE({rows})".format(rows=rows))
104-
for row_id, row in enumerate(self.fetch_rows(cursor, rows, fetchmany_size)):
105-
assert row[0] == row_id
106-
107-
108-
class LongRunningQueryMixin:
109-
"""Test mixin for long running queries."""
110-
111-
@pytest.mark.parametrize(
112-
"extra_params",
113-
[
114-
{},
115-
{"use_sea": True},
116-
],
117-
)
118-
def test_long_running_query(self, extra_params):
119-
"""Incrementally increase query size until it takes at least 3 minutes,
120-
and asserts that the query completes successfully.
121-
"""
122-
minutes = 60
123-
min_duration = 1 * minutes
124-
125-
duration = -1
126-
scale0 = 10000
127-
scale_factor = 50
128-
with self.cursor(extra_params) as cursor:
129-
while duration < min_duration:
130-
assert scale_factor < 4096, "Detected infinite loop"
131-
start = time.time()
132-
133-
cursor.execute(
134-
"""SELECT count(*)
135-
FROM RANGE({scale}) x
136-
JOIN RANGE({scale0}) y
137-
ON from_unixtime(x.id * y.id, "yyyy-MM-dd") LIKE "%not%a%date%"
138-
""".format(
139-
scale=scale_factor * scale0, scale0=scale0
140-
)
141-
)
142-
143-
(n,) = cursor.fetchone()
144-
assert n == 0
145-
146-
duration = time.time() - start
147-
current_fraction = duration / min_duration
148-
print("Took {} s with scale factor={}".format(duration, scale_factor))
149-
# Extrapolate linearly to reach 3 min and add 50% padding to push over the limit
150-
scale_factor = math.ceil(1.5 * scale_factor / current_fraction)
15137

15238

153-
# Keep backward-compatible alias that combines all three
154-
class LargeQueriesMixin(LargeWideResultSetMixin, LargeNarrowResultSetMixin, LongRunningQueryMixin):
155-
pass
39+
def _get_some_rows(cursor, fetchmany_size):
40+
row = cursor.fetchone()
41+
if row:
42+
return [row]
43+
else:
44+
return None

tests/e2e/test_driver.py

Lines changed: 75 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,7 @@
3939
)
4040
from databricks.sql.thrift_api.TCLIService import ttypes
4141
from tests.e2e.common.core_tests import CoreTestMixin, SmokeTestMixin
42-
from tests.e2e.common.large_queries_mixin import (
43-
LargeWideResultSetMixin,
44-
LargeNarrowResultSetMixin,
45-
LongRunningQueryMixin,
46-
)
42+
from tests.e2e.common.large_queries_mixin import fetch_rows
4743
from tests.e2e.common.timestamp_tests import TimestampTestsMixin
4844
from tests.e2e.common.decimal_tests import DecimalTestsMixin
4945
from tests.e2e.common.retry_test_mixins import (
@@ -142,27 +138,80 @@ def assertEqualRowValues(self, actual, expected):
142138
assert act[i] == exp[i]
143139

144140

145-
class _LargeQueryRowHelper:
146-
"""Shared helper for fetching rows one at a time in large query tests."""
147-
148-
def get_some_rows(self, cursor, fetchmany_size):
149-
row = cursor.fetchone()
150-
if row:
151-
return [row]
152-
else:
153-
return None
154-
155-
156-
class TestPySQLLargeWideResultSet(PySQLPytestTestCase, _LargeQueryRowHelper, LargeWideResultSetMixin):
157-
pass
158-
159-
160-
class TestPySQLLargeNarrowResultSet(PySQLPytestTestCase, _LargeQueryRowHelper, LargeNarrowResultSetMixin):
161-
pass
162-
163-
164-
class TestPySQLLongRunningQuery(PySQLPytestTestCase, LongRunningQueryMixin):
165-
pass
141+
class TestPySQLLargeWideResultSet(PySQLPytestTestCase):
142+
@pytest.mark.parametrize("extra_params", [{}, {"use_sea": True}])
143+
@pytest.mark.parametrize("lz4_compression", [False, True])
144+
def test_query_with_large_wide_result_set(self, extra_params, lz4_compression):
145+
resultSize = 100 * 1000 * 1000 # 100 MB
146+
width = 8192 # B
147+
rows = resultSize // width
148+
cols = width // 36
149+
fetchmany_size = 10 * 1024 * 1024 // width
150+
self.arraysize = 1000
151+
with self.cursor(extra_params) as cursor:
152+
cursor.connection.lz4_compression = lz4_compression
153+
uuids = ", ".join(["uuid() uuid{}".format(i) for i in range(cols)])
154+
cursor.execute(
155+
"SELECT id, {uuids} FROM RANGE({rows})".format(
156+
uuids=uuids, rows=rows
157+
)
158+
)
159+
assert lz4_compression == cursor.active_result_set.lz4_compressed
160+
for row_id, row in enumerate(
161+
fetch_rows(self, cursor, rows, fetchmany_size)
162+
):
163+
assert row[0] == row_id
164+
assert len(row[1]) == 36
165+
166+
167+
class TestPySQLLargeNarrowResultSet(PySQLPytestTestCase):
168+
@pytest.mark.parametrize("extra_params", [{}, {"use_sea": True}])
169+
def test_query_with_large_narrow_result_set(self, extra_params):
170+
resultSize = 100 * 1000 * 1000 # 100 MB
171+
width = 8 # sizeof(long)
172+
rows = resultSize / width
173+
fetchmany_size = 10 * 1024 * 1024 // width
174+
self.arraysize = 10000000
175+
with self.cursor(extra_params) as cursor:
176+
cursor.execute("SELECT * FROM RANGE({rows})".format(rows=rows))
177+
for row_id, row in enumerate(
178+
fetch_rows(self, cursor, rows, fetchmany_size)
179+
):
180+
assert row[0] == row_id
181+
182+
183+
class TestPySQLLongRunningQuery(PySQLPytestTestCase):
184+
@pytest.mark.parametrize("extra_params", [{}, {"use_sea": True}])
185+
def test_long_running_query(self, extra_params):
186+
"""Incrementally increase query size until it takes at least 1 minute,
187+
and asserts that the query completes successfully.
188+
"""
189+
import math
190+
191+
minutes = 60
192+
min_duration = 1 * minutes
193+
duration = -1
194+
scale0 = 10000
195+
scale_factor = 50
196+
with self.cursor(extra_params) as cursor:
197+
while duration < min_duration:
198+
assert scale_factor < 4096, "Detected infinite loop"
199+
start = time.time()
200+
cursor.execute(
201+
"""SELECT count(*)
202+
FROM RANGE({scale}) x
203+
JOIN RANGE({scale0}) y
204+
ON from_unixtime(x.id * y.id, "yyyy-MM-dd") LIKE "%not%a%date%"
205+
""".format(
206+
scale=scale_factor * scale0, scale0=scale0
207+
)
208+
)
209+
(n,) = cursor.fetchone()
210+
assert n == 0
211+
duration = time.time() - start
212+
current_fraction = duration / min_duration
213+
print("Took {} s with scale factor={}".format(duration, scale_factor))
214+
scale_factor = math.ceil(1.5 * scale_factor / current_fraction)
166215

167216

168217
class TestPySQLCloudFetch(PySQLPytestTestCase):

0 commit comments

Comments
 (0)