Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/commands/commands.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "src/acl.h"
#include "src/commands/ft_search.h"
#include "src/metrics.h"
#include "src/query/content_resolution.h"
#include "src/query/fanout.h"
#include "src/query/search.h"
#include "src/schema_manager.h"
Expand Down Expand Up @@ -58,6 +59,10 @@ int Reply(ValkeyModuleCtx *ctx, ValkeyModuleString **argv, int argc) {
return ValkeyModule_ReplyWithError(
ctx, parameters->search_result.status.message().data());
}
if (parameters->GetContentProcessing() ==
query::ContentProcessing::kContentRequired) {
query::FetchContent(*parameters, ctx);
}
parameters->SendReply(ctx, parameters->search_result);
return VALKEYMODULE_OK;
}
Expand Down
1 change: 1 addition & 0 deletions src/commands/commands.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ struct QueryCommand : public query::SearchParameters {
//
void QueryCompleteBackground(std::unique_ptr<SearchParameters> self) override;
void QueryCompleteMainThread(std::unique_ptr<SearchParameters> self) override;
bool CanResolveContentInReply() const override { return true; }

std::optional<vmsdk::BlockedClient> blocked_client;

Expand Down
36 changes: 20 additions & 16 deletions src/query/content_resolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,34 +39,38 @@ void ResolveContent(std::unique_ptr<SearchParameters> params) {
// moved). Fall through to content fetch.
}

// 3. Content fetch + filter via ProcessNeighborsForReply
// 3. Content fetch + filter
auto ctx = vmsdk::MakeUniqueValkeyThreadSafeContext(nullptr);
FetchContent(*params, ctx.get());

// 4. Call QueryCompleteMainThread
params->QueryCompleteMainThread(std::move(params));
}

void FetchContent(SearchParameters& params, ValkeyModuleCtx* ctx) {
const auto& attribute_data_type =
params->index_schema->GetAttributeDataType();
size_t original_size = params->search_result.neighbors.size();
params.index_schema->GetAttributeDataType();
size_t original_size = params.search_result.neighbors.size();

std::optional<std::string> vector_identifier = std::nullopt;
if (!params->attribute_alias.empty()) {
auto id = params->index_schema->GetIdentifier(params->attribute_alias);
if (!params.attribute_alias.empty()) {
auto id = params.index_schema->GetIdentifier(params.attribute_alias);
if (id.ok()) {
vector_identifier = *id;
}
}

query::ProcessNeighborsForReply(ctx.get(), attribute_data_type,
params->search_result.neighbors, *params,
vector_identifier, params->sortby_parameter);
query::ProcessNeighborsForReply(ctx, attribute_data_type,
params.search_result.neighbors, params,
vector_identifier, params.sortby_parameter);

// 4. Adjust search_result.total_count for removed neighbors
size_t removed = original_size - params->search_result.neighbors.size();
if (params->search_result.total_count > removed) {
params->search_result.total_count -= removed;
// Adjust search_result.total_count for removed neighbors
size_t removed = original_size - params.search_result.neighbors.size();
if (params.search_result.total_count > removed) {
params.search_result.total_count -= removed;
} else {
params->search_result.total_count = 0;
params.search_result.total_count = 0;
}

// 5. Call QueryCompleteMainThread
params->QueryCompleteMainThread(std::move(params));
}

} // namespace valkey_search::query
7 changes: 7 additions & 0 deletions src/query/content_resolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

#include <memory>

#include "vmsdk/src/valkey_module_api/valkey_module.h"

namespace valkey_search::query {

struct SearchParameters;
Expand All @@ -19,6 +21,11 @@ struct SearchParameters;
// content fetching via ProcessNeighborsForReply, and final completion.
void ResolveContent(std::unique_ptr<SearchParameters> params);

// Fetches content for neighbors and adjusts total_count for any removed
// neighbors. Can be called directly from the Reply callback when content
// resolution is deferred (content_resolution_pending_ optimization).
void FetchContent(SearchParameters& params, ValkeyModuleCtx* ctx);

} // namespace valkey_search::query

#endif // VALKEY_SEARCH_SRC_QUERY_CONTENT_RESOLUTION_H_
10 changes: 10 additions & 0 deletions src/query/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,16 @@ absl::Status SearchAsync(std::unique_ptr<SearchParameters> parameters,
parameters->QueryCompleteBackground(std::move(parameters));
break;
case ContentProcessing::kContentRequired:
if (parameters->CanResolveContentInReply()) {
// Optimization: defer content resolution to the Reply callback,
// combining RunByMain + UnblockClient into a single step.
parameters->QueryCompleteBackground(std::move(parameters));
} else {
vmsdk::RunByMain([parameters = std::move(parameters)]() mutable {
ResolveContent(std::move(parameters));
});
}
break;
case ContentProcessing::kContentionCheckRequired:
vmsdk::RunByMain([parameters = std::move(parameters)]() mutable {
ResolveContent(std::move(parameters));
Expand Down
5 changes: 5 additions & 0 deletions src/query/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,11 @@ struct SearchParameters {
// resolution due to contention with in-flight mutations.
unsigned int content_resolution_blocked_{0};

// Returns true if content resolution can be deferred to the Reply callback.
// Only QueryCommand overrides this to return true, since it has a Reply
// callback that can perform content resolution on the main thread.
virtual bool CanResolveContentInReply() const { return false; }

// In CME, when a LocalResponderSearch is used, the neighbors of that
// operation get moved into this operation. But the neighbors has string_view
// references into the return_attributes of the owning operation. So in order
Expand Down
137 changes: 137 additions & 0 deletions testing/search_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/synchronization/notification.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
Expand All @@ -40,8 +41,10 @@
#include "src/query/predicate.h"
#include "src/utils/patricia_tree.h"
#include "src/utils/string_interning.h"
#include "src/valkey_search.h"
#include "testing/common.h"
#include "vmsdk/src/managed_pointers.h"
#include "vmsdk/src/thread_pool.h"
#include "vmsdk/src/type_conversions.h"

namespace valkey_search {
Expand Down Expand Up @@ -1353,5 +1356,139 @@ INSTANTIATE_TEST_SUITE_P(
absl::StrCat(distance_metric, "_", std::get<1>(info.param).test_name);
return test_name;
});
// Test SearchParameters subclass that tracks which QueryComplete* method
// is called.
class TrackingSearchParameters : public query::SearchParameters {
public:
TrackingSearchParameters(bool can_resolve_in_reply)
: can_resolve_in_reply_(can_resolve_in_reply) {
timeout_ms = 10000;
db_num = 0;
cancellation_token = cancel::Make(timeout_ms, nullptr);
}

bool CanResolveContentInReply() const override {
return can_resolve_in_reply_;
}

void QueryCompleteBackground(
std::unique_ptr<SearchParameters> self) override {
background_called_ = true;
// Release ownership to prevent double-delete
self.release();
done_.Notify();
}

void QueryCompleteMainThread(
std::unique_ptr<SearchParameters> self) override {
main_thread_called_ = true;
// Release ownership to prevent double-delete
self.release();
done_.Notify();
}

absl::Notification done_;
bool background_called_{false};
bool main_thread_called_{false};

private:
bool can_resolve_in_reply_;
};

class ContentResolutionPendingTest : public ValkeySearchTest {
protected:
void SetUp() override {
ValkeySearchTest::SetUp();
InitThreadPools(2, 2, 1);
}
};

// When CanResolveContentInReply() returns true and GetContentProcessing()
// returns kContentRequired, SearchAsync should call QueryCompleteBackground
// (skipping RunByMain + ResolveContent).
TEST_F(ContentResolutionPendingTest, OptimizedPathCallsBackground) {
auto index_schema = CreateIndexSchemaWithMultipleAttributes();
auto params = std::make_unique<TrackingSearchParameters>(true);
auto *params_ptr = params.get();

params->index_schema = index_schema;
params->index_schema_name = kIndexSchemaName;
params->attribute_alias = kVectorAttributeAlias;
params->score_as = vmsdk::MakeUniqueValkeyString(kScoreAs);
params->dialect = kDialect;
params->k = 1;
params->ef = kEfRuntime;
auto vectors = DeterministicallyGenerateVectors(1, kVectorDimensions, 10.0);
params->query =
std::string((char *)vectors[0].data(), vectors[0].size() * sizeof(float));

// Verify preconditions: this is a kContentRequired query
EXPECT_FALSE(params->no_content);
EXPECT_EQ(params->GetContentProcessing(),
query::ContentProcessing::kContentRequired);
EXPECT_TRUE(params->CanResolveContentInReply());

VMSDK_EXPECT_OK(query::SearchAsync(
std::move(params), ValkeySearch::Instance().GetReaderThreadPool(),
query::SearchMode::kLocal));

params_ptr->done_.WaitForNotification();

EXPECT_TRUE(params_ptr->background_called_);
EXPECT_FALSE(params_ptr->main_thread_called_);

delete params_ptr;
}

// When CanResolveContentInReply() returns false and GetContentProcessing()
// returns kContentRequired, SearchAsync should use RunByMain (via
// EventLoopAddOneShot) and NOT set content_resolution_pending_.
TEST_F(ContentResolutionPendingTest, NonOptimizedPathUsesRunByMain) {
auto index_schema = CreateIndexSchemaWithMultipleAttributes();
auto params = std::make_unique<TrackingSearchParameters>(false);

params->index_schema = index_schema;
params->index_schema_name = kIndexSchemaName;
params->attribute_alias = kVectorAttributeAlias;
params->score_as = vmsdk::MakeUniqueValkeyString(kScoreAs);
params->dialect = kDialect;
params->k = 1;
params->ef = kEfRuntime;
auto vectors = DeterministicallyGenerateVectors(1, kVectorDimensions, 10.0);
params->query =
std::string((char *)vectors[0].data(), vectors[0].size() * sizeof(float));

// Verify preconditions: kContentRequired but can't resolve in reply
EXPECT_EQ(params->GetContentProcessing(),
query::ContentProcessing::kContentRequired);
EXPECT_FALSE(params->CanResolveContentInReply());

// Capture the RunByMain callback without executing it (ResolveContent
// requires the main thread). Just verify EventLoopAddOneShot was called.
absl::Notification oneshot_called;
ValkeyModuleEventLoopOneShotFunc captured_fn = nullptr;
void *captured_data = nullptr;
EXPECT_CALL(*kMockValkeyModule, EventLoopAddOneShot(testing::_, testing::_))
.WillOnce(
[&](ValkeyModuleEventLoopOneShotFunc fn, void *data) {
captured_fn = fn;
captured_data = data;
oneshot_called.Notify();
return 0;
});

VMSDK_EXPECT_OK(query::SearchAsync(
std::move(params), ValkeySearch::Instance().GetReaderThreadPool(),
query::SearchMode::kLocal));

// Wait for EventLoopAddOneShot to be called, confirming RunByMain was used
oneshot_called.WaitForNotification();
EXPECT_NE(captured_fn, nullptr);

// Clean up the captured callback data (it's an absl::AnyInvocable<void()>*)
auto *fn = static_cast<absl::AnyInvocable<void()> *>(captured_data);
delete fn;
}

} // namespace
} // namespace valkey_search
Loading