Skip to content

Commit 69aab85

Browse files
committed
tfbuilder: avoid ucx tag send/receive operation on non-rma-mapped memory
1 parent a02f8f6 commit 69aab85

File tree

3 files changed

+120
-26
lines changed

3 files changed

+120
-26
lines changed

src/TfBuilder/TfBuilderInput.cxx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ void TfBuilderInput::StfDeserializingThread()
238238
while (mState == RUNNING) {
239239

240240
std::unique_lock<std::mutex> lQueueLock(mStfMergerQueueLock);
241-
mStfMergerCondition.wait_for(lQueueLock, 10ms, [this]{ return mStfMergerRun.load(); });
241+
mStfMergerCondition.wait_for(lQueueLock, 5ms, [this]{ return mStfMergerRun.load(); });
242242
mStfMergerRun = false;
243243

244244
if (mStfMergeMap.empty()) {

src/TfBuilder/TfBuilderInputUCX.cxx

Lines changed: 51 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -150,14 +150,6 @@ bool TfBuilderInputUCX::start()
150150
return false;
151151
}
152152

153-
// start receiving thread pool
154-
for (unsigned i = 0; i < mThreadPoolSize; i++) {
155-
std::string lThreadName = "tfb_ucx_";
156-
lThreadName += std::to_string(i);
157-
158-
mThreadPool.emplace_back(std::move(create_thread_member(lThreadName.c_str(),
159-
&TfBuilderInputUCX::DataHandlerThread, this, i)));
160-
}
161153

162154
// Create the listener
163155
// Run the connection callback with pointer to us
@@ -284,6 +276,18 @@ bool TfBuilderInputUCX::map_data_region()
284276
mTimeFrameBuilder.mMemRes.mDataMemRes->set_ucx_address(lUcxMemPtr);
285277
ucp_data_region_set = true;
286278
DDDLOG("TfBuilderInputUCX::map_data_region(): mapped the data region size={}", lOrigSize);
279+
280+
// start receiving thread pool
281+
// NOTE: This must come after the region mapping. Threads are using mapped addresses
282+
for (unsigned i = 0; i < mThreadPoolSize; i++) {
283+
std::string lThreadName = "tfb_ucx_";
284+
lThreadName += std::to_string(i);
285+
286+
mThreadPool.emplace_back(std::move(
287+
create_thread_member(lThreadName.c_str(), &TfBuilderInputUCX::DataHandlerThread, this, i))
288+
);
289+
}
290+
287291
return true;
288292
}
289293

@@ -354,9 +358,17 @@ void TfBuilderInputUCX::DataHandlerThread(const unsigned pThreadIdx)
354358
// Deserialization object (stf ID)
355359
IovDeserializer lStfReceiver(mTimeFrameBuilder);
356360

357-
{ // warm up FMQ region caches for this thread
358-
mTimeFrameBuilder.newDataMessage(1);
359-
}
361+
// memory for meta-tag receive; increased later if needed
362+
std::uint64_t lMetaMemSize = std::uint64_t(2) << 20;
363+
FairMQMessagePtr lMetaMemMsg = nullptr;
364+
void *lMetaMemPtr = nullptr;
365+
366+
auto fAllocateMetaMessage = [&](std::uint64_t pSize) {
367+
lMetaMemSize = pSize;
368+
lMetaMemMsg = mTimeFrameBuilder.newDataMessage(pSize);
369+
lMetaMemPtr = mTimeFrameBuilder.mMemRes.mDataMemRes->get_ucx_ptr(lMetaMemMsg->GetData());
370+
};
371+
fAllocateMetaMessage(lMetaMemSize);
360372

361373
std::optional<std::string> lStfSenderIdOpt;
362374
std::vector<void*> lTxgPtrs;
@@ -392,19 +404,32 @@ void TfBuilderInputUCX::DataHandlerThread(const unsigned pThreadIdx)
392404
auto lStartLoop = clock::now();
393405

394406
// Receive STF iov and metadata
395-
const auto lStfMetaDataOtp = ucx::io::ucx_receive_string(lConn->worker);
396-
397-
if (!lStfMetaDataOtp.has_value()) {
398-
EDDLOG("DataHandlerThread {}: Failed to receive stf meta structure.", lStfSenderId);
407+
std::uint64_t lReqSize = 0;
408+
auto lRecvMetaSize = ucx::io::ucx_receive_tag(lConn->worker, lMetaMemPtr, lMetaMemSize, &lReqSize);
409+
if (lRecvMetaSize < 0) {
410+
EDDLOG("UCXDataHandlerThread: Failed to receive stf meta structure. from={}", lStfSenderId);
399411
continue;
412+
} if ((lRecvMetaSize == 0) && (lReqSize > lMetaMemSize)) {
413+
// memory too small
414+
while (lMetaMemSize < lReqSize) {
415+
lMetaMemSize *= 2;
416+
}
417+
// allocate larger buffer and continue
418+
fAllocateMetaMessage(lMetaMemSize);
419+
lRecvMetaSize = ucx::io::ucx_receive_tag_data(lConn->worker, lMetaMemPtr, lReqSize);
420+
if (lRecvMetaSize < 0) {
421+
EDDLOG("UCXDataHandlerThread: Failed to receive stf meta message. from={}", lStfSenderId);
422+
continue;
423+
}
424+
assert (lRecvMetaSize > 0 && std::uint64_t(lRecvMetaSize) == lReqSize);
400425
}
401426

402427
DDMON("tfbuilder", "recv.receive_meta_ms", since<std::chrono::milliseconds>(lStartLoop));
403428
lMetaDecodeStart = clock::now();
404429

405-
const auto &lStfMetaData = lStfMetaDataOtp.value();
406-
407-
lMeta.ParseFromString(lStfMetaData);
430+
if (!lMeta.ParseFromArray(lMetaMemPtr, lRecvMetaSize)) {
431+
EDDLOG("UCXDataHandlerThread: Failed to parse stf meta message. from={} size={}" ,lStfSenderId, lRecvMetaSize);
432+
}
408433

409434
lTfId = lMeta.stf_hdr_meta().stf_id();
410435

@@ -459,8 +484,14 @@ void TfBuilderInputUCX::DataHandlerThread(const unsigned pThreadIdx)
459484
}
460485

461486
// notify StfSender we completed
462-
std::string lOkStr = "OK";
463-
if (!ucx::io::ucx_send_string(lConn->worker, lConn->ucp_ep, lOkStr) ) {
487+
struct StringData {
488+
std::uint64_t mSize;
489+
char mMsg[4];
490+
} *lOk = reinterpret_cast<struct StringData*>(lMetaMemPtr);
491+
lOk->mSize = 2;
492+
std::memcpy(lOk->mMsg, "OK", 2);
493+
494+
if (!ucx::io::ucx_send_data(lConn->worker, lConn->ucp_ep, lOk->mMsg, &lOk->mSize) ) {
464495
EDDLOG_GRL(10000, "StfSender was NOT notified about transfer finish stf_sender={} tf_id={}", lStfSenderId, lTfId);
465496
}
466497

src/common/ucxtools/UCXSendRecv.h

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,6 @@ static constexpr ucp_tag_t STF_DONE_TAG = 1'000'000'000ULL;
3535

3636
#define make_ucp_req() (reinterpret_cast<char*>(alloca(UCX_REQUEST_SIZE)) + UCX_REQUEST_SIZE)
3737

38-
struct sync_tag_call_cookie {
39-
volatile bool mCompleted = false;
40-
};
41-
42-
4338
} /* ucx::impl */
4439

4540

@@ -270,6 +265,18 @@ bool ucx_send_string(dd_ucp_worker &worker, ucp_ep_h ep, const std::string &lStr
270265
return send_tag_blocking(worker, ep, lString.data(), size_buffer, impl::STRING_TAG);
271266
}
272267

268+
static inline
269+
bool ucx_send_data(dd_ucp_worker &worker, ucp_ep_h ep, const void *pData, const std::uint64_t *pSize)
270+
{
271+
const std::uint64_t pSizeOrig = *pSize;
272+
273+
if (!send_tag_blocking(worker, ep, pSize, sizeof(std::uint64_t), impl::STRING_SIZE_TAG) ) {
274+
return false;
275+
}
276+
// send actual string data
277+
return send_tag_blocking(worker, ep, pData, pSizeOrig, impl::STRING_TAG);
278+
}
279+
273280
static inline
274281
std::optional<std::string> ucx_receive_string(dd_ucp_worker &worker)
275282
{
@@ -288,6 +295,62 @@ std::optional<std::string> ucx_receive_string(dd_ucp_worker &worker)
288295
return lRetStr;
289296
}
290297

298+
static inline
299+
std::int64_t ucx_receive_tag(dd_ucp_worker &worker, void *pData, const std::size_t pSize, std::uint64_t *pReqSize)
300+
{
301+
assert (pSize >= sizeof(std::uint64_t));
302+
assert (pReqSize != nullptr);
303+
304+
// receive the size
305+
std::uint64_t size_rcv = 0;
306+
307+
if (!receive_tag_blocking(worker, pData, sizeof(std::uint64_t), impl::STRING_SIZE_TAG) ) {
308+
return -1;
309+
}
310+
311+
std::memcpy(&size_rcv, pData, sizeof(std::uint64_t));
312+
if (size_rcv > pSize) {
313+
// we need a larger buffer
314+
*pReqSize = size_rcv;
315+
return 0;
316+
}
317+
318+
if (!receive_tag_blocking(worker, pData, size_rcv, impl::STRING_TAG) ) {
319+
return -1;
320+
}
321+
*pReqSize = pSize;
322+
return size_rcv;
323+
}
324+
325+
static inline
326+
std::int64_t ucx_receive_tag_data(dd_ucp_worker &worker, void *pData, const std::uint64_t pReqSize)
327+
{
328+
if (!receive_tag_blocking(worker, pData, pReqSize, impl::STRING_TAG) ) {
329+
return -1;
330+
}
331+
332+
return pReqSize;
333+
}
334+
335+
336+
static inline
337+
std::optional<std::string> ucx_receive_string_data(dd_ucp_worker &worker)
338+
{
339+
// receive the size
340+
std::uint64_t size_rcv = 0;
341+
342+
if (!receive_tag_blocking(worker, &size_rcv, sizeof(std::uint64_t), impl::STRING_SIZE_TAG) ) {
343+
return std::nullopt;
344+
}
345+
346+
std::string lRetStr(size_rcv, 0);
347+
if (!receive_tag_blocking(worker, lRetStr.data(), lRetStr.size(), impl::STRING_TAG) ) {
348+
return std::nullopt;
349+
}
350+
351+
return lRetStr;
352+
}
353+
291354

292355
} /* ucx::io */
293356

0 commit comments

Comments
 (0)