Skip to content

Commit 7c84fd2

Browse files
authored
Support assigning vectors to multiple clusters to improve vector search quality and performance (#24390) (#30180)
1 parent 9e07300 commit 7c84fd2

35 files changed

+987
-164
lines changed

ydb/core/base/kmeans_clusters.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include "kmeans_clusters.h"
22

3+
#include <ydb/public/api/protos/ydb_table.pb.h>
4+
35
#include <library/cpp/dot_product/dot_product.h>
46
#include <library/cpp/l1_distance/l1_distance.h>
57
#include <library/cpp/l2_distance/l2_distance.h>
@@ -86,6 +88,14 @@ namespace {
8688
ValidateSettingInRange(name, result, minValue, maxValue, error);
8789
return result;
8890
}
91+
92+
double ParseDouble(const TString& name, const TString& value, TString& error) {
93+
double result = 0;
94+
if (!TryFromString(value, result)) {
95+
error = TStringBuilder() << "Invalid " << name << ": " << value;
96+
}
97+
return result;
98+
}
8999
}
90100

91101
// TODO(mbkkt) maybe compute floating sum in double? Needs benchmark
@@ -575,6 +585,18 @@ bool ValidateSettings(const Ydb::Table::KMeansTreeSettings& settings, TString& e
575585
return false;
576586
}
577587

588+
if (settings.has_overlap_clusters() &&
589+
settings.overlap_clusters() > settings.clusters()) {
590+
error = TStringBuilder() << "overlap_clusters should be less than or equal to clusters";
591+
return false;
592+
}
593+
594+
if (settings.has_overlap_ratio() &&
595+
settings.overlap_ratio() < 0) {
596+
error = TStringBuilder() << "overlap_ratio should be >= 0";
597+
return false;
598+
}
599+
578600
ui64 clustersPowLevels = 1;
579601
for (ui64 i = 0; i < settings.levels(); ++i) {
580602
clustersPowLevels *= settings.clusters();
@@ -650,6 +672,10 @@ bool FillSetting(Ydb::Table::KMeansTreeSettings& settings, const TString& name,
650672
settings.set_clusters(ParseUInt32(name, value, MinClusters, MaxClusters, error));
651673
} else if (nameLower =="levels") {
652674
settings.set_levels(ParseUInt32(name, value, MinLevels, MaxLevels, error));
675+
} else if (nameLower == "overlap_clusters") {
676+
settings.set_overlap_clusters(ParseUInt32(name, value, MinClusters, MaxClusters, error));
677+
} else if (nameLower == "overlap_ratio") {
678+
settings.set_overlap_ratio(ParseDouble(name, value, error));
653679
} else {
654680
error = TStringBuilder() << "Unknown index setting: " << name;
655681
return false;
@@ -683,4 +709,27 @@ void FilterOverlapRows(TVector<TSerializedCellVec>& rows, size_t distancePos, ui
683709
}
684710
}
685711

712+
void FilterOverlapRows(TVector<std::pair<NTableIndex::NKMeans::TClusterId, double>>& rowClusters, ui32 overlapClusters, double overlapRatio) {
713+
if (rowClusters.size() <= 1) {
714+
return;
715+
}
716+
std::sort(rowClusters.begin(), rowClusters.end(),
717+
[&](const std::pair<NTableIndex::NKMeans::TClusterId, double>& a,
718+
const std::pair<NTableIndex::NKMeans::TClusterId, double>& b) {
719+
return a.second < b.second;
720+
});
721+
if (rowClusters.size() > overlapClusters) {
722+
rowClusters.resize(overlapClusters);
723+
}
724+
if (overlapRatio > 0) {
725+
double thresh = (rowClusters[0].second < 0 ? rowClusters[0].second/overlapRatio : rowClusters[0].second*overlapRatio);
726+
for (size_t i = 1; i < rowClusters.size(); i++) {
727+
if (rowClusters[i].second > thresh) {
728+
rowClusters.resize(i);
729+
break;
730+
}
731+
}
732+
}
733+
}
734+
686735
}

ydb/core/base/kmeans_clusters.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,14 @@
22

33
#include <ydb/core/scheme/scheme_tablecell.h>
44

5-
#include <ydb/public/api/protos/ydb_table.pb.h>
5+
namespace Ydb::Table {
6+
class VectorIndexSettings;
7+
class KMeansTreeSettings;
8+
}
9+
10+
namespace NKikimr::NTableIndex::NKMeans {
11+
using TClusterId = ui64;
12+
}
613

714
namespace NKikimr::NKMeans {
815

@@ -56,5 +63,6 @@ bool ValidateSettings(const Ydb::Table::VectorIndexSettings& settings, TString&
5663
bool ValidateSettings(const Ydb::Table::KMeansTreeSettings& settings, TString& error);
5764
bool FillSetting(Ydb::Table::KMeansTreeSettings& settings, const TString& name, const TString& value, TString& error);
5865
void FilterOverlapRows(TVector<TSerializedCellVec>& rows, size_t distancePos, ui32 overlapClusters, double overlapRatio);
66+
void FilterOverlapRows(TVector<std::pair<NTableIndex::NKMeans::TClusterId, double>>& rowClusters, ui32 overlapClusters, double overlapRatio);
5967

6068
}

ydb/core/base/table_index.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,19 @@ inline constexpr const char* PostingTable = "indexImplPostingTable";
7979
inline constexpr const char* BuildSuffix0 = "0build";
8080
inline constexpr const char* BuildSuffix1 = "1build";
8181
inline constexpr auto IsForeignType = Ydb::Type::BOOL;
82+
inline constexpr auto IsForeignTypeName = "Bool";
8283
inline constexpr const char* IsForeignColumn = "__ydb_foreign";
8384
inline constexpr auto DistanceType = Ydb::Type::DOUBLE;
85+
inline constexpr auto DistanceTypeName = "Double";
8486
inline constexpr const char* DistanceColumn = "__ydb_distance";
8587

8688
// Prefix table
8789
inline constexpr const char* PrefixTable = "indexImplPrefixTable";
8890
inline constexpr const char* IdColumnSequence = "__ydb_id_sequence";
8991

9092
inline constexpr const int DefaultKMeansRounds = 3;
93+
inline constexpr const int DefaultOverlapClusters = 1;
94+
inline constexpr const double DefaultOverlapRatio = 0;
9195

9296
inline constexpr TClusterId PostingParentFlag = (1ull << 63ull);
9397

ydb/core/kqp/common/kqp_yql.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,11 +735,32 @@ NNodes::TCoNameValueTupleList TKqpStreamLookupSettings::BuildNode(TExprContext&
735735
.Done());
736736
}
737737

738+
if (VectorTopDistinct) {
739+
settings.emplace_back(
740+
Build<TCoNameValueTuple>(ctx, pos)
741+
.Name().Build(VectorTopDistinctSettingName)
742+
.Done());
743+
}
744+
738745
return Build<TCoNameValueTupleList>(ctx, pos)
739746
.Add(settings)
740747
.Done();
741748
}
742749

750+
bool TKqpStreamLookupSettings::HasVectorTopDistinct(const NNodes::TCoNameValueTupleList& list) {
751+
for (const auto& tuple : list) {
752+
auto name = tuple.Name().Value();
753+
if (name == VectorTopDistinctSettingName) {
754+
return true;
755+
}
756+
}
757+
return false;
758+
}
759+
760+
bool TKqpStreamLookupSettings::HasVectorTopDistinct(const NNodes::TKqlStreamLookupTable& node) {
761+
return TKqpStreamLookupSettings::HasVectorTopDistinct(node.Settings());
762+
}
763+
743764
TKqpStreamLookupSettings TKqpStreamLookupSettings::Parse(const NNodes::TCoNameValueTupleList& list) {
744765
TKqpStreamLookupSettings settings;
745766

@@ -775,6 +796,8 @@ TKqpStreamLookupSettings TKqpStreamLookupSettings::Parse(const NNodes::TCoNameVa
775796
} else if (name == VectorTopLimitSettingName) {
776797
YQL_ENSURE(tuple.Value().IsValid());
777798
settings.VectorTopLimit = tuple.Value().Cast().Ptr();
799+
} else if (name == VectorTopDistinctSettingName) {
800+
settings.VectorTopDistinct = true;
778801
} else {
779802
YQL_ENSURE(false, "Unknown KqpStreamLookup setting name '" << name << "'");
780803
}

ydb/core/kqp/common/kqp_yql.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ struct TKqpStreamLookupSettings {
6060
static constexpr TStringBuf VectorTopIndexSettingName = "VectorTopIndex";
6161
static constexpr TStringBuf VectorTopLimitSettingName = "VectorTopLimit";
6262
static constexpr TStringBuf VectorTopTargetSettingName = "VectorTopTarget";
63+
static constexpr TStringBuf VectorTopDistinctSettingName = "VectorTopDistinct";
6364

6465
// stream lookup strategy types
6566
static constexpr std::string_view LookupStrategyName = "LookupRows"sv;
@@ -76,11 +77,15 @@ struct TKqpStreamLookupSettings {
7677
TExprNode::TPtr VectorTopTarget;
7778
TExprNode::TPtr VectorTopLimit;
7879

80+
bool VectorTopDistinct = false;
81+
7982
NNodes::TCoNameValueTupleList BuildNode(TExprContext& ctx, TPositionHandle pos) const;
8083
static TKqpStreamLookupSettings Parse(const NNodes::TKqlStreamLookupTable& node);
8184
static TKqpStreamLookupSettings Parse(const NNodes::TKqlStreamLookupIndex& node);
8285
static TKqpStreamLookupSettings Parse(const NNodes::TKqpCnStreamLookup& node);
8386
static TKqpStreamLookupSettings Parse(const NNodes::TCoNameValueTupleList& node);
87+
static bool HasVectorTopDistinct(const NNodes::TKqlStreamLookupTable& node);
88+
static bool HasVectorTopDistinct(const NNodes::TCoNameValueTupleList& node);
8489
};
8590

8691
struct TKqpDeleteRowsIndexSettings {

ydb/core/kqp/executer_actor/kqp_tasks_graph.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,9 @@ void TKqpTasksGraph::BuildStreamLookupChannels(const TStageInfo& stageInfo, ui32
500500
auto target = ExtractPhyValue(stageInfo, in.GetTargetVector(), TxAlloc->HolderFactory, TxAlloc->TypeEnv, NUdf::TUnboxedValuePod());
501501
out.SetTargetVector(TString(target.AsStringRef()));
502502
out.SetLimit((ui32)ExtractPhyValue(stageInfo, in.GetLimit(), TxAlloc->HolderFactory, TxAlloc->TypeEnv, NUdf::TUnboxedValuePod()).Get<ui64>());
503+
for (auto& colIdx: in.GetDistinctColumns()) {
504+
out.AddDistinctColumns(colIdx);
505+
}
503506
}
504507

505508
TTransform streamLookupTransform;
@@ -520,6 +523,8 @@ void TKqpTasksGraph::BuildVectorResolveChannels(const TStageInfo& stageInfo, ui3
520523
auto* settings = GetMeta().Allocate<NKikimrTxDataShard::TKqpVectorResolveSettings>();
521524

522525
*settings->MutableIndexSettings() = vectorResolve.GetIndexSettings();
526+
settings->SetOverlapClusters(vectorResolve.GetOverlapClusters());
527+
settings->SetOverlapRatio(vectorResolve.GetOverlapRatio());
523528

524529
YQL_ENSURE(stageInfo.Meta.IndexMetas.size() == 1);
525530
const auto& levelTableInfo = stageInfo.Meta.IndexMetas.back().TableConstInfo;

ydb/core/kqp/opt/logical/kqp_opt_log_extract.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,11 @@ TExprBase KqpApplyExtractMembersToReadTable(TExprBase node, TExprContext& ctx, c
7474
return node;
7575
}
7676

77+
auto slt = node.Maybe<TKqlStreamLookupTable>();
78+
if (slt && TKqpStreamLookupSettings::HasVectorTopDistinct(slt.Cast())) {
79+
return node;
80+
}
81+
7782
TCoAtomList columnsNode = TExprBase(node.Ptr()->Child(TKqlReadColumnsNodeIdx)).Cast<TCoAtomList>();
7883
auto usedColumns = GetUsedColumns(node, columnsNode, parentsMap, allowMultiUsage, ctx);
7984
if (!usedColumns) {

ydb/core/kqp/opt/logical/kqp_opt_log_indexes.cpp

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,7 @@ void VectorReadMain(
588588
const TIntrusivePtr<TKikimrTableMetadata> & mainTableMeta,
589589
const TCoAtomList& mainColumns,
590590
const TKqpStreamLookupSettings& pushdownSettings,
591+
bool withOverlap,
591592
TExprNodePtr& read)
592593
{
593594
const bool isCovered = CheckIndexCovering(mainColumns, postingTableMeta);
@@ -606,21 +607,48 @@ void VectorReadMain(
606607
.Columns(postingColumns)
607608
.Settings(isVectorCovered ? settingsNode : settings.BuildNode(ctx, pos))
608609
.Done().Ptr();
610+
}
609611

610-
read = Build<TKqlStreamLookupTable>(ctx, pos)
611-
.Table(mainTable)
612-
.LookupKeys(read)
613-
.Columns(mainColumns)
614-
.Settings(settingsNode)
615-
.Done().Ptr();
616-
} else {
617-
read = Build<TKqlStreamLookupTable>(ctx, pos)
618-
.Table(postingTable)
619-
.LookupKeys(read)
620-
.Columns(mainColumns)
621-
.Settings(settingsNode)
622-
.Done().Ptr();
612+
const auto& targetTable = isCovered ? postingTable : mainTable;
613+
614+
if (withOverlap) {
615+
// mainColumns must contain primary key columns for DistinctColumns pushdown
616+
THashSet<TStringBuf> cols;
617+
for (const auto& col: mainColumns) {
618+
cols.insert(col.Value());
619+
}
620+
TVector<TCoAtom> columnsWithKey;
621+
for (const auto& col: mainTableMeta->KeyColumnNames) {
622+
if (!cols.contains(col)) {
623+
columnsWithKey.push_back(Build<TCoAtom>(ctx, pos)
624+
.Value(col)
625+
.Done());
626+
}
627+
}
628+
if (columnsWithKey.size()) {
629+
for (const auto& col: mainColumns) {
630+
columnsWithKey.push_back(col);
631+
}
632+
read = Build<TKqlStreamLookupTable>(ctx, pos)
633+
.Table(targetTable)
634+
.LookupKeys(read)
635+
.Columns<TCoAtomList>().Add(columnsWithKey).Build()
636+
.Settings(settingsNode)
637+
.Done().Ptr();
638+
read = Build<TCoExtractMembers>(ctx, pos)
639+
.Input(read)
640+
.Members(mainColumns)
641+
.Done().Ptr();
642+
return;
643+
}
623644
}
645+
646+
read = Build<TKqlStreamLookupTable>(ctx, pos)
647+
.Table(targetTable)
648+
.LookupKeys(read)
649+
.Columns(mainColumns)
650+
.Settings(settingsNode)
651+
.Done().Ptr();
624652
}
625653

626654
void VectorTopMain(TExprContext& ctx, const TCoTopBase& top, TExprNodePtr& read) {
@@ -633,6 +661,10 @@ void VectorTopMain(TExprContext& ctx, const TCoTopBase& top, TExprNodePtr& read)
633661
.Done().Ptr();
634662
}
635663

664+
// FIXME Most of this rewriting should probably be handled in kqp/opt/physical
665+
// Logical optimizer should only rewrite it to something like TKqlReadTableVectorIndex
666+
// This would remove the need for skipping KqpApplyExtractMembersToReadTable based on settings.VectorTopDistinct
667+
636668
TExprBase DoRewriteTopSortOverKMeansTree(
637669
const TReadMatch& match, const TMaybeNode<TCoFlatMap>& flatMap, const TExprBase& lambdaArgs, const TExprBase& lambdaBody, const TCoTopBase& top,
638670
TExprContext& ctx, const TKqpOptimizeContext& kqpCtx,
@@ -697,6 +729,9 @@ TExprBase DoRewriteTopSortOverKMeansTree(
697729

698730
const auto levelTop = kqpCtx.Config->KMeansTreeSearchTopSize.Get().GetOrElse(1);
699731

732+
const auto& kmeansDesc = std::get<NKikimrKqp::TVectorIndexKmeansTreeDescription>(indexDesc.SpecializedIndexDescription);
733+
const bool withOverlap = kmeansDesc.settings().overlap_clusters() > 1;
734+
700735
TKqpStreamLookupSettings settings;
701736
settings.Strategy = EStreamLookupStrategyType::LookupRows;
702737
settings.VectorTopColumn = NTableIndex::NKMeans::CentroidColumn;
@@ -716,7 +751,8 @@ TExprBase DoRewriteTopSortOverKMeansTree(
716751

717752
settings.VectorTopColumn = indexDesc.KeyColumns.back();
718753
settings.VectorTopLimit = top.Count().Ptr();
719-
VectorReadMain(ctx, pos, postingTable, postingTableDesc->Metadata, mainTable, tableDesc.Metadata, mainColumns, settings, read);
754+
settings.VectorTopDistinct = true;
755+
VectorReadMain(ctx, pos, postingTable, postingTableDesc->Metadata, mainTable, tableDesc.Metadata, mainColumns, settings, withOverlap, read);
720756

721757
if (flatMap) {
722758
read = Build<TCoFlatMap>(ctx, flatMap.Cast().Pos())
@@ -827,6 +863,9 @@ TExprBase DoRewriteTopSortOverPrefixedKMeansTree(
827863

828864
const auto levelTop = kqpCtx.Config->KMeansTreeSearchTopSize.Get().GetOrElse(1);
829865

866+
const auto& kmeansDesc = std::get<NKikimrKqp::TVectorIndexKmeansTreeDescription>(indexDesc.SpecializedIndexDescription);
867+
const bool withOverlap = kmeansDesc.settings().overlap_clusters() > 1;
868+
830869
TKqpStreamLookupSettings settings;
831870
settings.Strategy = EStreamLookupStrategyType::LookupRows;
832871
settings.VectorTopColumn = NTableIndex::NKMeans::CentroidColumn;
@@ -849,7 +888,8 @@ TExprBase DoRewriteTopSortOverPrefixedKMeansTree(
849888

850889
settings.VectorTopColumn = indexDesc.KeyColumns.back();
851890
settings.VectorTopLimit = top.Count().Ptr();
852-
VectorReadMain(ctx, pos, postingTable, postingTableDesc->Metadata, mainTable, tableDesc.Metadata, mainColumns, settings, read);
891+
settings.VectorTopDistinct = true;
892+
VectorReadMain(ctx, pos, postingTable, postingTableDesc->Metadata, mainTable, tableDesc.Metadata, mainColumns, settings, withOverlap, read);
853893

854894
if (mainLambda) {
855895
read = Build<TCoMap>(ctx, flatMap.Pos())

ydb/core/kqp/query_compiler/kqp_query_compiler.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1670,15 +1670,21 @@ class TKqpQueryCompiler : public IKqpQueryCompiler {
16701670
*vectorTopK.MutableSettings() = kmeansDesc.GetSettings().Getsettings();
16711671

16721672
// Column index
1673+
THashMap<TStringBuf, ui32> readColumnIndexes;
16731674
ui32 columnIdx = 0;
16741675
for (const auto& column: streamLookupProto.GetColumns()) {
1675-
if (column == settings.VectorTopColumn) {
1676-
break;
1676+
readColumnIndexes[column] = columnIdx++;
1677+
}
1678+
YQL_ENSURE(readColumnIndexes.contains(settings.VectorTopColumn));
1679+
vectorTopK.SetColumn(readColumnIndexes.at(settings.VectorTopColumn));
1680+
1681+
// Unique columns - required when we read from the index posting table and overlap is enabled
1682+
if (settings.VectorTopDistinct) {
1683+
for (const auto& keyColumn: mainTable->Metadata->KeyColumnNames) {
1684+
YQL_ENSURE(readColumnIndexes.contains(keyColumn));
1685+
vectorTopK.AddDistinctColumns(readColumnIndexes.at(keyColumn));
16771686
}
1678-
columnIdx++;
16791687
}
1680-
YQL_ENSURE(columnIdx < streamLookup.Columns().Size());
1681-
vectorTopK.SetColumn(columnIdx);
16821688

16831689
// Limit - may be a parameter which will be linked later
16841690
TExprBase expr(settings.VectorTopLimit);
@@ -1992,6 +1998,8 @@ class TKqpQueryCompiler : public IKqpQueryCompiler {
19921998
// Index settings
19931999
auto& kmeansDesc = std::get<NKikimrKqp::TVectorIndexKmeansTreeDescription>(indexDesc->SpecializedIndexDescription);
19942000
*vectorResolveProto.MutableIndexSettings() = kmeansDesc.GetSettings().Getsettings();
2001+
vectorResolveProto.SetOverlapClusters(kmeansDesc.GetSettings().overlap_clusters());
2002+
vectorResolveProto.SetOverlapRatio(kmeansDesc.GetSettings().overlap_ratio());
19952003

19962004
// Main table
19972005
FillTablesMap(vectorResolve.Table(), tablesMap);

0 commit comments

Comments
 (0)