diff --git a/cmd/epp/runner/runner.go b/cmd/epp/runner/runner.go index 5c3394764..836266d67 100644 --- a/cmd/epp/runner/runner.go +++ b/cmd/epp/runner/runner.go @@ -241,7 +241,7 @@ func (r *Runner) Run(ctx context.Context) error { } // --- Setup Datastore --- - epf, err := r.setupMetricsCollection(setupLog, r.featureGates[datalayer.FeatureGate]) + epf, err := r.setupMetricsCollection(setupLog, r.featureGates[datalayer.ExperimentalDatalayerFeatureGate]) if err != nil { return err } @@ -376,7 +376,7 @@ func (r *Runner) Run(ctx context.Context) error { MetricsStalenessThreshold: *metricsStalenessThreshold, Director: director, SaturationDetector: saturationDetector, - UseExperimentalDatalayerV2: r.featureGates[datalayer.FeatureGate], // pluggable data layer feature flag + UseExperimentalDatalayerV2: r.featureGates[datalayer.ExperimentalDatalayerFeatureGate], // pluggable data layer feature flag } if err := serverRunner.SetupWithManager(ctx, mgr); err != nil { setupLog.Error(err, "Failed to setup EPP controllers") @@ -467,8 +467,9 @@ func (r *Runner) parseConfigurationPhaseOne(ctx context.Context) (*configapi.End } } - loader.RegisterFeatureGate(datalayer.FeatureGate) + loader.RegisterFeatureGate(datalayer.ExperimentalDatalayerFeatureGate) loader.RegisterFeatureGate(flowcontrol.FeatureGate) + loader.RegisterFeatureGate(datalayer.PrepareDataPluginsFeatureGate) r.registerInTreePlugins() @@ -508,9 +509,15 @@ func (r *Runner) parseConfigurationPhaseTwo(ctx context.Context, rawConfig *conf // Add requestControl plugins r.requestControlConfig.AddPlugins(handle.GetAllPlugins()...) - // Sort prepare data plugins in DAG order (topological sort). Also check prepare data plugins for cycles. - if r.requestControlConfig.PrepareDataPluginGraph() != nil { - return nil, errors.New("failed to load the configuration - prepare data plugins have cyclic dependencies") + + // TODO(rahulgurnani): Remove feature gate check once prepare data plugins are stable. + if r.featureGates[datalayer.PrepareDataPluginsFeatureGate] { + // Sort prepare data plugins in DAG order (topological sort). Also check prepare data plugins for cycles. + if r.requestControlConfig.PrepareDataPluginGraph() != nil { + return nil, errors.New("failed to load the configuration - prepare data plugins have cyclic dependencies") + } + } else { + r.requestControlConfig.WithPrepareDataPlugins() } // Handler deprecated configuration options @@ -533,7 +540,7 @@ func (r *Runner) deprecatedConfigurationHelper(cfg *config.Config, logger logr.L if _, ok := os.LookupEnv(enableExperimentalDatalayerV2); ok { logger.Info("Enabling the experimental Data Layer V2 using environment variables is deprecated and will be removed in next version") - r.featureGates[datalayer.FeatureGate] = env.GetEnvBool(enableExperimentalDatalayerV2, false, logger) + r.featureGates[datalayer.ExperimentalDatalayerFeatureGate] = env.GetEnvBool(enableExperimentalDatalayerV2, false, logger) } if _, ok := os.LookupEnv(enableExperimentalFlowControlLayer); ok { logger.Info("Enabling the experimental Flow Control layer using environment variables is deprecated and will be removed in next version") diff --git a/pkg/epp/config/loader/configloader_test.go b/pkg/epp/config/loader/configloader_test.go index 82f57f4b2..50022647c 100644 --- a/pkg/epp/config/loader/configloader_test.go +++ b/pkg/epp/config/loader/configloader_test.go @@ -55,7 +55,7 @@ func TestLoadRawConfiguration(t *testing.T) { t.Parallel() // Register known feature gates for validation. - RegisterFeatureGate(datalayer.FeatureGate) + RegisterFeatureGate(datalayer.ExperimentalDatalayerFeatureGate) tests := []struct { name string @@ -87,7 +87,7 @@ func TestLoadRawConfiguration(t *testing.T) { }, }, }, - FeatureGates: configapi.FeatureGates{datalayer.FeatureGate}, + FeatureGates: configapi.FeatureGates{datalayer.ExperimentalDatalayerFeatureGate}, SaturationDetector: &configapi.SaturationDetector{ QueueDepthThreshold: 10, KVCacheUtilThreshold: 0.8, @@ -147,7 +147,7 @@ func TestInstantiateAndConfigure(t *testing.T) { // Not parallel because it modifies global plugin registry. registerTestPlugins(t) - RegisterFeatureGate(datalayer.FeatureGate) + RegisterFeatureGate(datalayer.ExperimentalDatalayerFeatureGate) tests := []struct { name string diff --git a/pkg/epp/datalayer/factory.go b/pkg/epp/datalayer/factory.go index 78765095c..c707d411d 100644 --- a/pkg/epp/datalayer/factory.go +++ b/pkg/epp/datalayer/factory.go @@ -26,7 +26,8 @@ import ( ) const ( - FeatureGate = "dataLayer" + ExperimentalDatalayerFeatureGate = "dataLayer" + PrepareDataPluginsFeatureGate = "prepareDataPlugins" ) // PoolInfo represents the DataStore information needed for endpoints. diff --git a/pkg/epp/datalayer/plugins/data_types.go b/pkg/epp/datalayer/plugins/data_types.go new file mode 100644 index 000000000..a737c7331 --- /dev/null +++ b/pkg/epp/datalayer/plugins/data_types.go @@ -0,0 +1,52 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package plugins + +import ( + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" +) + +const ( + PrefixCacheMatchInfoKey = "PrefixCacheMatchInfoKey" +) + +type PrefixCacheMatchInfo struct { + matchLength int + totalBlocks int +} + +func NewPrefixCacheMatchInfo(matchLen int, blockHashLen int) *PrefixCacheMatchInfo { + return &PrefixCacheMatchInfo{ + matchLength: matchLen, + totalBlocks: blockHashLen, + } +} + +func (p *PrefixCacheMatchInfo) MatchLength() int { + return p.matchLength +} + +func (p *PrefixCacheMatchInfo) TotalLength() int { + return p.totalBlocks +} + +func (p *PrefixCacheMatchInfo) Clone() datalayer.Cloneable { + return &PrefixCacheMatchInfo{ + matchLength: p.matchLength, + totalBlocks: p.totalBlocks, + } +} diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go index 19e69cf2b..bd47044c3 100644 --- a/pkg/epp/requestcontrol/director.go +++ b/pkg/epp/requestcontrol/director.go @@ -345,6 +345,9 @@ func (d *Director) runPreRequestPlugins(ctx context.Context, request *scheduling func (d *Director) runPrepareDataPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error { + if len(d.requestControlPlugins.prepareDataPlugins) == 0 { + return nil + } return prepareDataPluginsWithTimeout(prepareDataTimeout, d.requestControlPlugins.prepareDataPlugins, ctx, request, pods) } diff --git a/pkg/epp/requestcontrol/request_control_config.go b/pkg/epp/requestcontrol/request_control_config.go index 8f08ac121..a454a4d1b 100644 --- a/pkg/epp/requestcontrol/request_control_config.go +++ b/pkg/epp/requestcontrol/request_control_config.go @@ -108,6 +108,9 @@ func (c *Config) AddPlugins(pluginObjects ...plugins.Plugin) { // PrepareDataPluginGraph creates data dependency graph and sorts the plugins in topological order. // If a cycle is detected, it returns an error. func (c *Config) PrepareDataPluginGraph() error { + if len(c.prepareDataPlugins) == 0 { + return nil + } dag := buildDAG(c.prepareDataPlugins) plugins, err := sortPlugins(dag, c.prepareDataPlugins) if err != nil { diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go index 2a1a3a8b2..2543aef1a 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go @@ -28,6 +28,7 @@ import ( k8stypes "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/log" + dplugins "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" @@ -206,6 +207,30 @@ func (p *Plugin) WithName(name string) *Plugin { return p } +func (p *Plugin) Produces() map[string]any { + return map[string]any{dplugins.PrefixCacheMatchInfoKey: dplugins.PrefixCacheMatchInfo{}} +} + +func (p *Plugin) Consumes() map[string]any { + return map[string]any{} +} + +// PrepareRequestData hashes prompt, finds longest prefix match and stores it in pod as attribute. +func (p *Plugin) PrepareRequestData(ctx context.Context, request *types.LLMRequest, pods []types.Pod) error { + hashes := hashPrompt(ctx, request, getBlockSize(pods, p.config), p.config.MaxPrefixBlocksToMatch) + state := &SchedulingContextState{ + PrefixHashes: hashes, + PrefixCacheServers: p.matchLongestPrefix(ctx, hashes), + } + total := len(state.PrefixHashes) + + for _, pod := range pods { + matchLen := state.PrefixCacheServers[ServerID(pod.GetPod().NamespacedName)] + pod.Put(dplugins.PrefixCacheMatchInfoKey, dplugins.NewPrefixCacheMatchInfo(matchLen, total)) + } + return nil +} + // Score returns the scoring result for the given list of pods based on context. func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { // pre score step, hashing prompt and find longest prefix match. diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go index f0feeef68..82251b8fb 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go @@ -29,10 +29,16 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" + dplugins "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) +// static check to ensure Plugin implements the PrepareDataPlugin interface. +var _ requestcontrol.PrepareDataPlugin = &Plugin{} + func TestPrefixPluginCompletion(t *testing.T) { config := Config{ BlockSize: 4, @@ -571,6 +577,67 @@ func randomPrompt(n int) string { return sb.String() } +func TestPrepareRequestData(t *testing.T) { + config := Config{ + BlockSize: 4, + MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks, + LRUCapacityPerServer: DefaultLRUCapacityPerServer, + } + plugin := New(context.Background(), config) + + pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, MetricsState: backendmetrics.NewMetricsState(), AttributeMap: datalayer.NewAttributes()} + pod2 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, MetricsState: backendmetrics.NewMetricsState(), AttributeMap: datalayer.NewAttributes()} + pods := []types.Pod{pod1, pod2} + + // First request to populate cache. + req1 := &types.LLMRequest{ + RequestId: uuid.NewString(), + TargetModel: "test-model1", + Body: &types.LLMRequestBody{ + Completions: &types.CompletionsRequest{ + Prompt: "aaaabbbb", + }, + }, + } + _ = plugin.Score(context.Background(), types.NewCycleState(), req1, pods) + schedulingResult := &types.SchedulingResult{ + PrimaryProfileName: "default", + ProfileResults: map[string]*types.ProfileRunResult{ + "default": {TargetPods: []types.Pod{pod1}}, + }, + } + plugin.PreRequest(context.Background(), req1, schedulingResult) + plugin.wg.Wait() + + // Second request that shares a prefix. + req2 := &types.LLMRequest{ + RequestId: uuid.NewString(), + TargetModel: "test-model1", + Body: &types.LLMRequestBody{ + Completions: &types.CompletionsRequest{ + Prompt: "aaaacccc", + }, + }, + } + + err := plugin.PrepareRequestData(context.Background(), req2, pods) + assert.NoError(t, err) + + // Verify pod1 has the correct prefix match info + info1, ok := pod1.Get(dplugins.PrefixCacheMatchInfoKey) + assert.True(t, ok) + prefixInfo1 := info1.(*dplugins.PrefixCacheMatchInfo) + assert.Equal(t, 1, prefixInfo1.MatchLength()) // "aaaa" matches + assert.Equal(t, 2, prefixInfo1.TotalLength()) // "aaaacccc" -> 2 blocks + + // Verify pod2 has no match info + info2, ok := pod2.Get(dplugins.PrefixCacheMatchInfoKey) + assert.True(t, ok) + prefixInfo2 := info2.(*dplugins.PrefixCacheMatchInfo) + assert.Equal(t, 0, prefixInfo2.MatchLength()) // No match for pod2 + assert.Equal(t, 2, prefixInfo2.TotalLength()) +} + // BenchmarkPrefixPluginChatCompletionsStress is a stress test for chat completions with varying message counts and lengths func BenchmarkPrefixPluginChatCompletionsStress(b *testing.B) { blockSize := 8