diff --git a/lib/ClientConnection.cc b/lib/ClientConnection.cc index 4f7a1dd1..1d488d82 100644 --- a/lib/ClientConnection.cc +++ b/lib/ClientConnection.cc @@ -997,9 +997,14 @@ Future ClientConnection::newConsumerStats(uint6 lock.unlock(); LOG_ERROR(cnxString_ << " Client is not connected to the broker"); promise.setFailed(ResultNotConnected); + return promise.getFuture(); } pendingConsumerStatsMap_.insert(std::make_pair(requestId, promise)); lock.unlock(); + if (mockingRequests_.load(std::memory_order_acquire) && mockServer_ != nullptr && + mockServer_->sendRequest("CONSUMER_STATS", requestId)) { + return promise.getFuture(); + } sendCommand(Commands::newConsumerStats(consumerId, requestId)); return promise.getFuture(); } diff --git a/lib/ClientConnection.h b/lib/ClientConnection.h index aae53d23..b2770006 100644 --- a/lib/ClientConnection.h +++ b/lib/ClientConnection.h @@ -219,6 +219,8 @@ class PULSAR_PUBLIC ClientConnection : public std::enable_shared_from_this promise; @@ -284,8 +286,6 @@ class PULSAR_PUBLIC ClientConnection : public std::enable_shared_from_this inline AllocHandler customAllocReadHandler(Handler h) { return AllocHandler(readHandlerAllocator_, h); diff --git a/lib/MockServer.h b/lib/MockServer.h index bd413d33..2d830fc7 100644 --- a/lib/MockServer.h +++ b/lib/MockServer.h @@ -75,11 +75,18 @@ class MockServer : public std::enable_shared_from_this { } }); } - schedule(connection, request + std::to_string(requestId), iter->second, [connection, requestId] { - proto::CommandSuccess success; - success.set_request_id(requestId); - connection->handleSuccess(success); - }); + schedule(connection, request + std::to_string(requestId), iter->second, + [connection, request, requestId] { + if (request == "CONSUMER_STATS") { + proto::CommandConsumerStatsResponse response; + response.set_request_id(requestId); + connection->handleConsumerStatsResponse(response); + } else { + proto::CommandSuccess success; + success.set_request_id(requestId); + connection->handleSuccess(success); + } + }); return true; } else { return false; diff --git a/lib/MultiTopicsConsumerImpl.cc b/lib/MultiTopicsConsumerImpl.cc index 6e0ba86c..9c741faf 100644 --- a/lib/MultiTopicsConsumerImpl.cc +++ b/lib/MultiTopicsConsumerImpl.cc @@ -847,48 +847,47 @@ void MultiTopicsConsumerImpl::getBrokerConsumerStatsAsync(const BrokerConsumerSt Lock lock(mutex_); MultiTopicsBrokerConsumerStatsPtr statsPtr = std::make_shared(numberTopicPartitions_->load()); - LatchPtr latchPtr = std::make_shared(numberTopicPartitions_->load()); + auto latchPtr = std::make_shared(numberTopicPartitions_->load()); lock.unlock(); size_t i = 0; - consumers_.forEachValue([this, &latchPtr, &statsPtr, &i, callback](const ConsumerImplPtr& consumer) { - size_t index = i++; - auto weakSelf = weak_from_this(); - consumer->getBrokerConsumerStatsAsync([this, weakSelf, latchPtr, statsPtr, index, callback]( - Result result, const BrokerConsumerStats& stats) { - auto self = weakSelf.lock(); - if (self) { - handleGetConsumerStats(result, stats, latchPtr, statsPtr, index, callback); - } + auto failedResult = std::make_shared>(ResultOk); + consumers_.forEachValue( + [this, &latchPtr, &statsPtr, &i, callback, &failedResult](const ConsumerImplPtr& consumer) { + size_t index = i++; + auto weakSelf = weak_from_this(); + consumer->getBrokerConsumerStatsAsync( + [this, weakSelf, latchPtr, statsPtr, index, callback, failedResult]( + Result result, const BrokerConsumerStats& stats) { + auto self = weakSelf.lock(); + if (!self) { + return; + } + if (result == ResultOk) { + std::lock_guard lock{mutex_}; + statsPtr->add(stats, index); + } else { + // Store the first failed result as the final failed result + auto expected = ResultOk; + failedResult->compare_exchange_strong(expected, result); + } + if (--*latchPtr == 0) { + if (auto firstFailedResult = failedResult->load(std::memory_order_acquire); + firstFailedResult == ResultOk) { + callback(ResultOk, BrokerConsumerStats{statsPtr}); + } else { + // Fail the whole operation if any of the consumers failed + callback(firstFailedResult, {}); + } + } + }); }); - }); } void MultiTopicsConsumerImpl::getLastMessageIdAsync(const BrokerGetLastMessageIdCallback& callback) { callback(ResultOperationNotSupported, GetLastMessageIdResponse()); } -void MultiTopicsConsumerImpl::handleGetConsumerStats(Result res, - const BrokerConsumerStats& brokerConsumerStats, - const LatchPtr& latchPtr, - const MultiTopicsBrokerConsumerStatsPtr& statsPtr, - size_t index, - const BrokerConsumerStatsCallback& callback) { - Lock lock(mutex_); - if (res == ResultOk) { - latchPtr->countdown(); - statsPtr->add(brokerConsumerStats, index); - } else { - lock.unlock(); - callback(res, BrokerConsumerStats()); - return; - } - if (latchPtr->getCount() == 0) { - lock.unlock(); - callback(ResultOk, BrokerConsumerStats(statsPtr)); - } -} - std::shared_ptr MultiTopicsConsumerImpl::topicNamesValid(const std::vector& topics) { TopicNamePtr topicNamePtr = std::shared_ptr(); diff --git a/lib/MultiTopicsConsumerImpl.h b/lib/MultiTopicsConsumerImpl.h index b22227e3..dc628652 100644 --- a/lib/MultiTopicsConsumerImpl.h +++ b/lib/MultiTopicsConsumerImpl.h @@ -28,7 +28,6 @@ #include "ConsumerImpl.h" #include "ConsumerInterceptors.h" #include "Future.h" -#include "Latch.h" #include "LookupDataResult.h" #include "SynchronizedHashMap.h" #include "TestUtil.h" @@ -100,9 +99,6 @@ class MultiTopicsConsumerImpl : public ConsumerImplBase { uint64_t getNumberOfConnectedConsumer() override; void hasMessageAvailableAsync(const HasMessageAvailableCallback& callback) override; - void handleGetConsumerStats(Result, const BrokerConsumerStats&, const LatchPtr&, - const MultiTopicsBrokerConsumerStatsPtr&, size_t, - const BrokerConsumerStatsCallback&); // return first topic name when all topics name valid, or return null pointer static std::shared_ptr topicNamesValid(const std::vector& topics); void unsubscribeOneTopicAsync(const std::string& topic, const ResultCallback& callback); diff --git a/tests/ConsumerTest.cc b/tests/ConsumerTest.cc index f1bca77d..795613e0 100644 --- a/tests/ConsumerTest.cc +++ b/tests/ConsumerTest.cc @@ -40,6 +40,7 @@ #include "WaitUtils.h" #include "lib/ClientConnection.h" #include "lib/Future.h" +#include "lib/Latch.h" #include "lib/LogUtils.h" #include "lib/MessageIdUtil.h" #include "lib/MultiTopicsConsumerImpl.h" diff --git a/tests/MultiTopicsConsumerTest.cc b/tests/MultiTopicsConsumerTest.cc index d59b50dc..db3bc963 100644 --- a/tests/MultiTopicsConsumerTest.cc +++ b/tests/MultiTopicsConsumerTest.cc @@ -20,9 +20,13 @@ #include #include +#include +#include #include "ThreadSafeMessages.h" #include "lib/LogUtils.h" +#include "lib/MockServer.h" +#include "tests/PulsarFriend.h" static const std::string lookupUrl = "pulsar://localhost:6650"; @@ -142,3 +146,29 @@ TEST(MultiTopicsConsumerTest, testAcknowledgeInvalidMessageId) { client.close(); } + +TEST(MultiTopicsConsumerTest, testGetConsumerStatsFail) { + Client client{lookupUrl}; + std::vector topics{"testGetConsumerStatsFail0", "testGetConsumerStatsFail1"}; + Consumer consumer; + ASSERT_EQ(ResultOk, client.subscribe(topics, "sub", consumer)); + + auto connection = *PulsarFriend::getConnections(client).begin(); + auto mockServer = std::make_shared(connection); + connection->attachMockServer(mockServer); + + mockServer->setRequestDelay({{"CONSUMER_STATS", 3000}}); + auto future = std::async(std::launch::async, [&consumer]() { + BrokerConsumerStats stats; + return consumer.getBrokerConsumerStats(stats); + }); + // Trigger the `getBrokerConsumerStats` in a new thread + future.wait_for(std::chrono::milliseconds(100)); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + connection->handleKeepAliveTimeout(); + ASSERT_EQ(ResultDisconnected, future.get()); + + mockServer->close(); + client.close(); +}