From a68d3353470a604f57e4c2a4a1832a7b4f4171d6 Mon Sep 17 00:00:00 2001 From: Kaushik Iska Date: Wed, 1 Apr 2026 19:34:01 -0500 Subject: [PATCH] Add a pull-style streaming select API pg_clickhouse needs to consume select results one block at a time, but clickhouse-cpp only exposes a callback-driven select path today. That forces downstream users to layer coroutines or connection resets on top of the client when they need pull-style iteration. Add BeginSelect(), ReceiveSelectBlock(), and EndSelect() to mirror the existing multi-step insert workflow. The implementation reuses the existing query and packet handling code, keeps Query callbacks active for progress, profile, and log packets, and drains canceled queries so connections remain reusable. Add integration tests that cover full streaming iteration, preserved Query callbacks, early cleanup, end-of-stream reuse, and exception cleanup with subsequent reuse. --- clickhouse/client.cpp | 221 +++++++++++++++++++++++++++++++++++++----- clickhouse/client.h | 12 +++ ut/client_ut.cpp | 128 ++++++++++++++++++++++++ 3 files changed, 335 insertions(+), 26 deletions(-) diff --git a/clickhouse/client.cpp b/clickhouse/client.cpp index a8c8ea64..abdf16f5 100644 --- a/clickhouse/client.cpp +++ b/clickhouse/client.cpp @@ -1,4 +1,5 @@ #include "client.h" +#include "clickhouse/error_codes.h" #include "clickhouse/version.h" #include "protocol.h" @@ -157,6 +158,12 @@ class Client::Impl { void SelectWithExternalData(Query query, const ExternalTables& external_tables); + void BeginSelect(const Query& query); + + std::optional ReceiveSelectBlock(); + + void EndSelect(); + void SendCancel(); void Insert(const std::string& table_name, const std::string& query_id, const Block& block); @@ -208,6 +215,14 @@ class Client::Impl { void InitializeStreams(std::unique_ptr&& socket); + void EnsureIdle(const char* action) const; + + uint64_t DrainQueryResponse(const char* context); + + void ResetSelectState(); + + std::optional TakeSelectBlock(); + inline size_t GetConnectionAttempts() const { return options_.endpoints.size() * options_.send_retries; @@ -258,7 +273,12 @@ class Client::Impl { ServerInfo server_info_; - bool inserting_; + bool inserting_ = false; + bool selecting_ = false; + bool discarding_select_data_ = false; + bool select_finished_ = false; + std::optional select_block_; + std::unique_ptr select_query_; }; ClientOptions modifyClientOptions(ClientOptions opts) @@ -289,6 +309,11 @@ Client::Impl::Impl(const ClientOptions& opts, } Client::Impl::~Impl() { + try { + EndSelect(); + } catch (...) { + } + try { EndInsert(); } catch (...) { @@ -296,9 +321,7 @@ Client::Impl::~Impl() { } void Client::Impl::ExecuteQuery(Query query) { - if (inserting_) { - throw ValidationError("cannot execute query while inserting"); - } + EnsureIdle("execute query"); EnsureNull en(static_cast(&query), &events_); @@ -315,9 +338,7 @@ void Client::Impl::ExecuteQuery(Query query) { void Client::Impl::SelectWithExternalData(Query query, const ExternalTables& external_tables) { - if (inserting_) { - throw ValidationError("cannot execute query while inserting"); - } + EnsureIdle("execute query"); if (server_info_.revision < DBMS_MIN_REVISION_WITH_TEMPORARY_TABLES) { throw UnimplementedError("This version of ClickHouse server doesn't support temporary tables"); @@ -338,6 +359,89 @@ void Client::Impl::SelectWithExternalData(Query query, const ExternalTables& ext } } +void Client::Impl::BeginSelect(const Query& query) { + EnsureIdle("begin select"); + + if (options_.ping_before_query) { + RetryGuard([this]() { Ping(); }); + } + + select_query_ = std::make_unique(query); + select_block_.reset(); + select_finished_ = false; + selecting_ = true; + events_ = select_query_.get(); + + try { + SendQuery(*select_query_); + } catch (...) { + ResetSelectState(); + throw; + } +} + +std::optional Client::Impl::ReceiveSelectBlock() { + if (!selecting_) { + throw ValidationError("illegal call to ReceiveSelectBlock without first calling BeginSelect"); + } + + if (auto block = TakeSelectBlock()) { + return block; + } + + if (select_finished_) { + return std::nullopt; + } + + uint64_t server_packet = 0; + try { + while (ReceivePacket(&server_packet)) { + if (auto block = TakeSelectBlock()) { + return block; + } + } + } catch (...) { + select_finished_ = true; + throw; + } + + if (server_packet == ServerCodes::EndOfStream || server_packet == ServerCodes::Exception) { + select_finished_ = true; + return std::nullopt; + } + + select_finished_ = true; + throw ProtocolError(std::string{"unexpected packet from server while receiving select block, expected Data, EndOfStream or Exception, got: "} + + (server_packet ? std::to_string(server_packet) : "nothing")); +} + +void Client::Impl::EndSelect() { + if (!selecting_) { + return; + } + + if (select_finished_) { + ResetSelectState(); + return; + } + + try { + discarding_select_data_ = true; + SendCancel(); + DrainQueryResponse("receiving end of query"); + } catch (const ServerException& e) { + if (e.GetCode() != ErrorCodes::QUERY_WAS_CANCELLED) { + ResetSelectState(); + throw; + } + } catch (...) { + ResetSelectState(); + throw; + } + + ResetSelectState(); +} + void Client::Impl::SendBlockData(const Block& block) { if (compression_ == CompressionState::Enable) { std::unique_ptr compressed_output = std::make_unique(output_.get(), options_.max_compression_chunk_size, options_.compression_method); @@ -382,9 +486,7 @@ std::string NameToQueryString(const std::string &input) } void Client::Impl::Insert(const std::string& table_name, const std::string& query_id, const Block& block) { - if (inserting_) { - throw ValidationError("cannot execute query while inserting, use SendInsertData instead"); - } + EnsureIdle("insert"); if (options_.ping_before_query) { RetryGuard([this]() { Ping(); }); @@ -420,9 +522,7 @@ void Client::Impl::Insert(const std::string& table_name, const std::string& quer } Block Client::Impl::BeginInsert(Query query) { - if (inserting_) { - throw ValidationError("cannot execute query while inserting"); - } + EnsureIdle("begin insert"); EnsureNull en(static_cast(&query), &events_); @@ -469,23 +569,12 @@ void Client::Impl::EndInsert() { SendData(Block()); // Wait for EOS. - uint64_t eos_packet{0}; - while (ReceivePacket(&eos_packet)) { - ; - } - - if (eos_packet != ServerCodes::EndOfStream && eos_packet != ServerCodes::Exception - && eos_packet != ServerCodes::Log && options_.rethrow_exceptions) { - throw ProtocolError(std::string{"unexpected packet from server while receiving end of query, expected (expected Exception, EndOfStream or Log, got: "} - + (eos_packet ? std::to_string(eos_packet) : "nothing") + ")"); - } + DrainQueryResponse("receiving end of query"); inserting_ = false; } void Client::Impl::Ping() { - if (inserting_) { - throw ValidationError("cannot execute query while inserting"); - } + EnsureIdle("ping"); WireFormat::WriteUInt64(*output_, ClientCodes::Ping); output_->Flush(); @@ -501,6 +590,7 @@ void Client::Impl::Ping() { void Client::Impl::ResetConnection() { InitializeStreams(socket_factory_->connect(options_, current_endpoint_.value())); inserting_ = false; + ResetSelectState(); if (!Handshake()) { throw ProtocolError("fail to connect to " + options_.host); @@ -813,6 +903,17 @@ bool Client::Impl::ReceiveData() { } } + if (selecting_) { + if (discarding_select_data_) { + return true; + } + if (select_block_) { + throw ProtocolError("received unexpected data packet while previous select block is still pending"); + } + select_block_.emplace(std::move(block)); + return true; + } + if (events_) { events_->OnData(block); if (!events_->OnDataCancelable(block)) { @@ -876,6 +977,25 @@ void Client::Impl::SendCancel() { output_->Flush(); } +void Client::Impl::ResetSelectState() { + select_block_.reset(); + discarding_select_data_ = false; + select_finished_ = false; + selecting_ = false; + events_ = nullptr; + select_query_.reset(); +} + +std::optional Client::Impl::TakeSelectBlock() { + if (!select_block_) { + return std::nullopt; + } + + Block block = std::move(*select_block_); + select_block_.reset(); + return block; +} + void Client::Impl::SendQuery(const Query& query, bool finalize) { WireFormat::WriteUInt64(*output_, ClientCodes::Query); WireFormat::WriteString(*output_, query.GetQueryID()); @@ -1047,6 +1167,31 @@ void Client::Impl::InitializeStreams(std::unique_ptr&& socket) { std::swap(socket, socket_); } +void Client::Impl::EnsureIdle(const char* action) const { + if (inserting_) { + throw ValidationError(std::string("cannot ") + action + " while inserting"); + } + if (selecting_) { + throw ValidationError(std::string("cannot ") + action + " while selecting"); + } +} + +uint64_t Client::Impl::DrainQueryResponse(const char* context) { + uint64_t terminal_packet = 0; + while (ReceivePacket(&terminal_packet)) { + ; + } + + if (terminal_packet != ServerCodes::EndOfStream && terminal_packet != ServerCodes::Exception + && terminal_packet != ServerCodes::Log && options_.rethrow_exceptions) { + throw ProtocolError(std::string{"unexpected packet from server while "} + context + + ", expected Exception, EndOfStream or Log, got: " + + (terminal_packet ? std::to_string(terminal_packet) : "nothing")); + } + + return terminal_packet; +} + bool Client::Impl::SendHello() { WireFormat::WriteUInt64(*output_, ClientCodes::Hello); WireFormat::WriteString(*output_, std::string(CLIENT_NAME)); @@ -1196,6 +1341,30 @@ void Client::Select(const Query& query) { Execute(query); } +void Client::BeginSelect(const Query& query) { + impl_->BeginSelect(query); +} + +void Client::BeginSelect(const char* query) { + impl_->BeginSelect(Query(query)); +} + +void Client::BeginSelect(const std::string& query) { + impl_->BeginSelect(Query(query)); +} + +void Client::BeginSelect(const std::string& query, const std::string& query_id) { + impl_->BeginSelect(Query(query, query_id)); +} + +std::optional Client::ReceiveSelectBlock() { + return impl_->ReceiveSelectBlock(); +} + +void Client::EndSelect() { + impl_->EndSelect(); +} + void Client::SelectWithExternalData(const std::string& query, const ExternalTables& external_tables, SelectCallback cb) { impl_->SelectWithExternalData(Query(query).OnData(std::move(cb)), external_tables); } diff --git a/clickhouse/client.h b/clickhouse/client.h index 0486b1c8..15233b06 100644 --- a/clickhouse/client.h +++ b/clickhouse/client.h @@ -270,6 +270,18 @@ class Client { /// Alias for Execute. void Select(const Query& query); + /// Start a select query and consume result blocks with ReceiveSelectBlock. + void BeginSelect(const Query& query); + void BeginSelect(const char* query); + void BeginSelect(const std::string& query); + void BeginSelect(const std::string& query, const std::string& query_id); + + /// Receive the next block for a select session started by BeginSelect. + std::optional ReceiveSelectBlock(); + + /// End a select session started by BeginSelect. + void EndSelect(); + /// Intends for insert block of data into a table \p table_name. void Insert(const std::string& table_name, const Block& block); void Insert(const std::string& table_name, const std::string& query_id, const Block& block); diff --git a/ut/client_ut.cpp b/ut/client_ut.cpp index 6446fd3d..e2bff5e9 100644 --- a/ut/client_ut.cpp +++ b/ut/client_ut.cpp @@ -640,6 +640,134 @@ TEST_P(ClientCase, Numbers) { } } +TEST_P(ClientCase, StreamingSelect) { + try { + size_t num = 0; + + client_->BeginSelect("SELECT number, number FROM system.numbers LIMIT 1000"); + + while (auto block = client_->ReceiveSelectBlock()) { + if (block->GetColumnCount() == 0 || block->GetRowCount() == 0) { + continue; + } + + ASSERT_EQ(2u, block->GetColumnCount()); + + auto col = (*block)[0]->As(); + ASSERT_NE(nullptr, col); + + for (size_t i = 0; i < col->Size(); ++i, ++num) { + EXPECT_EQ(num, col->At(i)); + } + } + + EXPECT_FALSE(client_->ReceiveSelectBlock().has_value()); + EXPECT_EQ(1000u, num); + + client_->EndSelect(); + client_->EndSelect(); + } + catch (const clickhouse::ServerError & e) { + if (e.GetCode() == ErrorCodes::ACCESS_DENIED) + GTEST_SKIP() << e.what() << " : " << GetParam(); + else + throw; + } +} + +TEST_P(ClientCase, StreamingSelectPreservesQueryCallbacks) { + try { + Query query("SELECT * FROM system.numbers LIMIT 10;"); + + std::optional profile; + size_t total_rows = 0; + + query.OnProfile([&profile](const Profile& new_profile) { + profile = new_profile; + }); + + client_->BeginSelect(query); + while (auto block = client_->ReceiveSelectBlock()) { + total_rows += block->GetRowCount(); + } + client_->EndSelect(); + + EXPECT_EQ(10u, total_rows); + ASSERT_NE(profile, std::nullopt); + EXPECT_GE(profile->rows, 10u); + EXPECT_GE(profile->blocks, 1u); + EXPECT_GT(profile->bytes, 1u); + } + catch (const clickhouse::ServerError & e) { + if (e.GetCode() == ErrorCodes::ACCESS_DENIED) + GTEST_SKIP() << e.what() << " : " << GetParam(); + else + throw; + } +} + +TEST_P(ClientCase, StreamingSelectCanEndEarly) { + try { + client_->BeginSelect("SELECT number FROM system.numbers LIMIT 1000000"); + + ASSERT_TRUE(client_->ReceiveSelectBlock().has_value()); + + EXPECT_NO_THROW(client_->EndSelect()); + EXPECT_NO_THROW(client_->EndSelect()); + + size_t total_rows = 0; + client_->Select("SELECT 1", [&total_rows](const Block& next_block) { + total_rows += next_block.GetRowCount(); + }); + EXPECT_EQ(1u, total_rows); + } + catch (const clickhouse::ServerError & e) { + if (e.GetCode() == ErrorCodes::ACCESS_DENIED) + GTEST_SKIP() << e.what() << " : " << GetParam(); + else + throw; + } +} + +TEST_P(ClientCase, StreamingSelectAfterEosCanReuseConnection) { + try { + size_t total_rows = 0; + + client_->BeginSelect("SELECT number FROM system.numbers LIMIT 10"); + while (auto block = client_->ReceiveSelectBlock()) { + total_rows += block->GetRowCount(); + } + client_->EndSelect(); + + EXPECT_EQ(10u, total_rows); + + size_t reused_rows = 0; + client_->Select("SELECT 1", [&reused_rows](const Block& block) { + reused_rows += block.GetRowCount(); + }); + EXPECT_EQ(1u, reused_rows); + } + catch (const clickhouse::ServerError & e) { + if (e.GetCode() == ErrorCodes::ACCESS_DENIED) + GTEST_SKIP() << e.what() << " : " << GetParam(); + else + throw; + } +} + +TEST_P(ClientCase, StreamingSelectExceptionCanCleanupAndReuseConnection) { + client_->BeginSelect("SELECT missing_streaming_select_column FROM system.one"); + + EXPECT_THROW(client_->ReceiveSelectBlock(), ServerException); + EXPECT_NO_THROW(client_->EndSelect()); + + size_t reused_rows = 0; + client_->Select("SELECT 1", [&reused_rows](const Block& block) { + reused_rows += block.GetRowCount(); + }); + EXPECT_EQ(1u, reused_rows); +} + TEST_P(ClientCase, SimpleAggregateFunction) { const auto & server_info = client_->GetServerInfo(); if (versionNumber(server_info) < versionNumber(19, 9)) {