diff --git a/clickhouse/client.cpp b/clickhouse/client.cpp index a8c8ea64..abdf16f5 100644 --- a/clickhouse/client.cpp +++ b/clickhouse/client.cpp @@ -1,4 +1,5 @@ #include "client.h" +#include "clickhouse/error_codes.h" #include "clickhouse/version.h" #include "protocol.h" @@ -157,6 +158,12 @@ class Client::Impl { void SelectWithExternalData(Query query, const ExternalTables& external_tables); + void BeginSelect(const Query& query); + + std::optional ReceiveSelectBlock(); + + void EndSelect(); + void SendCancel(); void Insert(const std::string& table_name, const std::string& query_id, const Block& block); @@ -208,6 +215,14 @@ class Client::Impl { void InitializeStreams(std::unique_ptr&& socket); + void EnsureIdle(const char* action) const; + + uint64_t DrainQueryResponse(const char* context); + + void ResetSelectState(); + + std::optional TakeSelectBlock(); + inline size_t GetConnectionAttempts() const { return options_.endpoints.size() * options_.send_retries; @@ -258,7 +273,12 @@ class Client::Impl { ServerInfo server_info_; - bool inserting_; + bool inserting_ = false; + bool selecting_ = false; + bool discarding_select_data_ = false; + bool select_finished_ = false; + std::optional select_block_; + std::unique_ptr select_query_; }; ClientOptions modifyClientOptions(ClientOptions opts) @@ -289,6 +309,11 @@ Client::Impl::Impl(const ClientOptions& opts, } Client::Impl::~Impl() { + try { + EndSelect(); + } catch (...) { + } + try { EndInsert(); } catch (...) { @@ -296,9 +321,7 @@ Client::Impl::~Impl() { } void Client::Impl::ExecuteQuery(Query query) { - if (inserting_) { - throw ValidationError("cannot execute query while inserting"); - } + EnsureIdle("execute query"); EnsureNull en(static_cast(&query), &events_); @@ -315,9 +338,7 @@ void Client::Impl::ExecuteQuery(Query query) { void Client::Impl::SelectWithExternalData(Query query, const ExternalTables& external_tables) { - if (inserting_) { - throw ValidationError("cannot execute query while inserting"); - } + EnsureIdle("execute query"); if (server_info_.revision < DBMS_MIN_REVISION_WITH_TEMPORARY_TABLES) { throw UnimplementedError("This version of ClickHouse server doesn't support temporary tables"); @@ -338,6 +359,89 @@ void Client::Impl::SelectWithExternalData(Query query, const ExternalTables& ext } } +void Client::Impl::BeginSelect(const Query& query) { + EnsureIdle("begin select"); + + if (options_.ping_before_query) { + RetryGuard([this]() { Ping(); }); + } + + select_query_ = std::make_unique(query); + select_block_.reset(); + select_finished_ = false; + selecting_ = true; + events_ = select_query_.get(); + + try { + SendQuery(*select_query_); + } catch (...) { + ResetSelectState(); + throw; + } +} + +std::optional Client::Impl::ReceiveSelectBlock() { + if (!selecting_) { + throw ValidationError("illegal call to ReceiveSelectBlock without first calling BeginSelect"); + } + + if (auto block = TakeSelectBlock()) { + return block; + } + + if (select_finished_) { + return std::nullopt; + } + + uint64_t server_packet = 0; + try { + while (ReceivePacket(&server_packet)) { + if (auto block = TakeSelectBlock()) { + return block; + } + } + } catch (...) { + select_finished_ = true; + throw; + } + + if (server_packet == ServerCodes::EndOfStream || server_packet == ServerCodes::Exception) { + select_finished_ = true; + return std::nullopt; + } + + select_finished_ = true; + throw ProtocolError(std::string{"unexpected packet from server while receiving select block, expected Data, EndOfStream or Exception, got: "} + + (server_packet ? std::to_string(server_packet) : "nothing")); +} + +void Client::Impl::EndSelect() { + if (!selecting_) { + return; + } + + if (select_finished_) { + ResetSelectState(); + return; + } + + try { + discarding_select_data_ = true; + SendCancel(); + DrainQueryResponse("receiving end of query"); + } catch (const ServerException& e) { + if (e.GetCode() != ErrorCodes::QUERY_WAS_CANCELLED) { + ResetSelectState(); + throw; + } + } catch (...) { + ResetSelectState(); + throw; + } + + ResetSelectState(); +} + void Client::Impl::SendBlockData(const Block& block) { if (compression_ == CompressionState::Enable) { std::unique_ptr compressed_output = std::make_unique(output_.get(), options_.max_compression_chunk_size, options_.compression_method); @@ -382,9 +486,7 @@ std::string NameToQueryString(const std::string &input) } void Client::Impl::Insert(const std::string& table_name, const std::string& query_id, const Block& block) { - if (inserting_) { - throw ValidationError("cannot execute query while inserting, use SendInsertData instead"); - } + EnsureIdle("insert"); if (options_.ping_before_query) { RetryGuard([this]() { Ping(); }); @@ -420,9 +522,7 @@ void Client::Impl::Insert(const std::string& table_name, const std::string& quer } Block Client::Impl::BeginInsert(Query query) { - if (inserting_) { - throw ValidationError("cannot execute query while inserting"); - } + EnsureIdle("begin insert"); EnsureNull en(static_cast(&query), &events_); @@ -469,23 +569,12 @@ void Client::Impl::EndInsert() { SendData(Block()); // Wait for EOS. - uint64_t eos_packet{0}; - while (ReceivePacket(&eos_packet)) { - ; - } - - if (eos_packet != ServerCodes::EndOfStream && eos_packet != ServerCodes::Exception - && eos_packet != ServerCodes::Log && options_.rethrow_exceptions) { - throw ProtocolError(std::string{"unexpected packet from server while receiving end of query, expected (expected Exception, EndOfStream or Log, got: "} - + (eos_packet ? std::to_string(eos_packet) : "nothing") + ")"); - } + DrainQueryResponse("receiving end of query"); inserting_ = false; } void Client::Impl::Ping() { - if (inserting_) { - throw ValidationError("cannot execute query while inserting"); - } + EnsureIdle("ping"); WireFormat::WriteUInt64(*output_, ClientCodes::Ping); output_->Flush(); @@ -501,6 +590,7 @@ void Client::Impl::Ping() { void Client::Impl::ResetConnection() { InitializeStreams(socket_factory_->connect(options_, current_endpoint_.value())); inserting_ = false; + ResetSelectState(); if (!Handshake()) { throw ProtocolError("fail to connect to " + options_.host); @@ -813,6 +903,17 @@ bool Client::Impl::ReceiveData() { } } + if (selecting_) { + if (discarding_select_data_) { + return true; + } + if (select_block_) { + throw ProtocolError("received unexpected data packet while previous select block is still pending"); + } + select_block_.emplace(std::move(block)); + return true; + } + if (events_) { events_->OnData(block); if (!events_->OnDataCancelable(block)) { @@ -876,6 +977,25 @@ void Client::Impl::SendCancel() { output_->Flush(); } +void Client::Impl::ResetSelectState() { + select_block_.reset(); + discarding_select_data_ = false; + select_finished_ = false; + selecting_ = false; + events_ = nullptr; + select_query_.reset(); +} + +std::optional Client::Impl::TakeSelectBlock() { + if (!select_block_) { + return std::nullopt; + } + + Block block = std::move(*select_block_); + select_block_.reset(); + return block; +} + void Client::Impl::SendQuery(const Query& query, bool finalize) { WireFormat::WriteUInt64(*output_, ClientCodes::Query); WireFormat::WriteString(*output_, query.GetQueryID()); @@ -1047,6 +1167,31 @@ void Client::Impl::InitializeStreams(std::unique_ptr&& socket) { std::swap(socket, socket_); } +void Client::Impl::EnsureIdle(const char* action) const { + if (inserting_) { + throw ValidationError(std::string("cannot ") + action + " while inserting"); + } + if (selecting_) { + throw ValidationError(std::string("cannot ") + action + " while selecting"); + } +} + +uint64_t Client::Impl::DrainQueryResponse(const char* context) { + uint64_t terminal_packet = 0; + while (ReceivePacket(&terminal_packet)) { + ; + } + + if (terminal_packet != ServerCodes::EndOfStream && terminal_packet != ServerCodes::Exception + && terminal_packet != ServerCodes::Log && options_.rethrow_exceptions) { + throw ProtocolError(std::string{"unexpected packet from server while "} + context + + ", expected Exception, EndOfStream or Log, got: " + + (terminal_packet ? std::to_string(terminal_packet) : "nothing")); + } + + return terminal_packet; +} + bool Client::Impl::SendHello() { WireFormat::WriteUInt64(*output_, ClientCodes::Hello); WireFormat::WriteString(*output_, std::string(CLIENT_NAME)); @@ -1196,6 +1341,30 @@ void Client::Select(const Query& query) { Execute(query); } +void Client::BeginSelect(const Query& query) { + impl_->BeginSelect(query); +} + +void Client::BeginSelect(const char* query) { + impl_->BeginSelect(Query(query)); +} + +void Client::BeginSelect(const std::string& query) { + impl_->BeginSelect(Query(query)); +} + +void Client::BeginSelect(const std::string& query, const std::string& query_id) { + impl_->BeginSelect(Query(query, query_id)); +} + +std::optional Client::ReceiveSelectBlock() { + return impl_->ReceiveSelectBlock(); +} + +void Client::EndSelect() { + impl_->EndSelect(); +} + void Client::SelectWithExternalData(const std::string& query, const ExternalTables& external_tables, SelectCallback cb) { impl_->SelectWithExternalData(Query(query).OnData(std::move(cb)), external_tables); } diff --git a/clickhouse/client.h b/clickhouse/client.h index 0486b1c8..15233b06 100644 --- a/clickhouse/client.h +++ b/clickhouse/client.h @@ -270,6 +270,18 @@ class Client { /// Alias for Execute. void Select(const Query& query); + /// Start a select query and consume result blocks with ReceiveSelectBlock. + void BeginSelect(const Query& query); + void BeginSelect(const char* query); + void BeginSelect(const std::string& query); + void BeginSelect(const std::string& query, const std::string& query_id); + + /// Receive the next block for a select session started by BeginSelect. + std::optional ReceiveSelectBlock(); + + /// End a select session started by BeginSelect. + void EndSelect(); + /// Intends for insert block of data into a table \p table_name. void Insert(const std::string& table_name, const Block& block); void Insert(const std::string& table_name, const std::string& query_id, const Block& block); diff --git a/ut/client_ut.cpp b/ut/client_ut.cpp index 6446fd3d..e2bff5e9 100644 --- a/ut/client_ut.cpp +++ b/ut/client_ut.cpp @@ -640,6 +640,134 @@ TEST_P(ClientCase, Numbers) { } } +TEST_P(ClientCase, StreamingSelect) { + try { + size_t num = 0; + + client_->BeginSelect("SELECT number, number FROM system.numbers LIMIT 1000"); + + while (auto block = client_->ReceiveSelectBlock()) { + if (block->GetColumnCount() == 0 || block->GetRowCount() == 0) { + continue; + } + + ASSERT_EQ(2u, block->GetColumnCount()); + + auto col = (*block)[0]->As(); + ASSERT_NE(nullptr, col); + + for (size_t i = 0; i < col->Size(); ++i, ++num) { + EXPECT_EQ(num, col->At(i)); + } + } + + EXPECT_FALSE(client_->ReceiveSelectBlock().has_value()); + EXPECT_EQ(1000u, num); + + client_->EndSelect(); + client_->EndSelect(); + } + catch (const clickhouse::ServerError & e) { + if (e.GetCode() == ErrorCodes::ACCESS_DENIED) + GTEST_SKIP() << e.what() << " : " << GetParam(); + else + throw; + } +} + +TEST_P(ClientCase, StreamingSelectPreservesQueryCallbacks) { + try { + Query query("SELECT * FROM system.numbers LIMIT 10;"); + + std::optional profile; + size_t total_rows = 0; + + query.OnProfile([&profile](const Profile& new_profile) { + profile = new_profile; + }); + + client_->BeginSelect(query); + while (auto block = client_->ReceiveSelectBlock()) { + total_rows += block->GetRowCount(); + } + client_->EndSelect(); + + EXPECT_EQ(10u, total_rows); + ASSERT_NE(profile, std::nullopt); + EXPECT_GE(profile->rows, 10u); + EXPECT_GE(profile->blocks, 1u); + EXPECT_GT(profile->bytes, 1u); + } + catch (const clickhouse::ServerError & e) { + if (e.GetCode() == ErrorCodes::ACCESS_DENIED) + GTEST_SKIP() << e.what() << " : " << GetParam(); + else + throw; + } +} + +TEST_P(ClientCase, StreamingSelectCanEndEarly) { + try { + client_->BeginSelect("SELECT number FROM system.numbers LIMIT 1000000"); + + ASSERT_TRUE(client_->ReceiveSelectBlock().has_value()); + + EXPECT_NO_THROW(client_->EndSelect()); + EXPECT_NO_THROW(client_->EndSelect()); + + size_t total_rows = 0; + client_->Select("SELECT 1", [&total_rows](const Block& next_block) { + total_rows += next_block.GetRowCount(); + }); + EXPECT_EQ(1u, total_rows); + } + catch (const clickhouse::ServerError & e) { + if (e.GetCode() == ErrorCodes::ACCESS_DENIED) + GTEST_SKIP() << e.what() << " : " << GetParam(); + else + throw; + } +} + +TEST_P(ClientCase, StreamingSelectAfterEosCanReuseConnection) { + try { + size_t total_rows = 0; + + client_->BeginSelect("SELECT number FROM system.numbers LIMIT 10"); + while (auto block = client_->ReceiveSelectBlock()) { + total_rows += block->GetRowCount(); + } + client_->EndSelect(); + + EXPECT_EQ(10u, total_rows); + + size_t reused_rows = 0; + client_->Select("SELECT 1", [&reused_rows](const Block& block) { + reused_rows += block.GetRowCount(); + }); + EXPECT_EQ(1u, reused_rows); + } + catch (const clickhouse::ServerError & e) { + if (e.GetCode() == ErrorCodes::ACCESS_DENIED) + GTEST_SKIP() << e.what() << " : " << GetParam(); + else + throw; + } +} + +TEST_P(ClientCase, StreamingSelectExceptionCanCleanupAndReuseConnection) { + client_->BeginSelect("SELECT missing_streaming_select_column FROM system.one"); + + EXPECT_THROW(client_->ReceiveSelectBlock(), ServerException); + EXPECT_NO_THROW(client_->EndSelect()); + + size_t reused_rows = 0; + client_->Select("SELECT 1", [&reused_rows](const Block& block) { + reused_rows += block.GetRowCount(); + }); + EXPECT_EQ(1u, reused_rows); +} + TEST_P(ClientCase, SimpleAggregateFunction) { const auto & server_info = client_->GetServerInfo(); if (versionNumber(server_info) < versionNumber(19, 9)) {